summaryrefslogtreecommitdiff
path: root/src/CuTNetLib/cuCache.h
blob: 42d9b4d42b301829f473cb6b44aafda3fd4f8bfb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#ifndef _CUCACHE_H_
#define _CUCACHE_H_

#include "cumatrix.h"

namespace TNet {


  /**
   * \brief The feature-target pair cache
   *
   * \ingroup CuNNComp
   * Preloads mCachesize features and labels to GRAM
   * 
   * During every iterations read mBunchsize data to networks.
   *
   * When Cache is to be filled with more data that it can hold 
   * extras are stored in LeftOver, and when data is filled again
   * LeftOvers are moved to the Cache.
   *
   * Note:
   *  - Cache Size must be divisible by Bunch Size to ensure proper functionality
   *  - Once extracted data. Cache must be depleted to begin filling or every time extraction start at location Zero. 
   *  - Cache must be filled to begin extraction of Data or we can't start filling and harder to avoid discarding data.
   *  - @todo Why not implement CuCache as a Stack instead of a Queue?
   *  .
   */
  class CuCache {
    typedef enum { EMPTY, INTAKE, FULL, EXHAUST } State;
    public:
      CuCache();
      ~CuCache();
     
      /// Initialize the cache
      void Init(size_t cachesize, size_t bunchsize);

      /// Add data to cache, returns number of added vectors
      /// \param[in] rFeatures CuNN Input features data
      /// \param[in] rDesired CuNN features data label
      void AddData(const CuMatrix<BaseFloat>& rFeatures, const CuMatrix<BaseFloat>& rDesired);
      /// Randomizes the cache
      void Randomize();
      /// Get the bunch of training data
      /// \param[out] rFeatures Bunchsize of CuNN Input features data
      /// \param[out] rDesired  Bunchsize of CuNN features data label
      void GetBunch(CuMatrix<BaseFloat>& rFeatures, CuMatrix<BaseFloat>& rDesired);


      /// Returns true if the cache was completely filled
      bool Full()
      { return (mState == FULL); }
      
      /// Returns true if the cache is empty
      bool Empty()
      { return (mState == EMPTY || mIntakePos < mBunchsize); }
      
      /// Number of discarded frames
      int Discarded() 
      { return mDiscarded; }
      
      /// Set the trace message level
      void Trace(int trace)
      { mTrace = trace; }

    private:
    
      static long int GenerateRandom(int max)
      { return lrand48() % max; }
      
      State mState; ///< Current state of the cache

      size_t mIntakePos; ///< Number of intaken vectors by AddData
      size_t mExhaustPos; ///< Number of exhausted vectors by GetBunch
      
      size_t mCachesize; ///< Size of cache
      size_t mBunchsize; ///< Size of bunch
      int mDiscarded; ///< Number of discarded frames

      CuMatrix<BaseFloat> mFeatures; ///< Feature cache
      CuMatrix<BaseFloat> mFeaturesRandom; ///< Feature cache
      CuMatrix<BaseFloat> mFeaturesLeftover; ///< Feature cache
      
      CuMatrix<BaseFloat> mDesired;  ///< Desired vector cache
      CuMatrix<BaseFloat> mDesiredRandom;  ///< Desired vector cache
      CuMatrix<BaseFloat> mDesiredLeftover;  ///< Desired vector cache

      bool mRandomized;

      int mTrace;
  }; 

}

#endif