summaryrefslogtreecommitdiff
path: root/src/CuTNetLib/.svn/text-base/cuCache.h.svn-base
blob: cef2dd970560618d699033b193997d1a37d2770a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#ifndef _CUCACHE_H_
#define _CUCACHE_H_

#include "cumatrix.h"

namespace TNet {


  /**
   * The feature-target pair cache
   */
  class CuCache {
    typedef enum { EMPTY, INTAKE, FULL, EXHAUST } State;
    public:
      CuCache();
      ~CuCache();
     
      /// Initialize the cache
      void Init(size_t cachesize, size_t bunchsize);

      /// Add data to cache, returns number of added vectors
      void AddData(const CuMatrix<BaseFloat>& rFeatures, const CuMatrix<BaseFloat>& rDesired);
      /// Randomizes the cache
      void Randomize();
      /// Get the bunch of training data
      void GetBunch(CuMatrix<BaseFloat>& rFeatures, CuMatrix<BaseFloat>& rDesired);


      /// Returns true if the cache was completely filled
      bool Full()
      { return (mState == FULL); }
      
      /// Returns true if the cache is empty
      bool Empty()
      { return (mState == EMPTY || mIntakePos < mBunchsize); }
      
      /// Number of discarded frames
      int Discarded() 
      { return mDiscarded; }
      
      /// Set the trace message level
      void Trace(int trace)
      { mTrace = trace; }

    private:
    
      static long int GenerateRandom(int max)
      { return lrand48() % max; }
      
      State mState; ///< Current state of the cache

      size_t mIntakePos; ///< Number of intaken vectors by AddData
      size_t mExhaustPos; ///< Number of exhausted vectors by GetBunch
      
      size_t mCachesize; ///< Size of cache
      size_t mBunchsize; ///< Size of bunch
      int mDiscarded; ///< Number of discarded frames

      CuMatrix<BaseFloat> mFeatures; ///< Feature cache
      CuMatrix<BaseFloat> mFeaturesRandom; ///< Feature cache
      CuMatrix<BaseFloat> mFeaturesLeftover; ///< Feature cache
      
      CuMatrix<BaseFloat> mDesired;  ///< Desired vector cache
      CuMatrix<BaseFloat> mDesiredRandom;  ///< Desired vector cache
      CuMatrix<BaseFloat> mDesiredLeftover;  ///< Desired vector cache

      bool mRandomized;

      int mTrace;
  }; 

}

#endif