353
社区成员
发帖
与我相关
我的任务
分享
template <typename Dtype>
__global__ void aaal(const int n, const Dtype* data_im, const int channels,
const int height, const int width, const int ksize, const int pad,
const int stride, const int height_col, const int width_col,
Dtype* data_col, const int c_num, const int h_num, const int w_num, const int ksmks_num) {
for(int i = blockIdx.x; i < n; i += gridDim.x){
// blocks: (c_blockidx, h_blockidx, w_blockidx, ksmks_blockidx)
int ksmks_blockidx = i%ksmks_num; //ksmks is ksize * ksize
int blockid = __fdividef(i,ksmks_num);
int w_blockidx = blockid%w_num;
blockid = __fdividef(blockid,w_num);
int h_blockidx = blockid%h_num;
int c_blockidx = __fdividef(blockid,h_num);
// threads: (c_threadidx, h_threadidx, w_threadidx, ksmks_threadidx)
int ksmks_threadidx = threadIdx.x%ksmks_dim;
int ksw = (ksmks_blockidx*ksmks_dim+ksmks_threadidx)%ksize;
int ksh = __fdividef(ksmks_blockidx*ksmks_dim+ksmks_threadidx,ksize);
int threadid = __fdividef(threadIdx.x,ksmks_dim);
int w_threadidx = threadid%w_dim;
threadid = __fdividef(threadid,w_dim);
int h_threadidx = threadid%h_dim;
int c_threadidx = __fdividef(threadid,h_dim);
int channel_in = c_blockidx*c_dim+c_threadidx;
int w_out = w_blockidx*w_dim+w_threadidx;
int h_out = h_blockidx*h_dim+h_threadidx;
int h_in = h_out * stride - pad;
int w_in = w_out * stride - pad;
// shared memory
__shared__ Dtype tempmatrix[w_dim*h_dim*c_dim*ksmks_dim];
//
// read data from global memory
//
bool flag = (channel_in<channels && h_in<height && w_in<width && ksmks_blockidx*ksmks_num+ksmks_threadidx < ksize*ksize);
tempmatrix[((c_threadidx * h_dim + h_threadidx) * w_dim + w_threadidx)*ksmks_dim+ksmks_threadidx] = flag ? data_im[(channel_in * height + h_in) * width + w_in+ksh*width+ksw] : 0;
__syncthreads();
// blocks: (c_blockidx, h_blockidx, w_blockidx, ksmks_blockidx)
// threads: (h_threadidx, w_threadidx, c_threadidx, ksmks_threadidx)
threadid = __fdividef(threadIdx.x,ksmks_dim);
c_threadidx = threadid%c_dim;
threadid = __fdividef(threadid,c_dim);
w_threadidx = threadid%w_dim;
h_threadidx = __fdividef(threadid,w_dim);
w_out = w_blockidx*w_dim+w_threadidx;
h_out = h_blockidx*h_dim+h_threadidx;
channel_in = c_blockidx*c_dim+c_threadidx;
int channel_out = channel_in * ksize * ksize;
//
// write data into global memory
//
flag = (channel_in<channels && h_out<height_col && w_out<width_col && ksmks_blockidx*ksmks_num+ksmks_threadidx < ksize*ksize);
if(flag){
data_col += ( h_out * width_col + w_out) * channels * ksize * ksize + channel_out+ksh*ksize+ksw;
int h = h_in + ksh;
int w = w_in + ksw;
*data_col = (h >= 0 && w >= 0 && h < height && w < width) ?
//tempmatrix[((c_threadidx * h_dim + h_threadidx) * w_dim + w_threadidx)*ksmks_dim+ksmks_threadidx] : 0;
tempmatrix[((c_threadidx * h_dim + h_threadidx) * w_dim + w_threadidx)*ksmks_dim+ksmks_threadidx] : 0;
}
}
}