我编写了一个BP网络,需要将权值和网络节点数等信息保存到文件,再重新读出调用。
现在n1,n2,n3(网络节点数)读写正确,但是权值都不对。哪位牛人帮我解决一下这个问题呀?急啊,不胜感激
n1,n2,n3是整型,权值是double型下面是我的代码:
#define fastcopy(to,from,len)\
{\
        register char *_to,*_from;\
        register int _i,_l;\
        _to = (char *)(to);\
        _from = (char *)(from);\
        _l = (len);\
        for (_i = 0; _i < _l; _i++) *_to++ = *_from++;\
}
/* 保存BP网络 */
void bpnn_save(BPNN *net, char *filename)
{
        int n1, n2, n3, i, j, memcnt;
        double dvalue, **w;
        char *mem;
        FILE *fd;
        if ((fd = fopen(filename, "w")) == NULL) {
                printf("BPNN_SAVE: Cannot create '%s'\n", filename);
                return;
        }        n1 = net->input_n;  n2 = net->hidden_n;  n3 = net->output_n;
        printf("Saving %dx%dx%d network to '%s'\n", n1, n2, n3, filename);
        fflush(stdout);
        fwrite((char *) &n1, sizeof(int), 1, fd);
        fwrite((char *) &n2, sizeof(int), 1, fd);
        fwrite((char *) &n3, sizeof(int), 1, fd);        memcnt = 0;
        w = net->input_weights;
        mem = (char *) malloc ((unsigned) ((n1+1) * (n2+1) * sizeof(double)));
        for (i = 0; i <= n1; i++) {
                for (j = 0; j <= n2; j++) {
                        dvalue = w[i][j];
                        fastcopy(&mem[memcnt], &dvalue, sizeof(double));
                        memcnt += sizeof(double);
                }
        }       fwrite(mem, (n1+1)*(n2+1)*sizeof(double), 1, fd); 
       free(mem);
       // mem = (char *) malloc ((unsigned) ((n1+1) * (n2+1) * sizeof(double)));
        memcnt = 0;
        w = net->hidden_weights;
        mem = (char *) malloc ((unsigned) ((n2+1) * (n3+1) * sizeof(double)));
        for (i = 0; i <= n2; i++) {
                for (j = 0; j <= n3; j++) {
                        dvalue = w[i][j];
                        fastcopy(&mem[memcnt], &dvalue, sizeof(double));
                        memcnt += sizeof(double);
                }
        }        fwrite(mem, (n2+1) * (n3+1) * sizeof(double), 1, fd);
        free(mem);        fclose(fd);
        return;
}
/* 从文件中读取BP网络 */
BPNN *bpnn_read(char *filename)
{
        char *mem;
        BPNN *new1;
        int n1, n2, n3, i, j, memcnt;
        FILE *fd;        if ((fd = fopen(filename, "r")) == NULL) {
                return (NULL);
        }        printf("Reading '%s'\n", filename);  fflush(stdout);        fread((char *) &n1, sizeof(int), 1, fd);
printf("the input nods is:%d\n",n1);
        fread((char *) &n2, sizeof(int), 1, fd);
printf("the hidden nods is:%d\n",n2);
        fread((char *) &n3, sizeof(int), 1, fd);
printf("the output nods is:%d\n",n3);        new1 = bpnn_internal_create(n1, n2, n3);        printf("'%s' contains a %dx%dx%d network\n", filename, n1, n2, n3);
        printf("Reading input weights...");  fflush(stdout);        memcnt = 0;
        mem = (char *) malloc ((unsigned) ((n1+1) * (n2+1) * sizeof(double)));        fread( mem, (n1+1) * (n2+1) * sizeof(double), 1, fd);
        for (i = 0; i <= n1; i++) {
                for (j = 0; j <= n2; j++) {
                        fastcopy(&(new1->input_weights[i][j]), &mem[memcnt], sizeof(double));
//printf("%f\n",new1->input_weights[i][j]);
                        memcnt += sizeof(double);
                }
        }
        free(mem);        printf("Done\nReading hidden weights...");  fflush(stdout);        memcnt = 0;
        mem = (char *) malloc ((unsigned) ((n2+1) * (n3+1) * sizeof(double)));        fread( mem, (n2+1) * (n3+1) * sizeof(double), 1, fd);
        for (i = 0; i <= n2; i++) {
                for (j = 0; j <= n3; j++) {
                        fastcopy(&(new1->hidden_weights[i][j]), &mem[memcnt], sizeof(double));
                        memcnt += sizeof(double);
                }
        }
        free(mem);
        fclose(fd);        printf("Done\n");  fflush(stdout);        bpnn_zero_weights(new1->input_prev_weights, n1, n2);
        bpnn_zero_weights(new1->hidden_prev_weights, n2, n3);        return (new1);
}

解决方案 »

  1.   

    急啊,自己测试了一下,好像fastcopy并没有问题
    应该是fwrite和fread出了问题,怎么解决啊?哪位牛人来帮帮忙吧
      

  2.   

    你加上二进制标志试试:fd = fopen(filename, "wb") && fd = fopen(filename, "rb"))  
      

  3.   

    多谢楼上二位
    发现了三个bug:
    1、&mem[memcnt]应改为&(mem[memcnt])
    2、"w"和"r"分别改为"wb"和"rb"
    3、读写函数改为fread/fwrite(mem,sizeof(double),(n1+1)*(n2+1),fd)
    之后就好用了