summaryrefslogtreecommitdiff
path: root/src/CuBaseLib/.svn/text-base/curand.h.svn-base
blob: 8aa66d5f9f83082b05760861867728ec6fc88a5a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#ifndef _CU_RAND_H_
#define _CU_RAND_H_


#include "cumatrix.h"


namespace TNet {
  
  template<typename T> 
  class CuRand {
   public:

    CuRand(size_t rows, size_t cols)
    { SeedGpu(rows,cols); }

    ~CuRand() { }

    void SeedGpu(size_t rows, size_t cols);
    void Rand(CuMatrix<T>& tgt);
    void GaussRand(CuMatrix<T>& tgt);

    void BinarizeProbs(const CuMatrix<T>& probs, CuMatrix<T>& states);
    void AddGaussNoise(CuMatrix<T>& tgt, T gscale = 1.0);
  
   private:
    static void SeedRandom(Matrix<unsigned>& mat);
     
   private:
    CuMatrix<unsigned> z1, z2, z3, z4;
    CuMatrix<T> tmp;
  };

}


#include "curand.tcc"


#endif