summaryrefslogtreecommitdiff
path: root/src/TNetLib/Cache.cc
diff options
context:
space:
mode:
authorJoe Zhao <ztuowen@gmail.com>2014-04-14 08:14:45 +0800
committerJoe Zhao <ztuowen@gmail.com>2014-04-14 08:14:45 +0800
commitcccccbf6cca94a3eaf813b4468453160e91c332b (patch)
tree23418cb73a10ae3b0688681a7f0ba9b06424583e /src/TNetLib/Cache.cc
downloadtnet-cccccbf6cca94a3eaf813b4468453160e91c332b.tar.gz
tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.tar.bz2
tnet-cccccbf6cca94a3eaf813b4468453160e91c332b.zip
First commit
Diffstat (limited to 'src/TNetLib/Cache.cc')
-rw-r--r--src/TNetLib/Cache.cc248
1 files changed, 248 insertions, 0 deletions
diff --git a/src/TNetLib/Cache.cc b/src/TNetLib/Cache.cc
new file mode 100644
index 0000000..f498318
--- /dev/null
+++ b/src/TNetLib/Cache.cc
@@ -0,0 +1,248 @@
+
+#include <sys/time.h>
+
+#include "Cache.h"
+#include "Matrix.h"
+#include "Vector.h"
+
+
+namespace TNet {
+
+ Cache::
+ Cache()
+ : mState(EMPTY), mIntakePos(0), mExhaustPos(0), mDiscarded(0),
+ mRandomized(false), mTrace(0)
+ { }
+
+ Cache::
+ ~Cache()
+ { }
+
+ void
+ Cache::
+ Init(size_t cachesize, size_t bunchsize, long int seed)
+ {
+ if((cachesize % bunchsize) != 0) {
+ KALDI_ERR << "Non divisible cachesize" << cachesize
+ << " by bunchsize" << bunchsize;
+ }
+
+ mCachesize = cachesize;
+ mBunchsize = bunchsize;
+
+ mState = EMPTY;
+
+ mIntakePos = 0;
+ mExhaustPos = 0;
+
+ mRandomized = false;
+
+ if(seed == 0) {
+ //generate seed
+ struct timeval tv;
+ if (gettimeofday(&tv, 0) == -1) {
+ Error("gettimeofday does not work.");
+ exit(-1);
+ }
+ seed = (int)(tv.tv_sec) + (int)tv.tv_usec + (int)(tv.tv_usec*tv.tv_usec);
+ }
+
+ srand48(seed);
+
+ }
+
+ void
+ Cache::
+ AddData(const Matrix<BaseFloat>& rFeatures, const Matrix<BaseFloat>& rDesired)
+ {
+ assert(rFeatures.Rows() == rDesired.Rows());
+
+ //lazy buffers allocation
+ if(mFeatures.Rows() != mCachesize) {
+ mFeatures.Init(mCachesize,rFeatures.Cols());
+ mDesired.Init(mCachesize,rDesired.Cols());
+ }
+
+ //warn if segment longer than half-cache
+ if(rFeatures.Rows() > mCachesize/2) {
+ std::ostringstream os;
+ os << "Too long segment and small feature cache! "
+ << " cachesize: " << mCachesize
+ << " segmentsize: " << rFeatures.Rows();
+ Warning(os.str());
+ }
+
+ //change state
+ if(mState == EMPTY) {
+ if(mTrace&3) std::cout << "/" << std::flush;
+ mState = INTAKE; mIntakePos = 0;
+
+ //check for leftover from previous segment
+ int leftover = mFeaturesLeftover.Rows();
+ //check if leftover is not bigger than cachesize
+ if(leftover > mCachesize) {
+ std::ostringstream os;
+ os << "Too small feature cache: " << mCachesize
+ << ", truncating: "
+ << leftover - mCachesize << " frames from previous segment leftover";
+ //Error(os.str());
+ Warning(os.str());
+ leftover = mCachesize;
+ }
+ //prefill cache with leftover
+ if(leftover > 0) {
+ memcpy(mFeatures.pData(),mFeaturesLeftover.pData(),
+ (mFeaturesLeftover.MSize() < mFeatures.MSize()?
+ mFeaturesLeftover.MSize() : mFeatures.MSize())
+ );
+ memcpy(mDesired.pData(),mDesiredLeftover.pData(),
+ (mDesiredLeftover.MSize() < mDesired.MSize()?
+ mDesiredLeftover.MSize() : mDesired.MSize())
+ );
+ mFeaturesLeftover.Destroy();
+ mDesiredLeftover.Destroy();
+ mIntakePos += leftover;
+ }
+ }
+
+ assert(mState == INTAKE);
+ assert(rFeatures.Rows() == rDesired.Rows());
+ if(mTrace&2) std::cout << "F" << std::flush;
+
+ int cache_space = mCachesize - mIntakePos;
+ int feature_length = rFeatures.Rows();
+ int fill_rows = (cache_space<feature_length)? cache_space : feature_length;
+ int leftover = feature_length - fill_rows;
+
+ assert(cache_space > 0);
+ assert(mFeatures.Stride()==rFeatures.Stride());
+ assert(mDesired.Stride()==rDesired.Stride());
+
+ //copy the data to cache
+ memcpy(mFeatures.pData()+mIntakePos*mFeatures.Stride(),
+ rFeatures.pData(),
+ fill_rows*mFeatures.Stride()*sizeof(BaseFloat));
+
+ memcpy(mDesired.pData()+mIntakePos*mDesired.Stride(),
+ rDesired.pData(),
+ fill_rows*mDesired.Stride()*sizeof(BaseFloat));
+
+ //copy leftovers
+ if(leftover > 0) {
+ mFeaturesLeftover.Init(leftover,mFeatures.Cols());
+ mDesiredLeftover.Init(leftover,mDesired.Cols());
+
+ memcpy(mFeaturesLeftover.pData(),
+ rFeatures.pData()+fill_rows*rFeatures.Stride(),
+ mFeaturesLeftover.MSize());
+
+ memcpy(mDesiredLeftover.pData(),
+ rDesired.pData()+fill_rows*rDesired.Stride(),
+ mDesiredLeftover.MSize());
+ }
+
+ //update cursor
+ mIntakePos += fill_rows;
+
+ //change state
+ if(mIntakePos == mCachesize) {
+ if(mTrace&3) std::cout << "\\" << std::flush;
+ mState = FULL;
+ }
+ }
+
+
+
+ void
+ Cache::
+ Randomize()
+ {
+ assert(mState == FULL || mState == INTAKE);
+
+ if(mTrace&3) std::cout << "R" << std::flush;
+
+ //lazy initialization of the output buffers
+ mFeaturesRandom.Init(mCachesize,mFeatures.Cols());
+ mDesiredRandom.Init(mCachesize,mDesired.Cols());
+
+ //generate random series of integers
+ Vector<int> randmask(mIntakePos);
+ for(unsigned int i=0; i<mIntakePos; i++) {
+ randmask[i]=i;
+ }
+ int* ptr = randmask.pData();
+ std::random_shuffle(ptr, ptr+mIntakePos, GenerateRandom);
+
+ //randomize
+ for(int i=0; i<randmask.Dim(); i++) {
+ mFeaturesRandom[i].Copy(mFeatures[randmask[i]]);
+ mDesiredRandom[i].Copy(mDesired[randmask[i]]);
+ }
+
+ mRandomized = true;
+ }
+
+ void
+ Cache::
+ GetBunch(Matrix<BaseFloat>& rFeatures, Matrix<BaseFloat>& rDesired)
+ {
+ if(mState == EMPTY) {
+ Error("GetBunch on empty cache!!!");
+ }
+
+ //change state if full...
+ if(mState == FULL) {
+ if(mTrace&3) std::cout << "\\" << std::flush;
+ mState = EXHAUST; mExhaustPos = 0;
+ }
+
+ //final cache is not completely filled
+ if(mState == INTAKE) {
+ if(mTrace&3) std::cout << "\\-LAST_CACHE\n" << std::flush;
+ mState = EXHAUST; mExhaustPos = 0;
+ }
+
+ assert(mState == EXHAUST);
+
+ //init the output
+ if(rFeatures.Rows()!=mBunchsize || rFeatures.Cols()!=mFeatures.Cols()) {
+ rFeatures.Init(mBunchsize,mFeatures.Cols());
+ }
+ if(rDesired.Rows()!=mBunchsize || rDesired.Cols()!=mDesired.Cols()) {
+ rDesired.Init(mBunchsize,mDesired.Cols());
+ }
+
+ //copy the output
+ if(mRandomized) {
+ memcpy(rFeatures.pData(),
+ mFeaturesRandom.pData()+mExhaustPos*mFeatures.Stride(),
+ rFeatures.MSize());
+
+ memcpy(rDesired.pData(),
+ mDesiredRandom.pData()+mExhaustPos*mDesired.Stride(),
+ rDesired.MSize());
+ } else {
+ memcpy(rFeatures.pData(),
+ mFeatures.pData()+mExhaustPos*mFeatures.Stride(),
+ rFeatures.MSize());
+
+ memcpy(rDesired.pData(),
+ mDesired.pData()+mExhaustPos*mDesired.Stride(),
+ rDesired.MSize());
+ }
+
+
+ //update cursor
+ mExhaustPos += mBunchsize;
+
+ //change state to EMPTY
+ if(mExhaustPos > mIntakePos-mBunchsize) {
+ //we don't have more complete bunches...
+ mDiscarded += mIntakePos - mExhaustPos;
+
+ mState = EMPTY;
+ }
+ }
+
+
+}