blob: 8aa66d5f9f83082b05760861867728ec6fc88a5a (
plain)
| 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
 | #ifndef _CU_RAND_H_
#define _CU_RAND_H_
#include "cumatrix.h"
namespace TNet {
  
  template<typename T> 
  class CuRand {
   public:
    CuRand(size_t rows, size_t cols)
    { SeedGpu(rows,cols); }
    ~CuRand() { }
    void SeedGpu(size_t rows, size_t cols);
    void Rand(CuMatrix<T>& tgt);
    void GaussRand(CuMatrix<T>& tgt);
    void BinarizeProbs(const CuMatrix<T>& probs, CuMatrix<T>& states);
    void AddGaussNoise(CuMatrix<T>& tgt, T gscale = 1.0);
  
   private:
    static void SeedRandom(Matrix<unsigned>& mat);
     
   private:
    CuMatrix<unsigned> z1, z2, z3, z4;
    CuMatrix<T> tmp;
  };
}
#include "curand.tcc"
#endif
 |