diff options
Diffstat (limited to 'tensorflow/core')
32 files changed, 850 insertions, 73 deletions
diff --git a/tensorflow/core/common_runtime/sycl/sycl_device.cc b/tensorflow/core/common_runtime/sycl/sycl_device.cc index 19d39056ff..0abe25c373 100644 --- a/tensorflow/core/common_runtime/sycl/sycl_device.cc +++ b/tensorflow/core/common_runtime/sycl/sycl_device.cc @@ -23,7 +23,8 @@ limitations under the License. namespace tensorflow { -static std::unordered_set<SYCLDevice *> live_devices; +static std::unordered_set<SYCLDevice*> live_devices; +static bool first_time = true; void ShutdownSycl() { for (auto device : live_devices) { @@ -31,7 +32,6 @@ void ShutdownSycl() { } live_devices.clear(); } -bool first_time = true; void SYCLDevice::RegisterDevice() { if (first_time) { @@ -44,17 +44,27 @@ void SYCLDevice::RegisterDevice() { SYCLDevice::~SYCLDevice() { device_context_->Unref(); sycl_allocator_->EnterLameDuckMode(); - delete sycl_device_; - delete sycl_queue_; + if (sycl_device_) { + sycl_device_->synchronize(); + delete sycl_device_; + } + if (sycl_queue_) { + delete sycl_queue_; + } live_devices.erase(this); } void SYCLDevice::EnterLameDuckMode() { sycl_allocator_->EnterLameDuckMode(); - delete sycl_device_; - sycl_device_ = nullptr; - delete sycl_queue_; - sycl_queue_ = nullptr; + if (sycl_device_) { + sycl_device_->synchronize(); + delete sycl_device_; + sycl_device_ = nullptr; + } + if (sycl_queue_) { + delete sycl_queue_; + sycl_queue_ = nullptr; + } } void SYCLDevice::Compute(OpKernel *op_kernel, OpKernelContext *context) { @@ -110,7 +120,11 @@ Status SYCLDevice::FillContextMap(const Graph *graph, Status SYCLDevice::Sync() { sycl_device_->synchronize(); - return Status::OK(); + if (sycl_device_->ok()) { + return Status::OK(); + } else { + return errors::Internal("Unknown error detected on device ", name()); + } } } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/sycl/sycl_device_context.cc b/tensorflow/core/common_runtime/sycl/sycl_device_context.cc index b49420b1b5..a6be9195d4 100644 --- a/tensorflow/core/common_runtime/sycl/sycl_device_context.cc +++ b/tensorflow/core/common_runtime/sycl/sycl_device_context.cc @@ -95,6 +95,7 @@ void SYCLDeviceContext::CopyCPUTensorToDevice(const Tensor *cpu_tensor, assert(false && "unsupported type"); } } + device->eigen_sycl_device()->synchronize(); done(Status::OK()); } @@ -172,6 +173,7 @@ void SYCLDeviceContext::CopyDeviceTensorToCPU(const Tensor *device_tensor, assert(false && "unsupported type"); } } + device->eigen_sycl_device()->synchronize(); done(Status::OK()); } diff --git a/tensorflow/core/kernels/aggregate_ops.cc b/tensorflow/core/kernels/aggregate_ops.cc index b41e438b2b..50d0cc1727 100644 --- a/tensorflow/core/kernels/aggregate_ops.cc +++ b/tensorflow/core/kernels/aggregate_ops.cc @@ -28,6 +28,9 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL template <typename Device, typename T> class AddNOp : public OpKernel { @@ -152,6 +155,21 @@ REGISTER_KERNEL_BUILDER(Name("AddN") AddNOp<CPUDevice, int32>); #endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL +REGISTER_ADDN(float, SYCL); +REGISTER_ADDN(double, SYCL); + +// A special GPU kernel for int32. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +REGISTER_KERNEL_BUILDER(Name("AddN") + .Device(DEVICE_SYCL) + .TypeConstraint<int32>("T") + .HostMemory("inputs") + .HostMemory("sum"), + AddNOp<CPUDevice, int32>); +#endif // TENSORFLOW_USE_SYCL + #undef REGISTER_ADDN } // namespace tensorflow diff --git a/tensorflow/core/kernels/aggregate_ops_cpu.h b/tensorflow/core/kernels/aggregate_ops_cpu.h index ba5ebb7f0f..dfa3fe585e 100644 --- a/tensorflow/core/kernels/aggregate_ops_cpu.h +++ b/tensorflow/core/kernels/aggregate_ops_cpu.h @@ -23,6 +23,10 @@ limitations under the License. typedef Eigen::ThreadPoolDevice CPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL + namespace tensorflow { // Partial specializations for a CPUDevice, that uses the Eigen implementation @@ -133,6 +137,115 @@ struct Add9Functor<CPUDevice, T> { } }; +#ifdef TENSORFLOW_USE_SYCL +// Partial specializations for a SYCLDevice, that uses the Eigen implementation +// from AddNEigenImpl. +template <typename T> +struct Add2Functor<SYCLDevice, T> { + void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out, + typename TTypes<T>::ConstFlat in1, + typename TTypes<T>::ConstFlat in2) { + Add2EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2); + } +}; +template <typename T> +struct Add3Functor<SYCLDevice, T> { + void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out, + typename TTypes<T>::ConstFlat in1, + typename TTypes<T>::ConstFlat in2, + typename TTypes<T>::ConstFlat in3) { + Add3EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3); + } +}; +template <typename T> +struct Add4Functor<SYCLDevice, T> { + void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out, + typename TTypes<T>::ConstFlat in1, + typename TTypes<T>::ConstFlat in2, + typename TTypes<T>::ConstFlat in3, + typename TTypes<T>::ConstFlat in4) { + Add4EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4); + } +}; +template <typename T> +struct Add5Functor<SYCLDevice, T> { + void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out, + typename TTypes<T>::ConstFlat in1, + typename TTypes<T>::ConstFlat in2, + typename TTypes<T>::ConstFlat in3, + typename TTypes<T>::ConstFlat in4, + typename TTypes<T>::ConstFlat in5) { + Add5EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5); + } +}; +template <typename T> +struct Add6Functor<SYCLDevice, T> { + void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out, + typename TTypes<T>::ConstFlat in1, + typename TTypes<T>::ConstFlat in2, + typename TTypes<T>::ConstFlat in3, + typename TTypes<T>::ConstFlat in4, + typename TTypes<T>::ConstFlat in5, + typename TTypes<T>::ConstFlat in6) { + Add6EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6); + } +}; +template <typename T> +struct Add7Functor<SYCLDevice, T> { + void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out, + typename TTypes<T>::ConstFlat in1, + typename TTypes<T>::ConstFlat in2, + typename TTypes<T>::ConstFlat in3, + typename TTypes<T>::ConstFlat in4, + typename TTypes<T>::ConstFlat in5, + typename TTypes<T>::ConstFlat in6, + typename TTypes<T>::ConstFlat in7) { + Add7EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6, + in7); + } +}; + +template <typename T> +struct Add8Functor<SYCLDevice, T> { + void operator()( + const SYCLDevice& d, typename TTypes<T>::Flat out, + typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2, + typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4, + typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6, + typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) { + Add8EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6, + in7, in8); + } +}; + +template <typename T> +struct Add8pFunctor<SYCLDevice, T> { + void operator()( + const SYCLDevice& d, typename TTypes<T>::Flat out, + typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2, + typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4, + typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6, + typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) { + Add8pEigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6, + in7, in8); + } +}; + +template <typename T> +struct Add9Functor<SYCLDevice, T> { + void operator()( + const SYCLDevice& d, typename TTypes<T>::Flat out, + typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2, + typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4, + typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6, + typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8, + typename TTypes<T>::ConstFlat in9) { + Add9EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6, + in7, in8, in9); + } +}; +#endif // TENSORFLOW_USE_SYCL + } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc index e479b97109..3f8717f77f 100644 --- a/tensorflow/core/kernels/constant_op.cc +++ b/tensorflow/core/kernels/constant_op.cc @@ -305,4 +305,9 @@ REGISTER_KERNEL_BUILDER(Name("Placeholder").Device(DEVICE_GPU), PlaceholderOp); REGISTER_KERNEL_BUILDER(Name("PlaceholderV2").Device(DEVICE_GPU), PlaceholderOp); +#if TENSORFLOW_USE_SYCL +REGISTER_KERNEL_BUILDER(Name("Placeholder").Device(DEVICE_SYCL), PlaceholderOp); +REGISTER_KERNEL_BUILDER(Name("PlaceholderV2").Device(DEVICE_SYCL), + PlaceholderOp); +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/control_flow_ops.cc b/tensorflow/core/kernels/control_flow_ops.cc index b01263f288..5241a4d916 100644 --- a/tensorflow/core/kernels/control_flow_ops.cc +++ b/tensorflow/core/kernels/control_flow_ops.cc @@ -121,8 +121,20 @@ REGISTER_GPU_HOST_REF_KERNEL(string); SwitchOp) REGISTER_SYCL_KERNEL(bool); TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL); + +#define REGISTER_SYCL_REF_SWITCH(type) \ + REGISTER_KERNEL_BUILDER(Name("RefSwitch") \ + .Device(DEVICE_SYCL) \ + .HostMemory("pred") \ + .TypeConstraint<type>("T"), \ + SwitchOp) +REGISTER_SYCL_REF_SWITCH(bool); +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_SWITCH); + #undef REGISTER_SYCL_KERNEL -#endif +#undef REGISTER_SYCL_REF_SWITCH + +#endif // TENSORFLOW_USE_SYCL class RefSelectOp : public OpKernel { public: @@ -230,8 +242,18 @@ REGISTER_GPU_REF_KERNEL(bool); MergeOp) REGISTER_SYCL_KERNEL(bool); TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL); + +#define REGISTER_SYCL_REF_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("RefMerge") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint<type>("T") \ + .HostMemory("value_index"), \ + MergeOp) +REGISTER_SYCL_REF_KERNEL(bool); +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_KERNEL); #undef REGISTER_SYCL_KERNEL -#endif +#undef REGISTER_SYCL_REF_KERNEL +#endif // TENSORFLOW_USE_SYCL // Special GPU kernels for int32 and string. // TODO(b/25387198): Also enable int32 in device memory. This kernel @@ -289,7 +311,15 @@ REGISTER_GPU_REF_KERNEL(bool); Name("Enter").Device(DEVICE_SYCL).TypeConstraint<type>("T"), EnterOp) REGISTER_SYCL_KERNEL(bool); TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL); + +#define REGISTER_SYCL_REF_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("RefEnter").Device(DEVICE_SYCL).TypeConstraint<type>("T"), EnterOp) +REGISTER_SYCL_REF_KERNEL(bool); +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_KERNEL); + #undef REGISTER_SYCL_KERNEL +#undef REGISTER_SYCL_REF_KERNEL #endif // Special GPU kernels for int32 and string. @@ -349,8 +379,37 @@ REGISTER_GPU_KERNEL(bool); Name("Exit").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ExitOp) REGISTER_SYCL_KERNEL(bool); TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL); + +#define REGISTER_SYCL_REF_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("RefExit").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ExitOp) +REGISTER_SYCL_REF_KERNEL(bool); +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_KERNEL); + #undef REGISTER_SYCL_KERNEL -#endif +#undef REGISTER_SYCL_REF_KERNEL + +// Special GPU kernels for int32 and string. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +#define REGISTER_SYCL_HOST_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("Exit") \ + .Device(DEVICE_SYCL) \ + .HostMemory("data") \ + .HostMemory("output") \ + .TypeConstraint<type>("T"), \ + ExitOp); \ + REGISTER_KERNEL_BUILDER(Name("RefExit") \ + .Device(DEVICE_SYCL) \ + .HostMemory("data") \ + .HostMemory("output") \ + .TypeConstraint<type>("T"), \ + ExitOp) + +REGISTER_SYCL_HOST_KERNEL(int32); +REGISTER_SYCL_HOST_KERNEL(string); +#undef REGISTER_SYCL_HOST_KERNEL +#endif // TENSORFLOW_USE_SYCL // Special GPU kernels for int32 and string. // TODO(b/25387198): Also enable int32 in device memory. This kernel @@ -432,8 +491,39 @@ REGISTER_GPU_HOST_KERNEL(string); NextIterationOp) REGISTER_SYCL_KERNEL(bool); TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL); +#define REGISTER_SYCL_REF_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("RefNextIteration") \ + .Device(DEVICE_SYCL) \ + .HostMemory("data") \ + .HostMemory("output") \ + .TypeConstraint<type>("T"), \ + NextIterationOp) + REGISTER_SYCL_REF_KERNEL(bool); + TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_KERNEL); #undef REGISTER_SYCL_KERNEL -#endif +#undef REGISTER_SYCL_REF_KERNEL + +// Special GPU kernels for int32 and string. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +#define REGISTER_SYCL_HOST_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("NextIteration") \ + .Device(DEVICE_SYCL) \ + .HostMemory("data") \ + .HostMemory("output") \ + .TypeConstraint<type>("T"), \ + NextIterationOp); \ + REGISTER_KERNEL_BUILDER(Name("RefNextIteration") \ + .Device(DEVICE_SYCL) \ + .HostMemory("data") \ + .HostMemory("output") \ + .TypeConstraint<type>("T"), \ + NextIterationOp) + +REGISTER_SYCL_HOST_KERNEL(int32); +REGISTER_SYCL_HOST_KERNEL(string); +#undef REGISTER_SYCL_HOST_KERNEL +#endif // TENSORFLOW_USE_SYCL // A LoopCond op has one input and one output. The input is a boolean // scalar representing the taken branches of the "pivot" Switch that @@ -461,6 +551,14 @@ REGISTER_KERNEL_BUILDER(Name("LoopCond") .HostMemory("output"), LoopCondOp); +#ifdef TENSORFLOW_USE_SYCL +REGISTER_KERNEL_BUILDER(Name("LoopCond") + .Device(DEVICE_SYCL) + .HostMemory("input") + .HostMemory("output"), + LoopCondOp); +#endif // TENSORFLOW_USE_SYCL + // ControlTrigger kernels REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_CPU), ControlTriggerOp); @@ -468,6 +566,11 @@ REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_GPU), ControlTriggerOp); +#ifdef TENSORFLOW_USE_SYCL +REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_SYCL), + ControlTriggerOp); +#endif // TENSORFLOW_USE_SYCL + // When called, abort op will abort the current process. This can be used to // abort remote PSs when needed. class AbortOp : public OpKernel { @@ -493,4 +596,5 @@ class AbortOp : public OpKernel { }; REGISTER_KERNEL_BUILDER(Name("Abort").Device(DEVICE_CPU), AbortOp); + } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_exp.cc b/tensorflow/core/kernels/cwise_op_exp.cc index 0ee47f7dee..2e3a60cf79 100644 --- a/tensorflow/core/kernels/cwise_op_exp.cc +++ b/tensorflow/core/kernels/cwise_op_exp.cc @@ -27,6 +27,7 @@ REGISTER5(UnaryOp, CPU, "Exp", functor::exp, float, Eigen::half, double, .TypeConstraint<TYPE>("T"), \ UnaryOp<SYCLDevice, functor::exp<TYPE>>); REGISTER_SYCL_KERNEL(float); +REGISTER_SYCL_KERNEL(double); #undef REGISTER_SYCL_KERNEL #endif // TENSORFLOW_USE_SYCL diff --git a/tensorflow/core/kernels/cwise_op_floor_div.cc b/tensorflow/core/kernels/cwise_op_floor_div.cc index 69dbb70b83..8a600f8f95 100644 --- a/tensorflow/core/kernels/cwise_op_floor_div.cc +++ b/tensorflow/core/kernels/cwise_op_floor_div.cc @@ -27,7 +27,7 @@ REGISTER3(BinaryOp, CPU, "FloorDiv", functor::floor_div_real, float, Name("FloorDiv") \ .Device(DEVICE_SYCL) \ .TypeConstraint<TYPE>("T"), \ - BinaryOp<SYCLDevice, functor::floor_div<TYPE>>); + BinaryOp<SYCLDevice, functor::floor_div_real<TYPE>>); REGISTER_SYCL_KERNEL(float) #undef REGISTER_SYCL_KERNEL #endif // TENSORFLOW_USE_SYCL diff --git a/tensorflow/core/kernels/cwise_op_neg.cc b/tensorflow/core/kernels/cwise_op_neg.cc index 4221fc0710..c4a9b22883 100644 --- a/tensorflow/core/kernels/cwise_op_neg.cc +++ b/tensorflow/core/kernels/cwise_op_neg.cc @@ -27,6 +27,18 @@ REGISTER7(UnaryOp, CPU, "Neg", functor::neg, float, Eigen::half, double, int32, .TypeConstraint<TYPE>("T"), \ UnaryOp<SYCLDevice, functor::neg<TYPE>>); REGISTER_SYCL_KERNEL(float); +REGISTER_SYCL_KERNEL(double); + +// A special GPU kernel for int32. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +REGISTER_KERNEL_BUILDER(Name("Neg") + .Device(DEVICE_SYCL) + .HostMemory("x") + .HostMemory("y") + .TypeConstraint<int32>("T"), + UnaryOp<CPUDevice, functor::neg<int32>>); + #undef REGISTER_SYCL_KERNEL #endif // TENSORFLOW_USE_SYCL diff --git a/tensorflow/core/kernels/cwise_op_sub.cc b/tensorflow/core/kernels/cwise_op_sub.cc index e1326dbed1..eab1e2a09c 100644 --- a/tensorflow/core/kernels/cwise_op_sub.cc +++ b/tensorflow/core/kernels/cwise_op_sub.cc @@ -32,6 +32,18 @@ REGISTER(BinaryOp, CPU, "Sub", functor::sub, int32); .TypeConstraint<TYPE>("T"), \ BinaryOp<SYCLDevice, functor::sub<TYPE>>); REGISTER_SYCL_KERNEL(float); + REGISTER_SYCL_KERNEL(double); + +// A special GPU kernel for int32. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +REGISTER_KERNEL_BUILDER(Name("Sub") + .Device(DEVICE_SYCL) + .HostMemory("x") + .HostMemory("y") + .HostMemory("z") + .TypeConstraint<int32>("T"), + BinaryOp<CPUDevice, functor::sub<int32>>); #undef REGISTER_SYCL_KERNEL #endif // TENSORFLOW_USE_SYCL #if GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_ops_sycl_common.h b/tensorflow/core/kernels/cwise_ops_sycl_common.h index a0decbce87..3f6ff7303d 100644 --- a/tensorflow/core/kernels/cwise_ops_sycl_common.h +++ b/tensorflow/core/kernels/cwise_ops_sycl_common.h @@ -31,14 +31,6 @@ namespace functor { typedef Eigen::SyclDevice SYCLDevice; -template <typename Index, int N> Eigen::array<Index, N> GenerateArrayOfOnes() { - Eigen::array<Index, N> result; - for (int i = 0; i < N; ++i) { - result[i] = 1; - } - return result; -} - template <typename OUT, typename RHS> void Assign(const SYCLDevice& d, OUT out, RHS rhs) { out.device(d) = rhs; @@ -67,11 +59,9 @@ struct BinaryFunctor<SYCLDevice, Functor, NDIMS, has_errors> { typename Functor::tin_type in, bool* error) { typedef typename Functor::func Binary; constexpr int NumDims = Functor::tin_type::NumDimensions; - typedef typename Functor::tin_type::Scalar T; - typedef typename Functor::tin_type::Index Index; - Eigen::array<Index, NumDims> scalar_dim = GenerateArrayOfOnes<Index, NumDims>(); - Eigen::TensorMap<Eigen::Tensor<T, NumDims, Eigen::RowMajor>> tmp(scalar.data(), scalar_dim); - out.device(d) = tmp.broadcast(in.dimensions()).binaryExpr(in, Binary()); + static_assert(NumDims == 1, "Unexpected size"); + Eigen::Sizes<1> scalar_dim; + out.device(d) = scalar.reshape(scalar_dim).broadcast(in.dimensions()).binaryExpr(in, Binary()); } void Right(const SYCLDevice& d, typename Functor::tout_type out, @@ -79,11 +69,9 @@ struct BinaryFunctor<SYCLDevice, Functor, NDIMS, has_errors> { typename Functor::tscalar_type scalar, bool* error) { typedef typename Functor::func Binary; constexpr int NumDims = Functor::tin_type::NumDimensions; - typedef typename Functor::tin_type::Scalar T; - typedef typename Functor::tin_type::Index Index; - Eigen::array<Index, NumDims> scalar_dim = GenerateArrayOfOnes<Index, NumDims>(); - Eigen::TensorMap<Eigen::Tensor<T, NumDims, Eigen::RowMajor>> tmp(scalar.data(), scalar_dim); - out.device(d) = in.binaryExpr(tmp.broadcast(in.dimensions()), Binary()); + static_assert(NumDims == 1, "Unexpected size"); + Eigen::Sizes<1> scalar_dim; + out.device(d) = in.binaryExpr(scalar.reshape(scalar_dim).broadcast(in.dimensions()), Binary()); } void BCast(const SYCLDevice& d, diff --git a/tensorflow/core/kernels/debug_ops.h b/tensorflow/core/kernels/debug_ops.h index 848896a64a..aa47315f55 100644 --- a/tensorflow/core/kernels/debug_ops.h +++ b/tensorflow/core/kernels/debug_ops.h @@ -188,7 +188,6 @@ class DebugNumericSummaryOp : public OpKernel { const T* input_flat = input.template flat<T>().data(); element_count = input_shape.num_elements(); - const double element_count_double = static_cast<double>(element_count); for (int64 i = 0; i < element_count; ++i) { T x = input_flat[i]; if (Eigen::numext::isnan(x)) { diff --git a/tensorflow/core/kernels/pack_op.cc b/tensorflow/core/kernels/pack_op.cc index e072eb36b3..4977ad1d7c 100644 --- a/tensorflow/core/kernels/pack_op.cc +++ b/tensorflow/core/kernels/pack_op.cc @@ -34,6 +34,9 @@ typedef Eigen::ThreadPoolDevice CPUDevice; #if GOOGLE_CUDA typedef Eigen::GpuDevice GPUDevice; #endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL // -------------------------------------------------------------------------- template <typename Device, typename T> @@ -156,4 +159,26 @@ REGISTER_KERNEL_BUILDER(Name("Pack") #endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL + +#define REGISTER_SYCL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Pack").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ + PackOp<SYCLDevice, type>) + +REGISTER_SYCL(float); +#undef REGISTER_SYCL + +// A special GPU kernel for int32. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +REGISTER_KERNEL_BUILDER(Name("Pack") + .Device(DEVICE_SYCL) + .HostMemory("values") + .HostMemory("output") + .TypeConstraint<int32>("T"), + PackOp<CPUDevice, int32>); + +#endif // TENSORFLOW_USE_SYCL + } // namespace tensorflow diff --git a/tensorflow/core/kernels/reduction_ops_common.h b/tensorflow/core/kernels/reduction_ops_common.h index 1bb1a9fc50..625cea4228 100644 --- a/tensorflow/core/kernels/reduction_ops_common.h +++ b/tensorflow/core/kernels/reduction_ops_common.h @@ -40,6 +40,9 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL template <typename Device> struct Constants { @@ -60,13 +63,16 @@ struct Constants { }; #if defined(EIGEN_HAS_INDEX_LIST) -template <> -struct Constants<CPUDevice> { +struct ConstantsBase { const Eigen::IndexList<Eigen::type2index<0>> kZero; const Eigen::IndexList<Eigen::type2index<1>> kOne; const Eigen::IndexList<Eigen::type2index<0>, Eigen::type2index<2>> kZeroTwo; }; -#endif +template<> struct Constants<CPUDevice> : ConstantsBase{}; +#ifdef TENSORFLOW_USE_SYCL +template<> struct Constants<SYCLDevice> : ConstantsBase{}; +#endif // TENSORFLOW_USE_SYCL +#endif // EIGEN_HAS_INDEX_LIST class ReductionHelper { public: @@ -239,22 +245,31 @@ class ReductionOp : public OpKernel { namespace functor { -template <typename Reducer> -struct ReduceFunctor<CPUDevice, Reducer> { +template <typename Device, typename Reducer> +struct ReduceFunctorBase { template <typename OUT_T, typename IN_T, typename ReductionAxes> - static void Reduce(const CPUDevice& d, OUT_T out, IN_T in, + static void Reduce(const Device& d, OUT_T out, IN_T in, const ReductionAxes& reduction_axes, const Reducer& reducer) { ReduceEigenImpl(d, out, in, reduction_axes, reducer); } template <typename OUT_T> - static void FillIdentity(const CPUDevice& d, OUT_T out, + static void FillIdentity(const Device& d, OUT_T out, const Reducer& reducer) { FillIdentityEigenImpl(d, out, reducer); } }; +template <typename Reducer> +struct ReduceFunctor<CPUDevice, Reducer> + : ReduceFunctorBase<CPUDevice, Reducer>{}; +#if TENSORFLOW_USE_SYCL +template <typename Reducer> +struct ReduceFunctor<SYCLDevice, Reducer> + : ReduceFunctorBase<SYCLDevice, Reducer>{}; +#endif // TENSORFLOW_USE_SYCL + } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/reduction_ops_sum.cc b/tensorflow/core/kernels/reduction_ops_sum.cc index c7c7949fed..3aa38f418e 100644 --- a/tensorflow/core/kernels/reduction_ops_sum.cc +++ b/tensorflow/core/kernels/reduction_ops_sum.cc @@ -64,4 +64,31 @@ REGISTER_KERNEL_BUILDER( #endif +#ifdef TENSORFLOW_USE_SYCL +#define REGISTER_SYCL_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Sum") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint<type>("T") \ + .TypeConstraint<int32>("Tidx") \ + .HostMemory("reduction_indices"), \ + ReductionOp<SYCLDevice, type, Eigen::internal::SumReducer<type>>); +REGISTER_SYCL_KERNELS(float); +REGISTER_SYCL_KERNELS(double); +#undef REGISTER_SYCL_KERNELS + +// A special GPU kernel for int32. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +REGISTER_KERNEL_BUILDER( + Name("Sum") + .Device(DEVICE_SYCL) + .TypeConstraint<int32>("T") + .TypeConstraint<int32>("Tidx") + .HostMemory("input") + .HostMemory("output") + .HostMemory("reduction_indices"), + ReductionOp<CPUDevice, int32, Eigen::internal::SumReducer<int32>>); +#endif // TENSORFLOW_USE_SYCL + } // namespace tensorflow diff --git a/tensorflow/core/kernels/scatter_functor.h b/tensorflow/core/kernels/scatter_functor.h index a84d89c296..a27cc83e4c 100644 --- a/tensorflow/core/kernels/scatter_functor.h +++ b/tensorflow/core/kernels/scatter_functor.h @@ -25,6 +25,9 @@ namespace tensorflow { class OpKernelContext; typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL namespace scatter_op { @@ -82,10 +85,9 @@ struct ScatterFunctor { typename TTypes<Index>::ConstFlat indices); }; -// Specializations of scatter functor for CPU. -template <typename T, typename Index, scatter_op::UpdateOp op> -struct ScatterFunctor<CPUDevice, T, Index, op> { - Index operator()(OpKernelContext* c, const CPUDevice& d, +template <typename Device, typename T, typename Index, scatter_op::UpdateOp op> +struct ScatterFunctorBase { + Index operator()(OpKernelContext* c, const Device& d, typename TTypes<T>::Matrix params, typename TTypes<T>::ConstMatrix updates, typename TTypes<Index>::ConstFlat indices) { @@ -106,6 +108,15 @@ struct ScatterFunctor<CPUDevice, T, Index, op> { } }; +template <typename T, typename Index, scatter_op::UpdateOp op> +struct ScatterFunctor<CPUDevice, T, Index, op> + : ScatterFunctorBase<CPUDevice, T, Index, op>{}; +#if TENSORFLOW_USE_SYCL +template<typename T, typename Index, scatter_op::UpdateOp op> +struct ScatterFunctor<SYCLDevice, T, Index, op> + : ScatterFunctorBase<SYCLDevice, T, Index, op>{}; +#endif // TENSORFLOW_USE_SYCL + } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/scatter_op.cc b/tensorflow/core/kernels/scatter_op.cc index 604f753db1..827eb7dbca 100644 --- a/tensorflow/core/kernels/scatter_op.cc +++ b/tensorflow/core/kernels/scatter_op.cc @@ -27,6 +27,9 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL // Check whether updates.shape = indices.shape + params.shape[1:] static bool ValidShapes(const Tensor& params, const Tensor& updates, @@ -170,6 +173,20 @@ TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_UPDATE_GPU); #endif // GOOGLE_CUDA +// Registers GPU kernels. +#if TENSORFLOW_USE_SYCL +#define REGISTER_SCATTER_ARITHEMTIC_SYCL(type) \ + REGISTER_SCATTER_ARITHEMTIC(type, SYCL); + +#define REGISTER_SCATTER_UPDATE_SYCL(type) REGISTER_SCATTER_UPDATE(type, SYCL); + +REGISTER_SCATTER_ARITHEMTIC_SYCL(float); +REGISTER_SCATTER_UPDATE_SYCL(float); + +#undef REGISTER_SCATTER_ARITHEMTIC_SYCL +#undef REGISTER_SCATTER_UPDATE_SYCL +#endif // TENSORFLOW_USE_SYCL + #undef REGISTER_SCATTER_ARITHEMTIC #undef REGISTER_SCATTER_ARITHEMTIC_CPU #undef REGISTER_SCATTER_ARITHEMTIC_GPU diff --git a/tensorflow/core/kernels/session_ops.cc b/tensorflow/core/kernels/session_ops.cc index 3f1538164c..4550115c19 100644 --- a/tensorflow/core/kernels/session_ops.cc +++ b/tensorflow/core/kernels/session_ops.cc @@ -67,6 +67,19 @@ TF_CALL_NUMBER_TYPES(REGISTER_GPU_KERNEL); REGISTER_GPU_KERNEL(bool); #undef REGISTER_GPU_KERNEL +#ifdef TENSORFLOW_USE_SYCL +#define REGISTER_SYCL_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("GetSessionHandle") \ + .Device(DEVICE_SYCL) \ + .HostMemory("handle") \ + .TypeConstraint<type>("T"), \ + GetSessionHandleOp) + +TF_CALL_NUMBER_TYPES(REGISTER_SYCL_KERNEL); +REGISTER_SYCL_KERNEL(bool); +#undef REGISTER_SYCL_KERNEL +#endif // TENSORFLOW_USE_SYCL + class GetSessionTensorOp : public OpKernel { public: explicit GetSessionTensorOp(OpKernelConstruction* context) @@ -97,6 +110,19 @@ TF_CALL_NUMBER_TYPES(REGISTER_GPU_KERNEL); REGISTER_GPU_KERNEL(bool); #undef REGISTER_GPU_KERNEL +#ifdef TENSORFLOW_USE_SYCL +#define REGISTER_SYCL_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("GetSessionTensor") \ + .Device(DEVICE_SYCL) \ + .HostMemory("handle") \ + .TypeConstraint<type>("dtype"), \ + GetSessionTensorOp) + +TF_CALL_NUMBER_TYPES(REGISTER_SYCL_KERNEL); +REGISTER_SYCL_KERNEL(bool); +#undef REGISTER_SYCL_KERNEL +#endif // TENSORFLOW_USE_SYCL + class DeleteSessionTensorOp : public OpKernel { public: explicit DeleteSessionTensorOp(OpKernelConstruction* context) @@ -117,4 +143,9 @@ REGISTER_KERNEL_BUILDER( Name("DeleteSessionTensor").Device(DEVICE_GPU).HostMemory("handle"), DeleteSessionTensorOp); +#ifdef TENSORFLOW_USE_SYCL +REGISTER_KERNEL_BUILDER( + Name("DeleteSessionTensor").Device(DEVICE_SYCL).HostMemory("handle"), + DeleteSessionTensorOp); +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/shape_ops.cc b/tensorflow/core/kernels/shape_ops.cc index 7ff812cf27..496865de02 100644 --- a/tensorflow/core/kernels/shape_ops.cc +++ b/tensorflow/core/kernels/shape_ops.cc @@ -210,6 +210,43 @@ REGISTER_KERNEL_BUILDER(Name("ShapeN") ShapeNOp<int64>); #endif +#if TENSORFLOW_USE_SYCL +#define REGISTER_SYCL_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("ShapeN") \ + .Device(DEVICE_SYCL) \ + .HostMemory("output") \ + .TypeConstraint<int32>("out_type") \ + .TypeConstraint<type>("T"), \ + ShapeNOp<int32>); \ + REGISTER_KERNEL_BUILDER(Name("ShapeN") \ + .Device(DEVICE_SYCL) \ + .HostMemory("output") \ + .TypeConstraint<int64>("out_type") \ + .TypeConstraint<type>("T"), \ + ShapeNOp<int64>) + +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL); +#undef REGISTER_SYCL_KERNEL + +// A special GPU kernel for int32. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +REGISTER_KERNEL_BUILDER(Name("ShapeN") + .Device(DEVICE_SYCL) + .HostMemory("input") + .HostMemory("output") + .TypeConstraint<int32>("T") + .TypeConstraint<int32>("out_type"), + ShapeNOp<int32>); +REGISTER_KERNEL_BUILDER(Name("ShapeN") + .Device(DEVICE_SYCL) + .HostMemory("input") + .HostMemory("output") + .TypeConstraint<int32>("T") + .TypeConstraint<int64>("out_type"), + ShapeNOp<int64>); +#endif // TENSORFLOW_USE_SYCL + class RankOp : public OpKernel { public: explicit RankOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} diff --git a/tensorflow/core/kernels/slice_op.cc b/tensorflow/core/kernels/slice_op.cc index dc33a25cec..e2978eccbd 100644 --- a/tensorflow/core/kernels/slice_op.cc +++ b/tensorflow/core/kernels/slice_op.cc @@ -56,6 +56,9 @@ gtl::InlinedVector<int64, 4> IntTensorToInt64Vec(const Tensor& tensor) { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL // Shared code that is not dependent on the type of T. We do this to reduce // code size by not duplicating all this for all T (float, double, int32, etc.) @@ -300,4 +303,58 @@ REGISTER_KERNEL_BUILDER(Name("Slice") #endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL +// Forward declarations of the functor specializations for SYCL. +namespace functor { +#define DECLARE_SYCL_SPEC(T, NDIM) \ + template <> \ + void Slice<SYCLDevice, T, NDIM>::operator()( \ + const SYCLDevice& d, typename TTypes<T, NDIM>::Tensor output,\ + typename TTypes<T, NDIM>::ConstTensor input, \ + const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices, \ + const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes); \ + extern template struct Slice<SYCLDevice, T, NDIM>; + +#define DECLARE_FOR_N(T) \ + DECLARE_SYCL_SPEC(T, 1); \ + DECLARE_SYCL_SPEC(T, 2); \ + DECLARE_SYCL_SPEC(T, 3); \ + DECLARE_SYCL_SPEC(T, 4); \ + DECLARE_SYCL_SPEC(T, 5); \ + DECLARE_SYCL_SPEC(T, 6); + +TF_CALL_GPU_NUMBER_TYPES(DECLARE_FOR_N); +DECLARE_FOR_N(int32); + +#undef DECLARE_FOR_N +#undef DECLARE_SYCL_SPEC +} // namespace functor + +#define REGISTER_SYCL(type) \ + REGISTER_KERNEL_BUILDER(Name("Slice") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint<type>("T") \ + .HostMemory("begin") \ + .HostMemory("size") \ + .TypeConstraint<int32>("Index"), \ + SliceOp<SYCLDevice, type>) + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_SYCL); + +// A special GPU kernel for int32. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +REGISTER_KERNEL_BUILDER(Name("Slice") + .Device(DEVICE_SYCL) + .TypeConstraint<int32>("T") + .TypeConstraint<int32>("Index") + .HostMemory("input") + .HostMemory("begin") + .HostMemory("size") + .HostMemory("output"), + SliceOp<CPUDevice, int32>); + +#undef REGISTER_SYCL + +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/slice_op_cpu_impl.h b/tensorflow/core/kernels/slice_op_cpu_impl.h index 0b0700ec36..a70805658e 100644 --- a/tensorflow/core/kernels/slice_op_cpu_impl.h +++ b/tensorflow/core/kernels/slice_op_cpu_impl.h @@ -34,6 +34,18 @@ DEFINE_CPU_KERNELS(bfloat16); #undef DEFINE_CPU_KERNELS +#ifdef TENSORFLOW_USE_SYCL +using SyclDevice = Eigen::SyclDevice; + +#define DEFINE_SYCL_KERNELS(T) \ + template struct functor::Slice<SyclDevice, T, CPU_PROVIDED_IXDIM>; + +TF_CALL_GPU_NUMBER_TYPES(DEFINE_SYCL_KERNELS); +DEFINE_SYCL_KERNELS(int32); + +#undef DEFINE_SYCL_KERNELS +#endif // TENSORFLOW_USE_SYCL + } // namespace tensorflow #endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SLICE_OP_CPU_IMPL_H_ diff --git a/tensorflow/core/kernels/split_lib.h b/tensorflow/core/kernels/split_lib.h index 240cce46e0..ff92ffeeb3 100644 --- a/tensorflow/core/kernels/split_lib.h +++ b/tensorflow/core/kernels/split_lib.h @@ -48,6 +48,17 @@ struct Split<Eigen::ThreadPoolDevice, T> { const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_sizes); }; +#ifdef TENSORFLOW_USE_SYCL +template <typename T> +struct Split<Eigen::SyclDevice, T> { + void operator()(const Eigen::SyclDevice& d, + typename TTypes<T, 3>::Tensor output, + typename TTypes<T, 3>::ConstTensor input, + const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_indices, + const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_sizes); +}; +#endif // TENSORFLOW_USE_SYCL + } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/split_lib_cpu.cc b/tensorflow/core/kernels/split_lib_cpu.cc index 41b2d6f0f5..e377e4d97a 100644 --- a/tensorflow/core/kernels/split_lib_cpu.cc +++ b/tensorflow/core/kernels/split_lib_cpu.cc @@ -43,5 +43,24 @@ TF_CALL_ALL_TYPES(DEFINE_CPU_KERNELS) DEFINE_CPU_KERNELS(quint8) DEFINE_CPU_KERNELS(bfloat16) +#ifdef TENSORFLOW_USE_SYCL +template <typename T> +void Split<Eigen::SyclDevice, T>::operator()( + const Eigen::SyclDevice& d, typename TTypes<T, 3>::Tensor output, + typename TTypes<T, 3>::ConstTensor input, + const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_indices, + const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_sizes) { + if (output.size() < 131072) { + output = input.slice(slice_indices, slice_sizes); + } else { + output.device(d) = input.slice(slice_indices, slice_sizes); + } +} + +#define DEFINE_SYCL_KERNELS(T) template struct Split<Eigen::SyclDevice, T>; + +TF_CALL_GPU_NUMBER_TYPES(DEFINE_SYCL_KERNELS) +#endif // TENSORFLOW_USE_SYCL + } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/split_op.cc b/tensorflow/core/kernels/split_op.cc index 4b12e1f995..cca2fc41c2 100644 --- a/tensorflow/core/kernels/split_op.cc +++ b/tensorflow/core/kernels/split_op.cc @@ -36,6 +36,9 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL template <typename Device, typename T> class SplitOpBase : public OpKernel { @@ -243,6 +246,75 @@ class SplitOpGPU : public SplitOpBase<GPUDevice, T> { }; #endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL + +template <typename T> +class SplitOpSYCL : public SplitOpBase<SYCLDevice, T> { + public: + typedef SplitOpBase<SYCLDevice, T> Base; + explicit SplitOpSYCL(OpKernelConstruction* c) : Base(c) {} + + void Compute(OpKernelContext* context) override { + bool done = false; + Base::ComputeEasyCases(context, &done); + if (!context->status().ok() || done) { + return; + } + const int32 split_dim = context->input(0).flat<int32>()(0); + const int32 num_split = Base::num_outputs(); + const Tensor& input = context->input(1); + const TensorShape& input_shape = input.shape(); + + // Android also uses int32 indexing, so check here also. + OP_REQUIRES( + context, FastBoundsCheck(input.NumElements(), + std::numeric_limits<Eigen::DenseIndex>::max()), + errors::InvalidArgument("Split requires input size < ", + std::numeric_limits<Eigen::DenseIndex>::max())); + + Eigen::DenseIndex prefix_dim_size; + Eigen::DenseIndex split_dim_size; + Eigen::DenseIndex suffix_dim_size; + + std::tie(prefix_dim_size, split_dim_size, suffix_dim_size) = + Base::template SetDims<Eigen::DenseIndex>(input_shape, split_dim); + auto input_reshaped = + input.shaped<T, 3>({prefix_dim_size, split_dim_size, suffix_dim_size}); + + const int64 split_dim_output_size = split_dim_size / num_split; + TensorShape output_shape(input_shape); + output_shape.set_dim(split_dim, split_dim_output_size); + + Eigen::DSizes<Eigen::DenseIndex, 3> indices{0, 0, 0}; + Eigen::DSizes<Eigen::DenseIndex, 3> sizes{ + prefix_dim_size, split_dim_output_size, suffix_dim_size}; + + for (int i = 0; i < num_split; ++i) { + Tensor* result = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(i, output_shape, &result)); + if (prefix_dim_size * split_dim_output_size * suffix_dim_size > 0) { + Eigen::DSizes<Eigen::DenseIndex, 3> slice_indices; + Eigen::DSizes<Eigen::DenseIndex, 3> slice_sizes; + for (int j = 0; j < 3; ++j) { + slice_indices[j] = indices[j]; + slice_sizes[j] = sizes[j]; + } + + auto result_shaped = result->shaped<T, 3>( + {prefix_dim_size, split_dim_output_size, suffix_dim_size}); + + functor::Split<SYCLDevice, T>()(context->eigen_device<SYCLDevice>(), + result_shaped, input_reshaped, + slice_indices, slice_sizes); + } + indices[1] += split_dim_output_size; + } + } +}; + +#endif // TENSORFLOW_USE_SYCL + #define REGISTER_SPLIT(type) \ REGISTER_KERNEL_BUILDER(Name("Split") \ .Device(DEVICE_CPU) \ @@ -269,4 +341,17 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); #endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL +#define REGISTER_SYCL(type) \ + REGISTER_KERNEL_BUILDER(Name("Split") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint<type>("T") \ + .HostMemory("split_dim"), \ + SplitOpSYCL<type>) + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_SYCL); +#undef REGISTER_SYCL + +#endif // TENSORFLOW_USE_SYCL + } // end namespace tensorflow diff --git a/tensorflow/core/kernels/tile_ops.cc b/tensorflow/core/kernels/tile_ops.cc index e49a319aed..36cabaaf7d 100644 --- a/tensorflow/core/kernels/tile_ops.cc +++ b/tensorflow/core/kernels/tile_ops.cc @@ -40,6 +40,9 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL // Forward declarations of functors that will be defined in // tile_ops_cpu_impl*.cc and tile_ops_gpu.cu.cc. @@ -225,6 +228,11 @@ inline void TileOp<Device>::HandleCase( #define HANDLE_TYPE_NAME_GPU(T) \ HANDLE_CASE_DIM(GPUDevice, T, DataTypeToEnum<T>::value); +#ifdef TENSORFLOW_USE_SYCL +#define HANDLE_TYPE_NAME_SYCL(T) \ + HANDLE_CASE_DIM(SYCLDevice, T, DataTypeToEnum<T>::value); +#endif // TENSORFLOW_USE_SYCL + TF_CALL_bool(HANDLE_TYPE_NAME_CPU); TF_CALL_float(HANDLE_TYPE_NAME_CPU); TF_CALL_double(HANDLE_TYPE_NAME_CPU); @@ -248,8 +256,15 @@ TF_CALL_complex64(HANDLE_TYPE_NAME_GPU); TF_CALL_complex128(HANDLE_TYPE_NAME_GPU); #endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL +TF_CALL_float(HANDLE_TYPE_NAME_SYCL); +#endif // TENSORFLOW_USE_SYCL + #undef HANDLE_TYPE_NAME_CPU #undef HANDLE_TYPE_NAME_GPU +#ifdef TENSORFLOW_USE_SYCL +#undef HANDLE_TYPE_NAME_SYCL +#endif // TENSORFLOW_USE_SYCL #undef HANDLE_CASE_DIM #undef HANDLE_CASE @@ -578,4 +593,14 @@ REGISTER_KERNEL_BUILDER(Name("TileGrad") TileGradientOp<GPUDevice>); #endif // GOOGLE_CUDA + +#ifdef TENSORFLOW_USE_SYCL +REGISTER_KERNEL_BUILDER(Name("Tile") + .Device(DEVICE_SYCL) + .TypeConstraint<float>("T") + .TypeConstraint<int32>("Tmultiples") + .HostMemory("multiples"), + TileOp<SYCLDevice>); +#endif // TENSORFLOW_USE_SYCL + } // namespace tensorflow diff --git a/tensorflow/core/kernels/tile_ops_cpu_impl.h b/tensorflow/core/kernels/tile_ops_cpu_impl.h index 9cdf69ad0b..650c739ed5 100644 --- a/tensorflow/core/kernels/tile_ops_cpu_impl.h +++ b/tensorflow/core/kernels/tile_ops_cpu_impl.h @@ -62,6 +62,30 @@ TF_CALL_complex128(DEFINE_TYPE); #undef DEFINE_DIM #undef DEFINE_TYPE +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; + +// Register functors used for TileOp. +#define DEFINE_DIM(T, NDIM) template struct Tile<SYCLDevice, T, NDIM>; +#define DEFINE_TYPE(T) DEFINE_DIM(T, CPU_PROVIDED_IXDIM) + +TF_CALL_float(DEFINE_TYPE); + +#undef DEFINE_DIM +#undef DEFINE_TYPE + +// Register functors used for TileGradientOp. +#define DEFINE_DIM(T, NDIM) \ + template struct TileGrad<SYCLDevice, T, NDIM>; \ + template struct ReduceAndReshape<SYCLDevice, T, NDIM, 1>; +#define DEFINE_TYPE(T) DEFINE_DIM(T, CPU_PROVIDED_IXDIM) + +TF_CALL_float(DEFINE_TYPE); + +#undef DEFINE_DIM +#undef DEFINE_TYPE +#endif // TENSORFLOW_USE_SYCL + } // end namespace functor } // end namespace tensorflow diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc index 172449a998..641c991a7e 100644 --- a/tensorflow/core/kernels/training_ops.cc +++ b/tensorflow/core/kernels/training_ops.cc @@ -46,6 +46,17 @@ struct ApplyGradientDescent<CPUDevice, T> { } }; +#ifdef TENSORFLOW_USE_SYCL +template <typename T> +struct ApplyGradientDescent<SYCLDevice, T> { + void operator()(const SYCLDevice& d, typename TTypes<T>::Flat var, + typename TTypes<T>::ConstScalar lr, + typename TTypes<T>::ConstFlat grad) { + var.device(d) -= grad * lr(); + } +}; +#endif + template <typename T> struct ApplyAdadelta<CPUDevice, T> { void operator()(const CPUDevice& d, typename TTypes<T>::Flat var, @@ -357,6 +368,12 @@ TF_CALL_half(REGISTER_CPU_KERNELS); TF_CALL_float(REGISTER_CPU_KERNELS); TF_CALL_double(REGISTER_CPU_KERNELS); +#ifdef TENSORFLOW_USE_SYCL +#define REGISTER_SYCL_KERNELS(T) REGISTER_KERNELS(SYCL, T); +TF_CALL_float(REGISTER_SYCL_KERNELS); +#undef REGISTER_SYCL_KERNELS +#endif + #if GOOGLE_CUDA // Forward declarations of the functor specializations for GPU. namespace functor { diff --git a/tensorflow/core/kernels/transpose_functor_cpu.cc b/tensorflow/core/kernels/transpose_functor_cpu.cc index f8c87e7e2e..30b82f1843 100644 --- a/tensorflow/core/kernels/transpose_functor_cpu.cc +++ b/tensorflow/core/kernels/transpose_functor_cpu.cc @@ -114,4 +114,28 @@ Status DoTranspose<Device>(const Device& d, const Tensor& in, return Status::OK(); } +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; + +template <> +Status DoTranspose<SYCLDevice>(const SYCLDevice& d, const Tensor& in, + const gtl::ArraySlice<int32> perm, Tensor* out) { + CHECK_GE(in.dims(), 2); + CHECK_EQ(in.dims(), out->dims()); + CHECK_EQ(in.dims(), perm.size()); + CHECK_EQ(in.dtype(), out->dtype()); + switch (in.dtype()) { + + case DT_FLOAT: + case DT_INT32: + internal::Transpose<SYCLDevice, uint32>(d, in, perm, out); + break; + + default: + return errors::Unimplemented("Unsupported dtype on SYCL: ", in.dtype()); + } + return Status::OK(); +} +#endif // TENSORFLOW_USE_SYCL + } // namespace tensorflow diff --git a/tensorflow/core/kernels/unpack_op.cc b/tensorflow/core/kernels/unpack_op.cc index 29959cb187..2a14fa3265 100644 --- a/tensorflow/core/kernels/unpack_op.cc +++ b/tensorflow/core/kernels/unpack_op.cc @@ -32,6 +32,10 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL + template <typename Device, typename T> class UnpackOp : public OpKernel { public: @@ -149,4 +153,25 @@ REGISTER_KERNEL_BUILDER(Name("Unpack") #endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL +#define REGISTER_SYCL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Unpack").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ + UnpackOp<SYCLDevice, type>) + +REGISTER_SYCL(float); +#undef REGISTER_SYCL + +// A special SYCL kernel for int32. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +REGISTER_KERNEL_BUILDER(Name("Unpack") + .Device(DEVICE_SYCL) + .HostMemory("value") + .HostMemory("output") + .TypeConstraint<int32>("T"), + UnpackOp<CPUDevice, int32>); + +#endif // TENSORFLOW_USE_SYCL + } // end namespace tensorflow diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index a3b0512304..4741bc968a 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -416,24 +416,49 @@ REGISTER_OP("SplitV") .Attr("T: type") .Attr("Tlen: {int32, int64} = DT_INT64") .SetShapeFn([](InferenceContext* c) { - ShapeHandle unused; + DimensionHandle split_dimension; + TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &split_dimension)); int32 num_outputs = c->num_outputs(); - // Return unknown shapes with the same rank as the input - // or unknown rank if input's rank isn't known - // can't determine exact shapes until runtime because - // we don't know where the tensor containing the split sizes - // is located - int32 rank = c->Rank(c->input(0)); + ShapeHandle input = c->input(0); + int32 rank = c->Rank(input); ShapeHandle output_shape; + const Tensor* size_splits = c->input_tensor(1); if (rank == InferenceContext::kUnknownRank) { + // If the rank of input tensor is unknown, then return unkown shapes. output_shape = c->UnknownShape(); + for (int i = 0; i < num_outputs; ++i) { + c->set_output(i, output_shape); + } } else if (rank == 0) { + // Throw error if input is a scalar. return errors::InvalidArgument("Can't split scalars"); - } else { + } else if (size_splits == nullptr || !c->ValueKnown(split_dimension)) { + // If split dimension or tensor containing the split sizes is unkown, + // then return unknown shapes of same rank as input. output_shape = c->UnknownShapeOfRank(rank); - } - for (int i = 0; i < num_outputs; ++i) { - c->set_output(i, output_shape); + for (int i = 0; i < num_outputs; ++i) { + c->set_output(i, output_shape); + } + } else { + // Determine the output shape if split dimension and split sizes are known + int64 split_dim = c->Value(split_dimension); + TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, split_dim + 1, &input)); + std::vector<int64> data; + if (size_splits->dtype() == DT_INT32) { + data = AsInt64<int32>(size_splits, size_splits->shape().dim_size(0)); + } else { + data = AsInt64<int64>(size_splits, size_splits->shape().dim_size(0)); + } + if (num_outputs != data.size()) { + return errors::InvalidArgument( + "Length of size_splits should be equal to num_outputs"); + } + for (int i = 0; i < num_outputs; ++i) { + output_shape = c->UnknownShapeOfRank(rank); + TF_RETURN_IF_ERROR( + c->ReplaceDim(input, split_dim, c->MakeDim(data[i]), &output_shape)); + c->set_output(i, output_shape); + } } return Status::OK(); diff --git a/tensorflow/core/platform/default/logging.cc b/tensorflow/core/platform/default/logging.cc index 1d03725c78..56985eec15 100644 --- a/tensorflow/core/platform/default/logging.cc +++ b/tensorflow/core/platform/default/logging.cc @@ -84,39 +84,49 @@ void LogMessage::GenerateLogMessage() { namespace { -int64 MinLogLevel() { - const char* tf_env_var_val = getenv("TF_CPP_MIN_LOG_LEVEL"); +// Parse log level (int64) from environment variable (char*) +int64 LogLevelStrToInt(const char* tf_env_var_val) { if (tf_env_var_val == nullptr) { return 0; } // Ideally we would use env_var / safe_strto64, but it is // hard to use here without pulling in a lot of dependencies, - // so we do a poor-man's parsing. + // so we use std:istringstream instead string min_log_level(tf_env_var_val); - if (min_log_level == "1") { - // Maps to WARNING - return 1; - } else if (min_log_level == "2") { - // Maps to ERROR - return 2; - } else if (min_log_level == "3") { - // Maps to FATAL - return 3; - } else { - // Maps to INFO (the default). - return 0; + std::istringstream ss(min_log_level); + int64 level; + if (!(ss >> level)) { + // Invalid vlog level setting, set level to default (0) + level = 0; } + + return level; +} + +int64 MinLogLevelFromEnv() { + const char* tf_env_var_val = getenv("TF_CPP_MIN_LOG_LEVEL"); + return LogLevelStrToInt(tf_env_var_val); +} + +int64 MinVLogLevelFromEnv() { + const char* tf_env_var_val = getenv("TF_CPP_MIN_VLOG_LEVEL"); + return LogLevelStrToInt(tf_env_var_val); } } // namespace LogMessage::~LogMessage() { // Read the min log level once during the first call to logging. - static int64 min_log_level = MinLogLevel(); + static int64 min_log_level = MinLogLevelFromEnv(); if (TF_PREDICT_TRUE(severity_ >= min_log_level)) GenerateLogMessage(); } +int64 LogMessage::MinVLogLevel() { + static int64 min_vlog_level = MinVLogLevelFromEnv(); + return min_vlog_level; +} + LogMessageFatal::LogMessageFatal(const char* file, int line) : LogMessage(file, line, FATAL) {} LogMessageFatal::~LogMessageFatal() { diff --git a/tensorflow/core/platform/default/logging.h b/tensorflow/core/platform/default/logging.h index 961fb8b4ad..04ff9e12b6 100644 --- a/tensorflow/core/platform/default/logging.h +++ b/tensorflow/core/platform/default/logging.h @@ -41,6 +41,11 @@ class LogMessage : public std::basic_ostringstream<char> { LogMessage(const char* fname, int line, int severity); ~LogMessage(); + // Returns the minimum log level for VLOG statements. + // E.g., if MinVLogLevel() is 2, then VLOG(2) statements will produce output, + // but VLOG(3) will not. Defaults to 0. + static int64 MinVLogLevel(); + protected: void GenerateLogMessage(); @@ -71,11 +76,18 @@ class LogMessageFatal : public LogMessage { #define LOG(severity) _TF_LOG_##severity -// TODO(jeff): Define a proper implementation of VLOG_IS_ON +#ifdef IS_MOBILE_PLATFORM +// Turn VLOG off when under mobile devices for considerations of binary size. #define VLOG_IS_ON(lvl) ((lvl) <= 0) +#else +// Otherwise, Set TF_CPP_MIN_VLOG_LEVEL environment to update minimum log level +// of VLOG +#define VLOG_IS_ON(lvl) \ + ((lvl) <= ::tensorflow::internal::LogMessage::MinVLogLevel()) +#endif #define VLOG(lvl) \ - if (VLOG_IS_ON(lvl)) \ + if (TF_PREDICT_FALSE(VLOG_IS_ON(lvl))) \ ::tensorflow::internal::LogMessage(__FILE__, __LINE__, tensorflow::INFO) // CHECK dies with a fatal error if condition is not true. It is *not* |