From cccccbf6cca94a3eaf813b4468453160e91c332b Mon Sep 17 00:00:00 2001 From: Joe Zhao Date: Mon, 14 Apr 2014 08:14:45 +0800 Subject: First commit --- .../.svn/text-base/cuActivation.h.svn-base | 123 +++++++++++++++++++++ 1 file changed, 123 insertions(+) create mode 100644 src/CuTNetLib/.svn/text-base/cuActivation.h.svn-base (limited to 'src/CuTNetLib/.svn/text-base/cuActivation.h.svn-base') diff --git a/src/CuTNetLib/.svn/text-base/cuActivation.h.svn-base b/src/CuTNetLib/.svn/text-base/cuActivation.h.svn-base new file mode 100644 index 0000000..9fb2862 --- /dev/null +++ b/src/CuTNetLib/.svn/text-base/cuActivation.h.svn-base @@ -0,0 +1,123 @@ + +#ifndef _CUACT_FUN_I_ +#define _CUACT_FUN_I_ + + +#include "cuComponent.h" +#include "cumatrix.h" + + +namespace TNet +{ + + /** + * Common interface for activation functions + */ + class CuActivation : public CuComponent + { + public: + CuActivation(size_t nInputs, size_t nOutputs, CuComponent *pPred); + + protected: + }; + + + /** + * Sigmoid activation function + */ + class CuSigmoid : public CuActivation + { + public: + CuSigmoid(size_t nInputs, size_t nOutputs, CuComponent *pPred); + + ComponentType GetType() const; + const char* GetName() const; + + protected: + void PropagateFnc(const CuMatrix& X, CuMatrix& Y); + void BackpropagateFnc(const CuMatrix& X, CuMatrix& Y); + }; + + + /** + * Softmax activation function + */ + class CuSoftmax : public CuActivation + { + public: + CuSoftmax(size_t nInputs, size_t nOutputs, CuComponent *pPred); + + ComponentType GetType() const; + const char* GetName() const; + + protected: + void PropagateFnc(const CuMatrix& X, CuMatrix& Y); + void BackpropagateFnc(const CuMatrix& X, CuMatrix& Y); + }; + + + ////////////////////////////////////////////////////////////////////////// + // Inline functions + // Activation:: + inline + CuActivation:: + CuActivation(size_t nInputs, size_t nOutputs, CuComponent *pPred) + : CuComponent(nInputs,nOutputs, pPred) + { + assert(nInputs == nOutputs); + } + + + ////////////////////////////////////////////////////////////////////////// + // Inline functions + // Sigmoid:: + inline + CuSigmoid:: + CuSigmoid(size_t nInputs, size_t nOutputs, CuComponent *pPred) + : CuActivation(nInputs,nOutputs, pPred) + { } + + inline CuComponent::ComponentType + CuSigmoid:: + GetType() const + { + return CuComponent::SIGMOID; + } + + inline const char* + CuSigmoid:: + GetName() const + { + return ""; + } + + + + ////////////////////////////////////////////////////////////////////////// + // Inline functions + // Softmax:: + inline + CuSoftmax:: + CuSoftmax(size_t nInputs, size_t nOutputs, CuComponent *pPred) + : CuActivation(nInputs,nOutputs, pPred) + { } + + inline CuComponent::ComponentType + CuSoftmax:: + GetType() const + { + return CuComponent::SOFTMAX; + } + + inline const char* + CuSoftmax:: + GetName() const + { + return ""; + } + + +} //namespace + + +#endif -- cgit v1.2.3-70-g09d2