summaryrefslogtreecommitdiff
path: root/src/CuTNetLib/.svn/text-base/cuDiscreteLinearity.h.svn-base
diff options
context:
space:
mode:
Diffstat (limited to 'src/CuTNetLib/.svn/text-base/cuDiscreteLinearity.h.svn-base')
-rw-r--r--src/CuTNetLib/.svn/text-base/cuDiscreteLinearity.h.svn-base90
1 files changed, 90 insertions, 0 deletions
diff --git a/src/CuTNetLib/.svn/text-base/cuDiscreteLinearity.h.svn-base b/src/CuTNetLib/.svn/text-base/cuDiscreteLinearity.h.svn-base
new file mode 100644
index 0000000..06c8d74
--- /dev/null
+++ b/src/CuTNetLib/.svn/text-base/cuDiscreteLinearity.h.svn-base
@@ -0,0 +1,90 @@
+#ifndef _CUDISCRETE_LINEARITY_H_
+#define _CUDISCRETE_LINEARITY_H_
+
+
+#include "cuComponent.h"
+#include "cumatrix.h"
+
+
+#include "Matrix.h"
+#include "Vector.h"
+
+#include <vector>
+
+
+namespace TNet {
+
+ class CuDiscreteLinearity : public CuUpdatableComponent
+ {
+ public:
+
+ CuDiscreteLinearity(size_t nInputs, size_t nOutputs, CuComponent *pPred);
+ ~CuDiscreteLinearity();
+
+ ComponentType GetType() const;
+ const char* GetName() const;
+
+ void PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y);
+ void BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y);
+
+ void Update();
+
+ void ReadFromStream(std::istream& rIn);
+ void WriteToStream(std::ostream& rOut);
+
+ protected:
+ std::vector<CuMatrix<BaseFloat> > mLinearity; ///< Matrix with neuron weights
+ CuVector<BaseFloat> mBias; ///< Vector with biases
+
+ std::vector<CuMatrix<BaseFloat> > mLinearityCorrection; ///< Matrix for linearity updates
+ CuVector<BaseFloat> mBiasCorrection; ///< Vector for bias updates
+
+ size_t mNBlocks;
+
+ };
+
+
+
+
+ ////////////////////////////////////////////////////////////////////////////
+ // INLINE FUNCTIONS
+ // CuDiscreteLinearity::
+ inline
+ CuDiscreteLinearity::
+ CuDiscreteLinearity(size_t nInputs, size_t nOutputs, CuComponent *pPred)
+ : CuUpdatableComponent(nInputs, nOutputs, pPred),
+ //mLinearity(nInputs,nOutputs), mBias(nOutputs),
+ //mLinearityCorrection(nInputs,nOutputs), mBiasCorrection(nOutputs)
+ mNBlocks(0)
+ {
+ //mLinearityCorrection.SetConst(0.0);
+ //mBiasCorrection.SetConst(0.0);
+ }
+
+
+ inline
+ CuDiscreteLinearity::
+ ~CuDiscreteLinearity()
+ { }
+
+ inline CuComponent::ComponentType
+ CuDiscreteLinearity::
+ GetType() const
+ {
+ return CuComponent::DISCRETE_LINEARITY;
+ }
+
+ inline const char*
+ CuDiscreteLinearity::
+ GetName() const
+ {
+ return "<discretelinearity>";
+ }
+
+
+
+} //namespace
+
+
+
+#endif