#ifndef _CU_RAND_H_ #define _CU_RAND_H_ #include "cumatrix.h" namespace TNet { template<typename T> class CuRand { public: CuRand(size_t rows, size_t cols) { SeedGpu(rows,cols); } ~CuRand() { } void SeedGpu(size_t rows, size_t cols); void Rand(CuMatrix<T>& tgt); void GaussRand(CuMatrix<T>& tgt); void BinarizeProbs(const CuMatrix<T>& probs, CuMatrix<T>& states); void AddGaussNoise(CuMatrix<T>& tgt, T gscale = 1.0); private: static void SeedRandom(Matrix<unsigned>& mat); private: CuMatrix<unsigned> z1, z2, z3, z4; CuMatrix<T> tmp; }; } #include "curand.tcc" #endif