blob: 42d9b4d42b301829f473cb6b44aafda3fd4f8bfb (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
|
#ifndef _CUCACHE_H_
#define _CUCACHE_H_
#include "cumatrix.h"
namespace TNet {
/**
* \brief The feature-target pair cache
*
* \ingroup CuNNComp
* Preloads mCachesize features and labels to GRAM
*
* During every iterations read mBunchsize data to networks.
*
* When Cache is to be filled with more data that it can hold
* extras are stored in LeftOver, and when data is filled again
* LeftOvers are moved to the Cache.
*
* Note:
* - Cache Size must be divisible by Bunch Size to ensure proper functionality
* - Once extracted data. Cache must be depleted to begin filling or every time extraction start at location Zero.
* - Cache must be filled to begin extraction of Data or we can't start filling and harder to avoid discarding data.
* - @todo Why not implement CuCache as a Stack instead of a Queue?
* .
*/
class CuCache {
typedef enum { EMPTY, INTAKE, FULL, EXHAUST } State;
public:
CuCache();
~CuCache();
/// Initialize the cache
void Init(size_t cachesize, size_t bunchsize);
/// Add data to cache, returns number of added vectors
/// \param[in] rFeatures CuNN Input features data
/// \param[in] rDesired CuNN features data label
void AddData(const CuMatrix<BaseFloat>& rFeatures, const CuMatrix<BaseFloat>& rDesired);
/// Randomizes the cache
void Randomize();
/// Get the bunch of training data
/// \param[out] rFeatures Bunchsize of CuNN Input features data
/// \param[out] rDesired Bunchsize of CuNN features data label
void GetBunch(CuMatrix<BaseFloat>& rFeatures, CuMatrix<BaseFloat>& rDesired);
/// Returns true if the cache was completely filled
bool Full()
{ return (mState == FULL); }
/// Returns true if the cache is empty
bool Empty()
{ return (mState == EMPTY || mIntakePos < mBunchsize); }
/// Number of discarded frames
int Discarded()
{ return mDiscarded; }
/// Set the trace message level
void Trace(int trace)
{ mTrace = trace; }
private:
static long int GenerateRandom(int max)
{ return lrand48() % max; }
State mState; ///< Current state of the cache
size_t mIntakePos; ///< Number of intaken vectors by AddData
size_t mExhaustPos; ///< Number of exhausted vectors by GetBunch
size_t mCachesize; ///< Size of cache
size_t mBunchsize; ///< Size of bunch
int mDiscarded; ///< Number of discarded frames
CuMatrix<BaseFloat> mFeatures; ///< Feature cache
CuMatrix<BaseFloat> mFeaturesRandom; ///< Feature cache
CuMatrix<BaseFloat> mFeaturesLeftover; ///< Feature cache
CuMatrix<BaseFloat> mDesired; ///< Desired vector cache
CuMatrix<BaseFloat> mDesiredRandom; ///< Desired vector cache
CuMatrix<BaseFloat> mDesiredLeftover; ///< Desired vector cache
bool mRandomized;
int mTrace;
};
}
#endif
|