summaryrefslogtreecommitdiff
path: root/src/TNetLib/Activation.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/TNetLib/Activation.h')
-rw-r--r--src/TNetLib/Activation.h104
1 files changed, 104 insertions, 0 deletions
diff --git a/src/TNetLib/Activation.h b/src/TNetLib/Activation.h
new file mode 100644
index 0000000..90263d0
--- /dev/null
+++ b/src/TNetLib/Activation.h
@@ -0,0 +1,104 @@
+
+#ifndef _ACT_FUN_I_
+#define _ACT_FUN_I_
+
+
+#include "Component.h"
+
+
+namespace TNet
+{
+
+ /**
+ * Sigmoid activation function
+ */
+ class Sigmoid : public Component
+ {
+ public:
+ Sigmoid(size_t nInputs, size_t nOutputs, Component *pPred)
+ : Component(nInputs,nOutputs,pPred)
+ { }
+
+ ComponentType GetType() const
+ { return SIGMOID; }
+
+ const char* GetName() const
+ { return "<sigmoid>"; }
+
+ Component* Clone() const
+ { return new Sigmoid(GetNInputs(),GetNOutputs(),NULL); }
+
+ protected:
+ void PropagateFnc(const BfMatrix& X, BfMatrix& Y);
+ void BackpropagateFnc(const BfMatrix& X, BfMatrix& Y);
+ };
+
+
+ /**
+ * Softmax activation function
+ */
+ class Softmax : public Component
+ {
+ public:
+ Softmax(size_t nInputs, size_t nOutputs, Component *pPred)
+ : Component(nInputs,nOutputs,pPred)
+ { }
+
+ ComponentType GetType() const
+ { return SOFTMAX; }
+
+ const char* GetName() const
+ { return "<softmax>"; }
+
+ Component* Clone() const
+ { return new Softmax(GetNInputs(),GetNOutputs(),NULL); }
+
+ protected:
+ void PropagateFnc(const BfMatrix& X, BfMatrix& Y);
+ void BackpropagateFnc(const BfMatrix& X, BfMatrix& Y);
+ };
+
+
+ /**
+ * BlockSoftmax activation function.
+ * It is several softmaxes in one.
+ * The dimensions of softmaxes are given by integer vector.
+ * During backpropagation:
+ * If the derivatives sum up to 0, they are backpropagated.
+ * If the derivatives sup up to 1, they are discarded
+ * (like this we know that the softmax was 'inactive').
+ */
+ class BlockSoftmax : public Component
+ {
+ public:
+ BlockSoftmax(size_t nInputs, size_t nOutputs, Component *pPred)
+ : Component(nInputs,nOutputs,pPred)
+ { }
+
+ ComponentType GetType() const
+ { return BLOCK_SOFTMAX; }
+
+ const char* GetName() const
+ { return "<blocksoftmax>"; }
+
+ Component* Clone() const
+ { return new BlockSoftmax(*this); }
+
+ void ReadFromStream(std::istream& rIn);
+ void WriteToStream(std::ostream& rOut);
+
+ protected:
+ void PropagateFnc(const BfMatrix& X, BfMatrix& Y);
+ void BackpropagateFnc(const BfMatrix& X, BfMatrix& Y);
+
+ private:
+ Vector<int> mDim;
+ Vector<int> mDimOffset;
+ };
+
+
+
+} //namespace
+
+
+#endif