From 5ae0da8484744859e09fad869b44dccdb5f66f2f Mon Sep 17 00:00:00 2001 From: Tuowen Zhao Date: Sun, 10 Nov 2019 23:48:57 -0700 Subject: Add subgroup suggestion to sycl --- main.cpp | 47 ++++++++++++++++++++++++++++++++--------------- 1 file changed, 32 insertions(+), 15 deletions(-) diff --git a/main.cpp b/main.cpp index 469ebac..059d8d0 100644 --- a/main.cpp +++ b/main.cpp @@ -12,17 +12,17 @@ #define TILEI 16 #define ITER 10 -std::string demangle(const char* name) { +std::string demangle(const char *name) { int status = -4; // some arbitrary value to eliminate the compiler warning // enable c++11 by passing the flag -std=c++11 to g++ - std::unique_ptr res { + std::unique_ptr res{ abi::__cxa_demangle(name, NULL, NULL, &status), std::free }; - return (status==0) ? res.get() : name ; + return (status == 0) ? res.get() : name; } using namespace cl::sycl; @@ -30,26 +30,42 @@ using namespace cl::sycl; template class subgr; +template +class SGfunctor { +public: + SGfunctor(accessor sg_sz, + accessor sg_i) : sg_sz(sg_sz), sg_i(sg_i) {} + + [[cl::intel_reqd_sub_group_size(16)]] + void operator()(nd_item<1> NdItem) { + intel::sub_group SG = NdItem.get_sub_group(); + uint32_t wggid = NdItem.get_global_id(0); + uint32_t sgid = SG.get_local_id().get(0); + if (wggid == 0) + sg_sz[0] = SG.get_max_local_range()[0]; + sgid = SG.shuffle_up(sgid, 2); + sg_i[wggid] = sgid; + } + +private: + accessor sg_i; + accessor sg_sz; +}; + template void subkrnl(device &Device, size_t G = 256, size_t L = 64) { buffer sg_buf{1}; buffer sg_info{G}; queue Queue(Device); nd_range<1> NumOfWorkItems(G, L); - Queue.submit([&](handler &cgh) { + auto kernel = [&](handler &cgh) -> void { auto sg_sz = sg_buf.template get_access(cgh); auto sg_i = sg_info.template get_access(cgh); + SGfunctor sGfunctor(sg_sz, sg_i); - cgh.parallel_for>(NumOfWorkItems, [=](nd_item<1> NdItem) { - intel::sub_group SG = NdItem.get_sub_group(); - uint32_t wggid = NdItem.get_global_id(0); - uint32_t sgid = SG.get_local_id().get(0); - if (wggid == 0) - sg_sz[0] = SG.get_max_local_range()[0]; - sgid = SG.shuffle_up(sgid, 2); - sg_i[wggid] = sgid; - }); - }); + cgh.parallel_for>(NumOfWorkItems, sGfunctor); + }; + Queue.submit(kernel); Queue.wait(); auto sg_sz = sg_buf.template get_access(); @@ -133,7 +149,8 @@ void run27pt(device &Device) { const auto out_h = out_buf.get_access(); const auto in_h = in_buf.get_access(); double ed = omp_get_wtime(); - double elapsed = (ed_event.get_profiling_info() - st_event.get_profiling_info()) * 1e-9; + double elapsed = (ed_event.get_profiling_info() - + st_event.get_profiling_info()) * 1e-9; std::cout << "elapsed: " << (ed - st) / ITER << std::endl; std::cout << "elapsed: " << elapsed << std::endl; std::cout << "flops: " << N * N * N * 53.0 * ITER / elapsed * 1e-9 << std::endl; -- cgit v1.2.3-70-g09d2