diff options
Diffstat (limited to 'src/CuTNetLib')
| -rw-r--r-- | src/CuTNetLib/cuMisc.h | 20 | 
1 files changed, 12 insertions, 8 deletions
| diff --git a/src/CuTNetLib/cuMisc.h b/src/CuTNetLib/cuMisc.h index 137dbcf..3447739 100644 --- a/src/CuTNetLib/cuMisc.h +++ b/src/CuTNetLib/cuMisc.h @@ -116,7 +116,7 @@ namespace TNet {    {      public:      CuDistrib(size_t nInputs, size_t nOutputs, CuComponent* pPred) -      : CuComponent(nInputs,nOutputs,pPred),size(0),ErrInputVec() +      : CuComponent(nInputs,nOutputs,pPred),size(0),ErrorInputVec()      {      } @@ -132,9 +132,9 @@ namespace TNet {      void ReadFromStream(std::istream& rIn)      {        rIn >> std::ws >> size; -      ErrInputVec.clear(); +      ErrorInputVec.clear();        for (int i=0; i<size;++i) -        ErrInputVec.push_back(NULL); +        ErrorInputVec.push_back(NULL);      }      void WriteToStream(std::ostream& rOut)   @@ -156,8 +156,8 @@ namespace TNet {      const CuMatrix<BaseFloat>& GetErrorInput(int pos=0)      {        if (pos>=0 && pos<size) -        return *ErrInputVec[pos]; -      return *ErrInputVec[0]; +        return *ErrorInputVec[pos]; +      return *ErrorInputVec[0];      }      void SetErrorInput(const CuMatrix<BaseFloat>& rErrorInput,int pos=0) @@ -165,7 +165,7 @@ namespace TNet {        if (pos==0)          mpErrorInput=&rErrorInput;        if (pos>=0 && pos<size) -        ErrInputVec[pos]=&rErrorInput; +        ErrorInputVec[pos]=&rErrorInput;      }       protected: @@ -177,11 +177,11 @@ namespace TNet {      {        Y.SetZero();        for (int i=0;i<size;++i) -        Y.AddScaled(1.0,*ErrInputVec[i],1.0); +        Y.AddScaled(1.0,*ErrorInputVec[i],1.0);      }      int size; -    ConstMatrixPtrVec ErrInputVec; +    ConstMatrixPtrVec ErrorInputVec;      Vector<BaseFloat> Scale;    }; @@ -300,10 +300,12 @@ namespace TNet {          delete OutputVec[i];        rIn >> std::ws >> size;        OutputVec.clear(); +      ErrorInputVec.clear();        for (int i=0; i<size;++i)        {          rIn>>len;          OutputVec.push_back(new CuMatrix<BaseFloat>()); +        ErrorInputVec.push_back(NULL);          SectLen.push_back(len);        }      } @@ -409,10 +411,12 @@ namespace TNet {          delete ErrorOutputVec[i];        rIn >> std::ws >> size;        ErrorOutputVec.clear(); +      InputVec.clear();        for (int i=0; i<size;++i)        {          rIn>>len;          ErrorOutputVec.push_back(new CuMatrix<BaseFloat>()); +        InputVec.push_back(NULL);          SectLen.push_back(len);        }      } | 
