aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/common_runtime/sycl/sycl_device.cc32
-rw-r--r--tensorflow/core/common_runtime/sycl/sycl_device_context.cc2
-rw-r--r--tensorflow/core/kernels/aggregate_ops.cc18
-rw-r--r--tensorflow/core/kernels/aggregate_ops_cpu.h113
-rw-r--r--tensorflow/core/kernels/constant_op.cc5
-rw-r--r--tensorflow/core/kernels/control_flow_ops.cc112
-rw-r--r--tensorflow/core/kernels/cwise_op_exp.cc1
-rw-r--r--tensorflow/core/kernels/cwise_op_floor_div.cc2
-rw-r--r--tensorflow/core/kernels/cwise_op_neg.cc12
-rw-r--r--tensorflow/core/kernels/cwise_op_sub.cc12
-rw-r--r--tensorflow/core/kernels/cwise_ops_sycl_common.h24
-rw-r--r--tensorflow/core/kernels/debug_ops.h1
-rw-r--r--tensorflow/core/kernels/pack_op.cc25
-rw-r--r--tensorflow/core/kernels/reduction_ops_common.h29
-rw-r--r--tensorflow/core/kernels/reduction_ops_sum.cc27
-rw-r--r--tensorflow/core/kernels/scatter_functor.h19
-rw-r--r--tensorflow/core/kernels/scatter_op.cc17
-rw-r--r--tensorflow/core/kernels/session_ops.cc31
-rw-r--r--tensorflow/core/kernels/shape_ops.cc37
-rw-r--r--tensorflow/core/kernels/slice_op.cc57
-rw-r--r--tensorflow/core/kernels/slice_op_cpu_impl.h12
-rw-r--r--tensorflow/core/kernels/split_lib.h11
-rw-r--r--tensorflow/core/kernels/split_lib_cpu.cc19
-rw-r--r--tensorflow/core/kernels/split_op.cc85
-rw-r--r--tensorflow/core/kernels/tile_ops.cc25
-rw-r--r--tensorflow/core/kernels/tile_ops_cpu_impl.h24
-rw-r--r--tensorflow/core/kernels/training_ops.cc17
-rw-r--r--tensorflow/core/kernels/transpose_functor_cpu.cc24
-rw-r--r--tensorflow/core/kernels/unpack_op.cc25
-rw-r--r--tensorflow/core/ops/array_ops.cc47
-rw-r--r--tensorflow/core/platform/default/logging.cc42
-rw-r--r--tensorflow/core/platform/default/logging.h16
32 files changed, 850 insertions, 73 deletions
diff --git a/tensorflow/core/common_runtime/sycl/sycl_device.cc b/tensorflow/core/common_runtime/sycl/sycl_device.cc
index 19d39056ff..0abe25c373 100644
--- a/tensorflow/core/common_runtime/sycl/sycl_device.cc
+++ b/tensorflow/core/common_runtime/sycl/sycl_device.cc
@@ -23,7 +23,8 @@ limitations under the License.
namespace tensorflow {
-static std::unordered_set<SYCLDevice *> live_devices;
+static std::unordered_set<SYCLDevice*> live_devices;
+static bool first_time = true;
void ShutdownSycl() {
for (auto device : live_devices) {
@@ -31,7 +32,6 @@ void ShutdownSycl() {
}
live_devices.clear();
}
-bool first_time = true;
void SYCLDevice::RegisterDevice() {
if (first_time) {
@@ -44,17 +44,27 @@ void SYCLDevice::RegisterDevice() {
SYCLDevice::~SYCLDevice() {
device_context_->Unref();
sycl_allocator_->EnterLameDuckMode();
- delete sycl_device_;
- delete sycl_queue_;
+ if (sycl_device_) {
+ sycl_device_->synchronize();
+ delete sycl_device_;
+ }
+ if (sycl_queue_) {
+ delete sycl_queue_;
+ }
live_devices.erase(this);
}
void SYCLDevice::EnterLameDuckMode() {
sycl_allocator_->EnterLameDuckMode();
- delete sycl_device_;
- sycl_device_ = nullptr;
- delete sycl_queue_;
- sycl_queue_ = nullptr;
+ if (sycl_device_) {
+ sycl_device_->synchronize();
+ delete sycl_device_;
+ sycl_device_ = nullptr;
+ }
+ if (sycl_queue_) {
+ delete sycl_queue_;
+ sycl_queue_ = nullptr;
+ }
}
void SYCLDevice::Compute(OpKernel *op_kernel, OpKernelContext *context) {
@@ -110,7 +120,11 @@ Status SYCLDevice::FillContextMap(const Graph *graph,
Status SYCLDevice::Sync() {
sycl_device_->synchronize();
- return Status::OK();
+ if (sycl_device_->ok()) {
+ return Status::OK();
+ } else {
+ return errors::Internal("Unknown error detected on device ", name());
+ }
}
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/sycl/sycl_device_context.cc b/tensorflow/core/common_runtime/sycl/sycl_device_context.cc
index b49420b1b5..a6be9195d4 100644
--- a/tensorflow/core/common_runtime/sycl/sycl_device_context.cc
+++ b/tensorflow/core/common_runtime/sycl/sycl_device_context.cc
@@ -95,6 +95,7 @@ void SYCLDeviceContext::CopyCPUTensorToDevice(const Tensor *cpu_tensor,
assert(false && "unsupported type");
}
}
+ device->eigen_sycl_device()->synchronize();
done(Status::OK());
}
@@ -172,6 +173,7 @@ void SYCLDeviceContext::CopyDeviceTensorToCPU(const Tensor *device_tensor,
assert(false && "unsupported type");
}
}
+ device->eigen_sycl_device()->synchronize();
done(Status::OK());
}
diff --git a/tensorflow/core/kernels/aggregate_ops.cc b/tensorflow/core/kernels/aggregate_ops.cc
index b41e438b2b..50d0cc1727 100644
--- a/tensorflow/core/kernels/aggregate_ops.cc
+++ b/tensorflow/core/kernels/aggregate_ops.cc
@@ -28,6 +28,9 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+#endif // TENSORFLOW_USE_SYCL
template <typename Device, typename T>
class AddNOp : public OpKernel {
@@ -152,6 +155,21 @@ REGISTER_KERNEL_BUILDER(Name("AddN")
AddNOp<CPUDevice, int32>);
#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+REGISTER_ADDN(float, SYCL);
+REGISTER_ADDN(double, SYCL);
+
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("AddN")
+ .Device(DEVICE_SYCL)
+ .TypeConstraint<int32>("T")
+ .HostMemory("inputs")
+ .HostMemory("sum"),
+ AddNOp<CPUDevice, int32>);
+#endif // TENSORFLOW_USE_SYCL
+
#undef REGISTER_ADDN
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/aggregate_ops_cpu.h b/tensorflow/core/kernels/aggregate_ops_cpu.h
index ba5ebb7f0f..dfa3fe585e 100644
--- a/tensorflow/core/kernels/aggregate_ops_cpu.h
+++ b/tensorflow/core/kernels/aggregate_ops_cpu.h
@@ -23,6 +23,10 @@ limitations under the License.
typedef Eigen::ThreadPoolDevice CPUDevice;
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+#endif // TENSORFLOW_USE_SYCL
+
namespace tensorflow {
// Partial specializations for a CPUDevice, that uses the Eigen implementation
@@ -133,6 +137,115 @@ struct Add9Functor<CPUDevice, T> {
}
};
+#ifdef TENSORFLOW_USE_SYCL
+// Partial specializations for a SYCLDevice, that uses the Eigen implementation
+// from AddNEigenImpl.
+template <typename T>
+struct Add2Functor<SYCLDevice, T> {
+ void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1,
+ typename TTypes<T>::ConstFlat in2) {
+ Add2EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2);
+ }
+};
+template <typename T>
+struct Add3Functor<SYCLDevice, T> {
+ void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1,
+ typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3) {
+ Add3EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3);
+ }
+};
+template <typename T>
+struct Add4Functor<SYCLDevice, T> {
+ void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1,
+ typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3,
+ typename TTypes<T>::ConstFlat in4) {
+ Add4EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4);
+ }
+};
+template <typename T>
+struct Add5Functor<SYCLDevice, T> {
+ void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1,
+ typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3,
+ typename TTypes<T>::ConstFlat in4,
+ typename TTypes<T>::ConstFlat in5) {
+ Add5EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5);
+ }
+};
+template <typename T>
+struct Add6Functor<SYCLDevice, T> {
+ void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1,
+ typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3,
+ typename TTypes<T>::ConstFlat in4,
+ typename TTypes<T>::ConstFlat in5,
+ typename TTypes<T>::ConstFlat in6) {
+ Add6EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6);
+ }
+};
+template <typename T>
+struct Add7Functor<SYCLDevice, T> {
+ void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1,
+ typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3,
+ typename TTypes<T>::ConstFlat in4,
+ typename TTypes<T>::ConstFlat in5,
+ typename TTypes<T>::ConstFlat in6,
+ typename TTypes<T>::ConstFlat in7) {
+ Add7EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6,
+ in7);
+ }
+};
+
+template <typename T>
+struct Add8Functor<SYCLDevice, T> {
+ void operator()(
+ const SYCLDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
+ typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
+ typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) {
+ Add8EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6,
+ in7, in8);
+ }
+};
+
+template <typename T>
+struct Add8pFunctor<SYCLDevice, T> {
+ void operator()(
+ const SYCLDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
+ typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
+ typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) {
+ Add8pEigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6,
+ in7, in8);
+ }
+};
+
+template <typename T>
+struct Add9Functor<SYCLDevice, T> {
+ void operator()(
+ const SYCLDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
+ typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
+ typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8,
+ typename TTypes<T>::ConstFlat in9) {
+ Add9EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6,
+ in7, in8, in9);
+ }
+};
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace functor
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc
index e479b97109..3f8717f77f 100644
--- a/tensorflow/core/kernels/constant_op.cc
+++ b/tensorflow/core/kernels/constant_op.cc
@@ -305,4 +305,9 @@ REGISTER_KERNEL_BUILDER(Name("Placeholder").Device(DEVICE_GPU), PlaceholderOp);
REGISTER_KERNEL_BUILDER(Name("PlaceholderV2").Device(DEVICE_GPU),
PlaceholderOp);
+#if TENSORFLOW_USE_SYCL
+REGISTER_KERNEL_BUILDER(Name("Placeholder").Device(DEVICE_SYCL), PlaceholderOp);
+REGISTER_KERNEL_BUILDER(Name("PlaceholderV2").Device(DEVICE_SYCL),
+ PlaceholderOp);
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/control_flow_ops.cc b/tensorflow/core/kernels/control_flow_ops.cc
index b01263f288..5241a4d916 100644
--- a/tensorflow/core/kernels/control_flow_ops.cc
+++ b/tensorflow/core/kernels/control_flow_ops.cc
@@ -121,8 +121,20 @@ REGISTER_GPU_HOST_REF_KERNEL(string);
SwitchOp)
REGISTER_SYCL_KERNEL(bool);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
+
+#define REGISTER_SYCL_REF_SWITCH(type) \
+ REGISTER_KERNEL_BUILDER(Name("RefSwitch") \
+ .Device(DEVICE_SYCL) \
+ .HostMemory("pred") \
+ .TypeConstraint<type>("T"), \
+ SwitchOp)
+REGISTER_SYCL_REF_SWITCH(bool);
+TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_SWITCH);
+
#undef REGISTER_SYCL_KERNEL
-#endif
+#undef REGISTER_SYCL_REF_SWITCH
+
+#endif // TENSORFLOW_USE_SYCL
class RefSelectOp : public OpKernel {
public:
@@ -230,8 +242,18 @@ REGISTER_GPU_REF_KERNEL(bool);
MergeOp)
REGISTER_SYCL_KERNEL(bool);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
+
+#define REGISTER_SYCL_REF_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("RefMerge") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("value_index"), \
+ MergeOp)
+REGISTER_SYCL_REF_KERNEL(bool);
+TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_KERNEL);
#undef REGISTER_SYCL_KERNEL
-#endif
+#undef REGISTER_SYCL_REF_KERNEL
+#endif // TENSORFLOW_USE_SYCL
// Special GPU kernels for int32 and string.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
@@ -289,7 +311,15 @@ REGISTER_GPU_REF_KERNEL(bool);
Name("Enter").Device(DEVICE_SYCL).TypeConstraint<type>("T"), EnterOp)
REGISTER_SYCL_KERNEL(bool);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
+
+#define REGISTER_SYCL_REF_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("RefEnter").Device(DEVICE_SYCL).TypeConstraint<type>("T"), EnterOp)
+REGISTER_SYCL_REF_KERNEL(bool);
+TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_KERNEL);
+
#undef REGISTER_SYCL_KERNEL
+#undef REGISTER_SYCL_REF_KERNEL
#endif
// Special GPU kernels for int32 and string.
@@ -349,8 +379,37 @@ REGISTER_GPU_KERNEL(bool);
Name("Exit").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ExitOp)
REGISTER_SYCL_KERNEL(bool);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
+
+#define REGISTER_SYCL_REF_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("RefExit").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ExitOp)
+REGISTER_SYCL_REF_KERNEL(bool);
+TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_KERNEL);
+
#undef REGISTER_SYCL_KERNEL
-#endif
+#undef REGISTER_SYCL_REF_KERNEL
+
+// Special GPU kernels for int32 and string.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+#define REGISTER_SYCL_HOST_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("Exit") \
+ .Device(DEVICE_SYCL) \
+ .HostMemory("data") \
+ .HostMemory("output") \
+ .TypeConstraint<type>("T"), \
+ ExitOp); \
+ REGISTER_KERNEL_BUILDER(Name("RefExit") \
+ .Device(DEVICE_SYCL) \
+ .HostMemory("data") \
+ .HostMemory("output") \
+ .TypeConstraint<type>("T"), \
+ ExitOp)
+
+REGISTER_SYCL_HOST_KERNEL(int32);
+REGISTER_SYCL_HOST_KERNEL(string);
+#undef REGISTER_SYCL_HOST_KERNEL
+#endif // TENSORFLOW_USE_SYCL
// Special GPU kernels for int32 and string.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
@@ -432,8 +491,39 @@ REGISTER_GPU_HOST_KERNEL(string);
NextIterationOp)
REGISTER_SYCL_KERNEL(bool);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
+#define REGISTER_SYCL_REF_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("RefNextIteration") \
+ .Device(DEVICE_SYCL) \
+ .HostMemory("data") \
+ .HostMemory("output") \
+ .TypeConstraint<type>("T"), \
+ NextIterationOp)
+ REGISTER_SYCL_REF_KERNEL(bool);
+ TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_KERNEL);
#undef REGISTER_SYCL_KERNEL
-#endif
+#undef REGISTER_SYCL_REF_KERNEL
+
+// Special GPU kernels for int32 and string.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+#define REGISTER_SYCL_HOST_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("NextIteration") \
+ .Device(DEVICE_SYCL) \
+ .HostMemory("data") \
+ .HostMemory("output") \
+ .TypeConstraint<type>("T"), \
+ NextIterationOp); \
+ REGISTER_KERNEL_BUILDER(Name("RefNextIteration") \
+ .Device(DEVICE_SYCL) \
+ .HostMemory("data") \
+ .HostMemory("output") \
+ .TypeConstraint<type>("T"), \
+ NextIterationOp)
+
+REGISTER_SYCL_HOST_KERNEL(int32);
+REGISTER_SYCL_HOST_KERNEL(string);
+#undef REGISTER_SYCL_HOST_KERNEL
+#endif // TENSORFLOW_USE_SYCL
// A LoopCond op has one input and one output. The input is a boolean
// scalar representing the taken branches of the "pivot" Switch that
@@ -461,6 +551,14 @@ REGISTER_KERNEL_BUILDER(Name("LoopCond")
.HostMemory("output"),
LoopCondOp);
+#ifdef TENSORFLOW_USE_SYCL
+REGISTER_KERNEL_BUILDER(Name("LoopCond")
+ .Device(DEVICE_SYCL)
+ .HostMemory("input")
+ .HostMemory("output"),
+ LoopCondOp);
+#endif // TENSORFLOW_USE_SYCL
+
// ControlTrigger kernels
REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_CPU),
ControlTriggerOp);
@@ -468,6 +566,11 @@ REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_CPU),
REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_GPU),
ControlTriggerOp);
+#ifdef TENSORFLOW_USE_SYCL
+REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_SYCL),
+ ControlTriggerOp);
+#endif // TENSORFLOW_USE_SYCL
+
// When called, abort op will abort the current process. This can be used to
// abort remote PSs when needed.
class AbortOp : public OpKernel {
@@ -493,4 +596,5 @@ class AbortOp : public OpKernel {
};
REGISTER_KERNEL_BUILDER(Name("Abort").Device(DEVICE_CPU), AbortOp);
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_exp.cc b/tensorflow/core/kernels/cwise_op_exp.cc
index 0ee47f7dee..2e3a60cf79 100644
--- a/tensorflow/core/kernels/cwise_op_exp.cc
+++ b/tensorflow/core/kernels/cwise_op_exp.cc
@@ -27,6 +27,7 @@ REGISTER5(UnaryOp, CPU, "Exp", functor::exp, float, Eigen::half, double,
.TypeConstraint<TYPE>("T"), \
UnaryOp<SYCLDevice, functor::exp<TYPE>>);
REGISTER_SYCL_KERNEL(float);
+REGISTER_SYCL_KERNEL(double);
#undef REGISTER_SYCL_KERNEL
#endif // TENSORFLOW_USE_SYCL
diff --git a/tensorflow/core/kernels/cwise_op_floor_div.cc b/tensorflow/core/kernels/cwise_op_floor_div.cc
index 69dbb70b83..8a600f8f95 100644
--- a/tensorflow/core/kernels/cwise_op_floor_div.cc
+++ b/tensorflow/core/kernels/cwise_op_floor_div.cc
@@ -27,7 +27,7 @@ REGISTER3(BinaryOp, CPU, "FloorDiv", functor::floor_div_real, float,
Name("FloorDiv") \
.Device(DEVICE_SYCL) \
.TypeConstraint<TYPE>("T"), \
- BinaryOp<SYCLDevice, functor::floor_div<TYPE>>);
+ BinaryOp<SYCLDevice, functor::floor_div_real<TYPE>>);
REGISTER_SYCL_KERNEL(float)
#undef REGISTER_SYCL_KERNEL
#endif // TENSORFLOW_USE_SYCL
diff --git a/tensorflow/core/kernels/cwise_op_neg.cc b/tensorflow/core/kernels/cwise_op_neg.cc
index 4221fc0710..c4a9b22883 100644
--- a/tensorflow/core/kernels/cwise_op_neg.cc
+++ b/tensorflow/core/kernels/cwise_op_neg.cc
@@ -27,6 +27,18 @@ REGISTER7(UnaryOp, CPU, "Neg", functor::neg, float, Eigen::half, double, int32,
.TypeConstraint<TYPE>("T"), \
UnaryOp<SYCLDevice, functor::neg<TYPE>>);
REGISTER_SYCL_KERNEL(float);
+REGISTER_SYCL_KERNEL(double);
+
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("Neg")
+ .Device(DEVICE_SYCL)
+ .HostMemory("x")
+ .HostMemory("y")
+ .TypeConstraint<int32>("T"),
+ UnaryOp<CPUDevice, functor::neg<int32>>);
+
#undef REGISTER_SYCL_KERNEL
#endif // TENSORFLOW_USE_SYCL
diff --git a/tensorflow/core/kernels/cwise_op_sub.cc b/tensorflow/core/kernels/cwise_op_sub.cc
index e1326dbed1..eab1e2a09c 100644
--- a/tensorflow/core/kernels/cwise_op_sub.cc
+++ b/tensorflow/core/kernels/cwise_op_sub.cc
@@ -32,6 +32,18 @@ REGISTER(BinaryOp, CPU, "Sub", functor::sub, int32);
.TypeConstraint<TYPE>("T"), \
BinaryOp<SYCLDevice, functor::sub<TYPE>>);
REGISTER_SYCL_KERNEL(float);
+ REGISTER_SYCL_KERNEL(double);
+
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("Sub")
+ .Device(DEVICE_SYCL)
+ .HostMemory("x")
+ .HostMemory("y")
+ .HostMemory("z")
+ .TypeConstraint<int32>("T"),
+ BinaryOp<CPUDevice, functor::sub<int32>>);
#undef REGISTER_SYCL_KERNEL
#endif // TENSORFLOW_USE_SYCL
#if GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_ops_sycl_common.h b/tensorflow/core/kernels/cwise_ops_sycl_common.h
index a0decbce87..3f6ff7303d 100644
--- a/tensorflow/core/kernels/cwise_ops_sycl_common.h
+++ b/tensorflow/core/kernels/cwise_ops_sycl_common.h
@@ -31,14 +31,6 @@ namespace functor {
typedef Eigen::SyclDevice SYCLDevice;
-template <typename Index, int N> Eigen::array<Index, N> GenerateArrayOfOnes() {
- Eigen::array<Index, N> result;
- for (int i = 0; i < N; ++i) {
- result[i] = 1;
- }
- return result;
-}
-
template <typename OUT, typename RHS>
void Assign(const SYCLDevice& d, OUT out, RHS rhs) {
out.device(d) = rhs;
@@ -67,11 +59,9 @@ struct BinaryFunctor<SYCLDevice, Functor, NDIMS, has_errors> {
typename Functor::tin_type in, bool* error) {
typedef typename Functor::func Binary;
constexpr int NumDims = Functor::tin_type::NumDimensions;
- typedef typename Functor::tin_type::Scalar T;
- typedef typename Functor::tin_type::Index Index;
- Eigen::array<Index, NumDims> scalar_dim = GenerateArrayOfOnes<Index, NumDims>();
- Eigen::TensorMap<Eigen::Tensor<T, NumDims, Eigen::RowMajor>> tmp(scalar.data(), scalar_dim);
- out.device(d) = tmp.broadcast(in.dimensions()).binaryExpr(in, Binary());
+ static_assert(NumDims == 1, "Unexpected size");
+ Eigen::Sizes<1> scalar_dim;
+ out.device(d) = scalar.reshape(scalar_dim).broadcast(in.dimensions()).binaryExpr(in, Binary());
}
void Right(const SYCLDevice& d, typename Functor::tout_type out,
@@ -79,11 +69,9 @@ struct BinaryFunctor<SYCLDevice, Functor, NDIMS, has_errors> {
typename Functor::tscalar_type scalar, bool* error) {
typedef typename Functor::func Binary;
constexpr int NumDims = Functor::tin_type::NumDimensions;
- typedef typename Functor::tin_type::Scalar T;
- typedef typename Functor::tin_type::Index Index;
- Eigen::array<Index, NumDims> scalar_dim = GenerateArrayOfOnes<Index, NumDims>();
- Eigen::TensorMap<Eigen::Tensor<T, NumDims, Eigen::RowMajor>> tmp(scalar.data(), scalar_dim);
- out.device(d) = in.binaryExpr(tmp.broadcast(in.dimensions()), Binary());
+ static_assert(NumDims == 1, "Unexpected size");
+ Eigen::Sizes<1> scalar_dim;
+ out.device(d) = in.binaryExpr(scalar.reshape(scalar_dim).broadcast(in.dimensions()), Binary());
}
void BCast(const SYCLDevice& d,
diff --git a/tensorflow/core/kernels/debug_ops.h b/tensorflow/core/kernels/debug_ops.h
index 848896a64a..aa47315f55 100644
--- a/tensorflow/core/kernels/debug_ops.h
+++ b/tensorflow/core/kernels/debug_ops.h
@@ -188,7 +188,6 @@ class DebugNumericSummaryOp : public OpKernel {
const T* input_flat = input.template flat<T>().data();
element_count = input_shape.num_elements();
- const double element_count_double = static_cast<double>(element_count);
for (int64 i = 0; i < element_count; ++i) {
T x = input_flat[i];
if (Eigen::numext::isnan(x)) {
diff --git a/tensorflow/core/kernels/pack_op.cc b/tensorflow/core/kernels/pack_op.cc
index e072eb36b3..4977ad1d7c 100644
--- a/tensorflow/core/kernels/pack_op.cc
+++ b/tensorflow/core/kernels/pack_op.cc
@@ -34,6 +34,9 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
#if GOOGLE_CUDA
typedef Eigen::GpuDevice GPUDevice;
#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+#endif // TENSORFLOW_USE_SYCL
// --------------------------------------------------------------------------
template <typename Device, typename T>
@@ -156,4 +159,26 @@ REGISTER_KERNEL_BUILDER(Name("Pack")
#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+
+#define REGISTER_SYCL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Pack").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ PackOp<SYCLDevice, type>)
+
+REGISTER_SYCL(float);
+#undef REGISTER_SYCL
+
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("Pack")
+ .Device(DEVICE_SYCL)
+ .HostMemory("values")
+ .HostMemory("output")
+ .TypeConstraint<int32>("T"),
+ PackOp<CPUDevice, int32>);
+
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/reduction_ops_common.h b/tensorflow/core/kernels/reduction_ops_common.h
index 1bb1a9fc50..625cea4228 100644
--- a/tensorflow/core/kernels/reduction_ops_common.h
+++ b/tensorflow/core/kernels/reduction_ops_common.h
@@ -40,6 +40,9 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+#endif // TENSORFLOW_USE_SYCL
template <typename Device>
struct Constants {
@@ -60,13 +63,16 @@ struct Constants {
};
#if defined(EIGEN_HAS_INDEX_LIST)
-template <>
-struct Constants<CPUDevice> {
+struct ConstantsBase {
const Eigen::IndexList<Eigen::type2index<0>> kZero;
const Eigen::IndexList<Eigen::type2index<1>> kOne;
const Eigen::IndexList<Eigen::type2index<0>, Eigen::type2index<2>> kZeroTwo;
};
-#endif
+template<> struct Constants<CPUDevice> : ConstantsBase{};
+#ifdef TENSORFLOW_USE_SYCL
+template<> struct Constants<SYCLDevice> : ConstantsBase{};
+#endif // TENSORFLOW_USE_SYCL
+#endif // EIGEN_HAS_INDEX_LIST
class ReductionHelper {
public:
@@ -239,22 +245,31 @@ class ReductionOp : public OpKernel {
namespace functor {
-template <typename Reducer>
-struct ReduceFunctor<CPUDevice, Reducer> {
+template <typename Device, typename Reducer>
+struct ReduceFunctorBase {
template <typename OUT_T, typename IN_T, typename ReductionAxes>
- static void Reduce(const CPUDevice& d, OUT_T out, IN_T in,
+ static void Reduce(const Device& d, OUT_T out, IN_T in,
const ReductionAxes& reduction_axes,
const Reducer& reducer) {
ReduceEigenImpl(d, out, in, reduction_axes, reducer);
}
template <typename OUT_T>
- static void FillIdentity(const CPUDevice& d, OUT_T out,
+ static void FillIdentity(const Device& d, OUT_T out,
const Reducer& reducer) {
FillIdentityEigenImpl(d, out, reducer);
}
};
+template <typename Reducer>
+struct ReduceFunctor<CPUDevice, Reducer>
+ : ReduceFunctorBase<CPUDevice, Reducer>{};
+#if TENSORFLOW_USE_SYCL
+template <typename Reducer>
+struct ReduceFunctor<SYCLDevice, Reducer>
+ : ReduceFunctorBase<SYCLDevice, Reducer>{};
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace functor
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/reduction_ops_sum.cc b/tensorflow/core/kernels/reduction_ops_sum.cc
index c7c7949fed..3aa38f418e 100644
--- a/tensorflow/core/kernels/reduction_ops_sum.cc
+++ b/tensorflow/core/kernels/reduction_ops_sum.cc
@@ -64,4 +64,31 @@ REGISTER_KERNEL_BUILDER(
#endif
+#ifdef TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Sum") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("Tidx") \
+ .HostMemory("reduction_indices"), \
+ ReductionOp<SYCLDevice, type, Eigen::internal::SumReducer<type>>);
+REGISTER_SYCL_KERNELS(float);
+REGISTER_SYCL_KERNELS(double);
+#undef REGISTER_SYCL_KERNELS
+
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(
+ Name("Sum")
+ .Device(DEVICE_SYCL)
+ .TypeConstraint<int32>("T")
+ .TypeConstraint<int32>("Tidx")
+ .HostMemory("input")
+ .HostMemory("output")
+ .HostMemory("reduction_indices"),
+ ReductionOp<CPUDevice, int32, Eigen::internal::SumReducer<int32>>);
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/scatter_functor.h b/tensorflow/core/kernels/scatter_functor.h
index a84d89c296..a27cc83e4c 100644
--- a/tensorflow/core/kernels/scatter_functor.h
+++ b/tensorflow/core/kernels/scatter_functor.h
@@ -25,6 +25,9 @@ namespace tensorflow {
class OpKernelContext;
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+#endif // TENSORFLOW_USE_SYCL
namespace scatter_op {
@@ -82,10 +85,9 @@ struct ScatterFunctor {
typename TTypes<Index>::ConstFlat indices);
};
-// Specializations of scatter functor for CPU.
-template <typename T, typename Index, scatter_op::UpdateOp op>
-struct ScatterFunctor<CPUDevice, T, Index, op> {
- Index operator()(OpKernelContext* c, const CPUDevice& d,
+template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
+struct ScatterFunctorBase {
+ Index operator()(OpKernelContext* c, const Device& d,
typename TTypes<T>::Matrix params,
typename TTypes<T>::ConstMatrix updates,
typename TTypes<Index>::ConstFlat indices) {
@@ -106,6 +108,15 @@ struct ScatterFunctor<CPUDevice, T, Index, op> {
}
};
+template <typename T, typename Index, scatter_op::UpdateOp op>
+struct ScatterFunctor<CPUDevice, T, Index, op>
+ : ScatterFunctorBase<CPUDevice, T, Index, op>{};
+#if TENSORFLOW_USE_SYCL
+template<typename T, typename Index, scatter_op::UpdateOp op>
+struct ScatterFunctor<SYCLDevice, T, Index, op>
+ : ScatterFunctorBase<SYCLDevice, T, Index, op>{};
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace functor
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/scatter_op.cc b/tensorflow/core/kernels/scatter_op.cc
index 604f753db1..827eb7dbca 100644
--- a/tensorflow/core/kernels/scatter_op.cc
+++ b/tensorflow/core/kernels/scatter_op.cc
@@ -27,6 +27,9 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+#endif // TENSORFLOW_USE_SYCL
// Check whether updates.shape = indices.shape + params.shape[1:]
static bool ValidShapes(const Tensor& params, const Tensor& updates,
@@ -170,6 +173,20 @@ TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_UPDATE_GPU);
#endif // GOOGLE_CUDA
+// Registers GPU kernels.
+#if TENSORFLOW_USE_SYCL
+#define REGISTER_SCATTER_ARITHEMTIC_SYCL(type) \
+ REGISTER_SCATTER_ARITHEMTIC(type, SYCL);
+
+#define REGISTER_SCATTER_UPDATE_SYCL(type) REGISTER_SCATTER_UPDATE(type, SYCL);
+
+REGISTER_SCATTER_ARITHEMTIC_SYCL(float);
+REGISTER_SCATTER_UPDATE_SYCL(float);
+
+#undef REGISTER_SCATTER_ARITHEMTIC_SYCL
+#undef REGISTER_SCATTER_UPDATE_SYCL
+#endif // TENSORFLOW_USE_SYCL
+
#undef REGISTER_SCATTER_ARITHEMTIC
#undef REGISTER_SCATTER_ARITHEMTIC_CPU
#undef REGISTER_SCATTER_ARITHEMTIC_GPU
diff --git a/tensorflow/core/kernels/session_ops.cc b/tensorflow/core/kernels/session_ops.cc
index 3f1538164c..4550115c19 100644
--- a/tensorflow/core/kernels/session_ops.cc
+++ b/tensorflow/core/kernels/session_ops.cc
@@ -67,6 +67,19 @@ TF_CALL_NUMBER_TYPES(REGISTER_GPU_KERNEL);
REGISTER_GPU_KERNEL(bool);
#undef REGISTER_GPU_KERNEL
+#ifdef TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("GetSessionHandle") \
+ .Device(DEVICE_SYCL) \
+ .HostMemory("handle") \
+ .TypeConstraint<type>("T"), \
+ GetSessionHandleOp)
+
+TF_CALL_NUMBER_TYPES(REGISTER_SYCL_KERNEL);
+REGISTER_SYCL_KERNEL(bool);
+#undef REGISTER_SYCL_KERNEL
+#endif // TENSORFLOW_USE_SYCL
+
class GetSessionTensorOp : public OpKernel {
public:
explicit GetSessionTensorOp(OpKernelConstruction* context)
@@ -97,6 +110,19 @@ TF_CALL_NUMBER_TYPES(REGISTER_GPU_KERNEL);
REGISTER_GPU_KERNEL(bool);
#undef REGISTER_GPU_KERNEL
+#ifdef TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("GetSessionTensor") \
+ .Device(DEVICE_SYCL) \
+ .HostMemory("handle") \
+ .TypeConstraint<type>("dtype"), \
+ GetSessionTensorOp)
+
+TF_CALL_NUMBER_TYPES(REGISTER_SYCL_KERNEL);
+REGISTER_SYCL_KERNEL(bool);
+#undef REGISTER_SYCL_KERNEL
+#endif // TENSORFLOW_USE_SYCL
+
class DeleteSessionTensorOp : public OpKernel {
public:
explicit DeleteSessionTensorOp(OpKernelConstruction* context)
@@ -117,4 +143,9 @@ REGISTER_KERNEL_BUILDER(
Name("DeleteSessionTensor").Device(DEVICE_GPU).HostMemory("handle"),
DeleteSessionTensorOp);
+#ifdef TENSORFLOW_USE_SYCL
+REGISTER_KERNEL_BUILDER(
+ Name("DeleteSessionTensor").Device(DEVICE_SYCL).HostMemory("handle"),
+ DeleteSessionTensorOp);
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/shape_ops.cc b/tensorflow/core/kernels/shape_ops.cc
index 7ff812cf27..496865de02 100644
--- a/tensorflow/core/kernels/shape_ops.cc
+++ b/tensorflow/core/kernels/shape_ops.cc
@@ -210,6 +210,43 @@ REGISTER_KERNEL_BUILDER(Name("ShapeN")
ShapeNOp<int64>);
#endif
+#if TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("ShapeN") \
+ .Device(DEVICE_SYCL) \
+ .HostMemory("output") \
+ .TypeConstraint<int32>("out_type") \
+ .TypeConstraint<type>("T"), \
+ ShapeNOp<int32>); \
+ REGISTER_KERNEL_BUILDER(Name("ShapeN") \
+ .Device(DEVICE_SYCL) \
+ .HostMemory("output") \
+ .TypeConstraint<int64>("out_type") \
+ .TypeConstraint<type>("T"), \
+ ShapeNOp<int64>)
+
+TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
+#undef REGISTER_SYCL_KERNEL
+
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("ShapeN")
+ .Device(DEVICE_SYCL)
+ .HostMemory("input")
+ .HostMemory("output")
+ .TypeConstraint<int32>("T")
+ .TypeConstraint<int32>("out_type"),
+ ShapeNOp<int32>);
+REGISTER_KERNEL_BUILDER(Name("ShapeN")
+ .Device(DEVICE_SYCL)
+ .HostMemory("input")
+ .HostMemory("output")
+ .TypeConstraint<int32>("T")
+ .TypeConstraint<int64>("out_type"),
+ ShapeNOp<int64>);
+#endif // TENSORFLOW_USE_SYCL
+
class RankOp : public OpKernel {
public:
explicit RankOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
diff --git a/tensorflow/core/kernels/slice_op.cc b/tensorflow/core/kernels/slice_op.cc
index dc33a25cec..e2978eccbd 100644
--- a/tensorflow/core/kernels/slice_op.cc
+++ b/tensorflow/core/kernels/slice_op.cc
@@ -56,6 +56,9 @@ gtl::InlinedVector<int64, 4> IntTensorToInt64Vec(const Tensor& tensor) {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+#endif // TENSORFLOW_USE_SYCL
// Shared code that is not dependent on the type of T. We do this to reduce
// code size by not duplicating all this for all T (float, double, int32, etc.)
@@ -300,4 +303,58 @@ REGISTER_KERNEL_BUILDER(Name("Slice")
#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+// Forward declarations of the functor specializations for SYCL.
+namespace functor {
+#define DECLARE_SYCL_SPEC(T, NDIM) \
+ template <> \
+ void Slice<SYCLDevice, T, NDIM>::operator()( \
+ const SYCLDevice& d, typename TTypes<T, NDIM>::Tensor output,\
+ typename TTypes<T, NDIM>::ConstTensor input, \
+ const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices, \
+ const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes); \
+ extern template struct Slice<SYCLDevice, T, NDIM>;
+
+#define DECLARE_FOR_N(T) \
+ DECLARE_SYCL_SPEC(T, 1); \
+ DECLARE_SYCL_SPEC(T, 2); \
+ DECLARE_SYCL_SPEC(T, 3); \
+ DECLARE_SYCL_SPEC(T, 4); \
+ DECLARE_SYCL_SPEC(T, 5); \
+ DECLARE_SYCL_SPEC(T, 6);
+
+TF_CALL_GPU_NUMBER_TYPES(DECLARE_FOR_N);
+DECLARE_FOR_N(int32);
+
+#undef DECLARE_FOR_N
+#undef DECLARE_SYCL_SPEC
+} // namespace functor
+
+#define REGISTER_SYCL(type) \
+ REGISTER_KERNEL_BUILDER(Name("Slice") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("begin") \
+ .HostMemory("size") \
+ .TypeConstraint<int32>("Index"), \
+ SliceOp<SYCLDevice, type>)
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_SYCL);
+
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("Slice")
+ .Device(DEVICE_SYCL)
+ .TypeConstraint<int32>("T")
+ .TypeConstraint<int32>("Index")
+ .HostMemory("input")
+ .HostMemory("begin")
+ .HostMemory("size")
+ .HostMemory("output"),
+ SliceOp<CPUDevice, int32>);
+
+#undef REGISTER_SYCL
+
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/slice_op_cpu_impl.h b/tensorflow/core/kernels/slice_op_cpu_impl.h
index 0b0700ec36..a70805658e 100644
--- a/tensorflow/core/kernels/slice_op_cpu_impl.h
+++ b/tensorflow/core/kernels/slice_op_cpu_impl.h
@@ -34,6 +34,18 @@ DEFINE_CPU_KERNELS(bfloat16);
#undef DEFINE_CPU_KERNELS
+#ifdef TENSORFLOW_USE_SYCL
+using SyclDevice = Eigen::SyclDevice;
+
+#define DEFINE_SYCL_KERNELS(T) \
+ template struct functor::Slice<SyclDevice, T, CPU_PROVIDED_IXDIM>;
+
+TF_CALL_GPU_NUMBER_TYPES(DEFINE_SYCL_KERNELS);
+DEFINE_SYCL_KERNELS(int32);
+
+#undef DEFINE_SYCL_KERNELS
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SLICE_OP_CPU_IMPL_H_
diff --git a/tensorflow/core/kernels/split_lib.h b/tensorflow/core/kernels/split_lib.h
index 240cce46e0..ff92ffeeb3 100644
--- a/tensorflow/core/kernels/split_lib.h
+++ b/tensorflow/core/kernels/split_lib.h
@@ -48,6 +48,17 @@ struct Split<Eigen::ThreadPoolDevice, T> {
const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_sizes);
};
+#ifdef TENSORFLOW_USE_SYCL
+template <typename T>
+struct Split<Eigen::SyclDevice, T> {
+ void operator()(const Eigen::SyclDevice& d,
+ typename TTypes<T, 3>::Tensor output,
+ typename TTypes<T, 3>::ConstTensor input,
+ const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_indices,
+ const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_sizes);
+};
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace functor
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/split_lib_cpu.cc b/tensorflow/core/kernels/split_lib_cpu.cc
index 41b2d6f0f5..e377e4d97a 100644
--- a/tensorflow/core/kernels/split_lib_cpu.cc
+++ b/tensorflow/core/kernels/split_lib_cpu.cc
@@ -43,5 +43,24 @@ TF_CALL_ALL_TYPES(DEFINE_CPU_KERNELS)
DEFINE_CPU_KERNELS(quint8)
DEFINE_CPU_KERNELS(bfloat16)
+#ifdef TENSORFLOW_USE_SYCL
+template <typename T>
+void Split<Eigen::SyclDevice, T>::operator()(
+ const Eigen::SyclDevice& d, typename TTypes<T, 3>::Tensor output,
+ typename TTypes<T, 3>::ConstTensor input,
+ const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_indices,
+ const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_sizes) {
+ if (output.size() < 131072) {
+ output = input.slice(slice_indices, slice_sizes);
+ } else {
+ output.device(d) = input.slice(slice_indices, slice_sizes);
+ }
+}
+
+#define DEFINE_SYCL_KERNELS(T) template struct Split<Eigen::SyclDevice, T>;
+
+TF_CALL_GPU_NUMBER_TYPES(DEFINE_SYCL_KERNELS)
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace functor
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/split_op.cc b/tensorflow/core/kernels/split_op.cc
index 4b12e1f995..cca2fc41c2 100644
--- a/tensorflow/core/kernels/split_op.cc
+++ b/tensorflow/core/kernels/split_op.cc
@@ -36,6 +36,9 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+#endif // TENSORFLOW_USE_SYCL
template <typename Device, typename T>
class SplitOpBase : public OpKernel {
@@ -243,6 +246,75 @@ class SplitOpGPU : public SplitOpBase<GPUDevice, T> {
};
#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+
+template <typename T>
+class SplitOpSYCL : public SplitOpBase<SYCLDevice, T> {
+ public:
+ typedef SplitOpBase<SYCLDevice, T> Base;
+ explicit SplitOpSYCL(OpKernelConstruction* c) : Base(c) {}
+
+ void Compute(OpKernelContext* context) override {
+ bool done = false;
+ Base::ComputeEasyCases(context, &done);
+ if (!context->status().ok() || done) {
+ return;
+ }
+ const int32 split_dim = context->input(0).flat<int32>()(0);
+ const int32 num_split = Base::num_outputs();
+ const Tensor& input = context->input(1);
+ const TensorShape& input_shape = input.shape();
+
+ // Android also uses int32 indexing, so check here also.
+ OP_REQUIRES(
+ context, FastBoundsCheck(input.NumElements(),
+ std::numeric_limits<Eigen::DenseIndex>::max()),
+ errors::InvalidArgument("Split requires input size < ",
+ std::numeric_limits<Eigen::DenseIndex>::max()));
+
+ Eigen::DenseIndex prefix_dim_size;
+ Eigen::DenseIndex split_dim_size;
+ Eigen::DenseIndex suffix_dim_size;
+
+ std::tie(prefix_dim_size, split_dim_size, suffix_dim_size) =
+ Base::template SetDims<Eigen::DenseIndex>(input_shape, split_dim);
+ auto input_reshaped =
+ input.shaped<T, 3>({prefix_dim_size, split_dim_size, suffix_dim_size});
+
+ const int64 split_dim_output_size = split_dim_size / num_split;
+ TensorShape output_shape(input_shape);
+ output_shape.set_dim(split_dim, split_dim_output_size);
+
+ Eigen::DSizes<Eigen::DenseIndex, 3> indices{0, 0, 0};
+ Eigen::DSizes<Eigen::DenseIndex, 3> sizes{
+ prefix_dim_size, split_dim_output_size, suffix_dim_size};
+
+ for (int i = 0; i < num_split; ++i) {
+ Tensor* result = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(i, output_shape, &result));
+ if (prefix_dim_size * split_dim_output_size * suffix_dim_size > 0) {
+ Eigen::DSizes<Eigen::DenseIndex, 3> slice_indices;
+ Eigen::DSizes<Eigen::DenseIndex, 3> slice_sizes;
+ for (int j = 0; j < 3; ++j) {
+ slice_indices[j] = indices[j];
+ slice_sizes[j] = sizes[j];
+ }
+
+ auto result_shaped = result->shaped<T, 3>(
+ {prefix_dim_size, split_dim_output_size, suffix_dim_size});
+
+ functor::Split<SYCLDevice, T>()(context->eigen_device<SYCLDevice>(),
+ result_shaped, input_reshaped,
+ slice_indices, slice_sizes);
+ }
+ indices[1] += split_dim_output_size;
+ }
+ }
+};
+
+#endif // TENSORFLOW_USE_SYCL
+
#define REGISTER_SPLIT(type) \
REGISTER_KERNEL_BUILDER(Name("Split") \
.Device(DEVICE_CPU) \
@@ -269,4 +341,17 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL(type) \
+ REGISTER_KERNEL_BUILDER(Name("Split") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("split_dim"), \
+ SplitOpSYCL<type>)
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_SYCL);
+#undef REGISTER_SYCL
+
+#endif // TENSORFLOW_USE_SYCL
+
} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/tile_ops.cc b/tensorflow/core/kernels/tile_ops.cc
index e49a319aed..36cabaaf7d 100644
--- a/tensorflow/core/kernels/tile_ops.cc
+++ b/tensorflow/core/kernels/tile_ops.cc
@@ -40,6 +40,9 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+#endif // TENSORFLOW_USE_SYCL
// Forward declarations of functors that will be defined in
// tile_ops_cpu_impl*.cc and tile_ops_gpu.cu.cc.
@@ -225,6 +228,11 @@ inline void TileOp<Device>::HandleCase(
#define HANDLE_TYPE_NAME_GPU(T) \
HANDLE_CASE_DIM(GPUDevice, T, DataTypeToEnum<T>::value);
+#ifdef TENSORFLOW_USE_SYCL
+#define HANDLE_TYPE_NAME_SYCL(T) \
+ HANDLE_CASE_DIM(SYCLDevice, T, DataTypeToEnum<T>::value);
+#endif // TENSORFLOW_USE_SYCL
+
TF_CALL_bool(HANDLE_TYPE_NAME_CPU);
TF_CALL_float(HANDLE_TYPE_NAME_CPU);
TF_CALL_double(HANDLE_TYPE_NAME_CPU);
@@ -248,8 +256,15 @@ TF_CALL_complex64(HANDLE_TYPE_NAME_GPU);
TF_CALL_complex128(HANDLE_TYPE_NAME_GPU);
#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+TF_CALL_float(HANDLE_TYPE_NAME_SYCL);
+#endif // TENSORFLOW_USE_SYCL
+
#undef HANDLE_TYPE_NAME_CPU
#undef HANDLE_TYPE_NAME_GPU
+#ifdef TENSORFLOW_USE_SYCL
+#undef HANDLE_TYPE_NAME_SYCL
+#endif // TENSORFLOW_USE_SYCL
#undef HANDLE_CASE_DIM
#undef HANDLE_CASE
@@ -578,4 +593,14 @@ REGISTER_KERNEL_BUILDER(Name("TileGrad")
TileGradientOp<GPUDevice>);
#endif // GOOGLE_CUDA
+
+#ifdef TENSORFLOW_USE_SYCL
+REGISTER_KERNEL_BUILDER(Name("Tile")
+ .Device(DEVICE_SYCL)
+ .TypeConstraint<float>("T")
+ .TypeConstraint<int32>("Tmultiples")
+ .HostMemory("multiples"),
+ TileOp<SYCLDevice>);
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/tile_ops_cpu_impl.h b/tensorflow/core/kernels/tile_ops_cpu_impl.h
index 9cdf69ad0b..650c739ed5 100644
--- a/tensorflow/core/kernels/tile_ops_cpu_impl.h
+++ b/tensorflow/core/kernels/tile_ops_cpu_impl.h
@@ -62,6 +62,30 @@ TF_CALL_complex128(DEFINE_TYPE);
#undef DEFINE_DIM
#undef DEFINE_TYPE
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+
+// Register functors used for TileOp.
+#define DEFINE_DIM(T, NDIM) template struct Tile<SYCLDevice, T, NDIM>;
+#define DEFINE_TYPE(T) DEFINE_DIM(T, CPU_PROVIDED_IXDIM)
+
+TF_CALL_float(DEFINE_TYPE);
+
+#undef DEFINE_DIM
+#undef DEFINE_TYPE
+
+// Register functors used for TileGradientOp.
+#define DEFINE_DIM(T, NDIM) \
+ template struct TileGrad<SYCLDevice, T, NDIM>; \
+ template struct ReduceAndReshape<SYCLDevice, T, NDIM, 1>;
+#define DEFINE_TYPE(T) DEFINE_DIM(T, CPU_PROVIDED_IXDIM)
+
+TF_CALL_float(DEFINE_TYPE);
+
+#undef DEFINE_DIM
+#undef DEFINE_TYPE
+#endif // TENSORFLOW_USE_SYCL
+
} // end namespace functor
} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc
index 172449a998..641c991a7e 100644
--- a/tensorflow/core/kernels/training_ops.cc
+++ b/tensorflow/core/kernels/training_ops.cc
@@ -46,6 +46,17 @@ struct ApplyGradientDescent<CPUDevice, T> {
}
};
+#ifdef TENSORFLOW_USE_SYCL
+template <typename T>
+struct ApplyGradientDescent<SYCLDevice, T> {
+ void operator()(const SYCLDevice& d, typename TTypes<T>::Flat var,
+ typename TTypes<T>::ConstScalar lr,
+ typename TTypes<T>::ConstFlat grad) {
+ var.device(d) -= grad * lr();
+ }
+};
+#endif
+
template <typename T>
struct ApplyAdadelta<CPUDevice, T> {
void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
@@ -357,6 +368,12 @@ TF_CALL_half(REGISTER_CPU_KERNELS);
TF_CALL_float(REGISTER_CPU_KERNELS);
TF_CALL_double(REGISTER_CPU_KERNELS);
+#ifdef TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNELS(T) REGISTER_KERNELS(SYCL, T);
+TF_CALL_float(REGISTER_SYCL_KERNELS);
+#undef REGISTER_SYCL_KERNELS
+#endif
+
#if GOOGLE_CUDA
// Forward declarations of the functor specializations for GPU.
namespace functor {
diff --git a/tensorflow/core/kernels/transpose_functor_cpu.cc b/tensorflow/core/kernels/transpose_functor_cpu.cc
index f8c87e7e2e..30b82f1843 100644
--- a/tensorflow/core/kernels/transpose_functor_cpu.cc
+++ b/tensorflow/core/kernels/transpose_functor_cpu.cc
@@ -114,4 +114,28 @@ Status DoTranspose<Device>(const Device& d, const Tensor& in,
return Status::OK();
}
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+
+template <>
+Status DoTranspose<SYCLDevice>(const SYCLDevice& d, const Tensor& in,
+ const gtl::ArraySlice<int32> perm, Tensor* out) {
+ CHECK_GE(in.dims(), 2);
+ CHECK_EQ(in.dims(), out->dims());
+ CHECK_EQ(in.dims(), perm.size());
+ CHECK_EQ(in.dtype(), out->dtype());
+ switch (in.dtype()) {
+
+ case DT_FLOAT:
+ case DT_INT32:
+ internal::Transpose<SYCLDevice, uint32>(d, in, perm, out);
+ break;
+
+ default:
+ return errors::Unimplemented("Unsupported dtype on SYCL: ", in.dtype());
+ }
+ return Status::OK();
+}
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/unpack_op.cc b/tensorflow/core/kernels/unpack_op.cc
index 29959cb187..2a14fa3265 100644
--- a/tensorflow/core/kernels/unpack_op.cc
+++ b/tensorflow/core/kernels/unpack_op.cc
@@ -32,6 +32,10 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+#endif // TENSORFLOW_USE_SYCL
+
template <typename Device, typename T>
class UnpackOp : public OpKernel {
public:
@@ -149,4 +153,25 @@ REGISTER_KERNEL_BUILDER(Name("Unpack")
#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Unpack").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ UnpackOp<SYCLDevice, type>)
+
+REGISTER_SYCL(float);
+#undef REGISTER_SYCL
+
+// A special SYCL kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("Unpack")
+ .Device(DEVICE_SYCL)
+ .HostMemory("value")
+ .HostMemory("output")
+ .TypeConstraint<int32>("T"),
+ UnpackOp<CPUDevice, int32>);
+
+#endif // TENSORFLOW_USE_SYCL
+
} // end namespace tensorflow
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index a3b0512304..4741bc968a 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -416,24 +416,49 @@ REGISTER_OP("SplitV")
.Attr("T: type")
.Attr("Tlen: {int32, int64} = DT_INT64")
.SetShapeFn([](InferenceContext* c) {
- ShapeHandle unused;
+ DimensionHandle split_dimension;
+ TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &split_dimension));
int32 num_outputs = c->num_outputs();
- // Return unknown shapes with the same rank as the input
- // or unknown rank if input's rank isn't known
- // can't determine exact shapes until runtime because
- // we don't know where the tensor containing the split sizes
- // is located
- int32 rank = c->Rank(c->input(0));
+ ShapeHandle input = c->input(0);
+ int32 rank = c->Rank(input);
ShapeHandle output_shape;
+ const Tensor* size_splits = c->input_tensor(1);
if (rank == InferenceContext::kUnknownRank) {
+ // If the rank of input tensor is unknown, then return unkown shapes.
output_shape = c->UnknownShape();
+ for (int i = 0; i < num_outputs; ++i) {
+ c->set_output(i, output_shape);
+ }
} else if (rank == 0) {
+ // Throw error if input is a scalar.
return errors::InvalidArgument("Can't split scalars");
- } else {
+ } else if (size_splits == nullptr || !c->ValueKnown(split_dimension)) {
+ // If split dimension or tensor containing the split sizes is unkown,
+ // then return unknown shapes of same rank as input.
output_shape = c->UnknownShapeOfRank(rank);
- }
- for (int i = 0; i < num_outputs; ++i) {
- c->set_output(i, output_shape);
+ for (int i = 0; i < num_outputs; ++i) {
+ c->set_output(i, output_shape);
+ }
+ } else {
+ // Determine the output shape if split dimension and split sizes are known
+ int64 split_dim = c->Value(split_dimension);
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, split_dim + 1, &input));
+ std::vector<int64> data;
+ if (size_splits->dtype() == DT_INT32) {
+ data = AsInt64<int32>(size_splits, size_splits->shape().dim_size(0));
+ } else {
+ data = AsInt64<int64>(size_splits, size_splits->shape().dim_size(0));
+ }
+ if (num_outputs != data.size()) {
+ return errors::InvalidArgument(
+ "Length of size_splits should be equal to num_outputs");
+ }
+ for (int i = 0; i < num_outputs; ++i) {
+ output_shape = c->UnknownShapeOfRank(rank);
+ TF_RETURN_IF_ERROR(
+ c->ReplaceDim(input, split_dim, c->MakeDim(data[i]), &output_shape));
+ c->set_output(i, output_shape);
+ }
}
return Status::OK();
diff --git a/tensorflow/core/platform/default/logging.cc b/tensorflow/core/platform/default/logging.cc
index 1d03725c78..56985eec15 100644
--- a/tensorflow/core/platform/default/logging.cc
+++ b/tensorflow/core/platform/default/logging.cc
@@ -84,39 +84,49 @@ void LogMessage::GenerateLogMessage() {
namespace {
-int64 MinLogLevel() {
- const char* tf_env_var_val = getenv("TF_CPP_MIN_LOG_LEVEL");
+// Parse log level (int64) from environment variable (char*)
+int64 LogLevelStrToInt(const char* tf_env_var_val) {
if (tf_env_var_val == nullptr) {
return 0;
}
// Ideally we would use env_var / safe_strto64, but it is
// hard to use here without pulling in a lot of dependencies,
- // so we do a poor-man's parsing.
+ // so we use std:istringstream instead
string min_log_level(tf_env_var_val);
- if (min_log_level == "1") {
- // Maps to WARNING
- return 1;
- } else if (min_log_level == "2") {
- // Maps to ERROR
- return 2;
- } else if (min_log_level == "3") {
- // Maps to FATAL
- return 3;
- } else {
- // Maps to INFO (the default).
- return 0;
+ std::istringstream ss(min_log_level);
+ int64 level;
+ if (!(ss >> level)) {
+ // Invalid vlog level setting, set level to default (0)
+ level = 0;
}
+
+ return level;
+}
+
+int64 MinLogLevelFromEnv() {
+ const char* tf_env_var_val = getenv("TF_CPP_MIN_LOG_LEVEL");
+ return LogLevelStrToInt(tf_env_var_val);
+}
+
+int64 MinVLogLevelFromEnv() {
+ const char* tf_env_var_val = getenv("TF_CPP_MIN_VLOG_LEVEL");
+ return LogLevelStrToInt(tf_env_var_val);
}
} // namespace
LogMessage::~LogMessage() {
// Read the min log level once during the first call to logging.
- static int64 min_log_level = MinLogLevel();
+ static int64 min_log_level = MinLogLevelFromEnv();
if (TF_PREDICT_TRUE(severity_ >= min_log_level)) GenerateLogMessage();
}
+int64 LogMessage::MinVLogLevel() {
+ static int64 min_vlog_level = MinVLogLevelFromEnv();
+ return min_vlog_level;
+}
+
LogMessageFatal::LogMessageFatal(const char* file, int line)
: LogMessage(file, line, FATAL) {}
LogMessageFatal::~LogMessageFatal() {
diff --git a/tensorflow/core/platform/default/logging.h b/tensorflow/core/platform/default/logging.h
index 961fb8b4ad..04ff9e12b6 100644
--- a/tensorflow/core/platform/default/logging.h
+++ b/tensorflow/core/platform/default/logging.h
@@ -41,6 +41,11 @@ class LogMessage : public std::basic_ostringstream<char> {
LogMessage(const char* fname, int line, int severity);
~LogMessage();
+ // Returns the minimum log level for VLOG statements.
+ // E.g., if MinVLogLevel() is 2, then VLOG(2) statements will produce output,
+ // but VLOG(3) will not. Defaults to 0.
+ static int64 MinVLogLevel();
+
protected:
void GenerateLogMessage();
@@ -71,11 +76,18 @@ class LogMessageFatal : public LogMessage {
#define LOG(severity) _TF_LOG_##severity
-// TODO(jeff): Define a proper implementation of VLOG_IS_ON
+#ifdef IS_MOBILE_PLATFORM
+// Turn VLOG off when under mobile devices for considerations of binary size.
#define VLOG_IS_ON(lvl) ((lvl) <= 0)
+#else
+// Otherwise, Set TF_CPP_MIN_VLOG_LEVEL environment to update minimum log level
+// of VLOG
+#define VLOG_IS_ON(lvl) \
+ ((lvl) <= ::tensorflow::internal::LogMessage::MinVLogLevel())
+#endif
#define VLOG(lvl) \
- if (VLOG_IS_ON(lvl)) \
+ if (TF_PREDICT_FALSE(VLOG_IS_ON(lvl))) \
::tensorflow::internal::LogMessage(__FILE__, __LINE__, tensorflow::INFO)
// CHECK dies with a fatal error if condition is not true. It is *not*