blob: 8aa66d5f9f83082b05760861867728ec6fc88a5a (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
|
#ifndef _CU_RAND_H_
#define _CU_RAND_H_
#include "cumatrix.h"
namespace TNet {
template<typename T>
class CuRand {
public:
CuRand(size_t rows, size_t cols)
{ SeedGpu(rows,cols); }
~CuRand() { }
void SeedGpu(size_t rows, size_t cols);
void Rand(CuMatrix<T>& tgt);
void GaussRand(CuMatrix<T>& tgt);
void BinarizeProbs(const CuMatrix<T>& probs, CuMatrix<T>& states);
void AddGaussNoise(CuMatrix<T>& tgt, T gscale = 1.0);
private:
static void SeedRandom(Matrix<unsigned>& mat);
private:
CuMatrix<unsigned> z1, z2, z3, z4;
CuMatrix<T> tmp;
};
}
#include "curand.tcc"
#endif
|