summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTuowen Zhao <ztuowen@gmail.com>2019-11-10 23:48:57 -0700
committerTuowen Zhao <ztuowen@gmail.com>2019-11-10 23:48:57 -0700
commit5ae0da8484744859e09fad869b44dccdb5f66f2f (patch)
treeb62667158c83f3b8ab0b68c2726a56157b6f8d71
parentfe22bcb209bde62cf333232c3765d9c4836c37dd (diff)
downloadsycltest-5ae0da8484744859e09fad869b44dccdb5f66f2f.tar.gz
sycltest-5ae0da8484744859e09fad869b44dccdb5f66f2f.tar.bz2
sycltest-5ae0da8484744859e09fad869b44dccdb5f66f2f.zip
Add subgroup suggestion to sycl
-rw-r--r--main.cpp47
1 files changed, 32 insertions, 15 deletions
diff --git a/main.cpp b/main.cpp
index 469ebac..059d8d0 100644
--- a/main.cpp
+++ b/main.cpp
@@ -12,17 +12,17 @@
#define TILEI 16
#define ITER 10
-std::string demangle(const char* name) {
+std::string demangle(const char *name) {
int status = -4; // some arbitrary value to eliminate the compiler warning
// enable c++11 by passing the flag -std=c++11 to g++
- std::unique_ptr<char, void(*)(void*)> res {
+ std::unique_ptr<char, void (*)(void *)> res{
abi::__cxa_demangle(name, NULL, NULL, &status),
std::free
};
- return (status==0) ? res.get() : name ;
+ return (status == 0) ? res.get() : name;
}
using namespace cl::sycl;
@@ -31,25 +31,41 @@ template<typename T>
class subgr;
template<typename T>
+class SGfunctor {
+public:
+ SGfunctor(accessor<uint32_t, 1, access::mode::write, access::target::global_buffer> sg_sz,
+ accessor<T, 1, access::mode::write, access::target::global_buffer> sg_i) : sg_sz(sg_sz), sg_i(sg_i) {}
+
+ [[cl::intel_reqd_sub_group_size(16)]]
+ void operator()(nd_item<1> NdItem) {
+ intel::sub_group SG = NdItem.get_sub_group();
+ uint32_t wggid = NdItem.get_global_id(0);
+ uint32_t sgid = SG.get_local_id().get(0);
+ if (wggid == 0)
+ sg_sz[0] = SG.get_max_local_range()[0];
+ sgid = SG.shuffle_up(sgid, 2);
+ sg_i[wggid] = sgid;
+ }
+
+private:
+ accessor<T, 1, access::mode::write, access::target::global_buffer> sg_i;
+ accessor<uint32_t, 1, access::mode::write, access::target::global_buffer> sg_sz;
+};
+
+template<typename T>
void subkrnl(device &Device, size_t G = 256, size_t L = 64) {
buffer<uint32_t> sg_buf{1};
buffer<T> sg_info{G};
queue Queue(Device);
nd_range<1> NumOfWorkItems(G, L);
- Queue.submit([&](handler &cgh) {
+ auto kernel = [&](handler &cgh) -> void {
auto sg_sz = sg_buf.template get_access<access::mode::write>(cgh);
auto sg_i = sg_info.template get_access<access::mode::write>(cgh);
+ SGfunctor<T> sGfunctor(sg_sz, sg_i);
- cgh.parallel_for<subgr<T>>(NumOfWorkItems, [=](nd_item<1> NdItem) {
- intel::sub_group SG = NdItem.get_sub_group();
- uint32_t wggid = NdItem.get_global_id(0);
- uint32_t sgid = SG.get_local_id().get(0);
- if (wggid == 0)
- sg_sz[0] = SG.get_max_local_range()[0];
- sgid = SG.shuffle_up(sgid, 2);
- sg_i[wggid] = sgid;
- });
- });
+ cgh.parallel_for<subgr<T>>(NumOfWorkItems, sGfunctor);
+ };
+ Queue.submit(kernel);
Queue.wait();
auto sg_sz = sg_buf.template get_access<access::mode::read>();
@@ -133,7 +149,8 @@ void run27pt(device &Device) {
const auto out_h = out_buf.get_access<access::mode::read>();
const auto in_h = in_buf.get_access<access::mode::read>();
double ed = omp_get_wtime();
- double elapsed = (ed_event.get_profiling_info<info::event_profiling::command_end>() - st_event.get_profiling_info<info::event_profiling::command_start>()) * 1e-9;
+ double elapsed = (ed_event.get_profiling_info<info::event_profiling::command_end>() -
+ st_event.get_profiling_info<info::event_profiling::command_start>()) * 1e-9;
std::cout << "elapsed: " << (ed - st) / ITER << std::endl;
std::cout << "elapsed: " << elapsed << std::endl;
std::cout << "flops: " << N * N * N * 53.0 * ITER / elapsed * 1e-9 << std::endl;