aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/broadcast_to_op_gpu.cu.cc
diff options
context:
space:
mode:
authorGravatar Yong Tang <yong.tang.github@outlook.com>2018-04-15 19:22:39 -0700
committerGravatar Jonathan Hseu <vomjom@vomjom.net>2018-04-15 19:22:39 -0700
commitba1c53a5f2bb106e16ec7503dbd4d0db9ecc9799 (patch)
treed7b9281a71d4a1355183d2776a1df8cb41a1e2bc /tensorflow/core/kernels/broadcast_to_op_gpu.cu.cc
parent0586c57292a7bd1a79b4a03270c0f1c32d02a4af (diff)
Add support for explicit broadcasting in TensorFlow (#15243)
* Add support for explicit broadcasting in TensorFlow This fix tries to adds support for explicit broadcasting in TensorFlow, as was suggested in 14509. This fix adds the op of tf.broadcast_to, which is equivalent to the numpy.broadcast_to in numpy. This fix fixes 14509. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Register BroadcastTo op in array_ops.cc Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Sanitize with clang-format -i Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add test cases for tf.broadcast_to Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Sanitize bazel BUILD and python. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Split broadcast_to_ops_test from array_ops_test Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Support int64 shape Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Improve shape inference for broadcast_to Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add scalar input support for broadcast_to Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Update API defs tensorflow/core/api_def/update_api_def.sh Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Update API golden ``` bazel-bin/tensorflow/tools/api/tests/api_compatibility_test --update_goldens True ``` Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Update docstring for broadcast_to Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Enable GPU kernel for BroadcastTo Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Enable use_gpu=True for test cases Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Hiden the ops and export to tf.contrib.framework for now. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add the op to the _allowed_symbol in tf.contrib.framework Otherwise the symbole will be hidden Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Fix pylint sanity issue. Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Diffstat (limited to 'tensorflow/core/kernels/broadcast_to_op_gpu.cu.cc')
-rw-r--r--tensorflow/core/kernels/broadcast_to_op_gpu.cu.cc34
1 files changed, 34 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/broadcast_to_op_gpu.cu.cc b/tensorflow/core/kernels/broadcast_to_op_gpu.cu.cc
new file mode 100644
index 0000000000..6459571085
--- /dev/null
+++ b/tensorflow/core/kernels/broadcast_to_op_gpu.cu.cc
@@ -0,0 +1,34 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/kernels/broadcast_to_op.h"
+#include "tensorflow/core/framework/register_types.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+#define INSTANTIATE_GPU_KERNEL(Type) \
+ template class functor::BroadcastTo<GPUDevice, Type>;
+TF_CALL_GPU_ALL_TYPES(INSTANTIATE_GPU_KERNEL);
+#undef INSTANTIATE_GPU_KERNEL
+
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA