summaryrefslogtreecommitdiff
path: root/src/CuTNetLib/cuActivation.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/CuTNetLib/cuActivation.h')
-rw-r--r--src/CuTNetLib/cuActivation.h132
1 files changed, 132 insertions, 0 deletions
diff --git a/src/CuTNetLib/cuActivation.h b/src/CuTNetLib/cuActivation.h
new file mode 100644
index 0000000..c66640c
--- /dev/null
+++ b/src/CuTNetLib/cuActivation.h
@@ -0,0 +1,132 @@
+
+#ifndef _CUACT_FUN_I_
+#define _CUACT_FUN_I_
+
+
+#include "cuComponent.h"
+#include "cumatrix.h"
+
+
+namespace TNet
+{
+
+ /**
+ * \brief Common interface for activation functions
+ */
+ class CuActivation : public CuComponent
+ {
+ public:
+ CuActivation(size_t nInputs, size_t nOutputs, CuComponent *pPred);
+
+ protected:
+ };
+
+
+ /**
+ * \brief CuSigmoid activation function
+ *
+ * \ingroup CuNNActivation
+ * Implements forward pass: \f[ Y_i=\frac{1}{1+e^{-X_i}} \f]
+ * Error propagation: \f[ E_i=Y_i(1-Y_i)e_i \f]
+ */
+ class CuSigmoid : public CuActivation
+ {
+ public:
+ CuSigmoid(size_t nInputs, size_t nOutputs, CuComponent *pPred);
+
+ ComponentType GetType() const;
+ const char* GetName() const;
+
+ protected:
+ void PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y);
+ void BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y);
+ };
+
+
+ /**
+ * \brief CuSoftmax activation function
+ *
+ * \ingroup CuNNActivation
+ * Implements forward pass: \f[ Y_i=\frac{1}{Z} e^{X_i} \f]
+ * where \f$ Z=\Sigma_{i=0}^{i=N-1} e^{X_i} \f$
+ * Error Propagation: \f[ E_i=Y_i - \Sigma_{j=0}^{j=N-1} Y_i Y_j e_j \f]
+ */
+ class CuSoftmax : public CuActivation
+ {
+ public:
+ CuSoftmax(size_t nInputs, size_t nOutputs, CuComponent *pPred);
+
+ ComponentType GetType() const;
+ const char* GetName() const;
+
+ protected:
+ void PropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y);
+ void BackpropagateFnc(const CuMatrix<BaseFloat>& X, CuMatrix<BaseFloat>& Y);
+ };
+
+
+ //////////////////////////////////////////////////////////////////////////
+ // Inline functions
+ // Activation::
+ inline
+ CuActivation::
+ CuActivation(size_t nInputs, size_t nOutputs, CuComponent *pPred)
+ : CuComponent(nInputs,nOutputs, pPred)
+ {
+ assert(nInputs == nOutputs);
+ }
+
+
+ //////////////////////////////////////////////////////////////////////////
+ // Inline functions
+ // Sigmoid::
+ inline
+ CuSigmoid::
+ CuSigmoid(size_t nInputs, size_t nOutputs, CuComponent *pPred)
+ : CuActivation(nInputs,nOutputs, pPred)
+ { }
+
+ inline CuComponent::ComponentType
+ CuSigmoid::
+ GetType() const
+ {
+ return CuComponent::SIGMOID;
+ }
+
+ inline const char*
+ CuSigmoid::
+ GetName() const
+ {
+ return "<sigmoid>";
+ }
+
+
+
+ //////////////////////////////////////////////////////////////////////////
+ // Inline functions
+ // Softmax::
+ inline
+ CuSoftmax::
+ CuSoftmax(size_t nInputs, size_t nOutputs, CuComponent *pPred)
+ : CuActivation(nInputs,nOutputs, pPred)
+ { }
+
+ inline CuComponent::ComponentType
+ CuSoftmax::
+ GetType() const
+ {
+ return CuComponent::SOFTMAX;
+ }
+
+ inline const char*
+ CuSoftmax::
+ GetName() const
+ {
+ return "<softmax>";
+ }
+
+
+} //namespace
+
+
+#endif