From cccccbf6cca94a3eaf813b4468453160e91c332b Mon Sep 17 00:00:00 2001 From: Joe Zhao Date: Mon, 14 Apr 2014 08:14:45 +0800 Subject: First commit --- .../.svn/text-base/cuRecurrent.h.svn-base | 101 +++++++++++++++++++++ 1 file changed, 101 insertions(+) create mode 100644 src/CuTNetLib/.svn/text-base/cuRecurrent.h.svn-base (limited to 'src/CuTNetLib/.svn/text-base/cuRecurrent.h.svn-base') diff --git a/src/CuTNetLib/.svn/text-base/cuRecurrent.h.svn-base b/src/CuTNetLib/.svn/text-base/cuRecurrent.h.svn-base new file mode 100644 index 0000000..e487b27 --- /dev/null +++ b/src/CuTNetLib/.svn/text-base/cuRecurrent.h.svn-base @@ -0,0 +1,101 @@ +#ifndef _CU_RECURRENT_H_ +#define _CU_RECURRENT_H_ + + +#include "cuComponent.h" +#include "cumatrix.h" + + +#include "Matrix.h" +#include "Vector.h" + + +namespace TNet { + + class CuRecurrent : public CuUpdatableComponent + { + public: + + CuRecurrent(size_t nInputs, size_t nOutputs, CuComponent *pPred); + ~CuRecurrent(); + + ComponentType GetType() const; + const char* GetName() const; + + //CuUpdatableComponent API + void PropagateFnc(const CuMatrix& X, CuMatrix& Y); + void BackpropagateFnc(const CuMatrix& X, CuMatrix& Y); + + void Update(); + + //Recurrent training API + void BpttOrder(int ord) { + mBpttOrder = ord; + mInputHistory.Init(ord+1,GetNInputs()+GetNOutputs()); + } + void ClearHistory() { + mInputHistory.SetConst(0.0); + if(mOutput.MSize() > 0) { + mOutput.SetConst(0.0); + } + } + + //I/O + void ReadFromStream(std::istream& rIn); + void WriteToStream(std::ostream& rOut); + + protected: + CuMatrix mLinearity; + CuVector mBias; + + CuMatrix mLinearityCorrection; + CuVector mBiasCorrection; + + CuMatrix mInputHistory; + + int mBpttOrder; + }; + + + + + //////////////////////////////////////////////////////////////////////////// + // INLINE FUNCTIONS + // CuRecurrent:: + inline + CuRecurrent:: + CuRecurrent(size_t nInputs, size_t nOutputs, CuComponent *pPred) + : CuUpdatableComponent(nInputs, nOutputs, pPred), + mLinearity(nInputs+nOutputs,nOutputs), + mBias(nOutputs), + mLinearityCorrection(nInputs+nOutputs,nOutputs), + mBiasCorrection(nOutputs) + { } + + + inline + CuRecurrent:: + ~CuRecurrent() + { } + + inline CuComponent::ComponentType + CuRecurrent:: + GetType() const + { + return CuComponent::RECURRENT; + } + + inline const char* + CuRecurrent:: + GetName() const + { + return ""; + } + + + +} //namespace + + + +#endif -- cgit v1.2.3-70-g09d2