#ifndef _CU_RAND_H_ #define _CU_RAND_H_ #include "cumatrix.h" namespace TNet { template class CuRand { public: CuRand(size_t rows, size_t cols) { SeedGpu(rows,cols); } ~CuRand() { } void SeedGpu(size_t rows, size_t cols); void Rand(CuMatrix& tgt); void GaussRand(CuMatrix& tgt); void BinarizeProbs(const CuMatrix& probs, CuMatrix& states); void AddGaussNoise(CuMatrix& tgt, T gscale = 1.0); private: static void SeedRandom(Matrix& mat); private: CuMatrix z1, z2, z3, z4; CuMatrix tmp; }; } #include "curand.tcc" #endif