diff options
Diffstat (limited to 'src/CuTNetLib/cuActivation.h')
-rw-r--r-- | src/CuTNetLib/cuActivation.h | 132 |
1 files changed, 132 insertions, 0 deletions
diff --git a/src/CuTNetLib/cuActivation.h b/src/CuTNetLib/cuActivation.h new file mode 100644 index 0000000..c66640c --- /dev/null +++ b/src/CuTNetLib/cuActivation.h @@ -0,0 +1,132 @@ + +#ifndef _CUACT_FUN_I_ +#define _CUACT_FUN_I_ + + +#include "cuComponent.h" +#include "cumatrix.h" + + +namespace TNet +{ + + /** + * \brief Common interface for activation functions + */ + class CuActivation : public CuComponent + { + public: + CuActivation(size_t nInputs, size_t nOutputs, CuComponent *pPred); + + protected: + }; + + + /** + * \brief CuSigmoid activation function + * + * \ingroup CuNNActivation + * Implements forward pass: \f[ Y_i=\frac{1}{1+e^{-X_i}} \f] + * Error propagation: \f[ E_i=Y_i(1-Y_i)e_i \f] + */ + class CuSigmoid : public CuActivation + { + public: + CuSigmoid(size_t nInputs, size_t nOutputs, CuComponent *pPred); + + ComponentType GetType() const; + const char* GetName() const; + + protected: + void PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y); + void BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y); + }; + + + /** + * \brief CuSoftmax activation function + * + * \ingroup CuNNActivation + * Implements forward pass: \f[ Y_i=\frac{1}{Z} e^{X_i} \f] + * where \f$ Z=\Sigma_{i=0}^{i=N-1} e^{X_i} \f$ + * Error Propagation: \f[ E_i=Y_i - \Sigma_{j=0}^{j=N-1} Y_i Y_j e_j \f] + */ + class CuSoftmax : public CuActivation + { + public: + CuSoftmax(size_t nInputs, size_t nOutputs, CuComponent *pPred); + + ComponentType GetType() const; + const char* GetName() const; + + protected: + void PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y); + void BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y); + }; + + + ////////////////////////////////////////////////////////////////////////// + // Inline functions + // Activation:: + inline + CuActivation:: + CuActivation(size_t nInputs, size_t nOutputs, CuComponent *pPred) + : CuComponent(nInputs,nOutputs, pPred) + { + assert(nInputs == nOutputs); + } + + + ////////////////////////////////////////////////////////////////////////// + // Inline functions + // Sigmoid:: + inline + CuSigmoid:: + CuSigmoid(size_t nInputs, size_t nOutputs, CuComponent *pPred) + : CuActivation(nInputs,nOutputs, pPred) + { } + + inline CuComponent::ComponentType + CuSigmoid:: + GetType() const + { + return CuComponent::SIGMOID; + } + + inline const char* + CuSigmoid:: + GetName() const + { + return "<sigmoid>"; + } + + + + ////////////////////////////////////////////////////////////////////////// + // Inline functions + // Softmax:: + inline + CuSoftmax:: + CuSoftmax(size_t nInputs, size_t nOutputs, CuComponent *pPred) + : CuActivation(nInputs,nOutputs, pPred) + { } + + inline CuComponent::ComponentType + CuSoftmax:: + GetType() const + { + return CuComponent::SOFTMAX; + } + + inline const char* + CuSoftmax:: + GetName() const + { + return "<softmax>"; + } + + +} //namespace + + +#endif |