aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/kernels/bcast_ops.cc10
-rw-r--r--tensorflow/core/kernels/cwise_op_floor_div.cc10
-rw-r--r--tensorflow/core/kernels/cwise_op_isfinite.cc10
-rw-r--r--tensorflow/core/kernels/cwise_op_isinf.cc10
-rw-r--r--tensorflow/core/kernels/cwise_op_isnan.cc10
-rw-r--r--tensorflow/core/kernels/cwise_ops_sycl_common.h30
-rw-r--r--tensorflow/core/kernels/sendrecv_ops.cc2
-rw-r--r--tensorflow/python/kernel_tests/basic_gpu_test.py10
-rwxr-xr-xthird_party/sycl/crosstool/computecpp.tpl2
9 files changed, 88 insertions, 6 deletions
diff --git a/tensorflow/core/kernels/bcast_ops.cc b/tensorflow/core/kernels/bcast_ops.cc
index 10354cbb56..db8842a547 100644
--- a/tensorflow/core/kernels/bcast_ops.cc
+++ b/tensorflow/core/kernels/bcast_ops.cc
@@ -90,4 +90,14 @@ REGISTER_KERNEL_BUILDER(Name("BroadcastGradientArgs")
.HostMemory("r1"),
BCastGradArgsOp);
+#if TENSORFLOW_USE_SYCL
+REGISTER_KERNEL_BUILDER(Name("BroadcastGradientArgs")
+ .Device(DEVICE_SYCL)
+ .TypeConstraint<int32>("T")
+ .HostMemory("s0")
+ .HostMemory("s1")
+ .HostMemory("r0")
+ .HostMemory("r1"),
+ BCastGradArgsOp);
+#endif
} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_floor_div.cc b/tensorflow/core/kernels/cwise_op_floor_div.cc
index 83b2771ed2..7930d83413 100644
--- a/tensorflow/core/kernels/cwise_op_floor_div.cc
+++ b/tensorflow/core/kernels/cwise_op_floor_div.cc
@@ -18,6 +18,16 @@ limitations under the License.
namespace tensorflow {
REGISTER5(BinaryOp, CPU, "FloorDiv", functor::safe_floor_div, uint8, uint16,
int16, int32, int64);
+#if TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNEL(TYPE) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("FloorDiv") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<TYPE>("T"), \
+ BinaryOp<SYCLDevice, functor::floor_div<TYPE>>);
+TF_CALL_INTEGRAL_TYPES(REGISTER_SYCL_KERNEL);
+#undef REGISTER_SYCL_KERNEL
+#endif // TENSORFLOW_USE_SYCL
#if GOOGLE_CUDA
REGISTER4(BinaryOp, GPU, "FloorDiv", functor::floor_div, uint8, uint16, int16,
int64);
diff --git a/tensorflow/core/kernels/cwise_op_isfinite.cc b/tensorflow/core/kernels/cwise_op_isfinite.cc
index 954b5d25bd..e38b271318 100644
--- a/tensorflow/core/kernels/cwise_op_isfinite.cc
+++ b/tensorflow/core/kernels/cwise_op_isfinite.cc
@@ -18,6 +18,16 @@ limitations under the License.
namespace tensorflow {
REGISTER3(UnaryOp, CPU, "IsFinite", functor::isfinite, float, Eigen::half,
double);
+#if TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNEL(TYPE) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("IsFinite") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<TYPE>("T"), \
+ UnaryOp<SYCLDevice, functor::isfinite<TYPE>>);
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_SYCL_KERNEL);
+#undef REGISTER_SYCL_KERNEL
+#endif // TENSORFLOW_USE_SYCL
#if GOOGLE_CUDA
REGISTER3(UnaryOp, GPU, "IsFinite", functor::isfinite, float, Eigen::half,
double);
diff --git a/tensorflow/core/kernels/cwise_op_isinf.cc b/tensorflow/core/kernels/cwise_op_isinf.cc
index 407dadcb69..bf056dbe0e 100644
--- a/tensorflow/core/kernels/cwise_op_isinf.cc
+++ b/tensorflow/core/kernels/cwise_op_isinf.cc
@@ -17,6 +17,16 @@ limitations under the License.
namespace tensorflow {
REGISTER3(UnaryOp, CPU, "IsInf", functor::isinf, float, Eigen::half, double);
+#if TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNEL(TYPE) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("IsInf") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<TYPE>("T"), \
+ UnaryOp<SYCLDevice, functor::isinf<TYPE>>);
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_SYCL_KERNEL);
+#undef REGISTER_SYCL_KERNEL
+#endif // TENSORFLOW_USE_SYCL
#if GOOGLE_CUDA
REGISTER3(UnaryOp, GPU, "IsInf", functor::isinf, float, Eigen::half, double);
#endif
diff --git a/tensorflow/core/kernels/cwise_op_isnan.cc b/tensorflow/core/kernels/cwise_op_isnan.cc
index f150b2f3f4..d2bac23882 100644
--- a/tensorflow/core/kernels/cwise_op_isnan.cc
+++ b/tensorflow/core/kernels/cwise_op_isnan.cc
@@ -17,6 +17,16 @@ limitations under the License.
namespace tensorflow {
REGISTER3(UnaryOp, CPU, "IsNan", functor::isnan, float, Eigen::half, double);
+#if TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNEL(TYPE) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("IsNan") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<TYPE>("T"), \
+ UnaryOp<SYCLDevice, functor::isnan<TYPE>>);
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_SYCL_KERNEL);
+#undef REGISTER_SYCL_KERNEL
+#endif // TENSORFLOW_USE_SYCL
#if GOOGLE_CUDA
REGISTER3(UnaryOp, GPU, "IsNan", functor::isnan, float, Eigen::half, double);
#endif
diff --git a/tensorflow/core/kernels/cwise_ops_sycl_common.h b/tensorflow/core/kernels/cwise_ops_sycl_common.h
index 43385d7146..32b4dd0933 100644
--- a/tensorflow/core/kernels/cwise_ops_sycl_common.h
+++ b/tensorflow/core/kernels/cwise_ops_sycl_common.h
@@ -44,7 +44,7 @@ template <typename Functor>
struct UnaryFunctor<SYCLDevice, Functor> {
void operator()(const SYCLDevice& d, typename Functor::tout_type out,
typename Functor::tin_type in) {
- LOG(FATAL) << "UnaryFunctor::operator() NOT IMPLEMENTED ! ";
+ To32Bit(out).device(d) = To32Bit(in).unaryExpr(typename Functor::func());
}
};
@@ -54,19 +54,21 @@ struct BinaryFunctor<SYCLDevice, Functor, NDIMS, has_errors> {
void operator()(const SYCLDevice& d, typename Functor::tout_type out,
typename Functor::tin_type in0,
typename Functor::tin_type in1, bool* error) {
- Assign(d, out, in0.binaryExpr(in1, typename Functor::func()));
+ To32Bit(out).device(d) = To32Bit(in0).binaryExpr(in1, typename Functor::func());
}
void Left(const SYCLDevice& d, typename Functor::tout_type out,
typename Functor::tscalar_type scalar,
typename Functor::tin_type in, bool* error) {
- LOG(FATAL) << "BinaryFunctor::Left NOT IMPLEMENTED ! ";
+ typedef typename Functor::func Binary;
+ To32Bit(out).device(d) = To32Bit(in).binaryExpr(typename Functor::tin_type(scalar.data(),in.dimensions()), Binary());
}
void Right(const SYCLDevice& d, typename Functor::tout_type out,
typename Functor::tin_type in,
typename Functor::tscalar_type scalar, bool* error) {
- LOG(FATAL) << "BinaryFunctor::Right NOT IMPLEMENTED ! ";
+ typedef typename Functor::func Binary;
+ To32Bit(out).device(d) = To32Bit(in).binaryExpr(typename Functor::tin_type(scalar.data(),in.dimensions()), Binary());
}
void BCast(const SYCLDevice& d,
@@ -76,7 +78,25 @@ struct BinaryFunctor<SYCLDevice, Functor, NDIMS, has_errors> {
typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1,
typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1,
bool* error) {
- LOG(FATAL) << "BinaryFunctor::BCast NOT IMPLEMENTED ";
+ typedef typename Functor::in_type T;
+ typename Functor::func func;
+ if ((NDIMS == 2) && Functor::use_bcast_optimization &&
+ use_bcast_optimization<T>::value) {
+ const bool bcast0_all_one = AllOne<NDIMS>(bcast0);
+ const bool bcast1_all_one = AllOne<NDIMS>(bcast1);
+ if (bcast0_all_one && !bcast1_all_one) {
+ To32Bit(out).device(d) =
+ To32Bit(in0).binaryExpr(To32Bit(in1).broadcast(bcast1), func);
+ return;
+ }
+ if (!bcast0_all_one && bcast1_all_one) {
+ To32Bit(out).device(d) =
+ To32Bit(in0).broadcast(bcast0).binaryExpr(To32Bit(in1), func);
+ return;
+ }
+ }
+ To32Bit(out).device(d) = To32Bit(in0).broadcast(bcast0).binaryExpr(
+ To32Bit(in1).broadcast(bcast1), func);
}
};
diff --git a/tensorflow/core/kernels/sendrecv_ops.cc b/tensorflow/core/kernels/sendrecv_ops.cc
index b0cfcfce09..f3609ea706 100644
--- a/tensorflow/core/kernels/sendrecv_ops.cc
+++ b/tensorflow/core/kernels/sendrecv_ops.cc
@@ -80,6 +80,8 @@ REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_GPU), SendOp);
#if TENSORFLOW_USE_SYCL
REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_SYCL), SendOp);
+REGISTER_KERNEL_BUILDER(
+ Name("_HostSend").Device(DEVICE_SYCL).HostMemory("tensor"), SendOp);
#endif
REGISTER_KERNEL_BUILDER(Name("_HostSend").Device(DEVICE_CPU), SendOp);
diff --git a/tensorflow/python/kernel_tests/basic_gpu_test.py b/tensorflow/python/kernel_tests/basic_gpu_test.py
index 22c6a08e8a..4c92e1ea73 100644
--- a/tensorflow/python/kernel_tests/basic_gpu_test.py
+++ b/tensorflow/python/kernel_tests/basic_gpu_test.py
@@ -22,6 +22,7 @@ import tensorflow as tf
import math
import numpy as np
from tensorflow.python.ops import gen_math_ops
+from tensorflow.python.ops.gen_array_ops import _broadcast_gradient_args
class GPUBinaryOpsTest(tf.test.TestCase):
def _compareGPU(self, x, y, np_func, tf_func):
@@ -46,6 +47,15 @@ class GPUBinaryOpsTest(tf.test.TestCase):
self._compareGPU(x, y, np.subtract, tf.sub)
self._compareGPU(x, y, np.multiply, tf.mul)
self._compareGPU(x, y + 0.1, np.true_divide, tf.truediv)
+
+ def _GetGradientArgs(self, xs, ys):
+ with self.test_session(use_gpu=True) as sess:
+ return sess.run(_broadcast_gradient_args(xs, ys))
+
+ def testBroadcast(self):
+ r0, r1 = self._GetGradientArgs([2, 3, 5], [1])
+ self.assertAllEqual(r0, [])
+ self.assertAllEqual(r1, [0, 1, 2])
if __name__ == "__main__":
tf.test.main()
diff --git a/third_party/sycl/crosstool/computecpp.tpl b/third_party/sycl/crosstool/computecpp.tpl
index 56cc1fb3fd..d319a1eb75 100755
--- a/third_party/sycl/crosstool/computecpp.tpl
+++ b/third_party/sycl/crosstool/computecpp.tpl
@@ -50,7 +50,7 @@ def main():
x = subprocess.call([COMPUTECPP_DRIVER] +computecpp_compiler_flags )
if(x == 0):
- host_compiler_flags = ['-DTENSORFLOW_USE_SYCL', '-Wno-unused-variable', '-I', COMPUTECPP_INCLUDE, "--include",bc_out] + host_compiler_flags
+ host_compiler_flags = ['-D_GLIBCXX_USE_CXX11_ABI=0', '-DTENSORFLOW_USE_SYCL', '-Wno-unused-variable', '-I', COMPUTECPP_INCLUDE, "--include",bc_out] + host_compiler_flags
return subprocess.call([CPU_CXX_COMPILER] +host_compiler_flags )
return x
else: