aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/split_op_cpu.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/split_op_cpu.cc')
-rw-r--r--tensorflow/core/kernels/split_op_cpu.cc30
1 files changed, 30 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/split_op_cpu.cc b/tensorflow/core/kernels/split_op_cpu.cc
new file mode 100644
index 0000000000..b86deeb8fb
--- /dev/null
+++ b/tensorflow/core/kernels/split_op_cpu.cc
@@ -0,0 +1,30 @@
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/kernels/split_op.h"
+
+#include "tensorflow/core/framework/numeric_types.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor_types.h"
+
+namespace tensorflow {
+namespace functor {
+
+template <typename T>
+void Split<Eigen::ThreadPoolDevice, T>::operator()(
+ const Eigen::ThreadPoolDevice& d, typename TTypes<T, 3>::Tensor output,
+ typename TTypes<T, 3>::ConstTensor input,
+ const Eigen::DSizes<ptrdiff_t, 3>& slice_indices,
+ const Eigen::DSizes<ptrdiff_t, 3>& slice_sizes) {
+ if (output.size() < 131072) {
+ output = input.slice(slice_indices, slice_sizes);
+ } else {
+ output.device(d) = input.slice(slice_indices, slice_sizes);
+ }
+}
+
+#define DEFINE_CPU_KERNELS(T) template struct Split<Eigen::ThreadPoolDevice, T>;
+
+TF_CALL_ALL_TYPES(DEFINE_CPU_KERNELS)
+
+} // namespace functor
+} // namespace tensorflow