我编写了一个BP网络,需要将权值和网络节点数等信息保存到文件,再重新读出调用。
现在n1,n2,n3(网络节点数)读写正确,但是权值都不对。哪位牛人帮我解决一下这个问题呀?急啊,不胜感激
n1,n2,n3是整型,权值是double型下面是我的代码:
#define fastcopy(to,from,len)\
{\
register char *_to,*_from;\
register int _i,_l;\
_to = (char *)(to);\
_from = (char *)(from);\
_l = (len);\
for (_i = 0; _i < _l; _i++) *_to++ = *_from++;\
}
/* 保存BP网络 */
void bpnn_save(BPNN *net, char *filename)
{
int n1, n2, n3, i, j, memcnt;
double dvalue, **w;
char *mem;
FILE *fd;
if ((fd = fopen(filename, "w")) == NULL) {
printf("BPNN_SAVE: Cannot create '%s'\n", filename);
return;
} n1 = net->input_n; n2 = net->hidden_n; n3 = net->output_n;
printf("Saving %dx%dx%d network to '%s'\n", n1, n2, n3, filename);
fflush(stdout);
fwrite((char *) &n1, sizeof(int), 1, fd);
fwrite((char *) &n2, sizeof(int), 1, fd);
fwrite((char *) &n3, sizeof(int), 1, fd); memcnt = 0;
w = net->input_weights;
mem = (char *) malloc ((unsigned) ((n1+1) * (n2+1) * sizeof(double)));
for (i = 0; i <= n1; i++) {
for (j = 0; j <= n2; j++) {
dvalue = w[i][j];
fastcopy(&mem[memcnt], &dvalue, sizeof(double));
memcnt += sizeof(double);
}
} fwrite(mem, (n1+1)*(n2+1)*sizeof(double), 1, fd);
free(mem);
// mem = (char *) malloc ((unsigned) ((n1+1) * (n2+1) * sizeof(double)));
memcnt = 0;
w = net->hidden_weights;
mem = (char *) malloc ((unsigned) ((n2+1) * (n3+1) * sizeof(double)));
for (i = 0; i <= n2; i++) {
for (j = 0; j <= n3; j++) {
dvalue = w[i][j];
fastcopy(&mem[memcnt], &dvalue, sizeof(double));
memcnt += sizeof(double);
}
} fwrite(mem, (n2+1) * (n3+1) * sizeof(double), 1, fd);
free(mem); fclose(fd);
return;
}
/* 从文件中读取BP网络 */
BPNN *bpnn_read(char *filename)
{
char *mem;
BPNN *new1;
int n1, n2, n3, i, j, memcnt;
FILE *fd; if ((fd = fopen(filename, "r")) == NULL) {
return (NULL);
} printf("Reading '%s'\n", filename); fflush(stdout); fread((char *) &n1, sizeof(int), 1, fd);
printf("the input nods is:%d\n",n1);
fread((char *) &n2, sizeof(int), 1, fd);
printf("the hidden nods is:%d\n",n2);
fread((char *) &n3, sizeof(int), 1, fd);
printf("the output nods is:%d\n",n3); new1 = bpnn_internal_create(n1, n2, n3); printf("'%s' contains a %dx%dx%d network\n", filename, n1, n2, n3);
printf("Reading input weights..."); fflush(stdout); memcnt = 0;
mem = (char *) malloc ((unsigned) ((n1+1) * (n2+1) * sizeof(double))); fread( mem, (n1+1) * (n2+1) * sizeof(double), 1, fd);
for (i = 0; i <= n1; i++) {
for (j = 0; j <= n2; j++) {
fastcopy(&(new1->input_weights[i][j]), &mem[memcnt], sizeof(double));
//printf("%f\n",new1->input_weights[i][j]);
memcnt += sizeof(double);
}
}
free(mem); printf("Done\nReading hidden weights..."); fflush(stdout); memcnt = 0;
mem = (char *) malloc ((unsigned) ((n2+1) * (n3+1) * sizeof(double))); fread( mem, (n2+1) * (n3+1) * sizeof(double), 1, fd);
for (i = 0; i <= n2; i++) {
for (j = 0; j <= n3; j++) {
fastcopy(&(new1->hidden_weights[i][j]), &mem[memcnt], sizeof(double));
memcnt += sizeof(double);
}
}
free(mem);
fclose(fd); printf("Done\n"); fflush(stdout); bpnn_zero_weights(new1->input_prev_weights, n1, n2);
bpnn_zero_weights(new1->hidden_prev_weights, n2, n3); return (new1);
}
现在n1,n2,n3(网络节点数)读写正确,但是权值都不对。哪位牛人帮我解决一下这个问题呀?急啊,不胜感激
n1,n2,n3是整型,权值是double型下面是我的代码:
#define fastcopy(to,from,len)\
{\
register char *_to,*_from;\
register int _i,_l;\
_to = (char *)(to);\
_from = (char *)(from);\
_l = (len);\
for (_i = 0; _i < _l; _i++) *_to++ = *_from++;\
}
/* 保存BP网络 */
void bpnn_save(BPNN *net, char *filename)
{
int n1, n2, n3, i, j, memcnt;
double dvalue, **w;
char *mem;
FILE *fd;
if ((fd = fopen(filename, "w")) == NULL) {
printf("BPNN_SAVE: Cannot create '%s'\n", filename);
return;
} n1 = net->input_n; n2 = net->hidden_n; n3 = net->output_n;
printf("Saving %dx%dx%d network to '%s'\n", n1, n2, n3, filename);
fflush(stdout);
fwrite((char *) &n1, sizeof(int), 1, fd);
fwrite((char *) &n2, sizeof(int), 1, fd);
fwrite((char *) &n3, sizeof(int), 1, fd); memcnt = 0;
w = net->input_weights;
mem = (char *) malloc ((unsigned) ((n1+1) * (n2+1) * sizeof(double)));
for (i = 0; i <= n1; i++) {
for (j = 0; j <= n2; j++) {
dvalue = w[i][j];
fastcopy(&mem[memcnt], &dvalue, sizeof(double));
memcnt += sizeof(double);
}
} fwrite(mem, (n1+1)*(n2+1)*sizeof(double), 1, fd);
free(mem);
// mem = (char *) malloc ((unsigned) ((n1+1) * (n2+1) * sizeof(double)));
memcnt = 0;
w = net->hidden_weights;
mem = (char *) malloc ((unsigned) ((n2+1) * (n3+1) * sizeof(double)));
for (i = 0; i <= n2; i++) {
for (j = 0; j <= n3; j++) {
dvalue = w[i][j];
fastcopy(&mem[memcnt], &dvalue, sizeof(double));
memcnt += sizeof(double);
}
} fwrite(mem, (n2+1) * (n3+1) * sizeof(double), 1, fd);
free(mem); fclose(fd);
return;
}
/* 从文件中读取BP网络 */
BPNN *bpnn_read(char *filename)
{
char *mem;
BPNN *new1;
int n1, n2, n3, i, j, memcnt;
FILE *fd; if ((fd = fopen(filename, "r")) == NULL) {
return (NULL);
} printf("Reading '%s'\n", filename); fflush(stdout); fread((char *) &n1, sizeof(int), 1, fd);
printf("the input nods is:%d\n",n1);
fread((char *) &n2, sizeof(int), 1, fd);
printf("the hidden nods is:%d\n",n2);
fread((char *) &n3, sizeof(int), 1, fd);
printf("the output nods is:%d\n",n3); new1 = bpnn_internal_create(n1, n2, n3); printf("'%s' contains a %dx%dx%d network\n", filename, n1, n2, n3);
printf("Reading input weights..."); fflush(stdout); memcnt = 0;
mem = (char *) malloc ((unsigned) ((n1+1) * (n2+1) * sizeof(double))); fread( mem, (n1+1) * (n2+1) * sizeof(double), 1, fd);
for (i = 0; i <= n1; i++) {
for (j = 0; j <= n2; j++) {
fastcopy(&(new1->input_weights[i][j]), &mem[memcnt], sizeof(double));
//printf("%f\n",new1->input_weights[i][j]);
memcnt += sizeof(double);
}
}
free(mem); printf("Done\nReading hidden weights..."); fflush(stdout); memcnt = 0;
mem = (char *) malloc ((unsigned) ((n2+1) * (n3+1) * sizeof(double))); fread( mem, (n2+1) * (n3+1) * sizeof(double), 1, fd);
for (i = 0; i <= n2; i++) {
for (j = 0; j <= n3; j++) {
fastcopy(&(new1->hidden_weights[i][j]), &mem[memcnt], sizeof(double));
memcnt += sizeof(double);
}
}
free(mem);
fclose(fd); printf("Done\n"); fflush(stdout); bpnn_zero_weights(new1->input_prev_weights, n1, n2);
bpnn_zero_weights(new1->hidden_prev_weights, n2, n3); return (new1);
}
应该是fwrite和fread出了问题,怎么解决啊?哪位牛人来帮帮忙吧
发现了三个bug:
1、&mem[memcnt]应改为&(mem[memcnt])
2、"w"和"r"分别改为"wb"和"rb"
3、读写函数改为fread/fwrite(mem,sizeof(double),(n1+1)*(n2+1),fd)
之后就好用了