#include "cuCache.h" #include "cumath.h" namespace TNet { CuCache:: CuCache() : mState(EMPTY), mIntakePos(0), mExhaustPos(0), mDiscarded(0), mRandomized(false), mTrace(0) { } CuCache:: ~CuCache() { } void CuCache:: Init(size_t cachesize, size_t bunchsize) { if((cachesize % bunchsize) != 0) { Error("Non divisible cachesize by bunchsize"); } mCachesize = cachesize; mBunchsize = bunchsize; mState = EMPTY; mIntakePos = 0; mExhaustPos = 0; mRandomized = false; } void CuCache:: AddData(const CuMatrix& rFeatures, const CuMatrix& rDesired) { assert(rFeatures.Rows() == rDesired.Rows()); //lazy buffers allocation if(mFeatures.Rows() != mCachesize) { mFeatures.Init(mCachesize,rFeatures.Cols()); mDesired.Init(mCachesize,rDesired.Cols()); } //warn if segment longer than half-cache if(rFeatures.Rows() > mCachesize/2) { std::ostringstream os; os << "Too long segment and small feature cache! " << " cachesize: " << mCachesize << " segmentsize: " << rFeatures.Rows(); Warning(os.str()); } //change state if(mState == EMPTY) { if(mTrace&3) std::cout << "/" << std::flush; mState = INTAKE; mIntakePos = 0; //check for leftover from previous segment int leftover = mFeaturesLeftover.Rows(); //check if leftover is not bigger than cachesize if(leftover > mCachesize) { std::ostringstream os; os << "Too small feature cache: " << mCachesize << ", truncating: " << leftover - mCachesize << " frames from previous segment leftover"; //Error(os.str()); Warning(os.str()); leftover = mCachesize; } //prefill cache with leftover if(leftover > 0) { mFeatures.CopyRows(leftover,0,mFeaturesLeftover,0); mDesired.CopyRows(leftover,0,mDesiredLeftover,0); mFeaturesLeftover.Destroy(); mDesiredLeftover.Destroy(); mIntakePos += leftover; } } assert(mState == INTAKE); assert(rFeatures.Rows() == rDesired.Rows()); if(mTrace&2) std::cout << "F" << std::flush; int cache_space = mCachesize - mIntakePos; int feature_length = rFeatures.Rows(); int fill_rows = (cache_space 0); //copy the data to cache mFeatures.CopyRows(fill_rows,0,rFeatures,mIntakePos); mDesired.CopyRows(fill_rows,0,rDesired,mIntakePos); //copy leftovers if(leftover > 0) { mFeaturesLeftover.Init(leftover,mFeatures.Cols()); mDesiredLeftover.Init(leftover,mDesired.Cols()); mFeaturesLeftover.CopyRows(leftover,fill_rows,rFeatures,0); mDesiredLeftover.CopyRows(leftover,fill_rows,rDesired,0); } //update cursor mIntakePos += fill_rows; //change state if(mIntakePos == mCachesize) { if(mTrace&3) std::cout << "\\" << std::flush; mState = FULL; } } void CuCache:: Randomize() { assert(mState == FULL || mState == INTAKE); if(mTrace&3) std::cout << "R" << std::flush; //lazy initialization of hte output buffers mFeaturesRandom.Init(mCachesize,mFeatures.Cols()); mDesiredRandom.Init(mCachesize,mDesired.Cols()); //generate random series of integers Vector randmask(mIntakePos); for(unsigned int i=0; i cu_randmask; cu_randmask.CopyFrom(randmask); //randomize CuMath::Randomize(mFeaturesRandom,mFeatures,cu_randmask); CuMath::Randomize(mDesiredRandom,mDesired,cu_randmask); mRandomized = true; } void CuCache:: GetBunch(CuMatrix& rFeatures, CuMatrix& rDesired) { if(mState == EMPTY) { Error("GetBunch on empty cache!!!"); } //change state if full... if(mState == FULL) { if(mTrace&3) std::cout << "\\" << std::flush; mState = EXHAUST; mExhaustPos = 0; } //final cache is not completely filled if(mState == INTAKE) //&& mpFeatures->EndOfList() { if(mTrace&3) std::cout << "\\-LAST\n" << std::flush; mState = EXHAUST; mExhaustPos = 0; } assert(mState == EXHAUST); //init the output rFeatures.Init(mBunchsize,mFeatures.Cols()); rDesired.Init(mBunchsize,mDesired.Cols()); //copy the output if(mRandomized) { rFeatures.CopyRows(mBunchsize,mExhaustPos,mFeaturesRandom,0); rDesired.CopyRows(mBunchsize,mExhaustPos,mDesiredRandom,0); } else { rFeatures.CopyRows(mBunchsize,mExhaustPos,mFeatures,0); rDesired.CopyRows(mBunchsize,mExhaustPos,mDesired,0); } //update cursor mExhaustPos += mBunchsize; //change state to EMPTY if(mExhaustPos > mIntakePos-mBunchsize) { //we don't have more complete bunches... mDiscarded += mIntakePos - mExhaustPos; mState = EMPTY; } } }