aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/random_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/random_op.cc')
-rw-r--r--tensorflow/core/kernels/random_op.cc192
1 files changed, 192 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/random_op.cc b/tensorflow/core/kernels/random_op.cc
index 80b1be8d4c..e78f8e2621 100644
--- a/tensorflow/core/kernels/random_op.cc
+++ b/tensorflow/core/kernels/random_op.cc
@@ -48,6 +48,9 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+#endif // TENSORFLOW_USE_SYCL
namespace functor {
using random::PhiloxRandom;
@@ -549,4 +552,193 @@ TF_CALL_int64(REGISTER_INT);
#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+
+namespace functor {
+
+using namespace cl;
+
+template <class Distribution, bool VariableSamplesPerOutput>
+struct FillPhiloxRandomKernel;
+
+template <class Distribution>
+struct FillPhiloxRandomKernel<Distribution, false> {
+ typedef typename Distribution::ResultElementType T;
+ using write_accessor = sycl::accessor<uint8_t, 1, sycl::access::mode::write, sycl::access::target::global_buffer>;
+
+ FillPhiloxRandomKernel(write_accessor& data, random::PhiloxRandom& gen, Distribution& dist)
+ : data_(data),
+ gen_(gen),
+ dist_(dist) {
+ }
+
+ void operator()(sycl::nd_item<1> item) {
+ const size_t kGroupSize = Distribution::kResultElementCount;
+
+ const size_t item_id = item.get_global(0);
+ const size_t total_item_count = item.get_global_range(0);
+ size_t offset = item_id * kGroupSize;
+ gen_.Skip(item_id);
+
+ const size_t size = data_.get_size() / sizeof(T);
+ T* data = ConvertToActualTypeSycl(T, data_);
+
+ while (offset + kGroupSize <= size) {
+ const typename Distribution::ResultType samples = dist_(&gen_);
+ for (size_t i = 0; i < kGroupSize; ++i) {
+ data[offset + i] = samples[i];
+ }
+
+ offset += (total_item_count - 1) * kGroupSize;
+ gen_.Skip(total_item_count - 1);
+ }
+
+ const typename Distribution::ResultType samples = dist_(&gen_);
+ for (size_t i = 0; i < kGroupSize; ++i) {
+ if (offset >= size) {
+ return;
+ }
+ data[offset] = samples[i];
+ ++offset;
+ }
+ }
+
+ private:
+ write_accessor data_;
+ random::PhiloxRandom gen_;
+ Distribution dist_;
+};
+
+
+template <class Distribution>
+struct FillPhiloxRandomKernel<Distribution, true> {
+ typedef typename Distribution::ResultElementType T;
+ using write_accessor = sycl::accessor<uint8_t, 1, sycl::access::mode::write, sycl::access::target::global_buffer>;
+
+ FillPhiloxRandomKernel(write_accessor& data, random::PhiloxRandom& gen, Distribution& dist)
+ : data_(data),
+ gen_(gen),
+ dist_(dist) {
+ }
+
+ void operator()(sycl::nd_item<1> item) {
+ using random::PhiloxRandom;
+ using random::SingleSampleAdapter;
+
+ const size_t kReservedSamplesPerOutput = 256;
+ const size_t kGroupSize = Distribution::kResultElementCount;
+ const size_t kGeneratorSkipPerOutputGroup = kGroupSize *
+ kReservedSamplesPerOutput /
+ PhiloxRandom::kResultElementCount;
+
+ const size_t item_id = item.get_global(0);
+ const size_t total_item_count = item.get_global_range(0);
+ size_t group_index = item_id;
+ size_t offset = group_index * kGroupSize;
+
+ T* data = ConvertToActualTypeSycl(T, data_);
+ const size_t size = data_.get_size() / sizeof(T);
+
+ while (offset < size) {
+ // Since each output takes a variable number of samples, we need to
+ // realign the generator to the beginning for the current output group
+ PhiloxRandom gen = gen_;
+ gen.Skip(group_index * kGeneratorSkipPerOutputGroup);
+ SingleSampleAdapter<PhiloxRandom> single_samples(&gen);
+
+ const typename Distribution::ResultType samples = dist_(&single_samples);
+
+ for (size_t i = 0; i < kGroupSize; ++i) {
+ if (offset >= size) {
+ return;
+ }
+ data[offset] = samples[i];
+ ++offset;
+ }
+
+ offset += (total_item_count - 1) * kGroupSize;
+ group_index += total_item_count;
+ }
+ }
+
+ private:
+ write_accessor data_;
+ random::PhiloxRandom gen_;
+ Distribution dist_;
+};
+
+template <typename T>
+class FillRandomKernel;
+// Partial specialization for SYCL to fill the entire region with randoms
+// It splits the work into several tasks and run them in parallel
+template <class Distribution>
+void FillPhiloxRandom<SYCLDevice, Distribution>::operator()(
+ OpKernelContext* context, const SYCLDevice& device, random::PhiloxRandom gen,
+ typename Distribution::ResultElementType* data, int64 size,
+ Distribution dist) {
+
+ const size_t group_size = device.maxSyclThreadsPerBlock();
+ const size_t group_count = (size + group_size - 1) / group_size;
+
+ auto buffer = device.get_sycl_buffer(data);
+
+ device.sycl_queue().submit([&](sycl::handler& cgh) {
+ auto access = buffer.template get_access<sycl::access::mode::write>(cgh);
+
+ FillPhiloxRandomKernel<Distribution, Distribution::kVariableSamplesPerOutput> task(access, gen, dist);
+ cgh.parallel_for<class FillRandomKernel<Distribution>>(
+ sycl::nd_range<1>(sycl::range<1>(group_count * group_size), sycl::range<1>(group_size)),
+ task
+ );
+ });
+}
+
+}
+
+#define REGISTER(TYPE) \
+ template struct functor::FillPhiloxRandom< \
+ SYCLDevice, random::UniformDistribution<random::PhiloxRandom, TYPE> >; \
+ REGISTER_KERNEL_BUILDER( \
+ Name("RandomUniform") \
+ .Device(DEVICE_SYCL) \
+ .HostMemory("shape") \
+ .TypeConstraint<TYPE>("dtype"), \
+ PhiloxRandomOp<SYCLDevice, random::UniformDistribution< \
+ random::PhiloxRandom, TYPE> >); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("RandomStandardNormal") \
+ .Device(DEVICE_SYCL) \
+ .HostMemory("shape") \
+ .TypeConstraint<TYPE>("dtype"), \
+ PhiloxRandomOp<SYCLDevice, random::NormalDistribution< \
+ random::PhiloxRandom, TYPE> >); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("TruncatedNormal") \
+ .Device(DEVICE_SYCL) \
+ .HostMemory("shape") \
+ .TypeConstraint<TYPE>("dtype"), \
+ PhiloxRandomOp< \
+ SYCLDevice, \
+ random::TruncatedNormalDistribution< \
+ random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >);
+
+#define REGISTER_INT(IntType) \
+ REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \
+ .Device(DEVICE_SYCL) \
+ .HostMemory("shape") \
+ .HostMemory("minval") \
+ .HostMemory("maxval") \
+ .TypeConstraint<IntType>("Tout"), \
+ RandomUniformIntOp<SYCLDevice, IntType>);
+
+TF_CALL_float(REGISTER);
+TF_CALL_double(REGISTER);
+TF_CALL_int32(REGISTER_INT);
+TF_CALL_int64(REGISTER_INT);
+
+#undef REGISTER
+#undef REGISTER_INT
+
+#endif // TENSORFLOW_USE_SYCL
+
} // end namespace tensorflow