aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoitsteiner@users.noreply.github.com>2017-02-21 11:00:19 -0800
committerGravatar Vijay Vasudevan <vrv@google.com>2017-02-21 11:00:19 -0800
commit2c8d0dca978a246f54c506aae4587dbce5d3bcf0 (patch)
tree9efcc4097cce2224d5cd0bb83698d52d5a5a5819
parent43c71a03380d8de18202cc399563814b2f438cd2 (diff)
OpenCL Improvements (#7596)
* OpenCL improvements Added Tile, Transpose and Range Ops double support for SYCL device. Moved gpu_device_name() to test_util.py so now it can be used in force_gpu to pull either GPU or SYCL depending on what is available in the system. * Improvements to the SYCL device support - Registration of Type Traits required for stride slice op - Registration of ConcatOffset, _ListToArray, _ArrayToList Pad, Reverse ( CPU ), ReverseV2 ( CPU ), Size, ExpandDims, Squeeze, StridedSlice, StridedSliceGrad, StridedSliceAssign, TileGrad, InvertPermutation, Transpose - Registration of Sycl kernels only for essential data types - Floor_div_real has been disabled for SYCL device - Device in control_flow_ops_py_test.py needed to be lower cased * SYCL support improvements (#31) * Improvements to the SYCL device support This commit reduces number of failing tests when TensorFlow compiles for OpenCL support. - Registration of Type Traits required for stride slice op - Registration of ConcatOffset, _ListToArray, _ArrayToList Pad, Reverse ( CPU ), ReverseV2 ( CPU ), Size, ExpandDims, Squeeze, StridedSlice, StridedSliceGrad, StridedSliceAssign, TileGrad, InvertPermutation, Transpose - Registration of Sycl kernels only for essential data types - Floor_div_real has been disabled for SYCL device - Device in control_flow_ops_py_test.py needed to be lower cased * Fixes & Version bump (#33) * Fix Unbuntu typo. (#38) unbuntu -> ubuntu * Add problem descriptions and solutions (#35) * Add ComputeCpp lib folder to LD_LIBRARY_PATH * Add ImportError problem + solution If you get the error message "ImportError: libComputeCpp.so: cannot open shared object file: No such file or directory", make sure you have added the path to ComputeCpp's lib folder to your `LD_LIBRARY_PATH`. * Add another ImportError problem + solution If you get the error message "ImportError: cannot import name 'pywrap_tensorflow'" you may be standing in the TensorFlow directory. * Improvements to the SYCL device support * Registers FloorDiv, FloorMod and SoftMax Ops for SYCL device * Workaround for 0 bytes allocation for SYCL device (#42) * Sycl improvements (#44) - Eigen version bump - Extends Cast and Cwise ops benchmark to cover Sycl device - Extends device_lib_test.py to cover Sycl device - Registers int32, string and ResourceHandler to run on host for Enter and RefEnter Sycl Ops - Enables RecudeMax op for Sycl since Eigen implementation is ready - Registers Less op for Sycl device * Improved the formatting of the SYCL code * Fixed compilation error. * Made sure that using test sessions with force_gpu=True forces the placement on a gpu device even if none is detected.
-rw-r--r--tensorflow/core/BUILD4
-rw-r--r--tensorflow/core/common_runtime/sycl/sycl_allocator.cc7
-rw-r--r--tensorflow/core/common_runtime/sycl/sycl_allocator.h11
-rw-r--r--tensorflow/core/common_runtime/sycl/sycl_device.cc2
-rw-r--r--tensorflow/core/common_runtime/sycl/sycl_device.h9
-rw-r--r--tensorflow/core/common_runtime/sycl/sycl_device_context.cc256
-rw-r--r--tensorflow/core/common_runtime/sycl/sycl_device_context.h6
-rw-r--r--tensorflow/core/common_runtime/sycl/sycl_device_factory.cc12
-rw-r--r--tensorflow/core/framework/register_types_traits.h19
-rw-r--r--tensorflow/core/kernels/cast_op.cc56
-rw-r--r--tensorflow/core/kernels/cast_op_impl.h29
-rw-r--r--tensorflow/core/kernels/cast_op_impl_bool.cc10
-rw-r--r--tensorflow/core/kernels/cast_op_impl_double.cc10
-rw-r--r--tensorflow/core/kernels/cast_op_impl_float.cc10
-rw-r--r--tensorflow/core/kernels/cast_op_impl_int32.cc10
-rw-r--r--tensorflow/core/kernels/cast_op_impl_int64.cc12
-rw-r--r--tensorflow/core/kernels/cast_op_test.cc14
-rw-r--r--tensorflow/core/kernels/concat_lib.h8
-rw-r--r--tensorflow/core/kernels/concat_lib_cpu.cc19
-rw-r--r--tensorflow/core/kernels/concat_lib_cpu.h35
-rw-r--r--tensorflow/core/kernels/concat_op.cc50
-rw-r--r--tensorflow/core/kernels/constant_op.cc18
-rw-r--r--tensorflow/core/kernels/control_flow_ops.cc24
-rw-r--r--tensorflow/core/kernels/cwise_op_acos.cc1
-rw-r--r--tensorflow/core/kernels/cwise_op_add_1.cc15
-rw-r--r--tensorflow/core/kernels/cwise_op_asin.cc1
-rw-r--r--tensorflow/core/kernels/cwise_op_atan.cc1
-rw-r--r--tensorflow/core/kernels/cwise_op_ceil.cc1
-rw-r--r--tensorflow/core/kernels/cwise_op_cos.cc1
-rw-r--r--tensorflow/core/kernels/cwise_op_div.cc12
-rw-r--r--tensorflow/core/kernels/cwise_op_equal_to_1.cc12
-rw-r--r--tensorflow/core/kernels/cwise_op_expm1.cc3
-rw-r--r--tensorflow/core/kernels/cwise_op_floor.cc1
-rw-r--r--tensorflow/core/kernels/cwise_op_floor_div.cc21
-rw-r--r--tensorflow/core/kernels/cwise_op_floor_mod.cc10
-rw-r--r--tensorflow/core/kernels/cwise_op_greater.cc14
-rw-r--r--tensorflow/core/kernels/cwise_op_greater_equal.cc11
-rw-r--r--tensorflow/core/kernels/cwise_op_isfinite.cc1
-rw-r--r--tensorflow/core/kernels/cwise_op_isinf.cc1
-rw-r--r--tensorflow/core/kernels/cwise_op_isnan.cc1
-rw-r--r--tensorflow/core/kernels/cwise_op_less.cc10
-rw-r--r--tensorflow/core/kernels/cwise_op_less_equal.cc12
-rw-r--r--tensorflow/core/kernels/cwise_op_log.cc1
-rw-r--r--tensorflow/core/kernels/cwise_op_log1p.cc1
-rw-r--r--tensorflow/core/kernels/cwise_op_maximum.cc15
-rw-r--r--tensorflow/core/kernels/cwise_op_minimum.cc12
-rw-r--r--tensorflow/core/kernels/cwise_op_mul_1.cc8
-rw-r--r--tensorflow/core/kernels/cwise_op_pow.cc1
-rw-r--r--tensorflow/core/kernels/cwise_op_reciprocal.cc6
-rw-r--r--tensorflow/core/kernels/cwise_op_round.cc4
-rw-r--r--tensorflow/core/kernels/cwise_op_rsqrt.cc1
-rw-r--r--tensorflow/core/kernels/cwise_op_select.cc59
-rw-r--r--tensorflow/core/kernels/cwise_op_sigmoid.cc6
-rw-r--r--tensorflow/core/kernels/cwise_op_sign.cc13
-rw-r--r--tensorflow/core/kernels/cwise_op_sin.cc1
-rw-r--r--tensorflow/core/kernels/cwise_op_sqrt.cc3
-rw-r--r--tensorflow/core/kernels/cwise_op_square.cc1
-rw-r--r--tensorflow/core/kernels/cwise_op_tan.cc1
-rw-r--r--tensorflow/core/kernels/cwise_op_tanh.cc1
-rw-r--r--tensorflow/core/kernels/cwise_ops_gradients.h15
-rw-r--r--tensorflow/core/kernels/cwise_ops_test.cc51
-rw-r--r--tensorflow/core/kernels/debug_ops.cc2
-rw-r--r--tensorflow/core/kernels/dense_update_ops.cc1
-rw-r--r--tensorflow/core/kernels/fill_functor.cc2
-rw-r--r--tensorflow/core/kernels/function_ops.cc28
-rw-r--r--tensorflow/core/kernels/matmul_op.cc53
-rw-r--r--tensorflow/core/kernels/pack_op.cc1
-rw-r--r--tensorflow/core/kernels/pad_op.cc29
-rw-r--r--tensorflow/core/kernels/reduction_ops_common.h25
-rw-r--r--tensorflow/core/kernels/reduction_ops_max.cc23
-rw-r--r--tensorflow/core/kernels/reduction_ops_mean.cc13
-rw-r--r--tensorflow/core/kernels/reduction_ops_min.cc23
-rw-r--r--tensorflow/core/kernels/reduction_ops_prod.cc24
-rw-r--r--tensorflow/core/kernels/reduction_ops_sum.cc1
-rw-r--r--tensorflow/core/kernels/relu_op.cc29
-rw-r--r--tensorflow/core/kernels/relu_op.h4
-rw-r--r--tensorflow/core/kernels/reverse_op.cc35
-rw-r--r--tensorflow/core/kernels/scatter_op.cc4
-rw-r--r--tensorflow/core/kernels/sequence_ops.cc9
-rw-r--r--tensorflow/core/kernels/shape_ops.cc83
-rw-r--r--tensorflow/core/kernels/softmax_op.cc24
-rw-r--r--tensorflow/core/kernels/stage_op.cc6
-rw-r--r--tensorflow/core/kernels/strided_slice_op.cc67
-rw-r--r--tensorflow/core/kernels/strided_slice_op_impl.h14
-rw-r--r--tensorflow/core/kernels/tile_ops.cc31
-rw-r--r--tensorflow/core/kernels/tile_ops_cpu_impl.h4
-rw-r--r--tensorflow/core/kernels/training_ops.cc2
-rw-r--r--tensorflow/core/kernels/transpose_functor_cpu.cc1
-rw-r--r--tensorflow/core/kernels/transpose_op.cc29
-rw-r--r--tensorflow/core/kernels/transpose_op.h11
-rw-r--r--tensorflow/core/kernels/unpack_op.cc1
-rw-r--r--tensorflow/core/kernels/variable_ops.cc3
-rw-r--r--tensorflow/core/kernels/xent_op.cc26
-rw-r--r--tensorflow/core/ops/math_grad_test.cc8
-rw-r--r--tensorflow/g3doc/get_started/os_setup.md20
-rw-r--r--tensorflow/python/client/device_lib_test.py2
-rw-r--r--tensorflow/python/framework/test_util.py21
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py2
-rw-r--r--tensorflow/python/kernel_tests/stage_op_test.py6
-rw-r--r--tensorflow/python/platform/test.py10
-rwxr-xr-xthird_party/sycl/crosstool/computecpp.tpl33
-rw-r--r--tools/bazel.rc.template2
102 files changed, 1451 insertions, 221 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index a98e0b7fd8..e558e6e80a 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -116,6 +116,8 @@ load(
"//third_party/mkl:build_defs.bzl",
"if_mkl",
)
+load("@local_config_sycl//sycl:build_defs.bzl", "if_sycl")
+
# -----------------------------------------------------------------------------
# Public targets
@@ -729,7 +731,7 @@ cc_library(
"//tensorflow/core/kernels:ops_testutil",
"//tensorflow/core/kernels:ops_util",
"//tensorflow/core/platform/default/build_config:gtest",
- ],
+ ] + if_sycl([":sycl_runtime"]),
)
# This is a link-only library to provide a DirectSession
diff --git a/tensorflow/core/common_runtime/sycl/sycl_allocator.cc b/tensorflow/core/common_runtime/sycl/sycl_allocator.cc
index 0d238276f4..b7ef9361e9 100644
--- a/tensorflow/core/common_runtime/sycl/sycl_allocator.cc
+++ b/tensorflow/core/common_runtime/sycl/sycl_allocator.cc
@@ -25,6 +25,9 @@ string SYCLAllocator::Name() { return "device:SYCL"; }
void *SYCLAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
assert(device_);
+ if (num_bytes == 0) {
+ return device_->allocate(1);
+ }
auto p = device_->allocate(num_bytes);
return p;
}
@@ -42,6 +45,6 @@ void SYCLAllocator::EnterLameDuckMode() {
}
}
-} // namespace tensorflow
+} // namespace tensorflow
-#endif // TENSORFLOW_USE_SYCL
+#endif // TENSORFLOW_USE_SYCL
diff --git a/tensorflow/core/common_runtime/sycl/sycl_allocator.h b/tensorflow/core/common_runtime/sycl/sycl_allocator.h
index c896f7f603..15d9ab41a4 100644
--- a/tensorflow/core/common_runtime/sycl/sycl_allocator.h
+++ b/tensorflow/core/common_runtime/sycl/sycl_allocator.h
@@ -27,8 +27,8 @@ limitations under the License.
namespace tensorflow {
class SYCLAllocator : public Allocator {
-public:
- SYCLAllocator(Eigen::QueueInterface* device) : device_(device) {}
+ public:
+ SYCLAllocator(Eigen::QueueInterface *device) : device_(device) {}
virtual ~SYCLAllocator() override;
string Name() override;
void *AllocateRaw(size_t alignment, size_t num_bytes) override;
@@ -36,11 +36,12 @@ public:
void EnterLameDuckMode();
virtual bool ShouldAllocateEmptyTensors() override final { return true; }
-private:
+
+ private:
Eigen::QueueInterface *device_; // not owned
TF_DISALLOW_COPY_AND_ASSIGN(SYCLAllocator);
};
-} // namespace tensorflow
+} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_SYCL_SYCL_ALLOCATOR_H_
+#endif // TENSORFLOW_COMMON_RUNTIME_SYCL_SYCL_ALLOCATOR_H_
diff --git a/tensorflow/core/common_runtime/sycl/sycl_device.cc b/tensorflow/core/common_runtime/sycl/sycl_device.cc
index 0abe25c373..2c2185b2c0 100644
--- a/tensorflow/core/common_runtime/sycl/sycl_device.cc
+++ b/tensorflow/core/common_runtime/sycl/sycl_device.cc
@@ -23,7 +23,7 @@ limitations under the License.
namespace tensorflow {
-static std::unordered_set<SYCLDevice*> live_devices;
+static std::unordered_set<SYCLDevice *> live_devices;
static bool first_time = true;
void ShutdownSycl() {
diff --git a/tensorflow/core/common_runtime/sycl/sycl_device.h b/tensorflow/core/common_runtime/sycl/sycl_device.h
index b5a72d9476..a5c7c5f0ec 100644
--- a/tensorflow/core/common_runtime/sycl/sycl_device.h
+++ b/tensorflow/core/common_runtime/sycl/sycl_device.h
@@ -34,10 +34,11 @@ class SYCLDevice : public LocalDevice {
Bytes memory_limit, const DeviceLocality &locality,
const string &physical_device_desc, SYCLSelector sycl_selector,
Allocator *cpu_allocator)
- : LocalDevice(options, Device::BuildDeviceAttributes(
- name, DEVICE_SYCL, memory_limit, locality,
- physical_device_desc),
- nullptr),
+ : LocalDevice(
+ options,
+ Device::BuildDeviceAttributes(name, DEVICE_SYCL, memory_limit,
+ locality, physical_device_desc),
+ nullptr),
cpu_allocator_(cpu_allocator),
sycl_queue_(new Eigen::QueueInterface(sycl_selector)),
sycl_device_(new Eigen::SyclDevice(sycl_queue_)),
diff --git a/tensorflow/core/common_runtime/sycl/sycl_device_context.cc b/tensorflow/core/common_runtime/sycl/sycl_device_context.cc
index a6be9195d4..1c868f5606 100644
--- a/tensorflow/core/common_runtime/sycl/sycl_device_context.cc
+++ b/tensorflow/core/common_runtime/sycl/sycl_device_context.cc
@@ -17,8 +17,8 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "tensorflow/core/common_runtime/sycl/sycl_device_context.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/common_runtime/sycl/sycl_device_context.h"
namespace tensorflow {
@@ -31,68 +31,68 @@ void SYCLDeviceContext::CopyCPUTensorToDevice(const Tensor *cpu_tensor,
const void *src_ptr = DMAHelper::base(cpu_tensor);
void *dst_ptr = DMAHelper::base(device_tensor);
switch (cpu_tensor->dtype()) {
- case DT_FLOAT:
- device->eigen_sycl_device()->memcpyHostToDevice(
- static_cast<float *>(dst_ptr), static_cast<const float *>(src_ptr),
- total_bytes);
- break;
- case DT_DOUBLE:
- device->eigen_sycl_device()->memcpyHostToDevice(
- static_cast<double *>(dst_ptr), static_cast<const double *>(src_ptr),
- total_bytes);
- break;
- case DT_INT32:
- device->eigen_sycl_device()->memcpyHostToDevice(
- static_cast<int32 *>(dst_ptr), static_cast<const int32 *>(src_ptr),
- total_bytes);
- break;
- case DT_INT64:
- device->eigen_sycl_device()->memcpyHostToDevice(
- static_cast<int64 *>(dst_ptr), static_cast<const int64 *>(src_ptr),
- total_bytes);
- break;
- case DT_HALF:
- device->eigen_sycl_device()->memcpyHostToDevice(
- static_cast<Eigen::half *>(dst_ptr),
- static_cast<const Eigen::half *>(src_ptr), total_bytes);
- break;
- case DT_COMPLEX64:
- device->eigen_sycl_device()->memcpyHostToDevice(
- static_cast<std::complex<float> *>(dst_ptr),
- static_cast<const std::complex<float> *>(src_ptr), total_bytes);
- break;
- case DT_COMPLEX128:
- device->eigen_sycl_device()->memcpyHostToDevice(
- static_cast<std::complex<double> *>(dst_ptr),
- static_cast<const std::complex<double> *>(src_ptr), total_bytes);
- break;
- case DT_INT8:
- device->eigen_sycl_device()->memcpyHostToDevice(
- static_cast<int8 *>(dst_ptr), static_cast<const int8 *>(src_ptr),
- total_bytes);
- break;
- case DT_INT16:
- device->eigen_sycl_device()->memcpyHostToDevice(
- static_cast<int16 *>(dst_ptr), static_cast<const int16 *>(src_ptr),
- total_bytes);
- break;
- case DT_UINT8:
- device->eigen_sycl_device()->memcpyHostToDevice(
- static_cast<uint8 *>(dst_ptr), static_cast<const uint8 *>(src_ptr),
- total_bytes);
- break;
- case DT_UINT16:
- device->eigen_sycl_device()->memcpyHostToDevice(
- static_cast<uint16 *>(dst_ptr), static_cast<const uint16 *>(src_ptr),
- total_bytes);
- break;
- case DT_BOOL:
- device->eigen_sycl_device()->memcpyHostToDevice(
- static_cast<bool *>(dst_ptr), static_cast<const bool *>(src_ptr),
- total_bytes);
- break;
- default:
- assert(false && "unsupported type");
+ case DT_FLOAT:
+ device->eigen_sycl_device()->memcpyHostToDevice(
+ static_cast<float *>(dst_ptr), static_cast<const float *>(src_ptr),
+ total_bytes);
+ break;
+ case DT_DOUBLE:
+ device->eigen_sycl_device()->memcpyHostToDevice(
+ static_cast<double *>(dst_ptr),
+ static_cast<const double *>(src_ptr), total_bytes);
+ break;
+ case DT_INT32:
+ device->eigen_sycl_device()->memcpyHostToDevice(
+ static_cast<int32 *>(dst_ptr), static_cast<const int32 *>(src_ptr),
+ total_bytes);
+ break;
+ case DT_INT64:
+ device->eigen_sycl_device()->memcpyHostToDevice(
+ static_cast<int64 *>(dst_ptr), static_cast<const int64 *>(src_ptr),
+ total_bytes);
+ break;
+ case DT_HALF:
+ device->eigen_sycl_device()->memcpyHostToDevice(
+ static_cast<Eigen::half *>(dst_ptr),
+ static_cast<const Eigen::half *>(src_ptr), total_bytes);
+ break;
+ case DT_COMPLEX64:
+ device->eigen_sycl_device()->memcpyHostToDevice(
+ static_cast<std::complex<float> *>(dst_ptr),
+ static_cast<const std::complex<float> *>(src_ptr), total_bytes);
+ break;
+ case DT_COMPLEX128:
+ device->eigen_sycl_device()->memcpyHostToDevice(
+ static_cast<std::complex<double> *>(dst_ptr),
+ static_cast<const std::complex<double> *>(src_ptr), total_bytes);
+ break;
+ case DT_INT8:
+ device->eigen_sycl_device()->memcpyHostToDevice(
+ static_cast<int8 *>(dst_ptr), static_cast<const int8 *>(src_ptr),
+ total_bytes);
+ break;
+ case DT_INT16:
+ device->eigen_sycl_device()->memcpyHostToDevice(
+ static_cast<int16 *>(dst_ptr), static_cast<const int16 *>(src_ptr),
+ total_bytes);
+ break;
+ case DT_UINT8:
+ device->eigen_sycl_device()->memcpyHostToDevice(
+ static_cast<uint8 *>(dst_ptr), static_cast<const uint8 *>(src_ptr),
+ total_bytes);
+ break;
+ case DT_UINT16:
+ device->eigen_sycl_device()->memcpyHostToDevice(
+ static_cast<uint16 *>(dst_ptr),
+ static_cast<const uint16 *>(src_ptr), total_bytes);
+ break;
+ case DT_BOOL:
+ device->eigen_sycl_device()->memcpyHostToDevice(
+ static_cast<bool *>(dst_ptr), static_cast<const bool *>(src_ptr),
+ total_bytes);
+ break;
+ default:
+ assert(false && "unsupported type");
}
}
device->eigen_sycl_device()->synchronize();
@@ -106,71 +106,71 @@ void SYCLDeviceContext::CopyDeviceTensorToCPU(const Tensor *device_tensor,
StatusCallback done) {
const int64 total_bytes = device_tensor->TotalBytes();
if (total_bytes > 0) {
- const void* src_ptr = DMAHelper::base(device_tensor);
- void* dst_ptr = DMAHelper::base(cpu_tensor);
+ const void *src_ptr = DMAHelper::base(device_tensor);
+ void *dst_ptr = DMAHelper::base(cpu_tensor);
switch (device_tensor->dtype()) {
- case DT_FLOAT:
- device->eigen_sycl_device()->memcpyDeviceToHost(
- static_cast<float *>(dst_ptr), static_cast<const float *>(src_ptr),
- total_bytes);
- break;
- case DT_DOUBLE:
- device->eigen_sycl_device()->memcpyDeviceToHost(
- static_cast<double *>(dst_ptr), static_cast<const double *>(src_ptr),
- total_bytes);
- break;
- case DT_INT32:
- device->eigen_sycl_device()->memcpyDeviceToHost(
- static_cast<int32 *>(dst_ptr), static_cast<const int32 *>(src_ptr),
- total_bytes);
- break;
- case DT_INT64:
- device->eigen_sycl_device()->memcpyDeviceToHost(
- static_cast<int64 *>(dst_ptr), static_cast<const int64 *>(src_ptr),
- total_bytes);
- break;
- case DT_HALF:
- device->eigen_sycl_device()->memcpyDeviceToHost(
- static_cast<Eigen::half *>(dst_ptr),
- static_cast<const Eigen::half *>(src_ptr), total_bytes);
- break;
- case DT_COMPLEX64:
- device->eigen_sycl_device()->memcpyDeviceToHost(
- static_cast<std::complex<float> *>(dst_ptr),
- static_cast<const std::complex<float> *>(src_ptr), total_bytes);
- break;
- case DT_COMPLEX128:
- device->eigen_sycl_device()->memcpyDeviceToHost(
- static_cast<std::complex<double> *>(dst_ptr),
- static_cast<const std::complex<double> *>(src_ptr), total_bytes);
- break;
- case DT_INT8:
- device->eigen_sycl_device()->memcpyDeviceToHost(
- static_cast<int8 *>(dst_ptr), static_cast<const int8 *>(src_ptr),
- total_bytes);
- break;
- case DT_INT16:
- device->eigen_sycl_device()->memcpyDeviceToHost(
- static_cast<int16 *>(dst_ptr), static_cast<const int16 *>(src_ptr),
- total_bytes);
- break;
- case DT_UINT8:
- device->eigen_sycl_device()->memcpyDeviceToHost(
- static_cast<uint8 *>(dst_ptr), static_cast<const uint8 *>(src_ptr),
- total_bytes);
- break;
- case DT_UINT16:
- device->eigen_sycl_device()->memcpyDeviceToHost(
- static_cast<uint16 *>(dst_ptr), static_cast<const uint16 *>(src_ptr),
- total_bytes);
- break;
- case DT_BOOL:
- device->eigen_sycl_device()->memcpyDeviceToHost(
- static_cast<bool *>(dst_ptr), static_cast<const bool *>(src_ptr),
- total_bytes);
- break;
- default:
- assert(false && "unsupported type");
+ case DT_FLOAT:
+ device->eigen_sycl_device()->memcpyDeviceToHost(
+ static_cast<float *>(dst_ptr), static_cast<const float *>(src_ptr),
+ total_bytes);
+ break;
+ case DT_DOUBLE:
+ device->eigen_sycl_device()->memcpyDeviceToHost(
+ static_cast<double *>(dst_ptr),
+ static_cast<const double *>(src_ptr), total_bytes);
+ break;
+ case DT_INT32:
+ device->eigen_sycl_device()->memcpyDeviceToHost(
+ static_cast<int32 *>(dst_ptr), static_cast<const int32 *>(src_ptr),
+ total_bytes);
+ break;
+ case DT_INT64:
+ device->eigen_sycl_device()->memcpyDeviceToHost(
+ static_cast<int64 *>(dst_ptr), static_cast<const int64 *>(src_ptr),
+ total_bytes);
+ break;
+ case DT_HALF:
+ device->eigen_sycl_device()->memcpyDeviceToHost(
+ static_cast<Eigen::half *>(dst_ptr),
+ static_cast<const Eigen::half *>(src_ptr), total_bytes);
+ break;
+ case DT_COMPLEX64:
+ device->eigen_sycl_device()->memcpyDeviceToHost(
+ static_cast<std::complex<float> *>(dst_ptr),
+ static_cast<const std::complex<float> *>(src_ptr), total_bytes);
+ break;
+ case DT_COMPLEX128:
+ device->eigen_sycl_device()->memcpyDeviceToHost(
+ static_cast<std::complex<double> *>(dst_ptr),
+ static_cast<const std::complex<double> *>(src_ptr), total_bytes);
+ break;
+ case DT_INT8:
+ device->eigen_sycl_device()->memcpyDeviceToHost(
+ static_cast<int8 *>(dst_ptr), static_cast<const int8 *>(src_ptr),
+ total_bytes);
+ break;
+ case DT_INT16:
+ device->eigen_sycl_device()->memcpyDeviceToHost(
+ static_cast<int16 *>(dst_ptr), static_cast<const int16 *>(src_ptr),
+ total_bytes);
+ break;
+ case DT_UINT8:
+ device->eigen_sycl_device()->memcpyDeviceToHost(
+ static_cast<uint8 *>(dst_ptr), static_cast<const uint8 *>(src_ptr),
+ total_bytes);
+ break;
+ case DT_UINT16:
+ device->eigen_sycl_device()->memcpyDeviceToHost(
+ static_cast<uint16 *>(dst_ptr),
+ static_cast<const uint16 *>(src_ptr), total_bytes);
+ break;
+ case DT_BOOL:
+ device->eigen_sycl_device()->memcpyDeviceToHost(
+ static_cast<bool *>(dst_ptr), static_cast<const bool *>(src_ptr),
+ total_bytes);
+ break;
+ default:
+ assert(false && "unsupported type");
}
}
device->eigen_sycl_device()->synchronize();
@@ -178,4 +178,4 @@ void SYCLDeviceContext::CopyDeviceTensorToCPU(const Tensor *device_tensor,
}
} // namespace tensorflow
-#endif // TENSORFLOW_USE_SYCL
+#endif // TENSORFLOW_USE_SYCL
diff --git a/tensorflow/core/common_runtime/sycl/sycl_device_context.h b/tensorflow/core/common_runtime/sycl/sycl_device_context.h
index 1f7ad543d9..0f8f17b805 100644
--- a/tensorflow/core/common_runtime/sycl/sycl_device_context.h
+++ b/tensorflow/core/common_runtime/sycl/sycl_device_context.h
@@ -26,7 +26,7 @@ limitations under the License.
namespace tensorflow {
class SYCLDeviceContext : public DeviceContext {
-public:
+ public:
SYCLDeviceContext() {}
~SYCLDeviceContext() override {}
@@ -40,6 +40,6 @@ public:
StatusCallback done) override;
};
-} // namespace tensorflow
+} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_SYCL_SYCL_DEVICE_CONTEXT_H_
+#endif // TENSORFLOW_COMMON_RUNTIME_SYCL_SYCL_DEVICE_CONTEXT_H_
diff --git a/tensorflow/core/common_runtime/sycl/sycl_device_factory.cc b/tensorflow/core/common_runtime/sycl/sycl_device_factory.cc
index 51eb4973d8..a643fc7258 100644
--- a/tensorflow/core/common_runtime/sycl/sycl_device_factory.cc
+++ b/tensorflow/core/common_runtime/sycl/sycl_device_factory.cc
@@ -21,7 +21,7 @@ limitations under the License.
namespace tensorflow {
class SYCLDeviceFactory : public DeviceFactory {
-public:
+ public:
Status CreateDevices(const SessionOptions &options, const string &name_prefix,
std::vector<Device *> *devices) override {
int n = 1;
@@ -31,10 +31,10 @@ public:
}
for (int i = 0; i < n; i++) {
string name = strings::StrCat(name_prefix, "/device:SYCL:", i);
- devices->push_back(new SYCLDevice(options, name, Bytes(256 << 20),
- DeviceLocality(),
- SYCLDevice::GetShortDeviceDescription(),
- cl::sycl::gpu_selector(), cpu_allocator()));
+ devices->push_back(
+ new SYCLDevice(options, name, Bytes(256 << 20), DeviceLocality(),
+ SYCLDevice::GetShortDeviceDescription(),
+ cl::sycl::gpu_selector(), cpu_allocator()));
}
return Status::OK();
}
@@ -43,4 +43,4 @@ public:
REGISTER_LOCAL_DEVICE_FACTORY("SYCL", SYCLDeviceFactory, 200);
}
-#endif // TENSORFLOW_USE_SYCL
+#endif // TENSORFLOW_USE_SYCL
diff --git a/tensorflow/core/framework/register_types_traits.h b/tensorflow/core/framework/register_types_traits.h
index 8f8d9fd08e..c1fe5517c6 100644
--- a/tensorflow/core/framework/register_types_traits.h
+++ b/tensorflow/core/framework/register_types_traits.h
@@ -21,6 +21,10 @@ limitations under the License.
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+#endif // TENSORFLOW_USE_SYCL
+
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/platform/types.h"
@@ -66,6 +70,17 @@ struct proxy_type_pod<GPUDevice, 2> {
typedef Eigen::half type;
};
+#ifdef TENSORFLOW_USE_SYCL
+template <>
+struct proxy_type_pod<SYCLDevice, 8> {
+ typedef double type;
+};
+template <>
+struct proxy_type_pod<SYCLDevice, 4> {
+ typedef float type;
+};
+#endif // TENSORFLOW_USE_SYCL
+
/// If POD we use proxy_type_pod, otherwise this maps to identiy.
template <typename Device, typename T>
struct proxy_type {
@@ -81,6 +96,10 @@ struct proxy_type {
TF_CALL_int8(m) TF_CALL_complex128(m)
#define TF_CALL_GPU_PROXY_TYPES(m) \
TF_CALL_double(m) TF_CALL_float(m) TF_CALL_half(m) TF_CALL_int32(m)
+#ifdef TENSORFLOW_USE_SYCL
+#define TF_CALL_SYCL_PROXY_TYPES(m) \
+ TF_CALL_double(m) TF_CALL_float(m) TF_CALL_int32(m)
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
#endif // TENSORFLOW_FRAMEWORK_REGISTER_TYPES_TRAITS_H_
diff --git a/tensorflow/core/kernels/cast_op.cc b/tensorflow/core/kernels/cast_op.cc
index ab82c247d6..562934ed63 100644
--- a/tensorflow/core/kernels/cast_op.cc
+++ b/tensorflow/core/kernels/cast_op.cc
@@ -34,6 +34,9 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+#endif // TENSORFLOW_USE_SYCL
#define CURRY_TYPES2(FN, arg0) \
FN(arg0, bool); \
@@ -206,6 +209,52 @@ REGISTER_CAST_GPU(bfloat16, float);
#undef REGISTER_CAST_GPU
#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+class SyclCastOp : public CastOpBase {
+ public:
+ explicit SyclCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) {
+ OP_REQUIRES_OK(ctx, Prepare());
+ }
+
+ private:
+ Status Prepare() {
+ if (src_dtype_ == dst_dtype_) {
+ work_ = nullptr; // Identity
+ return Status::OK();
+ }
+ if (src_dtype_ == DT_BOOL) {
+ work_ = GetSyclCastFromBool(dst_dtype_);
+ } else if (src_dtype_ == DT_INT32) {
+ work_ = GetSyclCastFromInt32(dst_dtype_);
+ } else if (src_dtype_ == DT_INT64) {
+ work_ = GetSyclCastFromInt64(dst_dtype_);
+ } else if (src_dtype_ == DT_FLOAT) {
+ work_ = GetSyclCastFromFloat(dst_dtype_);
+ } else if (src_dtype_ == DT_DOUBLE) {
+ work_ = GetSyclCastFromDouble(dst_dtype_);
+ }
+
+ return work_ == nullptr ? Unimplemented() : Status::OK();
+ }
+};
+
+#define REGISTER_CAST_SYCL(srctype, dsttype) \
+ REGISTER_KERNEL_BUILDER(Name("Cast") \
+ .TypeConstraint<srctype>("SrcT") \
+ .TypeConstraint<dsttype>("DstT") \
+ .Device(DEVICE_SYCL), \
+ SyclCastOp)
+
+CURRY_TYPES2(REGISTER_CAST_SYCL, bool);
+CURRY_TYPES2(REGISTER_CAST_SYCL, int32);
+CURRY_TYPES2(REGISTER_CAST_SYCL, int64);
+CURRY_TYPES2(REGISTER_CAST_SYCL, float);
+CURRY_TYPES2(REGISTER_CAST_SYCL, double);
+
+#undef REGISTER_CAST_SYCL
+
+#endif // TENSORFLOW_USE_SYCL
+
#undef CURRY_TYPES2
// HostCast differs from Cast in that its input and output are in host memory.
@@ -213,5 +262,10 @@ REGISTER_KERNEL_BUILDER(Name("_HostCast").Device(DEVICE_CPU), CpuCastOp);
REGISTER_KERNEL_BUILDER(
Name("_HostCast").Device(DEVICE_GPU).HostMemory("x").HostMemory("y"),
CpuCastOp);
-
+#ifdef TENSORFLOW_USE_SYCL
+REGISTER_KERNEL_BUILDER(
+ Name("_HostCast").Device(DEVICE_SYCL).HostMemory("x").HostMemory("y"),
+ CpuCastOp);
+#endif // TENSORFLOW_USE_SYCL
} // end namespace tensorflow
+
diff --git a/tensorflow/core/kernels/cast_op_impl.h b/tensorflow/core/kernels/cast_op_impl.h
index cb7cc81937..1ee0796ac1 100644
--- a/tensorflow/core/kernels/cast_op_impl.h
+++ b/tensorflow/core/kernels/cast_op_impl.h
@@ -33,6 +33,16 @@ struct CastFunctor<Eigen::ThreadPoolDevice, O, I> {
}
};
+#ifdef TENSORFLOW_USE_SYCL
+template <typename O, typename I>
+struct CastFunctor<Eigen::SyclDevice, O, I> {
+ void operator()(const Eigen::SyclDevice& d, typename TTypes<O>::Flat o,
+ typename TTypes<I>::ConstFlat i) {
+ o.device(d) = i.template cast<O>();
+ }
+};
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace functor
#define CURRY_TYPES3(FN, arg0, arg1) \
@@ -140,6 +150,25 @@ GetGpuCastFromBfloat(DataType dst_dtype);
#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetSyclCastFromBool(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetSyclCastFromInt32(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetSyclCastFromInt64(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetSyclCastFromFloat(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetSyclCastFromDouble(DataType dst_dtype);
+
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_
+
diff --git a/tensorflow/core/kernels/cast_op_impl_bool.cc b/tensorflow/core/kernels/cast_op_impl_bool.cc
index 92fee89a47..a13f163009 100644
--- a/tensorflow/core/kernels/cast_op_impl_bool.cc
+++ b/tensorflow/core/kernels/cast_op_impl_bool.cc
@@ -34,4 +34,14 @@ GetGpuCastFromBool(DataType dst_dtype) {
}
#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetSyclCastFromBool(DataType dst_dtype) {
+ CURRY_TYPES3(CAST_CASE, SYCLDevice, bool);
+ return nullptr;
+}
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace tensorflow
+
diff --git a/tensorflow/core/kernels/cast_op_impl_double.cc b/tensorflow/core/kernels/cast_op_impl_double.cc
index fd20061d21..fdc8d51158 100644
--- a/tensorflow/core/kernels/cast_op_impl_double.cc
+++ b/tensorflow/core/kernels/cast_op_impl_double.cc
@@ -34,4 +34,14 @@ GetGpuCastFromDouble(DataType dst_dtype) {
}
#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetSyclCastFromDouble(DataType dst_dtype) {
+ CURRY_TYPES3(CAST_CASE, SYCLDevice, double);
+ return nullptr;
+}
+#endif // TENSORFLOW_USE_SYC
+
} // namespace tensorflow
+
diff --git a/tensorflow/core/kernels/cast_op_impl_float.cc b/tensorflow/core/kernels/cast_op_impl_float.cc
index 71e63fbff0..1241dcd8f2 100644
--- a/tensorflow/core/kernels/cast_op_impl_float.cc
+++ b/tensorflow/core/kernels/cast_op_impl_float.cc
@@ -49,4 +49,14 @@ GetGpuCastFromFloat(DataType dst_dtype) {
}
#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetSyclCastFromFloat(DataType dst_dtype) {
+ CURRY_TYPES3(CAST_CASE, SYCLDevice, float);
+ return nullptr;
+}
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace tensorflow
+
diff --git a/tensorflow/core/kernels/cast_op_impl_int32.cc b/tensorflow/core/kernels/cast_op_impl_int32.cc
index 0fc6e16afe..fca9cd60ec 100644
--- a/tensorflow/core/kernels/cast_op_impl_int32.cc
+++ b/tensorflow/core/kernels/cast_op_impl_int32.cc
@@ -34,4 +34,14 @@ GetGpuCastFromInt32(DataType dst_dtype) {
}
#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetSyclCastFromInt32(DataType dst_dtype) {
+ CURRY_TYPES3(CAST_CASE, CPUDevice, int32);
+ return nullptr;
+}
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace tensorflow
+
diff --git a/tensorflow/core/kernels/cast_op_impl_int64.cc b/tensorflow/core/kernels/cast_op_impl_int64.cc
index b5571b19a5..c0a543708d 100644
--- a/tensorflow/core/kernels/cast_op_impl_int64.cc
+++ b/tensorflow/core/kernels/cast_op_impl_int64.cc
@@ -19,6 +19,9 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+#endif // TENSORFLOW_USE_SYCL
std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
GetCpuCastFromInt64(DataType dst_dtype) {
@@ -34,4 +37,13 @@ GetGpuCastFromInt64(DataType dst_dtype) {
}
#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetSyclCastFromInt64(DataType dst_dtype) {
+ CURRY_TYPES3(CAST_CASE, SYCLDevice, int64);
+ return nullptr;
+}
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cast_op_test.cc b/tensorflow/core/kernels/cast_op_test.cc
index 5b7529bb8a..a106f287c1 100644
--- a/tensorflow/core/kernels/cast_op_test.cc
+++ b/tensorflow/core/kernels/cast_op_test.cc
@@ -105,7 +105,12 @@ static void BM_gpu_float_int64(int iters, int num) {
testing::BytesProcessed(static_cast<int64>(iters) * num *
(sizeof(float) + sizeof(int64)));
testing::UseRealTime();
+#if GOOGLE_CUDA
test::Benchmark("gpu", Cast<float, int64>(num)).Run(iters);
+#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+ test::Benchmark("sycl", Cast<float, int64>(num)).Run(iters);
+#endif // TENSORFLOW_USE_SYCL
}
BENCHMARK(BM_gpu_float_int64)->Arg(64 << 10)->Arg(32 << 20);
@@ -123,7 +128,12 @@ static void BM_gpu_bool_float(int iters, int num) {
testing::BytesProcessed(static_cast<int64>(iters) * num *
(sizeof(bool) + sizeof(float)));
testing::UseRealTime();
+#if GOOGLE_CUDA
test::Benchmark("gpu", Cast<bool, float>(num)).Run(iters);
+#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+ test::Benchmark("sycl", Cast<bool, float>(num)).Run(iters);
+#endif // TENSORFLOW_USE_SYCL
}
BENCHMARK(BM_gpu_bool_float)->Arg(64 << 10)->Arg(32 << 20);
@@ -168,7 +178,9 @@ static void BM_gpu_float_half(int iters, int num) {
testing::BytesProcessed(static_cast<int64>(iters) * num *
(sizeof(float) + sizeof(Eigen::half)));
testing::UseRealTime();
+#if GOOGLE_CUDA
test::Benchmark("gpu", Cast<float, Eigen::half>(num)).Run(iters);
+#endif // GOOGLE_CUDA
}
BENCHMARK(BM_gpu_float_half)->Arg(64 << 10)->Arg(32 << 20);
@@ -177,7 +189,9 @@ static void BM_gpu_half_float(int iters, int num) {
testing::BytesProcessed(static_cast<int64>(iters) * num *
(sizeof(float) + sizeof(Eigen::half)));
testing::UseRealTime();
+#if GOOGLE_CUDA
test::Benchmark("gpu", Cast<Eigen::half, float>(num)).Run(iters);
+#endif // GOOGLE_CUDA
}
BENCHMARK(BM_gpu_half_float)->Arg(64 << 10)->Arg(32 << 20);
diff --git a/tensorflow/core/kernels/concat_lib.h b/tensorflow/core/kernels/concat_lib.h
index cef873f804..14e6e1bc32 100644
--- a/tensorflow/core/kernels/concat_lib.h
+++ b/tensorflow/core/kernels/concat_lib.h
@@ -38,6 +38,14 @@ void ConcatGPU(
Tensor* output, typename TTypes<T, 2>::Tensor* output_flat);
#endif // GOOGLE_CUDA
+
+#ifdef TENSORFLOW_USE_SYCL
+template <typename T>
+void ConcatSYCL(const Eigen::SyclDevice& d,
+ const std::vector<
+ std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& inputs,
+ typename TTypes<T, 2>::Matrix* output);
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
#endif // TENSORFLOW_KERNELS_CONCAT_LIB_H_
diff --git a/tensorflow/core/kernels/concat_lib_cpu.cc b/tensorflow/core/kernels/concat_lib_cpu.cc
index f83aed6aef..f89948350c 100644
--- a/tensorflow/core/kernels/concat_lib_cpu.cc
+++ b/tensorflow/core/kernels/concat_lib_cpu.cc
@@ -74,4 +74,23 @@ REGISTER(qint16)
REGISTER(qint32)
REGISTER(bfloat16)
+#ifdef TENSORFLOW_USE_SYCL
+template <typename T>
+void ConcatSYCL(const Eigen::SyclDevice& d,
+ const std::vector<
+ std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& inputs,
+ typename TTypes<T, 2>::Matrix* output) {
+ ConcatSYCLImpl<T>(d, inputs, sizeof(T) /* cost_per_unit */, MemCpyCopier<T>(),
+ output);
+}
+#define REGISTER_SYCL(T) \
+ template void ConcatSYCL<T>( \
+ const Eigen::SyclDevice&, \
+ const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>&, \
+ typename TTypes<T, 2>::Matrix* output);
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_SYCL)
+
+#undef REGISTER_SYCL
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/concat_lib_cpu.h b/tensorflow/core/kernels/concat_lib_cpu.h
index 9d37cafb4e..6a933efde4 100644
--- a/tensorflow/core/kernels/concat_lib_cpu.h
+++ b/tensorflow/core/kernels/concat_lib_cpu.h
@@ -126,4 +126,39 @@ void ConcatCPUImpl(
cost_per_unit, work);
}
+#ifdef TENSORFLOW_USE_SYCL
+template <typename T, typename ElementCopier>
+void ConcatSYCLImpl(
+ const Eigen::SyclDevice& d,
+ const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>&
+ inputs,
+ int64 cost_per_unit, ElementCopier copier,
+ typename TTypes<T, 2>::Matrix* output) {
+ size_t num_inputs = inputs.size();
+
+ std::vector<ptrdiff_t> sizes;
+ sizes.reserve(num_inputs);
+ int64 row_size = 0;
+ for (const auto& input : inputs) {
+ sizes.push_back(input->dimension(1));
+ row_size += sizes.back();
+ }
+
+ T* out = &(*output)(0, 0);
+ std::vector<const T*> inp;
+ inp.reserve(num_inputs);
+ for (const auto& input : inputs) {
+ inp.push_back(&(*input)(0, 0));
+ }
+ const int64 dim0 = output->dimension(0);
+ for (int64 i = 0; i < dim0; ++i) {
+ for (int64 j = 0; j < num_inputs; ++j) {
+ auto size = sizes[j];
+ d.memcpy(out, inp[j], size * sizeof(T));
+ out += size;
+ inp[j] += size;
+ }
+ }
+}
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/concat_op.cc b/tensorflow/core/kernels/concat_op.cc
index e6dae5fa7e..9628a7efa4 100644
--- a/tensorflow/core/kernels/concat_op.cc
+++ b/tensorflow/core/kernels/concat_op.cc
@@ -35,6 +35,9 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
#if GOOGLE_CUDA
typedef Eigen::GpuDevice GPUDevice;
#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+#endif // TENSORFLOW_USE_SYCL
enum AxisArgumentName { NAME_IS_AXIS, NAME_IS_CONCAT_DIM };
@@ -134,6 +137,12 @@ class ConcatBaseOp : public OpKernel {
return;
}
#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+ if (std::is_same<Device, SYCLDevice>::value) {
+ ConcatSYCL<T>(c->eigen_sycl_device(), inputs_flat, &output_flat);
+ return;
+ }
+#endif // TENSORFLOW_USE_SYCL
ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
}
}
@@ -207,6 +216,39 @@ REGISTER_KERNEL_BUILDER(Name("ConcatV2")
#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL(type) \
+ REGISTER_KERNEL_BUILDER(Name("Concat") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("concat_dim"), \
+ ConcatOp<SYCLDevice, type>) \
+ REGISTER_KERNEL_BUILDER(Name("ConcatV2") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("Tidx") \
+ .HostMemory("axis"), \
+ ConcatV2Op<SYCLDevice, type>)
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_SYCL);
+REGISTER_KERNEL_BUILDER(Name("Concat")
+ .Device(DEVICE_SYCL)
+ .TypeConstraint<int32>("T")
+ .HostMemory("concat_dim")
+ .HostMemory("values")
+ .HostMemory("output"),
+ ConcatOp<CPUDevice, int32>);
+REGISTER_KERNEL_BUILDER(Name("ConcatV2")
+ .Device(DEVICE_SYCL)
+ .TypeConstraint<int32>("T")
+ .TypeConstraint<int32>("Tidx")
+ .HostMemory("values")
+ .HostMemory("axis")
+ .HostMemory("output"),
+ ConcatV2Op<CPUDevice, int32>);
+#undef REGISTER_SYCL
+#endif // TENSORFLOW_USE_SYCL
+
class ConcatOffsetOp : public OpKernel {
public:
explicit ConcatOffsetOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
@@ -293,4 +335,12 @@ REGISTER_KERNEL_BUILDER(Name("ConcatOffset")
.HostMemory("offset"),
ConcatOffsetOp);
+#ifdef TENSORFLOW_USE_SYCL
+REGISTER_KERNEL_BUILDER(Name("ConcatOffset")
+ .Device(DEVICE_SYCL)
+ .HostMemory("concat_dim")
+ .HostMemory("shape")
+ .HostMemory("offset"),
+ ConcatOffsetOp);
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc
index 306736fe54..a0f89f2abd 100644
--- a/tensorflow/core/kernels/constant_op.cc
+++ b/tensorflow/core/kernels/constant_op.cc
@@ -57,7 +57,10 @@ REGISTER_KERNEL_BUILDER(Name("Const").Device(DEVICE_CPU), ConstantOp);
REGISTER_KERNEL_BUILDER( \
Name("Const").Device(DEVICE_SYCL).TypeConstraint<TYPE>("dtype"), \
ConstantOp);
-TF_CALL_NUMBER_TYPES(REGISTER_SYCL_KERNEL);
+REGISTER_SYCL_KERNEL(float);
+REGISTER_SYCL_KERNEL(double);
+REGISTER_SYCL_KERNEL(bool);
+REGISTER_SYCL_KERNEL(int64);
#undef REGISTER_SYCL_KERNEL
#endif
@@ -112,6 +115,17 @@ REGISTER_KERNEL_BUILDER(Name("Const")
HostConstantOp);
#endif
+#ifdef TENSORFLOW_USE_SYCL
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("Const")
+ .Device(DEVICE_SYCL)
+ .HostMemory("output")
+ .TypeConstraint<int32>("dtype"),
+ HostConstantOp);
+#endif // TENSORFLOW_USE_SYCL
+
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
#ifdef TENSORFLOW_USE_SYCL
@@ -186,6 +200,7 @@ REGISTER_KERNEL(CPU, quint8);
#ifdef TENSORFLOW_USE_SYCL
REGISTER_KERNEL(SYCL, float)
+REGISTER_KERNEL(SYCL, double)
REGISTER_KERNEL_BUILDER(Name("Fill")
.Device(DEVICE_SYCL)
.TypeConstraint<int32>("T")
@@ -245,6 +260,7 @@ TF_CALL_POD_STRING_TYPES(REGISTER_CPU);
#ifdef TENSORFLOW_USE_SYCL
REGISTER_KERNEL(float, SYCL);
+REGISTER_KERNEL(bool, SYCL);
REGISTER_KERNEL_BUILDER(Name("ZerosLike")
.Device(DEVICE_SYCL)
.TypeConstraint<int32>("T")
diff --git a/tensorflow/core/kernels/control_flow_ops.cc b/tensorflow/core/kernels/control_flow_ops.cc
index 1a73a3d0f8..6a79be5a95 100644
--- a/tensorflow/core/kernels/control_flow_ops.cc
+++ b/tensorflow/core/kernels/control_flow_ops.cc
@@ -321,6 +321,30 @@ TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_KERNEL);
#undef REGISTER_SYCL_KERNEL
#undef REGISTER_SYCL_REF_KERNEL
+#define REGISTER_SYCL_HOST_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("Enter") \
+ .Device(DEVICE_SYCL) \
+ .HostMemory("data") \
+ .HostMemory("output") \
+ .TypeConstraint<type>("T"), \
+ EnterOp)
+
+#define REGISTER_SYCL_HOST_REF_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("RefEnter") \
+ .Device(DEVICE_SYCL) \
+ .HostMemory("data") \
+ .HostMemory("output") \
+ .TypeConstraint<type>("T"), \
+ EnterOp)
+
+REGISTER_SYCL_HOST_KERNEL(int32);
+REGISTER_SYCL_HOST_REF_KERNEL(int32);
+REGISTER_SYCL_HOST_KERNEL(string);
+REGISTER_SYCL_HOST_REF_KERNEL(string);
+REGISTER_SYCL_HOST_KERNEL(ResourceHandle);
+
+#undef REGISTER_SYCL_HOST_KERNEL
+#undef REGISTER_SYCL_HOST_REF_KERNEL
#endif
// Special GPU kernels for int32 and string.
diff --git a/tensorflow/core/kernels/cwise_op_acos.cc b/tensorflow/core/kernels/cwise_op_acos.cc
index 1d2d815027..65801da3c7 100644
--- a/tensorflow/core/kernels/cwise_op_acos.cc
+++ b/tensorflow/core/kernels/cwise_op_acos.cc
@@ -26,6 +26,7 @@ REGISTER2(UnaryOp, CPU, "Acos", functor::acos, float, double);
.TypeConstraint<TYPE>("T"), \
UnaryOp<SYCLDevice, functor::acos<TYPE>>);
REGISTER_SYCL_KERNEL(float);
+REGISTER_SYCL_KERNEL(double);
#undef REGISTER_SYCL_KERNEL
#endif // TENSORFLOW_USE_SYCL
diff --git a/tensorflow/core/kernels/cwise_op_add_1.cc b/tensorflow/core/kernels/cwise_op_add_1.cc
index a6bff78694..f6e9b59cf8 100644
--- a/tensorflow/core/kernels/cwise_op_add_1.cc
+++ b/tensorflow/core/kernels/cwise_op_add_1.cc
@@ -18,7 +18,7 @@ limitations under the License.
namespace tensorflow {
REGISTER5(BinaryOp, CPU, "Add", functor::add, float, Eigen::half, double, int32,
int64);
-
+
#if TENSORFLOW_USE_SYCL
#define REGISTER_SYCL_KERNEL(TYPE) \
REGISTER_KERNEL_BUILDER( \
@@ -26,10 +26,19 @@ REGISTER5(BinaryOp, CPU, "Add", functor::add, float, Eigen::half, double, int32,
.Device(DEVICE_SYCL) \
.TypeConstraint<TYPE>("T"), \
BinaryOp<SYCLDevice, functor::add<TYPE>>);
- REGISTER_SYCL_KERNEL(float);
+REGISTER_SYCL_KERNEL(float);
+REGISTER_SYCL_KERNEL(double);
#undef REGISTER_SYCL_KERNEL
+
+REGISTER_KERNEL_BUILDER(Name("Add")
+ .Device(DEVICE_SYCL)
+ .HostMemory("x")
+ .HostMemory("y")
+ .HostMemory("z")
+ .TypeConstraint<int32>("T"),
+ BinaryOp<CPUDevice, functor::add<int32>>);
#endif // TENSORFLOW_USE_SYCL
-
+
#if GOOGLE_CUDA
REGISTER3(BinaryOp, GPU, "Add", functor::add, float, Eigen::half, double);
diff --git a/tensorflow/core/kernels/cwise_op_asin.cc b/tensorflow/core/kernels/cwise_op_asin.cc
index 92a22e90c4..c9ebfe759b 100644
--- a/tensorflow/core/kernels/cwise_op_asin.cc
+++ b/tensorflow/core/kernels/cwise_op_asin.cc
@@ -26,6 +26,7 @@ REGISTER2(UnaryOp, CPU, "Asin", functor::asin, float, double);
.TypeConstraint<TYPE>("T"), \
UnaryOp<SYCLDevice, functor::asin<TYPE>>);
REGISTER_SYCL_KERNEL(float);
+REGISTER_SYCL_KERNEL(double);
#undef REGISTER_SYCL_KERNEL
#endif // TENSORFLOW_USE_SYCL
diff --git a/tensorflow/core/kernels/cwise_op_atan.cc b/tensorflow/core/kernels/cwise_op_atan.cc
index 825e85283f..72645b303f 100644
--- a/tensorflow/core/kernels/cwise_op_atan.cc
+++ b/tensorflow/core/kernels/cwise_op_atan.cc
@@ -26,6 +26,7 @@ REGISTER2(UnaryOp, CPU, "Atan", functor::atan, float, double);
.TypeConstraint<TYPE>("T"), \
UnaryOp<SYCLDevice, functor::atan<TYPE>>);
REGISTER_SYCL_KERNEL(float);
+REGISTER_SYCL_KERNEL(double);
#undef REGISTER_SYCL_KERNEL
#endif // TENSORFLOW_USE_SYCL
diff --git a/tensorflow/core/kernels/cwise_op_ceil.cc b/tensorflow/core/kernels/cwise_op_ceil.cc
index c5a4aaf831..c74e10576d 100644
--- a/tensorflow/core/kernels/cwise_op_ceil.cc
+++ b/tensorflow/core/kernels/cwise_op_ceil.cc
@@ -26,6 +26,7 @@ REGISTER3(UnaryOp, CPU, "Ceil", functor::ceil, float, Eigen::half, double);
.TypeConstraint<TYPE>("T"), \
UnaryOp<SYCLDevice, functor::ceil<TYPE>>);
REGISTER_SYCL_KERNEL(float);
+REGISTER_SYCL_KERNEL(double);
#undef REGISTER_SYCL_KERNEL
#endif // TENSORFLOW_USE_SYCL
diff --git a/tensorflow/core/kernels/cwise_op_cos.cc b/tensorflow/core/kernels/cwise_op_cos.cc
index a758da5842..634c90adc6 100644
--- a/tensorflow/core/kernels/cwise_op_cos.cc
+++ b/tensorflow/core/kernels/cwise_op_cos.cc
@@ -27,6 +27,7 @@ REGISTER5(UnaryOp, CPU, "Cos", functor::cos, float, Eigen::half, double,
.TypeConstraint<TYPE>("T"), \
UnaryOp<SYCLDevice, functor::cos<TYPE>>);
REGISTER_SYCL_KERNEL(float);
+REGISTER_SYCL_KERNEL(double);
#undef REGISTER_SYCL_KERNEL
#endif // TENSORFLOW_USE_SYCL
diff --git a/tensorflow/core/kernels/cwise_op_div.cc b/tensorflow/core/kernels/cwise_op_div.cc
index 74d8faedb5..1e2300832f 100644
--- a/tensorflow/core/kernels/cwise_op_div.cc
+++ b/tensorflow/core/kernels/cwise_op_div.cc
@@ -37,8 +37,18 @@ REGISTER5(BinaryOp, CPU, "RealDiv", functor::div, float, Eigen::half, double,
.TypeConstraint<TYPE>("T"), \
BinaryOp<SYCLDevice, functor::div<TYPE>>);
REGISTER_SYCL_KERNEL(float)
-REGISTER_SYCL_KERNEL(int32)
+REGISTER_SYCL_KERNEL(double)
#undef REGISTER_SYCL_KERNEL
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("Div")
+ .Device(DEVICE_SYCL)
+ .HostMemory("x")
+ .HostMemory("y")
+ .HostMemory("z")
+ .TypeConstraint<int32>("T"),
+ BinaryOp<CPUDevice, functor::safe_div<int32>>);
#endif // TENSORFLOW_USE_SYCL
#if GOOGLE_CUDA
REGISTER9(BinaryOp, GPU, "Div", functor::div, float, Eigen::half, double, uint8,
diff --git a/tensorflow/core/kernels/cwise_op_equal_to_1.cc b/tensorflow/core/kernels/cwise_op_equal_to_1.cc
index 7bd44abd39..93ea768836 100644
--- a/tensorflow/core/kernels/cwise_op_equal_to_1.cc
+++ b/tensorflow/core/kernels/cwise_op_equal_to_1.cc
@@ -34,4 +34,16 @@ REGISTER_KERNEL_BUILDER(Name("Equal")
BinaryOp<CPUDevice, functor::equal_to<int32>>);
#endif
+#ifdef TENSORFLOW_USE_SYCL
+REGISTER2(BinaryOp, SYCL, "Equal", functor::equal_to, float, double);
+
+REGISTER_KERNEL_BUILDER(Name("Equal")
+ .Device(DEVICE_SYCL)
+ .HostMemory("x")
+ .HostMemory("y")
+ .HostMemory("z")
+ .TypeConstraint<int32>("T"),
+ BinaryOp<CPUDevice, functor::equal_to<int32>>);
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_expm1.cc b/tensorflow/core/kernels/cwise_op_expm1.cc
index f1c53ca272..5573c2bcc2 100644
--- a/tensorflow/core/kernels/cwise_op_expm1.cc
+++ b/tensorflow/core/kernels/cwise_op_expm1.cc
@@ -21,4 +21,7 @@ REGISTER5(UnaryOp, CPU, "Expm1", functor::expm1, float, Eigen::half, double,
#if GOOGLE_CUDA
REGISTER3(UnaryOp, GPU, "Expm1", functor::expm1, float, Eigen::half, double);
#endif
+#ifdef TENSORFLOW_USE_SYCL
+REGISTER(UnaryOp, SYCL, "Expm1", functor::expm1, float);
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_floor.cc b/tensorflow/core/kernels/cwise_op_floor.cc
index 129d754b82..59e32d7f6f 100644
--- a/tensorflow/core/kernels/cwise_op_floor.cc
+++ b/tensorflow/core/kernels/cwise_op_floor.cc
@@ -26,6 +26,7 @@ REGISTER3(UnaryOp, CPU, "Floor", functor::floor, float, Eigen::half, double);
.TypeConstraint<TYPE>("T"), \
UnaryOp<SYCLDevice, functor::floor<TYPE>>);
REGISTER_SYCL_KERNEL(float);
+REGISTER_SYCL_KERNEL(double);
#undef REGISTER_SYCL_KERNEL
#endif // TENSORFLOW_USE_SYCL
diff --git a/tensorflow/core/kernels/cwise_op_floor_div.cc b/tensorflow/core/kernels/cwise_op_floor_div.cc
index 8a600f8f95..fa81ef0872 100644
--- a/tensorflow/core/kernels/cwise_op_floor_div.cc
+++ b/tensorflow/core/kernels/cwise_op_floor_div.cc
@@ -21,17 +21,6 @@ REGISTER5(BinaryOp, CPU, "FloorDiv", functor::safe_floor_div, uint8, uint16,
REGISTER3(BinaryOp, CPU, "FloorDiv", functor::floor_div_real, float,
Eigen::half, double);
-#if TENSORFLOW_USE_SYCL
-#define REGISTER_SYCL_KERNEL(TYPE) \
- REGISTER_KERNEL_BUILDER( \
- Name("FloorDiv") \
- .Device(DEVICE_SYCL) \
- .TypeConstraint<TYPE>("T"), \
- BinaryOp<SYCLDevice, functor::floor_div_real<TYPE>>);
-REGISTER_SYCL_KERNEL(float)
-#undef REGISTER_SYCL_KERNEL
-#endif // TENSORFLOW_USE_SYCL
-
#if GOOGLE_CUDA
REGISTER4(BinaryOp, GPU, "FloorDiv", functor::floor_div, uint8, uint16, int16,
int64);
@@ -51,4 +40,14 @@ REGISTER_KERNEL_BUILDER(Name("FloorDiv")
.TypeConstraint<int32>("T"),
BinaryOp<CPUDevice, functor::safe_floor_div<int32>>);
#endif
+
+#ifdef TENSORFLOW_USE_SYCL
+REGISTER_KERNEL_BUILDER(Name("FloorDiv")
+ .Device(DEVICE_SYCL)
+ .HostMemory("x")
+ .HostMemory("y")
+ .HostMemory("z")
+ .TypeConstraint<int32>("T"),
+ BinaryOp<CPUDevice, functor::safe_floor_div<int32>>);
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_floor_mod.cc b/tensorflow/core/kernels/cwise_op_floor_mod.cc
index 4e641a8bb3..55f8a30461 100644
--- a/tensorflow/core/kernels/cwise_op_floor_mod.cc
+++ b/tensorflow/core/kernels/cwise_op_floor_mod.cc
@@ -31,4 +31,14 @@ REGISTER_KERNEL_BUILDER(Name("FloorMod")
.TypeConstraint<int32>("T"),
BinaryOp<CPUDevice, functor::safe_floor_mod<int32>>);
#endif
+
+#ifdef TENSORFLOW_USE_SYCL
+REGISTER_KERNEL_BUILDER(Name("FloorMod")
+ .Device(DEVICE_SYCL)
+ .HostMemory("x")
+ .HostMemory("y")
+ .HostMemory("z")
+ .TypeConstraint<int32>("T"),
+ BinaryOp<CPUDevice, functor::safe_floor_mod<int32>>);
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_greater.cc b/tensorflow/core/kernels/cwise_op_greater.cc
index 8c9691d1ea..6b5a806aa2 100644
--- a/tensorflow/core/kernels/cwise_op_greater.cc
+++ b/tensorflow/core/kernels/cwise_op_greater.cc
@@ -33,5 +33,19 @@ REGISTER_KERNEL_BUILDER(Name("Greater")
.TypeConstraint<int32>("T"),
BinaryOp<CPUDevice, functor::greater<int32>>);
#endif
+#ifdef TENSORFLOW_USE_SYCL
+REGISTER(BinaryOp, SYCL, "Greater", functor::greater, float);
+
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("Greater")
+ .Device(DEVICE_SYCL)
+ .HostMemory("x")
+ .HostMemory("y")
+ .HostMemory("z")
+ .TypeConstraint<int32>("T"),
+ BinaryOp<CPUDevice, functor::greater<int32>>);
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_greater_equal.cc b/tensorflow/core/kernels/cwise_op_greater_equal.cc
index a6083cb9cd..ac21528256 100644
--- a/tensorflow/core/kernels/cwise_op_greater_equal.cc
+++ b/tensorflow/core/kernels/cwise_op_greater_equal.cc
@@ -34,4 +34,15 @@ REGISTER_KERNEL_BUILDER(Name("GreaterEqual")
BinaryOp<CPUDevice, functor::greater_equal<int32>>);
#endif
+#ifdef TENSORFLOW_USE_SYCL
+REGISTER(BinaryOp, SYCL, "GreaterEqual", functor::greater_equal, float);
+
+REGISTER_KERNEL_BUILDER(Name("GreaterEqual")
+ .Device(DEVICE_SYCL)
+ .HostMemory("x")
+ .HostMemory("y")
+ .HostMemory("z")
+ .TypeConstraint<int32>("T"),
+ BinaryOp<CPUDevice, functor::greater_equal<int32>>);
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_isfinite.cc b/tensorflow/core/kernels/cwise_op_isfinite.cc
index 59976141c7..0faeffa95c 100644
--- a/tensorflow/core/kernels/cwise_op_isfinite.cc
+++ b/tensorflow/core/kernels/cwise_op_isfinite.cc
@@ -27,6 +27,7 @@ REGISTER3(UnaryOp, CPU, "IsFinite", functor::isfinite, float, Eigen::half,
.TypeConstraint<TYPE>("T"), \
UnaryOp<SYCLDevice, functor::isfinite<TYPE>>);
REGISTER_SYCL_KERNEL(float);
+REGISTER_SYCL_KERNEL(double);
#undef REGISTER_SYCL_KERNEL
#endif // TENSORFLOW_USE_SYCL
diff --git a/tensorflow/core/kernels/cwise_op_isinf.cc b/tensorflow/core/kernels/cwise_op_isinf.cc
index 675cb95b95..df63006b3f 100644
--- a/tensorflow/core/kernels/cwise_op_isinf.cc
+++ b/tensorflow/core/kernels/cwise_op_isinf.cc
@@ -26,6 +26,7 @@ REGISTER3(UnaryOp, CPU, "IsInf", functor::isinf, float, Eigen::half, double);
.TypeConstraint<TYPE>("T"), \
UnaryOp<SYCLDevice, functor::isinf<TYPE>>);
REGISTER_SYCL_KERNEL(float);
+REGISTER_SYCL_KERNEL(double);
#undef REGISTER_SYCL_KERNEL
#endif // TENSORFLOW_USE_SYCL
diff --git a/tensorflow/core/kernels/cwise_op_isnan.cc b/tensorflow/core/kernels/cwise_op_isnan.cc
index c394087ed8..e1cf7a8637 100644
--- a/tensorflow/core/kernels/cwise_op_isnan.cc
+++ b/tensorflow/core/kernels/cwise_op_isnan.cc
@@ -26,6 +26,7 @@ REGISTER3(UnaryOp, CPU, "IsNan", functor::isnan, float, Eigen::half, double);
.TypeConstraint<TYPE>("T"), \
UnaryOp<SYCLDevice, functor::isnan<TYPE>>);
REGISTER_SYCL_KERNEL(float);
+REGISTER_SYCL_KERNEL(double);
#undef REGISTER_SYCL_KERNEL
#endif // TENSORFLOW_USE_SYCL
diff --git a/tensorflow/core/kernels/cwise_op_less.cc b/tensorflow/core/kernels/cwise_op_less.cc
index 701007d637..a38f1024a9 100644
--- a/tensorflow/core/kernels/cwise_op_less.cc
+++ b/tensorflow/core/kernels/cwise_op_less.cc
@@ -33,5 +33,15 @@ REGISTER_KERNEL_BUILDER(Name("Less")
.TypeConstraint<int32>("T"),
BinaryOp<CPUDevice, functor::less<int32>>);
#endif
+#ifdef TENSORFLOW_USE_SYCL
+REGISTER3(BinaryOp, SYCL, "Less", functor::less, float, double, int64);
+REGISTER_KERNEL_BUILDER(Name("Less")
+ .Device(DEVICE_SYCL)
+ .HostMemory("x")
+ .HostMemory("y")
+ .HostMemory("z")
+ .TypeConstraint<int32>("T"),
+ BinaryOp<CPUDevice, functor::less<int32>>);
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_less_equal.cc b/tensorflow/core/kernels/cwise_op_less_equal.cc
index 97fd1ae919..3a2cc2ae0e 100644
--- a/tensorflow/core/kernels/cwise_op_less_equal.cc
+++ b/tensorflow/core/kernels/cwise_op_less_equal.cc
@@ -34,4 +34,16 @@ REGISTER_KERNEL_BUILDER(Name("LessEqual")
BinaryOp<CPUDevice, functor::less_equal<int32>>);
#endif
+#ifdef TENSORFLOW_USE_SYCL
+REGISTER(BinaryOp, SYCL, "LessEqual", functor::less_equal, float);
+
+REGISTER_KERNEL_BUILDER(Name("LessEqual")
+ .Device(DEVICE_SYCL)
+ .HostMemory("x")
+ .HostMemory("y")
+ .HostMemory("z")
+ .TypeConstraint<int32>("T"),
+ BinaryOp<CPUDevice, functor::less_equal<int32>>);
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_log.cc b/tensorflow/core/kernels/cwise_op_log.cc
index 71c4588b3d..5e74e778c7 100644
--- a/tensorflow/core/kernels/cwise_op_log.cc
+++ b/tensorflow/core/kernels/cwise_op_log.cc
@@ -27,6 +27,7 @@ REGISTER5(UnaryOp, CPU, "Log", functor::log, float, Eigen::half, double,
.TypeConstraint<TYPE>("T"), \
UnaryOp<SYCLDevice, functor::log<TYPE>>);
REGISTER_SYCL_KERNEL(float);
+REGISTER_SYCL_KERNEL(double);
#undef REGISTER_SYCL_KERNEL
#endif // TENSORFLOW_USE_SYCL
diff --git a/tensorflow/core/kernels/cwise_op_log1p.cc b/tensorflow/core/kernels/cwise_op_log1p.cc
index 03ea3a0a89..edb821318e 100644
--- a/tensorflow/core/kernels/cwise_op_log1p.cc
+++ b/tensorflow/core/kernels/cwise_op_log1p.cc
@@ -27,6 +27,7 @@ REGISTER5(UnaryOp, CPU, "Log1p", functor::log1p, float, Eigen::half, double,
.TypeConstraint<TYPE>("T"), \
UnaryOp<SYCLDevice, functor::log1p<TYPE>>);
REGISTER_SYCL_KERNEL(float);
+REGISTER_SYCL_KERNEL(double);
#undef REGISTER_SYCL_KERNEL
#endif // TENSORFLOW_USE_SYCL
diff --git a/tensorflow/core/kernels/cwise_op_maximum.cc b/tensorflow/core/kernels/cwise_op_maximum.cc
index f93b5a8303..7311f25ec0 100644
--- a/tensorflow/core/kernels/cwise_op_maximum.cc
+++ b/tensorflow/core/kernels/cwise_op_maximum.cc
@@ -34,4 +34,19 @@ REGISTER_KERNEL_BUILDER(Name("Maximum")
BinaryOp<CPUDevice, functor::maximum<int32>>);
#endif
+#ifdef TENSORFLOW_USE_SYCL
+REGISTER(BinaryOp, SYCL, "Maximum", functor::maximum, float);
+
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("Maximum")
+ .Device(DEVICE_SYCL)
+ .HostMemory("x")
+ .HostMemory("y")
+ .HostMemory("z")
+ .TypeConstraint<int32>("T"),
+ BinaryOp<CPUDevice, functor::maximum<int32>>);
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_minimum.cc b/tensorflow/core/kernels/cwise_op_minimum.cc
index 36800975a8..99e5a76620 100644
--- a/tensorflow/core/kernels/cwise_op_minimum.cc
+++ b/tensorflow/core/kernels/cwise_op_minimum.cc
@@ -34,4 +34,16 @@ REGISTER_KERNEL_BUILDER(Name("Minimum")
BinaryOp<CPUDevice, functor::minimum<int32>>);
#endif
+#ifdef TENSORFLOW_USE_SYCL
+REGISTER(BinaryOp, SYCL, "Minimum", functor::minimum, float);
+
+REGISTER_KERNEL_BUILDER(Name("Minimum")
+ .Device(DEVICE_SYCL)
+ .HostMemory("x")
+ .HostMemory("y")
+ .HostMemory("z")
+ .TypeConstraint<int32>("T"),
+ BinaryOp<CPUDevice, functor::minimum<int32>>);
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_mul_1.cc b/tensorflow/core/kernels/cwise_op_mul_1.cc
index e23fe6761d..5273522626 100644
--- a/tensorflow/core/kernels/cwise_op_mul_1.cc
+++ b/tensorflow/core/kernels/cwise_op_mul_1.cc
@@ -28,7 +28,15 @@ REGISTER5(BinaryOp, CPU, "Mul", functor::mul, float, Eigen::half, double,
.TypeConstraint<TYPE>("T"), \
BinaryOp<SYCLDevice, functor::mul<TYPE>>);
REGISTER_SYCL_KERNEL(float)
+REGISTER_SYCL_KERNEL(double)
#undef REGISTER_SYCL_KERNEL
+REGISTER_KERNEL_BUILDER(Name("Mul")
+ .Device(DEVICE_SYCL)
+ .HostMemory("x")
+ .HostMemory("y")
+ .HostMemory("z")
+ .TypeConstraint<int32>("T"),
+ BinaryOp<CPUDevice, functor::mul<int32>>);
#endif // TENSORFLOW_USE_SYCL
#if GOOGLE_CUDA
REGISTER4(BinaryOp, GPU, "Mul", functor::mul, float, Eigen::half, double,
diff --git a/tensorflow/core/kernels/cwise_op_pow.cc b/tensorflow/core/kernels/cwise_op_pow.cc
index 8eeba6ab14..f1780168e4 100644
--- a/tensorflow/core/kernels/cwise_op_pow.cc
+++ b/tensorflow/core/kernels/cwise_op_pow.cc
@@ -27,6 +27,7 @@ REGISTER7(BinaryOp, CPU, "Pow", functor::pow, float, Eigen::half, double, int32,
.TypeConstraint<TYPE>("T"), \
BinaryOp<SYCLDevice, functor::pow<TYPE>>);
REGISTER_SYCL_KERNEL(float);
+REGISTER_SYCL_KERNEL(double);
#undef REGISTER_SYCL_KERNEL
#endif // TENSORFLOW_USE_SYCL
diff --git a/tensorflow/core/kernels/cwise_op_reciprocal.cc b/tensorflow/core/kernels/cwise_op_reciprocal.cc
index d858a077f5..8c0e21f9cf 100644
--- a/tensorflow/core/kernels/cwise_op_reciprocal.cc
+++ b/tensorflow/core/kernels/cwise_op_reciprocal.cc
@@ -36,6 +36,9 @@ REGISTER5(UnaryOp, CPU, "Reciprocal", functor::inverse, float, Eigen::half,
REGISTER4(UnaryOp, GPU, "Reciprocal", functor::inverse, float, Eigen::half,
double, int64);
#endif
+#ifdef TENSORFLOW_USE_SYCL
+REGISTER(UnaryOp, SYCL, "Reciprocal", functor::inverse, float);
+#endif // TENSORFLOW_USE_SYCL
REGISTER5(SimpleBinaryOp, CPU, "ReciprocalGrad", functor::inverse_grad, float,
Eigen::half, double, complex64, complex128);
@@ -43,4 +46,7 @@ REGISTER5(SimpleBinaryOp, CPU, "ReciprocalGrad", functor::inverse_grad, float,
REGISTER3(SimpleBinaryOp, GPU, "ReciprocalGrad", functor::inverse_grad, float,
Eigen::half, double);
#endif
+#ifdef TENSORFLOW_USE_SYCL
+REGISTER(SimpleBinaryOp, SYCL, "ReciprocalGrad", functor::inverse_grad, float);
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_round.cc b/tensorflow/core/kernels/cwise_op_round.cc
index 7a4482dbb2..e192f89782 100644
--- a/tensorflow/core/kernels/cwise_op_round.cc
+++ b/tensorflow/core/kernels/cwise_op_round.cc
@@ -20,9 +20,9 @@ REGISTER5(UnaryOp, CPU, "Round", functor::round, Eigen::half, float, double,
int32, int64);
#ifdef TENSORFLOW_USE_SYCL
-REGISTER(UnaryOp, SYCL, "Round", functor::round, float);
+REGISTER2(UnaryOp, SYCL, "Round", functor::round, float, double);
namespace functor {
-DEFINE_UNARY1(round, float);
+DEFINE_UNARY2(round, float, double);
} // namespace functor
#endif
diff --git a/tensorflow/core/kernels/cwise_op_rsqrt.cc b/tensorflow/core/kernels/cwise_op_rsqrt.cc
index 7dc96d47a6..f23725f48e 100644
--- a/tensorflow/core/kernels/cwise_op_rsqrt.cc
+++ b/tensorflow/core/kernels/cwise_op_rsqrt.cc
@@ -27,6 +27,7 @@ REGISTER5(UnaryOp, CPU, "Rsqrt", functor::rsqrt, float, Eigen::half, double,
.TypeConstraint<TYPE>("T"), \
UnaryOp<SYCLDevice, functor::rsqrt<TYPE>>);
REGISTER_SYCL_KERNEL(float);
+REGISTER_SYCL_KERNEL(double);
#undef REGISTER_SYCL_KERNEL
#endif // TENSORFLOW_USE_SYCL
diff --git a/tensorflow/core/kernels/cwise_op_select.cc b/tensorflow/core/kernels/cwise_op_select.cc
index 8160fb74c2..b5deffdb85 100644
--- a/tensorflow/core/kernels/cwise_op_select.cc
+++ b/tensorflow/core/kernels/cwise_op_select.cc
@@ -28,6 +28,10 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+#endif // TENSORFLOW_USE_SYCL
+
template <typename Device, typename T>
class SelectOp : public OpKernel {
public:
@@ -163,12 +167,24 @@ REGISTER_SELECT_GPU(complex128);
#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+// Registration of the SYCL implementations.
+#define REGISTER_SELECT_SYCL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Select").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ SelectOp<SYCLDevice, type>);
+
+REGISTER_SELECT_SYCL(float);
+REGISTER_SELECT_SYCL(int32);
+#undef REGISTER_SELECT_SYCL
+#endif // TENSORFLOW_USE_SYCL
+
namespace functor {
// CPU Specializations of Select functors.
-template <typename T>
-struct SelectFunctor<CPUDevice, T> {
- void operator()(const CPUDevice& d, typename TTypes<T>::Flat out,
+template <typename Device, typename T>
+struct SelectFunctorBase {
+ void operator()(const Device& d, typename TTypes<T>::Flat out,
typename TTypes<bool>::ConstFlat cond_flat,
typename TTypes<T>::ConstFlat then_flat,
typename TTypes<T>::ConstFlat else_flat) {
@@ -176,10 +192,18 @@ struct SelectFunctor<CPUDevice, T> {
}
};
-// CPU Specializations of Select functors with scalar
template <typename T>
-struct SelectScalarFunctor<CPUDevice, T> {
- void operator()(const CPUDevice& d, typename TTypes<T>::Flat out,
+struct SelectFunctor<CPUDevice, T>
+ : SelectFunctorBase<CPUDevice, T> {};
+#ifdef TENSORFLOW_USE_SYCL
+template <typename T>
+struct SelectFunctor<SYCLDevice, T>
+ : SelectFunctorBase<SYCLDevice, T> {};
+#endif // TENSORFLOW_USE_SYCL
+
+template <typename Device, typename T>
+struct SelectScalarFunctorBase {
+ void operator()(const Device& d, typename TTypes<T>::Flat out,
TTypes<bool>::ConstScalar cond,
typename TTypes<T>::ConstFlat then_flat,
typename TTypes<T>::ConstFlat else_flat) {
@@ -187,9 +211,19 @@ struct SelectScalarFunctor<CPUDevice, T> {
}
};
+// CPU Specializations of Select functors with scalar
template <typename T>
-struct BatchSelectFunctor<CPUDevice, T> {
- void operator()(const CPUDevice& d,
+struct SelectScalarFunctor<CPUDevice, T>
+ : SelectScalarFunctorBase<CPUDevice, T> {};
+#ifdef TENSORFLOW_USE_SYCL
+template <typename T>
+struct SelectScalarFunctor<SYCLDevice, T>
+ : SelectScalarFunctorBase<SYCLDevice, T> {};
+#endif // TENSORFLOW_USE_SYCL
+
+template <typename Device, typename T>
+struct BatchSelectFunctorBase {
+ void operator()(const Device& d,
typename TTypes<T>::Matrix output_flat_outer_dims,
TTypes<bool>::ConstVec cond_vec,
typename TTypes<T>::ConstMatrix then_flat_outer_dims,
@@ -214,6 +248,15 @@ struct BatchSelectFunctor<CPUDevice, T> {
}
};
+template <typename T>
+struct BatchSelectFunctor<CPUDevice, T>
+ : BatchSelectFunctorBase<CPUDevice, T> {};
+#ifdef TENSORFLOW_USE_SYCL
+template <typename T>
+struct BatchSelectFunctor<SYCLDevice, T>
+ : BatchSelectFunctorBase<SYCLDevice, T> {};
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace functor
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_sigmoid.cc b/tensorflow/core/kernels/cwise_op_sigmoid.cc
index cc1f9b8f03..a76a088ac8 100644
--- a/tensorflow/core/kernels/cwise_op_sigmoid.cc
+++ b/tensorflow/core/kernels/cwise_op_sigmoid.cc
@@ -23,6 +23,9 @@ REGISTER5(UnaryOp, CPU, "Sigmoid", functor::sigmoid, float, Eigen::half, double,
REGISTER3(UnaryOp, GPU, "Sigmoid", functor::sigmoid, float, Eigen::half,
double);
#endif
+#ifdef TENSORFLOW_USE_SYCL
+REGISTER(UnaryOp, SYCL, "Sigmoid", functor::sigmoid, float);
+#endif // TENSORFLOW_USE_SYCL
REGISTER5(SimpleBinaryOp, CPU, "SigmoidGrad", functor::sigmoid_grad, float,
Eigen::half, double, complex64, complex128);
@@ -30,5 +33,8 @@ REGISTER5(SimpleBinaryOp, CPU, "SigmoidGrad", functor::sigmoid_grad, float,
REGISTER3(SimpleBinaryOp, GPU, "SigmoidGrad", functor::sigmoid_grad, float,
Eigen::half, double);
#endif
+#ifdef TENSORFLOW_USE_SYCL
+REGISTER(SimpleBinaryOp, SYCL, "SigmoidGrad", functor::sigmoid_grad, float);
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_sign.cc b/tensorflow/core/kernels/cwise_op_sign.cc
index 568906612a..dedd414db5 100644
--- a/tensorflow/core/kernels/cwise_op_sign.cc
+++ b/tensorflow/core/kernels/cwise_op_sign.cc
@@ -33,4 +33,17 @@ REGISTER_KERNEL_BUILDER(Name("Sign")
UnaryOp<CPUDevice, functor::sign<int32>>);
#endif
+#ifdef TENSORFLOW_USE_SYCL
+REGISTER(UnaryOp, SYCL, "Sign", functor::sign, float);
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("Sign")
+ .Device(DEVICE_SYCL)
+ .HostMemory("x")
+ .HostMemory("y")
+ .TypeConstraint<int32>("T"),
+ UnaryOp<CPUDevice, functor::sign<int32>>);
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_sin.cc b/tensorflow/core/kernels/cwise_op_sin.cc
index 8d0c0959f7..ab54c61b56 100644
--- a/tensorflow/core/kernels/cwise_op_sin.cc
+++ b/tensorflow/core/kernels/cwise_op_sin.cc
@@ -27,6 +27,7 @@ REGISTER5(UnaryOp, CPU, "Sin", functor::sin, float, Eigen::half, double,
.TypeConstraint<TYPE>("T"), \
UnaryOp<SYCLDevice, functor::sin<TYPE>>);
REGISTER_SYCL_KERNEL(float);
+REGISTER_SYCL_KERNEL(double);
#undef REGISTER_SYCL_KERNEL
#endif // TENSORFLOW_USE_SYC
diff --git a/tensorflow/core/kernels/cwise_op_sqrt.cc b/tensorflow/core/kernels/cwise_op_sqrt.cc
index 710001517b..55acf648db 100644
--- a/tensorflow/core/kernels/cwise_op_sqrt.cc
+++ b/tensorflow/core/kernels/cwise_op_sqrt.cc
@@ -27,8 +27,9 @@ REGISTER5(UnaryOp, CPU, "Sqrt", functor::sqrt, float, Eigen::half, double,
.TypeConstraint<TYPE>("T"), \
UnaryOp<SYCLDevice, functor::sqrt<TYPE>>);
REGISTER_SYCL_KERNEL(float);
+REGISTER_SYCL_KERNEL(double);
#undef REGISTER_SYCL_KERNEL
-#endif // TENSORFLOW_USE_SYC
+#endif // TENSORFLOW_USE_SYCL
#if GOOGLE_CUDA
REGISTER3(UnaryOp, GPU, "Sqrt", functor::sqrt, float, Eigen::half, double);
diff --git a/tensorflow/core/kernels/cwise_op_square.cc b/tensorflow/core/kernels/cwise_op_square.cc
index f867f127a7..afcacfec1c 100644
--- a/tensorflow/core/kernels/cwise_op_square.cc
+++ b/tensorflow/core/kernels/cwise_op_square.cc
@@ -27,6 +27,7 @@ REGISTER7(UnaryOp, CPU, "Square", functor::square, float, Eigen::half, double,
.TypeConstraint<TYPE>("T"), \
UnaryOp<SYCLDevice, functor::square<TYPE>>);
REGISTER_SYCL_KERNEL(float);
+REGISTER_SYCL_KERNEL(double);
#undef REGISTER_SYCL_KERNEL
#endif // TENSORFLOW_USE_SYC
diff --git a/tensorflow/core/kernels/cwise_op_tan.cc b/tensorflow/core/kernels/cwise_op_tan.cc
index ac49cad88f..9c850c9420 100644
--- a/tensorflow/core/kernels/cwise_op_tan.cc
+++ b/tensorflow/core/kernels/cwise_op_tan.cc
@@ -26,6 +26,7 @@ REGISTER2(UnaryOp, CPU, "Tan", functor::tan, float, double);
.TypeConstraint<TYPE>("T"), \
UnaryOp<SYCLDevice, functor::tan<TYPE>>);
REGISTER_SYCL_KERNEL(float);
+REGISTER_SYCL_KERNEL(double);
#undef REGISTER_SYCL_KERNEL
#endif // TENSORFLOW_USE_SYC
diff --git a/tensorflow/core/kernels/cwise_op_tanh.cc b/tensorflow/core/kernels/cwise_op_tanh.cc
index ae2c473e20..1dbc13061b 100644
--- a/tensorflow/core/kernels/cwise_op_tanh.cc
+++ b/tensorflow/core/kernels/cwise_op_tanh.cc
@@ -28,6 +28,7 @@ REGISTER5(UnaryOp, CPU, "Tanh", functor::tanh, float, Eigen::half, double,
.TypeConstraint<TYPE>("T"), \
UnaryOp<SYCLDevice, functor::tanh<TYPE>>);
REGISTER_SYCL_KERNEL(float);
+REGISTER_SYCL_KERNEL(double);
#undef REGISTER_SYCL_KERNEL
#endif // TENSORFLOW_USE_SYC
diff --git a/tensorflow/core/kernels/cwise_ops_gradients.h b/tensorflow/core/kernels/cwise_ops_gradients.h
index 671de380d3..77b330f589 100644
--- a/tensorflow/core/kernels/cwise_ops_gradients.h
+++ b/tensorflow/core/kernels/cwise_ops_gradients.h
@@ -171,6 +171,21 @@ struct SimpleBinaryFunctor<CPUDevice, Functor> {
}
};
+
+#ifdef TENSORFLOW_USE_SYCL
+// Partial specialization of BinaryFunctor for SYCL devices
+typedef Eigen::SyclDevice SYCLDevice;
+template <typename Functor>
+struct SimpleBinaryFunctor<SYCLDevice, Functor> {
+ void operator()(const SYCLDevice& d, typename Functor::tout_type out,
+ typename Functor::tin_type in0,
+ typename Functor::tin_type in1) {
+ out.device(d) = in0.binaryExpr(in1, typename Functor::func());
+ }
+};
+
+#endif // TENSORFLOW_USE_SYCL
+
template <typename T>
struct tanh_grad : base<T, Eigen::internal::scalar_tanh_gradient_op<T>> {};
diff --git a/tensorflow/core/kernels/cwise_ops_test.cc b/tensorflow/core/kernels/cwise_ops_test.cc
index 6250928aca..92018ec871 100644
--- a/tensorflow/core/kernels/cwise_ops_test.cc
+++ b/tensorflow/core/kernels/cwise_ops_test.cc
@@ -51,18 +51,38 @@ static int ColsFromArg(int arg) { return (arg % kRows); }
BENCHMARK(BM_##DEVICE##_##FUNC##_##TYPE)->Range(4 << 10, 1 << 20);
BM_UNARY(cpu, Floor, float, DT_FLOAT);
+#if GOOGLE_CUDA
BM_UNARY(gpu, Floor, float, DT_FLOAT);
+#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+BM_UNARY(sycl, Floor, float, DT_FLOAT);
+#endif // TENSORFLOW_USE_SYCL
+
BM_UNARY(cpu, Floor, double, DT_DOUBLE);
+#if GOOGLE_CUDA
BM_UNARY(gpu, Floor, double, DT_DOUBLE);
+#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+BM_UNARY(sycl, Floor, double, DT_DOUBLE);
+#endif // TENSORFLOW_USE_SYCL
+
BM_UNARY(cpu, Conj, std::complex<float>, DT_COMPLEX64);
+#if GOOGLE_CUDA
BM_UNARY(gpu, Conj, std::complex<float>, DT_COMPLEX64);
+#endif // GOOGLE_CUDA
BM_UNARY(cpu, Conj, std::complex<double>, DT_COMPLEX128);
+#if GOOGLE_CUDA
BM_UNARY(gpu, Conj, std::complex<double>, DT_COMPLEX128);
+#endif // GOOGLE_CUDA
BM_UNARY(cpu, Rint, double, DT_DOUBLE);
+#if GOOGLE_CUDA
BM_UNARY(gpu, Rint, double, DT_DOUBLE);
+#endif // GOOGLE_CUDA
BM_UNARY(cpu, Rint, float, DT_FLOAT);
+#if GOOGLE_CUDA
BM_UNARY(gpu, Rint, float, DT_FLOAT);
+#endif // GOOGLE_CUDA
// data func scalar.
static Graph* BinaryScalar(int num, const string& func) {
@@ -90,9 +110,20 @@ static Graph* BinaryScalar(int num, const string& func) {
->Arg(1048576);
BM_BINARY_SCALAR(cpu, Less);
+#if GOOGLE_CUDA
BM_BINARY_SCALAR(gpu, Less);
+#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+BM_BINARY_SCALAR(sycl, Less);
+#endif // TENSORFLOW_USE_SYCL
+
BM_BINARY_SCALAR(cpu, Add);
+#if GOOGLE_CUDA
BM_BINARY_SCALAR(gpu, Add);
+#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+BM_BINARY_SCALAR(sycl, Add);
+#endif // TENSORFLOW_USE_SYCL
#undef BM_BINARY_SCALAR
template <class T>
@@ -130,9 +161,13 @@ static Graph* BiasAdd(int rows, int cols, DataType type) {
using Eigen::half;
BM_BIAS_ADD_ALL(cpu, float, DT_FLOAT);
+#if GOOGLE_CUDA
BM_BIAS_ADD_ALL(gpu, float, DT_FLOAT);
+#endif // GOOGLE_CUDA
BM_BIAS_ADD_ALL(cpu, half, DT_HALF);
+#if GOOGLE_CUDA
BM_BIAS_ADD_ALL(gpu, half, DT_HALF);
+#endif // GOOGLE_CUDA
#undef BM_BIAS_ADD_ALL
#undef BM_BIAS_ADD
@@ -180,12 +215,18 @@ static Graph* BiasAddGrad(int rows, int cols, int channels, DataType type,
BM_BIAS_ADD_GRAD(DEVICE, FORMAT, C_TYPE, TF_TYPE, 4096, 4096, 1);
using Eigen::half;
+#if GOOGLE_CUDA
BM_BIAS_ADD_GRAD_ALL(gpu, NCHW, float, DT_FLOAT);
BM_BIAS_ADD_GRAD_ALL(gpu, NCHW, half, DT_HALF);
+#endif // GOOGLE_CUDA
BM_BIAS_ADD_GRAD_ALL(cpu, NHWC, float, DT_FLOAT);
+#if GOOGLE_CUDA
BM_BIAS_ADD_GRAD_ALL(gpu, NHWC, float, DT_FLOAT);
+#endif // GOOGLE_CUDA
BM_BIAS_ADD_GRAD_ALL(cpu, NHWC, half, DT_HALF);
+#if GOOGLE_CUDA
BM_BIAS_ADD_GRAD_ALL(gpu, NHWC, half, DT_HALF);
+#endif // GOOGLE_CUDA
#undef BM_BIAS_ADD_GRAD_ALL
#undef BM_BIAS_ADD_GRAD
@@ -223,7 +264,12 @@ static Graph* BcastAdd(int rows, int cols, int dim) {
BM_BCAST_ADD_ROW(DEVICE, 2048, 512); \
BM_BCAST_ADD_ROW(DEVICE, 4096, 512);
BM_BCAST_ADD_ROW_ALL(cpu);
+#if GOOGLE_CUDA
BM_BCAST_ADD_ROW_ALL(gpu);
+#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+BM_BCAST_ADD_ROW_ALL(sycl);
+#endif // TENSORFLOW_USE_SYCL
#undef BM_BCAST_ADD_ROW_ALL
#undef BM_BCAST_ADD_ROW
@@ -244,7 +290,12 @@ BM_BCAST_ADD_ROW_ALL(gpu);
BM_BCAST_ADD_COL(DEVICE, 2048, 512); \
BM_BCAST_ADD_COL(DEVICE, 4096, 512);
BM_BCAST_ADD_COL_ALL(cpu);
+#if GOOGLE_CUDA
BM_BCAST_ADD_COL_ALL(gpu);
+#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+BM_BCAST_ADD_COL_ALL(sycl);
+#endif // TENSORFLOW_USE_SYCL
#undef BM_BCAST_ADD_COL_ALL
#undef BM_BCAST_ADD_COL
diff --git a/tensorflow/core/kernels/debug_ops.cc b/tensorflow/core/kernels/debug_ops.cc
index 0706b72a89..d0f5db3bf2 100644
--- a/tensorflow/core/kernels/debug_ops.cc
+++ b/tensorflow/core/kernels/debug_ops.cc
@@ -97,6 +97,7 @@ REGISTER_GPU_DEBUG_NAN_COUNT(double);
.TypeConstraint<type>("T"), \
DebugNanCountOp<type>);
REGISTER_GPU_DEBUG_NAN_COUNT(float);
+REGISTER_GPU_DEBUG_NAN_COUNT(double);
#endif // TENSORFLOW_USE_SYCL
// Register debug numeric summary ops.
@@ -129,6 +130,7 @@ REGISTER_GPU_DEBUG_NUMERIC_SUMMARY_COUNT(double);
.TypeConstraint<type>("T"), \
DebugNumericSummaryOp<type>);
REGISTER_GPU_DEBUG_NUMERIC_SUMMARY_COUNT(float);
+REGISTER_GPU_DEBUG_NUMERIC_SUMMARY_COUNT(double);
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/dense_update_ops.cc b/tensorflow/core/kernels/dense_update_ops.cc
index 42fe6e88c9..767f143727 100644
--- a/tensorflow/core/kernels/dense_update_ops.cc
+++ b/tensorflow/core/kernels/dense_update_ops.cc
@@ -152,6 +152,7 @@ typedef Eigen::SyclDevice SYCLDevice;
DenseUpdateOp<SYCLDevice, type, DenseUpdateType::SUB>);
REGISTER_SYCL_KERNEL(float);
+REGISTER_SYCL_KERNEL(double);
#undef REGISTER_SYCL_KERNEL
#endif
diff --git a/tensorflow/core/kernels/fill_functor.cc b/tensorflow/core/kernels/fill_functor.cc
index 08ec4baff3..0df8f9d3ed 100644
--- a/tensorflow/core/kernels/fill_functor.cc
+++ b/tensorflow/core/kernels/fill_functor.cc
@@ -62,6 +62,8 @@ void SetZeroFunctor<Eigen::SyclDevice, T>::operator()(
#define DEFINE_SETZERO_SYCL(T) \
template struct SetZeroFunctor<Eigen::SyclDevice, T>;
DEFINE_SETZERO_SYCL(float);
+DEFINE_SETZERO_SYCL(bool);
+DEFINE_SETZERO_SYCL(double);
#undef DEFINE_SETZERO_SYCL
#endif // TENSORFLOW_USE_SYCL
diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc
index 9aa289c3c9..d08dec46d1 100644
--- a/tensorflow/core/kernels/function_ops.cc
+++ b/tensorflow/core/kernels/function_ops.cc
@@ -185,6 +185,34 @@ REGISTER_KERNEL_BUILDER(Name("_ArrayToList")
.TypeConstraint<int32>("T"),
PassOn);
+#ifdef TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("_ListToArray").Device(DEVICE_SYCL).TypeConstraint<type>("T"),\
+ PassOn); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("_ArrayToList").Device(DEVICE_SYCL).TypeConstraint<type>("T"),\
+ PassOn);
+
+REGISTER_SYCL_KERNELS(float);
+REGISTER_SYCL_KERNELS(double);
+
+#undef REGISTER_SYCL_KERNELS
+
+REGISTER_KERNEL_BUILDER(Name("_ListToArray")
+ .Device(DEVICE_SYCL)
+ .HostMemory("input")
+ .HostMemory("output")
+ .TypeConstraint<int32>("T"),
+ PassOn);
+REGISTER_KERNEL_BUILDER(Name("_ArrayToList")
+ .Device(DEVICE_SYCL)
+ .HostMemory("input")
+ .HostMemory("output")
+ .TypeConstraint<int32>("T"),
+ PassOn);
+#endif // TENSORFLOW_USE_SYCL
+
class SymbolicGradientOp : public AsyncOpKernel {
public:
SymbolicGradientOp(OpKernelConstruction* ctx)
diff --git a/tensorflow/core/kernels/matmul_op.cc b/tensorflow/core/kernels/matmul_op.cc
index a2b0127fac..57c055885c 100644
--- a/tensorflow/core/kernels/matmul_op.cc
+++ b/tensorflow/core/kernels/matmul_op.cc
@@ -46,6 +46,9 @@ perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+#endif // TENSORFLOW_USE_SYCL
template <typename Device, typename T, bool USE_CUBLAS>
struct LaunchMatMul;
@@ -118,27 +121,42 @@ bool ExplicitVectorMatrixOptimization<Eigen::half>(
return false;
}
-// On CPUs, we ignore USE_CUBLAS
-template <typename T>
-struct LaunchMatMulCPU {
+template <typename Device, typename T>
+struct LaunchMatMulBase {
static void launch(
OpKernelContext* ctx, OpKernel* kernel, const Tensor& a, const Tensor& b,
const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
Tensor* out) {
+#ifndef TENSORFLOW_USE_SYCL
// An explicit vector-matrix multiply is much better optimized than an
// implicit one and this is a bottleneck during non-batched inference.
bool was_vector = ExplicitVectorMatrixOptimization<T>(a, b, dim_pair, out);
if (!was_vector) {
- functor::MatMulFunctor<CPUDevice, T>()(ctx->eigen_device<CPUDevice>(),
+#endif // TENSORFLOW_USE_SYCL
+ functor::MatMulFunctor<Device, T>()(ctx->eigen_device<Device>(),
out->matrix<T>(), a.matrix<T>(),
b.matrix<T>(), dim_pair);
+#ifndef TENSORFLOW_USE_SYCL
}
+#endif // TENSORFLOW_USE_SYCL
}
};
+// On CPUs, we ignore USE_CUBLAS
+template <typename T>
+struct LaunchMatMulCPU : LaunchMatMulBase<CPUDevice, T> {};
+
template <typename T, bool USE_CUBLAS>
struct LaunchMatMul<CPUDevice, T, USE_CUBLAS> : public LaunchMatMulCPU<T> {};
+#ifdef TENSORFLOW_USE_SYCL
+template <typename T>
+struct LaunchMatMulSYCL : LaunchMatMulBase<SYCLDevice, T> {};
+
+template <typename T, bool USE_CUBLAS>
+struct LaunchMatMul<SYCLDevice, T, USE_CUBLAS> : public LaunchMatMulSYCL<T> {};
+#endif // TENSORFLOW_USE_SYCL
+
#if GOOGLE_CUDA
template <typename T>
@@ -256,6 +274,20 @@ struct MatMulFunctor<CPUDevice, T> {
}
};
+#ifdef TENSORFLOW_USE_SYCL
+// Partial specialization MatMulFunctor<Device=SYCLDevice, T>.
+template <typename T>
+struct MatMulFunctor<SYCLDevice, T> {
+ void operator()(
+ const SYCLDevice& d, typename MatMulTypes<T>::out_type out,
+ typename MatMulTypes<T>::in_type in0,
+ typename MatMulTypes<T>::in_type in1,
+ const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair) {
+ MatMul<SYCLDevice>(d, out, in0, in1, dim_pair);
+ }
+};
+#endif // TENSORFLOW_USE_SYCL
+
} // end namespace functor
#define REGISTER_CPU(T) \
@@ -294,4 +326,17 @@ TF_CALL_half(REGISTER_GPU);
#endif
#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("MatMul").Device(DEVICE_SYCL).TypeConstraint<T>("T"), \
+ MatMulOp<SYCLDevice, T, false /* xxblas */>); \
+ REGISTER_KERNEL_BUILDER(Name("MatMul") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<T>("T") \
+ .Label("eigen"), \
+ MatMulOp<SYCLDevice, T, false /* xxblas */>)
+TF_CALL_float(REGISTER_SYCL);
+
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/pack_op.cc b/tensorflow/core/kernels/pack_op.cc
index 4977ad1d7c..a6650f369b 100644
--- a/tensorflow/core/kernels/pack_op.cc
+++ b/tensorflow/core/kernels/pack_op.cc
@@ -167,6 +167,7 @@ REGISTER_KERNEL_BUILDER(Name("Pack")
PackOp<SYCLDevice, type>)
REGISTER_SYCL(float);
+REGISTER_SYCL(double);
#undef REGISTER_SYCL
// A special GPU kernel for int32.
diff --git a/tensorflow/core/kernels/pad_op.cc b/tensorflow/core/kernels/pad_op.cc
index bec2d02cb5..91984319c6 100644
--- a/tensorflow/core/kernels/pad_op.cc
+++ b/tensorflow/core/kernels/pad_op.cc
@@ -38,6 +38,9 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+#endif // TENSORFLOW_USE_SYCL
template <typename Device, typename T>
class PadOp : public OpKernel {
@@ -199,4 +202,30 @@ REGISTER_KERNEL_BUILDER(Name("Pad")
PadOp<CPUDevice, int32>);
#endif
+#ifdef TENSORFLOW_USE_SYCL
+// Registration of the GPU implementations.
+#define REGISTER_SYCL_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER(Name("Pad") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<int32>("Tpaddings") \
+ .HostMemory("paddings"), \
+ PadOp<SYCLDevice, T>)
+
+REGISTER_SYCL_KERNEL(float);
+REGISTER_SYCL_KERNEL(double);
+
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("Pad")
+ .Device(DEVICE_SYCL)
+ .TypeConstraint<int32>("T")
+ .TypeConstraint<int32>("Tpaddings")
+ .HostMemory("input")
+ .HostMemory("paddings")
+ .HostMemory("output"),
+ PadOp<CPUDevice, int32>);
+#endif // TENSORFLOW_USE_SYCL
+
} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/reduction_ops_common.h b/tensorflow/core/kernels/reduction_ops_common.h
index 625cea4228..19071b47f1 100644
--- a/tensorflow/core/kernels/reduction_ops_common.h
+++ b/tensorflow/core/kernels/reduction_ops_common.h
@@ -268,6 +268,31 @@ struct ReduceFunctor<CPUDevice, Reducer>
template <typename Reducer>
struct ReduceFunctor<SYCLDevice, Reducer>
: ReduceFunctorBase<SYCLDevice, Reducer>{};
+
+template <typename T>
+struct ReduceFunctor<SYCLDevice, Eigen::internal::MeanReducer<T> > {
+ template <typename OUT_T, typename IN_T, typename ReductionAxes>
+ static void Reduce(const SYCLDevice& d, OUT_T out, IN_T in,
+ const ReductionAxes& reduction_axes,
+ const Eigen::internal::MeanReducer<T>& reducer) {
+ typedef typename IN_T::Index Index;
+ // Eigen sum reductions are much faster on GPU than mean reductions:
+ // Simply trigger them by computing the sum of the weighted inputs.
+ Index num_coeffs_to_reduce = 1;
+ for (int i = 0; i < Eigen::internal::array_size<ReductionAxes>::value;
+ ++i) {
+ num_coeffs_to_reduce *= in.dimension(reduction_axes[i]);
+ }
+ T scale = T(1.0) / num_coeffs_to_reduce;
+ out.device(d) = (in * scale).sum(reduction_axes);
+ }
+
+ template <typename OUT_T>
+ static void FillIdentity(const SYCLDevice& d, OUT_T out,
+ const Eigen::internal::MeanReducer<T>& reducer) {
+ FillIdentityEigenImpl(d, out, reducer);
+ }
+};
#endif // TENSORFLOW_USE_SYCL
} // namespace functor
diff --git a/tensorflow/core/kernels/reduction_ops_max.cc b/tensorflow/core/kernels/reduction_ops_max.cc
index db86157c8e..5ab97d1eee 100644
--- a/tensorflow/core/kernels/reduction_ops_max.cc
+++ b/tensorflow/core/kernels/reduction_ops_max.cc
@@ -57,4 +57,27 @@ REGISTER_KERNEL_BUILDER(
#endif
+#ifdef TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Max") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("Tidx") \
+ .HostMemory("reduction_indices"), \
+ ReductionOp<SYCLDevice, type, Eigen::internal::MaxReducer<type>>);
+REGISTER_SYCL_KERNELS(float);
+#undef REGISTER_SYCL_KERNELS
+
+REGISTER_KERNEL_BUILDER(
+ Name("Max")
+ .Device(DEVICE_SYCL)
+ .HostMemory("reduction_indices")
+ .HostMemory("input")
+ .HostMemory("output")
+ .TypeConstraint<int32>("T")
+ .TypeConstraint<int32>("Tidx"),
+ ReductionOp<CPUDevice, int32, Eigen::internal::MaxReducer<int32>>);
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/reduction_ops_mean.cc b/tensorflow/core/kernels/reduction_ops_mean.cc
index fef3cd0699..e018cb55dd 100644
--- a/tensorflow/core/kernels/reduction_ops_mean.cc
+++ b/tensorflow/core/kernels/reduction_ops_mean.cc
@@ -44,4 +44,17 @@ REGISTER_GPU_KERNELS(double);
#endif
+#ifdef TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Mean") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("Tidx") \
+ .HostMemory("reduction_indices"), \
+ ReductionOp<SYCLDevice, type, Eigen::internal::MeanReducer<type>>);
+REGISTER_SYCL_KERNELS(float);
+#undef REGISTER_SYCL_KERNELS
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/reduction_ops_min.cc b/tensorflow/core/kernels/reduction_ops_min.cc
index c362bc8867..ec240421b9 100644
--- a/tensorflow/core/kernels/reduction_ops_min.cc
+++ b/tensorflow/core/kernels/reduction_ops_min.cc
@@ -57,4 +57,27 @@ REGISTER_KERNEL_BUILDER(
#endif
+#ifdef TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Min") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("Tidx") \
+ .HostMemory("reduction_indices"), \
+ ReductionOp<SYCLDevice, type, Eigen::internal::MinReducer<type>>);
+REGISTER_SYCL_KERNELS(float);
+#undef REGISTER_SYCL_KERNELS
+
+REGISTER_KERNEL_BUILDER(
+ Name("Min")
+ .Device(DEVICE_SYCL)
+ .HostMemory("reduction_indices")
+ .HostMemory("input")
+ .HostMemory("output")
+ .TypeConstraint<int32>("T")
+ .TypeConstraint<int32>("Tidx"),
+ ReductionOp<CPUDevice, int32, Eigen::internal::MinReducer<int32>>);
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/reduction_ops_prod.cc b/tensorflow/core/kernels/reduction_ops_prod.cc
index c6aff8c2ed..e04c655dab 100644
--- a/tensorflow/core/kernels/reduction_ops_prod.cc
+++ b/tensorflow/core/kernels/reduction_ops_prod.cc
@@ -45,4 +45,28 @@ REGISTER_GPU_KERNELS(double);
#endif
+#ifdef TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Prod") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("Tidx") \
+ .HostMemory("reduction_indices"), \
+ ReductionOp<SYCLDevice, type, Eigen::internal::ProdReducer<type>>);
+REGISTER_SYCL_KERNELS(float);
+REGISTER_SYCL_KERNELS(double);
+#undef REGISTER_SYCL_KERNELS
+
+REGISTER_KERNEL_BUILDER(
+ Name("Prod")
+ .Device(DEVICE_SYCL)
+ .TypeConstraint<int32>("T")
+ .TypeConstraint<int32>("Tidx")
+ .HostMemory("input")
+ .HostMemory("output")
+ .HostMemory("reduction_indices"),
+ ReductionOp<CPUDevice, int32, Eigen::internal::ProdReducer<int32>>);
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/reduction_ops_sum.cc b/tensorflow/core/kernels/reduction_ops_sum.cc
index 3aa38f418e..938ca66a0c 100644
--- a/tensorflow/core/kernels/reduction_ops_sum.cc
+++ b/tensorflow/core/kernels/reduction_ops_sum.cc
@@ -74,7 +74,6 @@ REGISTER_KERNEL_BUILDER(
.HostMemory("reduction_indices"), \
ReductionOp<SYCLDevice, type, Eigen::internal::SumReducer<type>>);
REGISTER_SYCL_KERNELS(float);
-REGISTER_SYCL_KERNELS(double);
#undef REGISTER_SYCL_KERNELS
// A special GPU kernel for int32.
diff --git a/tensorflow/core/kernels/relu_op.cc b/tensorflow/core/kernels/relu_op.cc
index f24a71ec8c..d70398bea5 100644
--- a/tensorflow/core/kernels/relu_op.cc
+++ b/tensorflow/core/kernels/relu_op.cc
@@ -29,6 +29,9 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+#endif // TENSORFLOW_USE_SYCL
#define REGISTER_RELU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
@@ -131,4 +134,30 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+// Registration of the GPU implementations.
+#define REGISTER_SYCL_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ ReluOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ReluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ ReluGradOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu6").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ Relu6Op<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu6Grad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ Relu6GradOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Elu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ EluOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("EluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ EluGradOp<SYCLDevice, type>)
+
+REGISTER_SYCL_KERNELS(float);
+#undef REGISTER_SYCL_KERNELS
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/relu_op.h b/tensorflow/core/kernels/relu_op.h
index 365c6201a5..e2e0bd48dd 100644
--- a/tensorflow/core/kernels/relu_op.h
+++ b/tensorflow/core/kernels/relu_op.h
@@ -175,6 +175,10 @@ void EluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
} // namespace tensorflow
+#ifdef TENSORFLOW_USE_SYCL
+#undef EIGEN_USE_SYCL
+#endif // TENSORFLOW_USE_SYCL
+
#undef EIGEN_USE_THREADS
#endif // TENSORFLOW_KERNELS_RELU_OP_H_
diff --git a/tensorflow/core/kernels/reverse_op.cc b/tensorflow/core/kernels/reverse_op.cc
index 7852499965..596dac9087 100644
--- a/tensorflow/core/kernels/reverse_op.cc
+++ b/tensorflow/core/kernels/reverse_op.cc
@@ -33,6 +33,9 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+#endif // TENSORFLOW_USE_SYCL
namespace {
@@ -351,4 +354,36 @@ REGISTER_KERNEL_BUILDER(Name("ReverseV2")
ReverseV2Op<CPUDevice, int32>);
#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNELS(T) \
+ REGISTER_KERNEL_BUILDER(Name("Reverse") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<T>("T") \
+ .HostMemory("dims"), \
+ ReverseOp<SYCLDevice, T>) \
+ REGISTER_KERNEL_BUILDER(Name("ReverseV2") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<int32>("Tidx") \
+ .HostMemory("axis"), \
+ ReverseV2Op<SYCLDevice, T>)
+TF_CALL_float(REGISTER_SYCL_KERNELS);
+
+REGISTER_KERNEL_BUILDER(Name("Reverse")
+ .Device(DEVICE_SYCL)
+ .TypeConstraint<int32>("T")
+ .HostMemory("tensor")
+ .HostMemory("dims")
+ .HostMemory("output"),
+ ReverseOp<CPUDevice, int32>);
+REGISTER_KERNEL_BUILDER(Name("ReverseV2")
+ .Device(DEVICE_SYCL)
+ .TypeConstraint<int32>("T")
+ .TypeConstraint<int32>("Tidx")
+ .HostMemory("tensor")
+ .HostMemory("axis")
+ .HostMemory("output"),
+ ReverseV2Op<CPUDevice, int32>);
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/scatter_op.cc b/tensorflow/core/kernels/scatter_op.cc
index 827eb7dbca..51dad49cfe 100644
--- a/tensorflow/core/kernels/scatter_op.cc
+++ b/tensorflow/core/kernels/scatter_op.cc
@@ -180,8 +180,8 @@ TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_UPDATE_GPU);
#define REGISTER_SCATTER_UPDATE_SYCL(type) REGISTER_SCATTER_UPDATE(type, SYCL);
-REGISTER_SCATTER_ARITHEMTIC_SYCL(float);
-REGISTER_SCATTER_UPDATE_SYCL(float);
+TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHEMTIC_SYCL);
+TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_UPDATE_SYCL);
#undef REGISTER_SCATTER_ARITHEMTIC_SYCL
#undef REGISTER_SCATTER_UPDATE_SYCL
diff --git a/tensorflow/core/kernels/sequence_ops.cc b/tensorflow/core/kernels/sequence_ops.cc
index c24ecdf8b9..c8ea923020 100644
--- a/tensorflow/core/kernels/sequence_ops.cc
+++ b/tensorflow/core/kernels/sequence_ops.cc
@@ -92,9 +92,11 @@ class RangeOp : public OpKernel {
#ifdef TENSORFLOW_USE_SYCL
#define REGISTER_SYCL_KERNEL(T) REGISTER_KERNEL(DEVICE_SYCL, T)
TF_CALL_float(REGISTER_SYCL_KERNEL);
+TF_CALL_double(REGISTER_SYCL_KERNEL);
TF_CALL_int32(REGISTER_SYCL_KERNEL);
TF_CALL_int64(REGISTER_SYCL_KERNEL);
-#endif // TENSORFLOW_USE_SYCL
+#undef REGISTER_SYCL_KERNEL
+#endif // TENSORFLOW_USE_SYCL
TF_CALL_float(REGISTER_CPU_KERNEL);
TF_CALL_double(REGISTER_CPU_KERNEL);
@@ -170,4 +172,9 @@ TF_CALL_double(REGISTER_CPU_KERNEL);
TF_CALL_float(REGISTER_GPU_KERNEL);
TF_CALL_double(REGISTER_GPU_KERNEL);
+#ifdef TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNEL(T) REGISTER_KERNEL(DEVICE_SYCL, T)
+TF_CALL_float(REGISTER_SYCL_KERNEL);
+TF_CALL_double(REGISTER_SYCL_KERNEL);
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/shape_ops.cc b/tensorflow/core/kernels/shape_ops.cc
index 6bc0b4560b..177a32464b 100644
--- a/tensorflow/core/kernels/shape_ops.cc
+++ b/tensorflow/core/kernels/shape_ops.cc
@@ -201,6 +201,7 @@ REGISTER_KERNEL_BUILDER(Name("Rank").Device(DEVICE_CPU).HostMemory("output"),
.HostMemory("output"), \
RankOp);
REGISTER_SYCL_KERNEL(float);
+REGISTER_SYCL_KERNEL(double);
#undef REGISTER_SYCL_KERNEL
// A special GPU kernel for int32 and bool.
@@ -297,6 +298,43 @@ REGISTER_KERNEL_BUILDER(Name("Size")
SizeOp<int64>);
#endif
+#ifdef TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("Size") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("out_type") \
+ .HostMemory("output"), \
+ SizeOp<int32>); \
+ REGISTER_KERNEL_BUILDER(Name("Size") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("out_type") \
+ .HostMemory("output"), \
+ SizeOp<int64>);
+REGISTER_SYCL_KERNEL(float);
+REGISTER_SYCL_KERNEL(double);
+#undef REGISTER_SYCL_KERNEL
+
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("Size")
+ .Device(DEVICE_SYCL)
+ .TypeConstraint<int32>("T")
+ .TypeConstraint<int32>("out_type")
+ .HostMemory("input")
+ .HostMemory("output"),
+ SizeOp<int32>);
+REGISTER_KERNEL_BUILDER(Name("Size")
+ .Device(DEVICE_SYCL)
+ .TypeConstraint<int32>("T")
+ .TypeConstraint<int64>("out_type")
+ .HostMemory("input")
+ .HostMemory("output"),
+ SizeOp<int64>);
+#endif // TENSORFLOW_USE_SYCL
+
// ExpandDims ------------------------------------
REGISTER_KERNEL_BUILDER(Name("ExpandDims")
.Device(DEVICE_CPU)
@@ -323,7 +361,30 @@ REGISTER_KERNEL_BUILDER(Name("ExpandDims")
.HostMemory("dim")
.HostMemory("output"),
ExpandDimsOp);
-#endif
+#endif // GOOGLE_CUDA
+
+#ifdef TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("ExpandDims") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("Tdim") \
+ .HostMemory("dim"), \
+ ExpandDimsOp);
+REGISTER_SYCL_KERNEL(float)
+REGISTER_SYCL_KERNEL(double)
+
+#undef REGISTER_SYCL_KERNEL
+
+REGISTER_KERNEL_BUILDER(Name("ExpandDims")
+ .Device(DEVICE_SYCL)
+ .TypeConstraint<int32>("T")
+ .TypeConstraint<int32>("Tdim")
+ .HostMemory("input")
+ .HostMemory("dim")
+ .HostMemory("output"),
+ ExpandDimsOp);
+#endif // TENSORFLOW_USE_SYCL
// Squeeze ---------------------------------------
REGISTER_KERNEL_BUILDER(Name("Squeeze").Device(DEVICE_CPU), SqueezeOp);
@@ -347,4 +408,24 @@ REGISTER_KERNEL_BUILDER(Name("Squeeze")
SqueezeOp);
#endif
+#if TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Squeeze").Device(DEVICE_SYCL).TypeConstraint<type>("T"),\
+ SqueezeOp);
+REGISTER_SYCL_KERNEL(float);
+REGISTER_SYCL_KERNEL(double);
+#undef REGISTER_SYCL_KERNEL
+
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("Squeeze")
+ .Device(DEVICE_SYCL)
+ .TypeConstraint<int32>("T")
+ .HostMemory("input")
+ .HostMemory("output"),
+ SqueezeOp);
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/softmax_op.cc b/tensorflow/core/kernels/softmax_op.cc
index c7ae93852f..de11de32f1 100644
--- a/tensorflow/core/kernels/softmax_op.cc
+++ b/tensorflow/core/kernels/softmax_op.cc
@@ -28,17 +28,27 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+#endif // TENSORFLOW_USE_SYCL
// Partial specialization for a CPUDevice, that uses the Eigen implementation
// from SoftmaxEigenImpl.
namespace functor {
-template <typename T>
-struct SoftmaxFunctor<CPUDevice, T> {
- void operator()(const CPUDevice& d, typename TTypes<T>::ConstMatrix logits,
+template <typename Device, typename T>
+struct SoftmaxFunctorBase {
+ void operator()(const Device& d, typename TTypes<T>::ConstMatrix logits,
typename TTypes<T>::Matrix softmax, const bool log) {
- SoftmaxEigenImpl<CPUDevice, T>::Compute(d, logits, softmax, log);
+ SoftmaxEigenImpl<Device, T>::Compute(d, logits, softmax, log);
}
};
+template <typename T>
+struct SoftmaxFunctor<CPUDevice, T> : SoftmaxFunctorBase<CPUDevice, T> {};
+
+#ifdef TENSORFLOW_USE_SYCL
+template <typename T>
+struct SoftmaxFunctor<SYCLDevice, T> : SoftmaxFunctorBase<SYCLDevice, T> {};
+#endif // TENSORFLOW_USE_SYCL
} // namespace functor
#define REGISTER_CPU(T) \
@@ -76,4 +86,10 @@ REGISTER_KERNEL_BUILDER(
SoftmaxOp<GPUDevice, float>);
#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+REGISTER_KERNEL_BUILDER(
+ Name("Softmax").Device(DEVICE_SYCL).TypeConstraint<float>("T"),
+ SoftmaxOp<SYCLDevice, float>);
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/stage_op.cc b/tensorflow/core/kernels/stage_op.cc
index c18b992ea1..161ba89212 100644
--- a/tensorflow/core/kernels/stage_op.cc
+++ b/tensorflow/core/kernels/stage_op.cc
@@ -99,6 +99,9 @@ REGISTER_KERNEL_BUILDER(Name("Stage").Device(DEVICE_CPU), StageOp);
#if GOOGLE_CUDA
REGISTER_KERNEL_BUILDER(Name("Stage").Device(DEVICE_GPU), StageOp);
#endif
+#ifdef TENSORFLOW_USE_SYCL
+REGISTER_KERNEL_BUILDER(Name("Stage").Device(DEVICE_SYCL), StageOp);
+#endif // TENSORFLOW_USE_SYCL
class UnstageOp : public OpKernel {
public:
@@ -126,5 +129,8 @@ REGISTER_KERNEL_BUILDER(Name("Unstage").Device(DEVICE_CPU), UnstageOp);
#if GOOGLE_CUDA
REGISTER_KERNEL_BUILDER(Name("Unstage").Device(DEVICE_GPU), UnstageOp);
#endif
+#ifdef TENSORFLOW_USE_SYCL
+REGISTER_KERNEL_BUILDER(Name("Unstage").Device(DEVICE_SYCL), UnstageOp);
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/strided_slice_op.cc b/tensorflow/core/kernels/strided_slice_op.cc
index a373865509..8a3d09f1c1 100644
--- a/tensorflow/core/kernels/strided_slice_op.cc
+++ b/tensorflow/core/kernels/strided_slice_op.cc
@@ -450,4 +450,71 @@ REGISTER_KERNEL_BUILDER(Name("StridedSliceAssign")
#undef REGISTER_GPU
#endif // GOOGLE_CUDA
+
+#ifdef TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL(type) \
+ REGISTER_KERNEL_BUILDER(Name("StridedSlice") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("begin") \
+ .HostMemory("end") \
+ .HostMemory("strides") \
+ .TypeConstraint<int32>("Index"), \
+ StridedSliceOp<SYCLDevice, type>) \
+ REGISTER_KERNEL_BUILDER(Name("StridedSliceGrad") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("shape") \
+ .HostMemory("begin") \
+ .HostMemory("end") \
+ .HostMemory("strides") \
+ .TypeConstraint<int32>("Index"), \
+ StridedSliceGradOp<SYCLDevice, type>)\
+ REGISTER_KERNEL_BUILDER(Name("StridedSliceAssign") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("begin") \
+ .HostMemory("end") \
+ .HostMemory("strides") \
+ .TypeConstraint<int32>("Index"), \
+ StridedSliceAssignOp<SYCLDevice, type>)
+
+REGISTER_SYCL(float);
+REGISTER_SYCL(double);
+
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("StridedSlice")
+ .Device(DEVICE_SYCL)
+ .TypeConstraint<int32>("T")
+ .TypeConstraint<int32>("Index")
+ .HostMemory("input")
+ .HostMemory("begin")
+ .HostMemory("end")
+ .HostMemory("strides")
+ .HostMemory("output"),
+ StridedSliceOp<CPUDevice, int32>);
+REGISTER_KERNEL_BUILDER(Name("StridedSliceGrad")
+ .Device(DEVICE_SYCL)
+ .TypeConstraint<int32>("T")
+ .TypeConstraint<int32>("Index")
+ .HostMemory("shape")
+ .HostMemory("begin")
+ .HostMemory("end")
+ .HostMemory("strides")
+ .HostMemory("dy")
+ .HostMemory("output"),
+ StridedSliceGradOp<CPUDevice, int32>);
+REGISTER_KERNEL_BUILDER(Name("StridedSliceAssign")
+ .Device(DEVICE_SYCL)
+ .TypeConstraint<int32>("T")
+ .TypeConstraint<int32>("Index")
+ .HostMemory("ref")
+ .HostMemory("begin")
+ .HostMemory("end")
+ .HostMemory("strides"),
+ StridedSliceAssignOp<CPUDevice, int32>)
+#undef REGISTER_SYCL
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/strided_slice_op_impl.h b/tensorflow/core/kernels/strided_slice_op_impl.h
index e89d1920b9..93cede398a 100644
--- a/tensorflow/core/kernels/strided_slice_op_impl.h
+++ b/tensorflow/core/kernels/strided_slice_op_impl.h
@@ -285,6 +285,20 @@ DECLARE_FOR_N_GPU(int32);
TF_CALL_ALL_TYPES(DECLARE_FOR_N_CPU);
DECLARE_FOR_N_CPU(bfloat16);
+#ifdef TENSORFLOW_USE_SYCL
+#define PREVENT_FOR_N_SYCL(T) \
+ PREVENT_INSTANTIATE(T, STRIDED_SLICE_INSTANTIATE_DIM)
+
+#define DECLARE_FOR_N_SYCL(T) \
+ INSTANTIATE(SYCLDevice, T, STRIDED_SLICE_INSTANTIATE_DIM)
+
+TF_CALL_SYCL_PROXY_TYPES(PREVENT_FOR_N_SYCL);
+TF_CALL_GPU_NUMBER_TYPES(DECLARE_FOR_N_SYCL);
+DECLARE_FOR_N_SYCL(int32);
+
+#undef DECLARE_FOR_N_SYCL
+#endif // TENSORFLOW_USE_SYCL
+
#undef INSTANTIATE
#undef DECLARE_FOR_N_CPU
#undef DECLARE_FOR_N_GPU
diff --git a/tensorflow/core/kernels/tile_ops.cc b/tensorflow/core/kernels/tile_ops.cc
index e55c8679e9..9822b021eb 100644
--- a/tensorflow/core/kernels/tile_ops.cc
+++ b/tensorflow/core/kernels/tile_ops.cc
@@ -260,6 +260,8 @@ TF_CALL_complex128(HANDLE_TYPE_NAME_GPU);
#ifdef TENSORFLOW_USE_SYCL
TF_CALL_float(HANDLE_TYPE_NAME_SYCL);
+TF_CALL_double(HANDLE_TYPE_NAME_SYCL);
+TF_CALL_int32(HANDLE_TYPE_NAME_SYCL);
#endif // TENSORFLOW_USE_SYCL
#undef HANDLE_TYPE_NAME_CPU
@@ -506,6 +508,16 @@ TF_CALL_complex64(HANDLE_TYPE_NAME_GPU);
TF_CALL_complex128(HANDLE_TYPE_NAME_GPU);
#endif // GOOGLE_CUDA
+#if TENSORFLOW_USE_SYCL
+#define HANDLE_TYPE_NAME_SYCL(T) \
+ HANDLE_CASE_DIM(SYCLDevice, T, DataTypeToEnum<T>::value);
+
+TF_CALL_float(HANDLE_TYPE_NAME_SYCL);
+TF_CALL_double(HANDLE_TYPE_NAME_SYCL);
+TF_CALL_int32(HANDLE_TYPE_NAME_SYCL);
+#undef HANDLE_TYPE_NAME_SYCL
+#endif // TENSORFLOW_USE_SYCL
+
#undef HANDLE_TYPE_NAME_CPU
#undef HANDLE_TYPE_NAME_GPU
#undef HANDLE_CASE_DIM
@@ -605,6 +617,25 @@ REGISTER_KERNEL_BUILDER(Name("Tile")
.TypeConstraint<int32>("Tmultiples")
.HostMemory("multiples"),
TileOp<SYCLDevice>);
+REGISTER_KERNEL_BUILDER(Name("Tile")
+ .Device(DEVICE_SYCL)
+ .TypeConstraint<double>("T")
+ .TypeConstraint<int32>("Tmultiples")
+ .HostMemory("multiples"),
+ TileOp<SYCLDevice>);
+
+REGISTER_KERNEL_BUILDER(Name("TileGrad")
+ .Device(DEVICE_SYCL)
+ .TypeConstraint<float>("T")
+ .TypeConstraint<int32>("Tmultiples")
+ .HostMemory("multiples"),
+ TileGradientOp<SYCLDevice>);
+REGISTER_KERNEL_BUILDER(Name("TileGrad")
+ .Device(DEVICE_SYCL)
+ .TypeConstraint<double>("T")
+ .TypeConstraint<int32>("Tmultiples")
+ .HostMemory("multiples"),
+ TileGradientOp<SYCLDevice>);
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/tile_ops_cpu_impl.h b/tensorflow/core/kernels/tile_ops_cpu_impl.h
index 650c739ed5..f06cc5514c 100644
--- a/tensorflow/core/kernels/tile_ops_cpu_impl.h
+++ b/tensorflow/core/kernels/tile_ops_cpu_impl.h
@@ -70,6 +70,8 @@ typedef Eigen::SyclDevice SYCLDevice;
#define DEFINE_TYPE(T) DEFINE_DIM(T, CPU_PROVIDED_IXDIM)
TF_CALL_float(DEFINE_TYPE);
+TF_CALL_double(DEFINE_TYPE);
+TF_CALL_int32(DEFINE_TYPE);
#undef DEFINE_DIM
#undef DEFINE_TYPE
@@ -81,6 +83,8 @@ TF_CALL_float(DEFINE_TYPE);
#define DEFINE_TYPE(T) DEFINE_DIM(T, CPU_PROVIDED_IXDIM)
TF_CALL_float(DEFINE_TYPE);
+TF_CALL_double(DEFINE_TYPE);
+TF_CALL_int32(DEFINE_TYPE);
#undef DEFINE_DIM
#undef DEFINE_TYPE
diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc
index 336c6b0ccc..5c2d371430 100644
--- a/tensorflow/core/kernels/training_ops.cc
+++ b/tensorflow/core/kernels/training_ops.cc
@@ -423,6 +423,7 @@ TF_CALL_double(REGISTER_CPU_KERNELS);
#ifdef TENSORFLOW_USE_SYCL
#define REGISTER_SYCL_KERNELS(T) REGISTER_KERNELS(SYCL, T);
TF_CALL_float(REGISTER_SYCL_KERNELS);
+TF_CALL_double(REGISTER_SYCL_KERNELS);
#undef REGISTER_SYCL_KERNELS
#endif
@@ -2355,6 +2356,7 @@ TF_CALL_double(REGISTER_CPU_KERNELS);
#define REGISTER_SYCL_KERNELS(T) REGISTER_KERNELS(SYCL, T);
TF_CALL_float(REGISTER_SYCL_KERNELS);
+TF_CALL_double(REGISTER_SYCL_KERNELS);
#endif
#if GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/transpose_functor_cpu.cc b/tensorflow/core/kernels/transpose_functor_cpu.cc
index 30b82f1843..3681b9a129 100644
--- a/tensorflow/core/kernels/transpose_functor_cpu.cc
+++ b/tensorflow/core/kernels/transpose_functor_cpu.cc
@@ -127,6 +127,7 @@ Status DoTranspose<SYCLDevice>(const SYCLDevice& d, const Tensor& in,
switch (in.dtype()) {
case DT_FLOAT:
+ case DT_DOUBLE:
case DT_INT32:
internal::Transpose<SYCLDevice, uint32>(d, in, perm, out);
break;
diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc
index 67300c1e96..4d303f0173 100644
--- a/tensorflow/core/kernels/transpose_op.cc
+++ b/tensorflow/core/kernels/transpose_op.cc
@@ -82,6 +82,15 @@ REGISTER_KERNEL_BUILDER(Name("InvertPermutation")
.HostMemory("y"),
InvertPermutationOp);
+#ifdef TENSORFLOW_USE_SYCL
+REGISTER_KERNEL_BUILDER(Name("InvertPermutation")
+ .Device(DEVICE_SYCL)
+ .TypeConstraint<int32>("T")
+ .HostMemory("x")
+ .HostMemory("y"),
+ InvertPermutationOp);
+#endif // TENSORFLOW_USE_SYCL
+
// output = TransposeOp(T<any> input, T<int32> perm) takes a tensor
// of type T and rank N, and a permutation of 0, 1, ..., N-1. It
// shuffles the dimensions of the input tensor according to permutation.
@@ -201,4 +210,24 @@ TF_CALL_POD_TYPES(REGISTER);
#undef REGISTER
#endif
+#ifdef TENSORFLOW_USE_SYCL
+Status TransposeSyclOp::DoTranspose(OpKernelContext* ctx, const Tensor& in,
+ gtl::ArraySlice<int32> perm, Tensor* out) {
+ typedef Eigen::SyclDevice SYCLDevice;
+ return ::tensorflow::DoTranspose(ctx->eigen_device<SYCLDevice>(), in, perm,
+ out);
+}
+#define REGISTER(T) \
+ REGISTER_KERNEL_BUILDER(Name("Transpose") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<int32>("Tperm") \
+ .HostMemory("perm"), \
+ TransposeSyclOp);
+REGISTER(float);
+REGISTER(bool);
+REGISTER(int32);
+#undef REGISTER
+#endif
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/transpose_op.h b/tensorflow/core/kernels/transpose_op.h
index 3b209c0ccc..5f40bcecc1 100644
--- a/tensorflow/core/kernels/transpose_op.h
+++ b/tensorflow/core/kernels/transpose_op.h
@@ -50,6 +50,17 @@ class TransposeGpuOp : public TransposeOp {
gtl::ArraySlice<int32> perm, Tensor* out) override;
};
+#ifdef TENSORFLOW_USE_SYCL
+class TransposeSyclOp : public TransposeOp {
+ public:
+ explicit TransposeSyclOp(OpKernelConstruction* ctx) : TransposeOp(ctx) {}
+
+ protected:
+ Status DoTranspose(OpKernelContext* ctx, const Tensor& in,
+ gtl::ArraySlice<int32> perm, Tensor* out) override;
+};
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace tensorflow
#endif // TENSORFLOW_KERNELS_TRANSPOSE_OP_H_
diff --git a/tensorflow/core/kernels/unpack_op.cc b/tensorflow/core/kernels/unpack_op.cc
index 2a14fa3265..e4c79ae17b 100644
--- a/tensorflow/core/kernels/unpack_op.cc
+++ b/tensorflow/core/kernels/unpack_op.cc
@@ -160,6 +160,7 @@ REGISTER_KERNEL_BUILDER(Name("Unpack")
UnpackOp<SYCLDevice, type>)
REGISTER_SYCL(float);
+REGISTER_SYCL(double);
#undef REGISTER_SYCL
// A special SYCL kernel for int32.
diff --git a/tensorflow/core/kernels/variable_ops.cc b/tensorflow/core/kernels/variable_ops.cc
index 34e227156d..7a4d9dc650 100644
--- a/tensorflow/core/kernels/variable_ops.cc
+++ b/tensorflow/core/kernels/variable_ops.cc
@@ -58,8 +58,9 @@ REGISTER_KERNEL_BUILDER(Name("IsVariableInitialized").Device(DEVICE_CPU),
IsVariableInitializedOp);
REGISTER_SYCL_KERNEL(float);
+REGISTER_SYCL_KERNEL(double);
#undef REGISTER_SYCL_KERNEL
-#endif
+#endif // TENSORFLOW_USE_SYCL
#if GOOGLE_CUDA
// Only register 'Variable' on GPU for the subset of types also supported by
diff --git a/tensorflow/core/kernels/xent_op.cc b/tensorflow/core/kernels/xent_op.cc
index 639bad5f04..56cad8e9eb 100644
--- a/tensorflow/core/kernels/xent_op.cc
+++ b/tensorflow/core/kernels/xent_op.cc
@@ -28,6 +28,9 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+#endif // TENSORFLOW_USE_SYCL
template <typename Device, typename T>
class SoftmaxXentWithLogitsOp : public OpKernel {
@@ -74,17 +77,25 @@ class SoftmaxXentWithLogitsOp : public OpKernel {
// Partial specialization for a CPUDevice, that uses the Eigen implementation
// from XentEigenImpl.
namespace functor {
-template <typename T>
-struct XentFunctor<CPUDevice, T> {
- void operator()(const CPUDevice& d, typename TTypes<T>::ConstMatrix logits,
+template <typename Device, typename T>
+struct XentFunctorBase {
+ void operator()(const Device& d, typename TTypes<T>::ConstMatrix logits,
typename TTypes<T>::ConstMatrix labels,
typename TTypes<T>::Matrix scratch,
typename TTypes<T>::Vec loss,
typename TTypes<T>::Matrix backprop) {
- XentEigenImpl<CPUDevice, T>::Compute(d, logits, labels, scratch, loss,
+ XentEigenImpl<Device, T>::Compute(d, logits, labels, scratch, loss,
backprop);
}
};
+
+template <typename T>
+struct XentFunctor<CPUDevice, T> : XentFunctorBase<CPUDevice, T> {};
+
+#ifdef TENSORFLOW_USE_SYCL
+template <typename T>
+struct XentFunctor<SYCLDevice, T> : XentFunctorBase<SYCLDevice, T> {};
+#endif // TENSORFLOW_USE_SYCL
} // namespace functor
#define REGISTER_CPU(T) \
@@ -111,4 +122,11 @@ REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits")
SoftmaxXentWithLogitsOp<GPUDevice, double>);
#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits")
+ .Device(DEVICE_SYCL)
+ .TypeConstraint<float>("T"),
+ SoftmaxXentWithLogitsOp<SYCLDevice, float>);
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace tensorflow
diff --git a/tensorflow/core/ops/math_grad_test.cc b/tensorflow/core/ops/math_grad_test.cc
index 2def59ff04..8670ca307c 100644
--- a/tensorflow/core/ops/math_grad_test.cc
+++ b/tensorflow/core/ops/math_grad_test.cc
@@ -390,7 +390,7 @@ class TestOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("TestOpWithNoGrad").Device(DEVICE_CPU), TestOp);
#ifdef TENSORFLOW_USE_SYCL
REGISTER_KERNEL_BUILDER(Name("TestOpWithNoGrad").Device(DEVICE_SYCL), TestOp);
-#endif // TENSORFLOW_USE_SYCL
+#endif // TENSORFLOW_USE_SYCL
TEST_F(MathGradTest, Error_Reporting) {
auto x = test::AsTensor<float>({-3.f});
@@ -707,6 +707,8 @@ TEST_F(MathGradTest, Pow) {
}
}
+//TODO{lukeiwanski}: Implement Complex Pow for SYCL
+#ifndef TENSORFLOW_USE_SYCL
TEST_F(MathGradTest, ComplexPow) {
auto x = test::AsTensor<complex64>({0.f, 2.f, -2.f}, TensorShape({3}));
auto y = test::AsTensor<complex64>({2.f, 2.f, 2.f}, TensorShape({3}));
@@ -725,6 +727,7 @@ TEST_F(MathGradTest, ComplexPow) {
dy, test::AsTensor<complex64>({h(0.f, 2.f), h(2.f, 2.f), h(-2.f, 2.f)},
TensorShape({3})));
}
+#endif // TENSORFLOW_USE_SYCL
TEST_F(MathGradTest, Maximum) {
auto x = test::AsTensor<float>({-3.f, -2.f, -1.f, 1.f, 2.f, 3.f},
@@ -886,6 +889,8 @@ TEST_F(MathGradTest, MatMul_11) {
test::ExpectClose(dy, MatMul(dz, true, x, true));
}
+//TODO{lukeiwanski}: Implement BatchMatMul for SYCL
+#ifndef TENSORFLOW_USE_SYCL
TEST_F(MathGradTest, BatchMatMul_00) {
auto x = test::AsTensor<float>({1.f, 2.f, 3.f, 4.f, 5.f, 6.f},
TensorShape({1, 2, 3}));
@@ -933,6 +938,7 @@ TEST_F(MathGradTest, BatchMatMul_11) {
test::ExpectClose(dx, BatchMatMul(y, true, dz, true));
test::ExpectClose(dy, BatchMatMul(dz, true, x, true));
}
+#endif // TENSORFLOW_USE_SYCL
TEST_F(MathGradTest, Sum_dim0) {
auto x = test::AsTensor<float>({-3.f, -2.f, -1.f, 1.f, 2.f, 3.f},
diff --git a/tensorflow/g3doc/get_started/os_setup.md b/tensorflow/g3doc/get_started/os_setup.md
index 709dfd53cf..e9315c0750 100644
--- a/tensorflow/g3doc/get_started/os_setup.md
+++ b/tensorflow/g3doc/get_started/os_setup.md
@@ -706,7 +706,7 @@ for production (though it will mature over time).
##### Download and install OpenCL drivers
The exact steps required for a functional OpenCL installation will depend on
-your environment. For Unbuntu 14.04, the following steps are known to work:
+your environment. For Ubuntu 14.04, the following steps are known to work:
```bash
sudo apt-get install ocl-icd-opencl-dev opencl-headers
@@ -727,12 +727,17 @@ and copy the files into e.g. `/usr/local/computecpp`:
```bash
tar -xvzf ComputeCpp-CE-0.1.1-Ubuntu.14.04-64bit.tar.gz
-sudo mkdir /usr/local/computecpp
sudo cp -R ComputeCpp-CE-0.1.1-Linux /usr/local/computecpp
sudo chmod -R a+r /usr/local/computecpp/
sudo chmod -R a+x /usr/local/computecpp/bin
```
+Add the lib folder to your `LD_LIBRARY_PATH` to make Python find `libComputeCpp.so` by adding the following line to your `~/.bash_profile`:
+
+```bash
+export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/computecpp/lib"
+```
+
### Prepare environment for Mac OS X
We recommend using [homebrew](http://brew.sh) to install the bazel dependency,
@@ -1306,6 +1311,12 @@ When importing tensorflow, you may see an "ImportError" raised. Below
are some possible examples and solutions:
```
+ImportError: cannot import name 'pywrap_tensorflow'
+```
+
+This can occur if you try to import tensorflow while your current working directory is in the same directory as TensorFlow is located. If that is the case, change the working directory (i.e. `cd` in bash or `os.chdir` in python) to some folder outside of the TensorFlow directory and try importing tensorflow again.
+
+```
ImportError: /lib64/libc.so.6: version `GLIBC_2.16' not found (required by ..._pywrap_tensorflow.so)
```
@@ -1323,3 +1334,8 @@ directory, and so tries to import directly from the source code instead of
your installed tensorflow package. Solution: don't import tensorflow
from the tensorflow source code root directory, if you are.
+```
+ImportError: libComputeCpp.so: cannot open shared object file: No such file or directory
+```
+
+Make sure you have added the path to ComputeCpp's `lib` folder to your `LD_LIBRARY_PATH` (as mentioned above).
diff --git a/tensorflow/python/client/device_lib_test.py b/tensorflow/python/client/device_lib_test.py
index 561ce09099..7bba10efac 100644
--- a/tensorflow/python/client/device_lib_test.py
+++ b/tensorflow/python/client/device_lib_test.py
@@ -34,7 +34,7 @@ class DeviceLibTest(test_util.TensorFlowTestCase):
# GPU test
if test.is_gpu_available():
self.assertGreater(len(devices), 1)
- self.assertTrue("GPU" in [d.device_type for d in devices])
+ self.assertTrue("GPU" in [d.device_type for d in devices] or "SYCL" in [d.device_type for d in devices])
if __name__ == "__main__":
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index f2fd687adf..3ea7e547ee 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -44,7 +44,14 @@ from tensorflow.python.platform import googletest
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
from tensorflow.python.util.protobuf import compare
+from tensorflow.python.client import device_lib
+def gpu_device_name():
+ """Returns the name of a GPU device if available or the empty string."""
+ for x in device_lib.list_local_devices():
+ if x.device_type == 'GPU' or x.device_type == 'SYCL':
+ return x.name
+ return ''
def assert_ops_in_graph(expected_ops, graph):
"""Assert all expected operations are found.
@@ -301,7 +308,12 @@ class TensorFlowTestCase(googletest.TestCase):
sess = self._cached_session
with sess.graph.as_default(), sess.as_default():
if force_gpu:
- with sess.graph.device("/gpu:0"):
+ # Use the name of an actual device if one is detected, or '/gpu:0'
+ # otherwise
+ gpu_name = gpu_device_name()
+ if len(gpu_name) == 0:
+ gpu_name = '/gpu:0'
+ with sess.graph.device(gpu_name):
yield sess
elif use_gpu:
yield sess
@@ -311,7 +323,12 @@ class TensorFlowTestCase(googletest.TestCase):
else:
with session.Session(graph=graph, config=prepare_config(config)) as sess:
if force_gpu:
- with sess.graph.device("/gpu:0"):
+ # Use the name of an actual device if one is detected, or '/gpu:0'
+ # otherwise
+ gpu_name = gpu_device_name()
+ if len(gpu_name) == 0:
+ gpu_name = '/gpu:0'
+ with sess.graph.device(gpu_name):
yield sess
elif use_gpu:
yield sess
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 1e510b2868..04c2c0c9e7 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -1342,7 +1342,7 @@ class ControlFlowTest(test.TestCase):
def _testWhileGrad_ColocateGradients(self, colocate):
gpu_dev_name = test.gpu_device_name() if test.is_gpu_available() else "/gpu:0"
- gpu_short_name = gpu_dev_name.split('/')[-1]
+ gpu_short_name = gpu_dev_name.split('/')[-1].lower()
with self.test_session(graph=ops.Graph()) as sess:
v = constant_op.constant(2.0, name="v")
diff --git a/tensorflow/python/kernel_tests/stage_op_test.py b/tensorflow/python/kernel_tests/stage_op_test.py
index c2d760809b..e5f449fcf2 100644
--- a/tensorflow/python/kernel_tests/stage_op_test.py
+++ b/tensorflow/python/kernel_tests/stage_op_test.py
@@ -31,7 +31,7 @@ class StageTest(test.TestCase):
with ops.device('/cpu:0'):
x = array_ops.placeholder(dtypes.float32)
v = 2. * (array_ops.zeros([1024, 1024]) + x)
- with ops.device('/gpu:0'):
+ with ops.device(test.gpu_device_name()):
stager = data_flow_ops.StagingArea([dtypes.float32])
stage = stager.put([v])
y = stager.get()
@@ -46,7 +46,7 @@ class StageTest(test.TestCase):
with ops.device('/cpu:0'):
x = array_ops.placeholder(dtypes.float32)
v = 2. * (array_ops.zeros([128, 128]) + x)
- with ops.device('/gpu:0'):
+ with ops.device(test.gpu_device_name()):
stager = data_flow_ops.StagingArea([dtypes.float32, dtypes.float32])
stage = stager.put([x, v])
z, y = stager.get()
@@ -62,7 +62,7 @@ class StageTest(test.TestCase):
with ops.device('/cpu:0'):
x = array_ops.placeholder(dtypes.float32)
v = 2. * (array_ops.zeros([128, 128]) + x)
- with ops.device('/gpu:0'):
+ with ops.device(test.gpu_device_name()):
stager = data_flow_ops.StagingArea(
[dtypes.float32, dtypes.float32],
shapes=[[], [128, 128]],
diff --git a/tensorflow/python/platform/test.py b/tensorflow/python/platform/test.py
index 18cf8f7b99..501f0c8b35 100644
--- a/tensorflow/python/platform/test.py
+++ b/tensorflow/python/platform/test.py
@@ -40,6 +40,7 @@ from tensorflow.python.util.all_util import remove_undocumented
# pylint: disable=unused-import
from tensorflow.python.framework.test_util import assert_equal_graph_def
from tensorflow.python.framework.test_util import TensorFlowTestCase as TestCase
+from tensorflow.python.framework.test_util import gpu_device_name
from tensorflow.python.ops.gradient_checker import compute_gradient_error
from tensorflow.python.ops.gradient_checker import compute_gradient
@@ -105,15 +106,6 @@ def is_gpu_available(cuda_only=False):
return any((x.device_type == 'GPU' or x.device_type == 'SYCL')
for x in _device_lib.list_local_devices())
-
-def gpu_device_name():
- """Returns the name of a GPU device if available or the empty string."""
- for x in _device_lib.list_local_devices():
- if x.device_type == 'GPU' or x.device_type == 'SYCL':
- return x.name
- return ''
-
-
_allowed_symbols = [
# We piggy-back googletest documentation.
'Benchmark',
diff --git a/third_party/sycl/crosstool/computecpp.tpl b/third_party/sycl/crosstool/computecpp.tpl
index a5e6b9fe93..66dd9aea7b 100755
--- a/third_party/sycl/crosstool/computecpp.tpl
+++ b/third_party/sycl/crosstool/computecpp.tpl
@@ -26,9 +26,7 @@ def main():
if(output_file_index == 1):
# we are linking
- return subprocess.call([CPU_CXX_COMPILER] + compiler_flags)
-
- compiler_flags = compiler_flags + ['-D_GLIBCXX_USE_CXX11_ABI=0', '-DEIGEN_USE_SYCL=1']
+ return subprocess.call([CPU_CXX_COMPILER] + compiler_flags + ['-Wl,--no-undefined'])
# find what we compile
compiling_cpp = 0
@@ -38,6 +36,28 @@ def main():
if(compited_file_name.endswith(('.cc', '.c++', '.cpp', '.CPP', '.C', '.cxx'))):
compiling_cpp = 1;
+ compiler_flags = compiler_flags + ['-D_GLIBCXX_USE_CXX11_ABI=0', '-DEIGEN_USE_SYCL=1', '-DTENSORFLOW_USE_SYCL', '-DEIGEN_HAS_C99_MATH']
+
+ if(compiling_cpp == 1):
+ # create a blacklist of folders that will be skipped when compiling with ComputeCpp
+ _skip = ["external", "llvm", ".cu.cc"]
+ # if compiling external project skip computecpp
+ if any(_folder in _skip for _folder in output_file_name):
+ return subprocess.call([CPU_CXX_COMPILER] + compiler_flags)
+
+ if(compiling_cpp == 1):
+ # this is an optimisation that will check if compiled file has to be compiled with ComputeCpp
+
+ _tmp_flags = [flag for flag in compiler_flags if not flag.startswith(('-o', output_file_name))]
+ # create preprocessed of the file
+ _cmd = " ".join([CPU_CXX_COMPILER] + _tmp_flags + ["-E"])
+ # check if it has parallel_for< in it
+ _cmd += " | grep \".parallel_for\" > /dev/null"
+ ps = subprocess.call(_cmd, shell=True)
+ # if not call CXX compiler
+ if(ps != 0):
+ return subprocess.call([CPU_CXX_COMPILER] + compiler_flags)
+
if(compiling_cpp == 1):
filename, file_extension = os.path.splitext(output_file_name)
bc_out = filename + '.sycl'
@@ -52,9 +72,12 @@ def main():
# dont want that in case of compiling with computecpp first
host_compiler_flags = [flag for flag in compiler_flags
if not flag.startswith(('-MF', '-MD',))
- if not '.d' in flag]
+ if not '.d' in flag
+ ]
+
+ host_compiler_flags[host_compiler_flags.index('-c')] = "--include"
- host_compiler_flags = ['-D_GLIBCXX_USE_CXX11_ABI=0', '-DTENSORFLOW_USE_SYCL', '-Wno-unused-variable', '-I', COMPUTECPP_INCLUDE, '--include', bc_out] + host_compiler_flags
+ host_compiler_flags = ['-xc++', '-D_GLIBCXX_USE_CXX11_ABI=0', '-DTENSORFLOW_USE_SYCL', '-Wno-unused-variable', '-I', COMPUTECPP_INCLUDE, '-c', bc_out] + host_compiler_flags
x = subprocess.call([CPU_CXX_COMPILER] + host_compiler_flags)
return x
else:
diff --git a/tools/bazel.rc.template b/tools/bazel.rc.template
index 48c9f5aa3f..3622b9423c 100644
--- a/tools/bazel.rc.template
+++ b/tools/bazel.rc.template
@@ -7,7 +7,7 @@ build:sycl --crosstool_top=@local_config_sycl//crosstool:toolchain
build:sycl --define=using_sycl=true
build:sycl_asan --crosstool_top=@local_config_sycl//crosstool:toolchain
-build:sycl_asan --define=using_sycl=true --copt -fno-omit-frame-pointer --copt -fsanitize-coverage=3 --copt -fsanitize=address --copt -DGPR_NO_DIRECT_SYSCALLS --linkopt -fPIC --linkopt -lasan
+build:sycl_asan --define=using_sycl=true --copt -fno-omit-frame-pointer --copt -fsanitize-coverage=3 --copt -DGPR_NO_DIRECT_SYSCALLS --linkopt -fPIC --linkopt -fsanitize=address
build --force_python=py$PYTHON_MAJOR_VERSION
build --host_force_python=py$PYTHON_MAJOR_VERSION