#include #include "cukernels.h" /***************** * CUDA kernels */ //CuMatrix template __global__ static void _set_const(T* mat, T value, MatrixDim d) { int i = blockIdx.x * blockDim.x + threadIdx.x; int j = blockIdx.y * blockDim.y + threadIdx.y; int index = i + j*d.stride; if ( i < d.cols && j < d.rows ) mat[index] = value; } template __global__ static void _apply_log(T* mat, MatrixDim d) { int i = blockIdx.x * blockDim.x + threadIdx.x; int j = blockIdx.y * blockDim.y + threadIdx.y; int index = i + j*d.stride; if ( i < d.cols && j < d.rows ) mat[index] = log(mat[index]); } template __global__ static void _apply_mask(T* mat, const float* mask, MatrixDim dmat, MatrixDim dmask) { int i = blockIdx.x * blockDim.x + threadIdx.x; int j = blockIdx.y * blockDim.y + threadIdx.y; int index = i + j*dmat.stride; int index2 = i + j*dmask.stride; if ( i < dmat.cols && j < dmat.rows ) if(mask[index2] == 0) mat[index] = 0; } template __global__ static void _apply_l1(T* mat, T l1, MatrixDim d) { int i = blockIdx.x * blockDim.x + threadIdx.x; int j = blockIdx.y * blockDim.y + threadIdx.y; int index = i + j*d.stride; if ( i < d.cols && j < d.rows ) { T value = mat[index]; T tgt; if(abs(value) < l1) { tgt = 0; } else { tgt = (value > 0?value-l1:value+l1); } mat[index] = tgt; } } template __global__ static void _scale_cols(T* mat, const T* scale, MatrixDim d) { int i = blockIdx.x * blockDim.x + threadIdx.x; int j = blockIdx.y * blockDim.y + threadIdx.y; int index = i + j*d.stride; if ( i < d.cols && j < d.rows ) mat[index] *= scale[i]; } template __global__ static void _scale_rows(T* mat, const T* scale, MatrixDim d) { int i = blockIdx.x * blockDim.x + threadIdx.x; int j = blockIdx.y * blockDim.y + threadIdx.y; int index = i + j*d.stride; if ( i < d.cols && j < d.rows ) mat[index] *= scale[j]; } template __global__ static void _add_scaled(T alpha, const T* A, T beta, T* dst, MatrixDim d) { int i = blockIdx.x * blockDim.x + threadIdx.x; int j = blockIdx.y * blockDim.y + threadIdx.y; int index = i + j*d.stride; if ( i < d.cols && j < d.rows ) dst[index] = alpha*A[index] + beta*dst[index]; } template __global__ static void _add_scaled_row(T alpha, const T* row, T beta, T* dst, MatrixDim d) { int i = blockIdx.x * blockDim.x + threadIdx.x; int j = blockIdx.y * blockDim.y + threadIdx.y; int index = i + j*d.stride; #if 0 //this does not accelerate :( __shared__ T aux[16]; if(threadIdx.y == 0 && i < d.cols) aux[threadIdx.x] = row[i]; __syncthreads(); if ( i < d.cols && j < d.rows ) dst[index] = alpha*aux[threadIdx.x] + beta*dst[index]; #else if ( i < d.cols && j < d.rows ) dst[index] = alpha*row[i] + beta*dst[index]; #endif } template __global__ static void _mul_elem(T* mat, const T* A, MatrixDim d) { int i = blockIdx.x * blockDim.x + threadIdx.x; int j = blockIdx.y * blockDim.y + threadIdx.y; int index = i + j*d.stride; if ( i < d.cols && j < d.rows ) mat[index] = mat[index] * A[index]; } template __global__ static void _log_elem(T* mat, MatrixDim d) { int i = blockIdx.x * blockDim.x + threadIdx.x; int j = blockIdx.y * blockDim.y + threadIdx.y; int index = i + j*d.stride; if ( i < d.cols && j < d.rows ) { if(mat[index] < FLT_MIN) mat[index] = FLT_MIN; mat[index] = log(mat[index]); } } //CuVector template __global__ static void _add_col_sum(T alpha, const T* mat, T beta, T* vec, MatrixDim d) { int i = blockIdx.x * blockDim.x + threadIdx.x; //This should be called 1-D int j = blockIdx.y * blockDim.y + threadIdx.y; if(j > 0) return; if(i < d.cols) { double sum = 0.0; for(int k = 0; k < d.rows; k++) { sum += mat[i+k*d.stride]; } vec[i] = alpha*sum + beta*vec[i]; } } template __global__ static void _add_col_sum_reduce(T alpha, const T* mat, T beta, T* vec, MatrixDim d) { //flipped x,y for reducing... x..row, y..col int j = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.y * blockDim.y + threadIdx.y; if(blockIdx.x > 0) return; if(blockDim.y != 1) return; //copy vector to shared mem __shared__ T aux[512]; aux[threadIdx.x] = mat[i+j*d.stride]; __syncthreads(); T sum = _sum_reduce(aux); __syncthreads(); //copy out the result vec[i] = alpha*sum + beta*vec[i]; } //CuMath template __global__ static void _sigmoid(T*y, const T*x, MatrixDim d) { int i = blockIdx.x * blockDim.x + threadIdx.x; int j = blockIdx.y * blockDim.y + threadIdx.y; int index = i + j*d.stride; if( i < d.cols && j < d.rows ) { T res = 1.0 / (1.0 + exp(-x[index])); /* if(res < 0.001) res = 0.001; if(res > 0.999) res = 0.999; */ y[index] = res; } } template __global__ static void _diff_sigmoid(T*eout, const T*e, const T*y, MatrixDim d) { int i = blockIdx.x * blockDim.x + threadIdx.x; int j = blockIdx.y * blockDim.y + threadIdx.y; int index = i + j*d.stride; if( i < d.cols && j < d.rows ) eout[index] = y[index]*(1.0-y[index]) * e[index]; } template __global__ static void _softmax(T*y, const T*x, MatrixDim d) { int j = blockIdx.x * blockDim.x + threadIdx.x; if(j >= d.rows) return; //copy to output and find max... double max = -1e20; double sum = 0.0; for(int i=0; i __device__ static T _max_reduce(T buffer[]) { // Total number of active threads int nTotalThreads = blockDim.x; __syncthreads(); while(nTotalThreads > 1) { int halfPoint = ((1+nTotalThreads) >> 1); // divide by two // only the first half of the threads will be active. if (threadIdx.x < halfPoint) { // Get the shared value stored by another thread T temp = -1e20; if(threadIdx.x+halfPoint < nTotalThreads) { temp = buffer[threadIdx.x + halfPoint]; } if (temp > buffer[threadIdx.x]) buffer[threadIdx.x] = temp; } __syncthreads(); nTotalThreads = ((1+nTotalThreads) >> 1); // divide by two. } // the result return buffer[0]; } template __device__ static T _sum_reduce(T buffer[]) { // Total number of active threads int nTotalThreads = blockDim.x; __syncthreads(); while(nTotalThreads > 1) { int halfPoint = ((1+nTotalThreads) >> 1); // divide by two // only the first half of the threads will be active. if (threadIdx.x < halfPoint) { // Get the shared value stored by another thread T temp = 0.0; if(threadIdx.x+halfPoint < nTotalThreads) { temp = buffer[threadIdx.x + halfPoint]; } buffer[threadIdx.x] += temp; } __syncthreads(); nTotalThreads = ((1+nTotalThreads) >> 1); // divide by two. } // the result return buffer[0]; } template __global__ static void _softmax_reduce(T*y, const T*x, MatrixDim d) { int i = blockIdx.x * blockDim.x + threadIdx.x; int j = blockIdx.y * blockDim.y + threadIdx.y; if(blockIdx.x > 0) return; if(blockDim.y > 1) return; __shared__ T row_data[256]; __shared__ T aux[256]; //copy the input to row_data row_data[i] = x[i+j*d.stride]; __syncthreads(); //copy input to aux aux[i] = row_data[i]; __syncthreads(); //get the maximum value T max = _max_reduce(aux); __syncthreads(); //calculate exp(data-max) row_data[i] = exp(row_data[i]-max); //copy the values to aux aux[i] = row_data[i]; __syncthreads(); //get the sum T sum = _sum_reduce(aux); __syncthreads(); //divide the values row_data[i] /= sum; //copy out y[i+j*d.stride] = row_data[i]; } template __global__ static void _expand(T* y, const T* x, const int* off, MatrixDim d_out, MatrixDim d_in) { int i = blockIdx.x * blockDim.x + threadIdx.x; int j = blockIdx.y * blockDim.y + threadIdx.y; int index = i + j*d_out.stride; if( i < d_out.cols && j < d_out.rows ) { int src_col = i % d_in.cols; int src_row = j + off[i / d_in.cols]; if(src_row < 0) src_row = 0; if(src_row >= d_in.rows) src_row = d_in.rows-1; y[index] = x[src_col + src_row*d_in.stride]; } } template __global__ static void _rearrange(T* y, const T* x, const int* copy_from, MatrixDim d_out, MatrixDim d_in) { int i = blockIdx.x * blockDim.x + threadIdx.x; int j = blockIdx.y * blockDim.y + threadIdx.y; int index = i + j*d_out.stride; if( i < d_out.cols && j < d_out.rows ) { int src_col = copy_from[i]; if(src_col >= 0 && src_col < d_in.cols) { y[index] = x[src_col + j*d_in.stride]; } else { y[index] = 1.0/0.0; } } } template __global__ static void _randomize(T* y, const T* x, const int* copy_from, MatrixDim d_out, MatrixDim d_in) { int i = blockIdx.x * blockDim.x + threadIdx.x; int j = blockIdx.y * blockDim.y + threadIdx.y; int index = i + j*d_out.stride; if( i < d_out.cols && j < d_out.rows ) { int src_row = copy_from[j]; y[index] = x[i + src_row*d_in.stride]; } } template __global__ static void _check_class(const T* out, const T* des, int* match, MatrixDim d) { int i = blockIdx.x * blockDim.x + threadIdx.x; int j = blockIdx.y * blockDim.y + threadIdx.y; if(j>0) return; if(i out_max) { out_max = val; out_id = k; } } for(int k=0; k des_max) { des_max = val; des_id = k; } } match[i] = ((out_id == des_id)?1:0); } } template __device__ static int _max_id_reduce(T val[],int idx[]) { // Total number of active threads int nTotalThreads = blockDim.x; __syncthreads(); while(nTotalThreads > 1) { int halfPoint = ((1+nTotalThreads) >> 1); // divide by two // only the first half of the threads will be active. if (threadIdx.x < halfPoint) { // Get the shared value stored by another thread T temp = -1e20; if(threadIdx.x+halfPoint < nTotalThreads) { temp = val[idx[threadIdx.x + halfPoint]]; } if (temp > val[idx[threadIdx.x]]) idx[threadIdx.x]=idx[threadIdx.x + halfPoint]; } __syncthreads(); nTotalThreads = ((1+nTotalThreads) >> 1); // divide by two. } // the result return idx[0]; } template __global__ static void _check_class_reduce(const T* out, const T* des, int* match, MatrixDim d) { int i = blockIdx.x * blockDim.x + threadIdx.x; int j = blockIdx.y * blockDim.y + threadIdx.y; if(blockIdx.x > 0) return; if(blockDim.y != 1) return; __shared__ T value[256]; __shared__ int index[256]; value[threadIdx.x] = out[i+j*d.stride]; index[threadIdx.x] = threadIdx.x; __syncthreads(); int out_max = _max_id_reduce(value,index); __syncthreads(); value[threadIdx.x] = des[i+j*d.stride]; index[threadIdx.x] = threadIdx.x; __syncthreads(); int des_max = _max_id_reduce(value,index); __syncthreads(); if(threadIdx.x == 0) { match[j] = ((out_max == des_max)?1:0); } } /************** * C wrappers around CUDA kernels */ //:FLOAT: //CuMatrix void cudaF_set_const(dim3 Gr, dim3 Bl, float* mat, float value, MatrixDim d) { _set_const<<>>(mat,value,d); } void cudaF_apply_log(dim3 Gr, dim3 Bl, float* mat, MatrixDim d) { _apply_log<<>>(mat,d); } void cudaF_apply_mask(dim3 Gr, dim3 Bl, float* mat, const float* mask, MatrixDim dmat, MatrixDim dmask) { _apply_mask<<>>(mat,mask,dmat,dmask); } void cudaF_apply_l1(dim3 Gr, dim3 Bl, float* mat, float l1, MatrixDim d) { _apply_l1<<>>(mat,l1,d); } void cudaF_scale_cols(dim3 Gr, dim3 Bl, float* mat, const float* scale, MatrixDim d) { _scale_cols<<>>(mat,scale,d); } void cudaF_scale_rows(dim3 Gr, dim3 Bl, float* mat, const float* scale, MatrixDim d) { _scale_rows<<>>(mat,scale,d); } void cudaF_add_scaled(dim3 Gr, dim3 Bl, float alpha, const float* A, float beta, float* dst, MatrixDim d) { _add_scaled<<>>(alpha,A,beta,dst,d); } void cudaF_add_scaled_row(dim3 Gr, dim3 Bl, float alpha, const float* row, float beta, float* dst, MatrixDim d) { _add_scaled_row<<>>(alpha,row,beta,dst,d); } void cudaF_mul_elem(dim3 Gr, dim3 Bl, float*mat, const float*A, MatrixDim d) { _mul_elem<<>>(mat,A,d); } void cudaF_log_elem(dim3 Gr, dim3 Bl, float*mat, MatrixDim d) { _log_elem<<>>(mat,d); } //CuVector void cudaF_add_col_sum(size_t Gr, size_t Bl, float alpha, const float* mat, float beta, float* vec, MatrixDim d) { _add_col_sum<<>>(alpha,mat,beta,vec,d); } void cudaF_add_col_sum_reduce(dim3 Gr, dim3 Bl, float alpha, const float* mat, float beta, float* vec, MatrixDim d) { _add_col_sum_reduce<<>>(alpha,mat,beta,vec,d); } //CuMath void cudaF_sigmoid (dim3 Gr, dim3 Bl, float *y, const float*x, MatrixDim d) { _sigmoid<<>>(y, x, d); } void cudaF_diff_sigmoid (dim3 Gr, dim3 Bl, float*eout, const float*e, const float*y, MatrixDim d) { _diff_sigmoid<<>>(eout, e, y, d); } void cudaF_softmax (size_t Gr, size_t Bl, float*y, const float*x, MatrixDim d) { _softmax<<>>(y, x, d); } void cudaF_softmax_reduce (dim3 Gr, dim3 Bl, float*y, const float*x, MatrixDim d) { _softmax_reduce<<>>(y, x, d); } void cudaF_expand(dim3 Gr, dim3 Bl, float* y, const float* x, const int* off, MatrixDim d_out, MatrixDim d_in) { _expand<<>>(y,x,off,d_out,d_in); } void cudaF_rearrange(dim3 Gr, dim3 Bl, float* y, const float* x, const int* copy_from, MatrixDim d_out, MatrixDim d_in) { _rearrange<<>>(y,x,copy_from,d_out,d_in); } void cudaF_randomize(dim3 Gr, dim3 Bl, float* y, const float* x, const int* copy_from, MatrixDim d_out, MatrixDim d_in) { _randomize<<>>(y,x,copy_from,d_out,d_in); } void cudaF_check_class(size_t Gr, size_t Bl, const float* out, const float* des, int* match, MatrixDim d) { _check_class<<>>(out,des,match,d); } void cudaF_check_class_reduce(dim3 Gr, dim3 Bl, const float* out, const float* des, int* match, MatrixDim d) { _check_class_reduce<<>>(out,des,match,d); } //:DOUBLE: //CuMatrix void cudaD_set_const(dim3 Gr, dim3 Bl, double* mat, double value, MatrixDim d) { _set_const<<>>(mat,value,d); } void cudaD_apply_log(dim3 Gr, dim3 Bl, double* mat, MatrixDim d) { _apply_log<<>>(mat,d); } void cudaD_scale_cols(dim3 Gr, dim3 Bl, double* mat, const double* scale, MatrixDim d) { _scale_cols<<>>(mat,scale,d); } void cudaD_scale_rows(dim3 Gr, dim3 Bl, double* mat, const double* scale, MatrixDim d) { _scale_rows<<>>(mat,scale,d); } void cudaD_add_scaled(dim3 Gr, dim3 Bl, double alpha, const double* A, double beta, double* dst, MatrixDim d) { _add_scaled<<>>(alpha,A,beta,dst,d); } void cudaD_add_scaled_row(dim3 Gr, dim3 Bl, double alpha, const double* row, double beta, double* dst, MatrixDim d) { _add_scaled_row<<>>(alpha,row,beta,dst,d); } void cudaD_mul_elem(dim3 Gr, dim3 Bl, double*mat, const double*A, MatrixDim d) { _mul_elem<<>>(mat,A,d); } void cudaD_log_elem(dim3 Gr, dim3 Bl, double*mat, MatrixDim d) { _log_elem<<>>(mat,d); } //CuVector void cudaD_add_col_sum(size_t Gr, size_t Bl, double alpha, const double* mat, double beta, double* vec, MatrixDim d) { _add_col_sum<<>>(alpha,mat,beta,vec,d); } //CuMath void cudaD_sigmoid (dim3 Gr, dim3 Bl, double *y, const double*x, MatrixDim d) { _sigmoid<<>>(y, x, d); } void cudaD_diff_sigmoid (dim3 Gr, dim3 Bl, double*eout, const double*e, const double*y, MatrixDim d) { _diff_sigmoid<<>>(eout, e, y, d); } void cudaD_softmax (size_t Gr, size_t Bl, double*y, const double*x, MatrixDim d) { _softmax<<>>(y, x, d); } void cudaD_expand(dim3 Gr, dim3 Bl, double* y, const double* x, const int* off, MatrixDim d_out, MatrixDim d_in) { _expand<<>>(y,x,off,d_out,d_in); } void cudaD_rearrange(dim3 Gr, dim3 Bl, double* y, const double* x, const int* copy_from, MatrixDim d_out, MatrixDim d_in) { _rearrange<<>>(y,x,copy_from,d_out,d_in); } void cudaD_randomize(dim3 Gr, dim3 Bl, double* y, const double* x, const int* copy_from, MatrixDim d_out, MatrixDim d_in) { _randomize<<>>(y,x,copy_from,d_out,d_in); } void cudaD_check_class(size_t Gr, size_t Bl, const double* out, const double* des, int* match, MatrixDim d) { _check_class<<>>(out,des,match,d); }