aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-08-16 08:33:16 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-16 10:01:57 -0700
commit79b1489dcd77f51ae262e1b3f3c0830c90c17e1a (patch)
tree30bca1ac4baa1768bd0f515701688425af690177 /tensorflow
parent9c565882833bc48026515ec4590a6de1de5ff3ad (diff)
Shard tile_ops.cc for cpu, to reduce compilation times.
Change: 130411083
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/makefile/tf_op_files.txt5
-rw-r--r--tensorflow/core/kernels/BUILD8
-rw-r--r--tensorflow/core/kernels/tile_ops.cc156
-rw-r--r--tensorflow/core/kernels/tile_ops_cpu_impl.h68
-rw-r--r--tensorflow/core/kernels/tile_ops_cpu_impl_1.cc18
-rw-r--r--tensorflow/core/kernels/tile_ops_cpu_impl_2.cc18
-rw-r--r--tensorflow/core/kernels/tile_ops_cpu_impl_3.cc18
-rw-r--r--tensorflow/core/kernels/tile_ops_cpu_impl_4.cc18
-rw-r--r--tensorflow/core/kernels/tile_ops_cpu_impl_5.cc18
-rw-r--r--tensorflow/core/kernels/tile_ops_gpu.cu.cc2
-rw-r--r--tensorflow/core/kernels/tile_ops_impl.h (renamed from tensorflow/core/kernels/tile_ops.h)6
11 files changed, 259 insertions, 76 deletions
diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt
index f859631888..12cc2b0160 100644
--- a/tensorflow/contrib/makefile/tf_op_files.txt
+++ b/tensorflow/contrib/makefile/tf_op_files.txt
@@ -7,6 +7,11 @@ tensorflow/core/kernels/transpose_functor_cpu.cc
tensorflow/core/kernels/training_ops.cc
tensorflow/core/kernels/topk_op.cc
tensorflow/core/kernels/tile_ops.cc
+tensorflow/core/kernels/tile_ops_cpu_impl_1.cc
+tensorflow/core/kernels/tile_ops_cpu_impl_2.cc
+tensorflow/core/kernels/tile_ops_cpu_impl_3.cc
+tensorflow/core/kernels/tile_ops_cpu_impl_4.cc
+tensorflow/core/kernels/tile_ops_cpu_impl_5.cc
tensorflow/core/kernels/strided_slice_op_inst_6.cc
tensorflow/core/kernels/strided_slice_op_inst_5.cc
tensorflow/core/kernels/strided_slice_op_inst_4.cc
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 7f8b2b439d..7735e97842 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -1940,7 +1940,8 @@ filegroup(
"save_restore_tensor.h",
"softplus_op.h",
"softsign_op.h",
- "tile_ops.h",
+ "tile_ops_cpu_impl.h",
+ "tile_ops_impl.h",
"training_ops.h",
"transpose_functor.h",
"transpose_op.h",
@@ -2019,6 +2020,11 @@ filegroup(
"stack_ops.cc",
"summary_op.cc",
"tile_ops.cc",
+ "tile_ops_cpu_impl_1.cc",
+ "tile_ops_cpu_impl_2.cc",
+ "tile_ops_cpu_impl_3.cc",
+ "tile_ops_cpu_impl_4.cc",
+ "tile_ops_cpu_impl_5.cc",
"topk_op.cc",
"training_ops.cc",
"transpose_functor_cpu.cc",
diff --git a/tensorflow/core/kernels/tile_ops.cc b/tensorflow/core/kernels/tile_ops.cc
index 5990bfbcf3..4b2d2fa589 100644
--- a/tensorflow/core/kernels/tile_ops.cc
+++ b/tensorflow/core/kernels/tile_ops.cc
@@ -21,23 +21,70 @@ limitations under the License.
#define EIGEN_USE_GPU
#endif // GOOGLE_CUDA
-#include "tensorflow/core/kernels/tile_ops.h"
-
#include <vector>
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/type_index.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
+// Forward declarations of functors that will be defined in
+// tile_ops_cpu_impl*.cc and tile_ops_gpu.cu.cc.
+namespace functor {
+template <typename Device, typename T, int NDIM>
+struct Tile {
+ void operator()(const Device& d, typename TTypes<T, NDIM>::Tensor out,
+ typename TTypes<T, NDIM>::ConstTensor in,
+ const Eigen::array<int32, NDIM>& broadcast_array) const;
+};
+
+template <typename Device, typename T>
+struct Tile<Device, T, 0> {
+ void operator()(const Device& d, typename TTypes<T, 0>::Tensor out,
+ typename TTypes<T, 0>::ConstTensor in,
+ const Eigen::array<int32, 0>&) const;
+};
+
+template <typename Device, typename T, int NDIM>
+struct TileGrad {
+ void operator()(const Device& d, typename TTypes<T, NDIM>::Tensor out,
+ typename TTypes<T, NDIM>::ConstTensor in,
+ const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices,
+ const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes,
+ bool first) const;
+};
+
+template <typename Device, typename T>
+struct TileGrad<Device, T, 0> {
+ void operator()(const Device& d, typename TTypes<T, 0>::Tensor out,
+ typename TTypes<T, 0>::ConstTensor in,
+ const Eigen::DSizes<Eigen::DenseIndex, 0>&,
+ const Eigen::DSizes<Eigen::DenseIndex, 0>&, bool first) const;
+};
+
+template <typename Device, typename T, int NDIM, int REDUCEDNDIM>
+struct ReduceAndReshape {
+ void operator()(
+ const Device& d, typename TTypes<T, NDIM>::Tensor out,
+ typename TTypes<T, NDIM>::ConstTensor in,
+ const Eigen::DSizes<Eigen::DenseIndex, REDUCEDNDIM>& reduce_dim,
+ const Eigen::DSizes<Eigen::DenseIndex, NDIM>& reshape_dim) const;
+};
+} // namespace functor
+
// --------------------------------------------------------------------------
template <typename Device>
class TileOp : public OpKernel {
@@ -153,7 +200,7 @@ inline void TileOp<Device>::HandleCase(
<< DataTypeString(DT) << ", " << NDIM;
}
-#define HANDLE_CASE(device, dtype, ndim) \
+#define HANDLE_CASE(device, T, dtype, ndim) \
template <> \
template <> \
void TileOp<device>::HandleCase<dtype, ndim>( \
@@ -163,15 +210,18 @@ inline void TileOp<Device>::HandleCase(
}
// 0-D handled above
-#define HANDLE_CASE_DIM(device, dtype) \
- HANDLE_CASE(device, dtype, 1); \
- HANDLE_CASE(device, dtype, 2); \
- HANDLE_CASE(device, dtype, 3); \
- HANDLE_CASE(device, dtype, 4); \
- HANDLE_CASE(device, dtype, 5);
+#define HANDLE_CASE_DIM(device, T, dtype) \
+ HANDLE_CASE(device, T, dtype, 1); \
+ HANDLE_CASE(device, T, dtype, 2); \
+ HANDLE_CASE(device, T, dtype, 3); \
+ HANDLE_CASE(device, T, dtype, 4); \
+ HANDLE_CASE(device, T, dtype, 5);
#define HANDLE_TYPE_NAME_CPU(T) \
- HANDLE_CASE_DIM(CPUDevice, DataTypeToEnum<T>::value);
+ HANDLE_CASE_DIM(CPUDevice, T, DataTypeToEnum<T>::value);
+
+#define HANDLE_TYPE_NAME_GPU(T) \
+ HANDLE_CASE_DIM(GPUDevice, T, DataTypeToEnum<T>::value);
TF_CALL_bool(HANDLE_TYPE_NAME_CPU);
TF_CALL_float(HANDLE_TYPE_NAME_CPU);
@@ -186,15 +236,16 @@ TF_CALL_complex128(HANDLE_TYPE_NAME_CPU);
TF_CALL_string(HANDLE_TYPE_NAME_CPU);
#if GOOGLE_CUDA
-HANDLE_CASE_DIM(GPUDevice, DT_FLOAT);
-HANDLE_CASE_DIM(GPUDevice, DT_DOUBLE);
-HANDLE_CASE_DIM(GPUDevice, DT_INT16);
-HANDLE_CASE_DIM(GPUDevice, DT_INT32);
-HANDLE_CASE_DIM(GPUDevice, DT_INT64);
-HANDLE_CASE_DIM(GPUDevice, DT_HALF);
+TF_CALL_float(HANDLE_TYPE_NAME_GPU);
+TF_CALL_double(HANDLE_TYPE_NAME_GPU);
+TF_CALL_int16(HANDLE_TYPE_NAME_GPU);
+TF_CALL_int32(HANDLE_TYPE_NAME_GPU);
+TF_CALL_int64(HANDLE_TYPE_NAME_GPU);
+TF_CALL_half(HANDLE_TYPE_NAME_GPU);
#endif // GOOGLE_CUDA
#undef HANDLE_TYPE_NAME_CPU
+#undef HANDLE_TYPE_NAME_GPU
#undef HANDLE_CASE_DIM
#undef HANDLE_CASE
@@ -385,7 +436,7 @@ inline void TileGradientOp<Device>::HandleCase(
<< ", " << NDIM;
}
-#define HANDLE_CASE(device, dtype, ndim) \
+#define HANDLE_CASE(device, T, dtype, ndim) \
template <> \
template <> \
void TileGradientOp<device>::HandleCase<dtype, ndim>( \
@@ -395,15 +446,18 @@ inline void TileGradientOp<Device>::HandleCase(
}
// 0-D handled specially above
-#define HANDLE_CASE_DIM(device, dtype) \
- HANDLE_CASE(device, dtype, 1); \
- HANDLE_CASE(device, dtype, 2); \
- HANDLE_CASE(device, dtype, 3); \
- HANDLE_CASE(device, dtype, 4); \
- HANDLE_CASE(device, dtype, 5);
+#define HANDLE_CASE_DIM(device, T, dtype) \
+ HANDLE_CASE(device, T, dtype, 1); \
+ HANDLE_CASE(device, T, dtype, 2); \
+ HANDLE_CASE(device, T, dtype, 3); \
+ HANDLE_CASE(device, T, dtype, 4); \
+ HANDLE_CASE(device, T, dtype, 5);
#define HANDLE_TYPE_NAME_CPU(T) \
- HANDLE_CASE_DIM(CPUDevice, DataTypeToEnum<T>::value);
+ HANDLE_CASE_DIM(CPUDevice, T, DataTypeToEnum<T>::value);
+
+#define HANDLE_TYPE_NAME_GPU(T) \
+ HANDLE_CASE_DIM(GPUDevice, T, DataTypeToEnum<T>::value);
TF_CALL_float(HANDLE_TYPE_NAME_CPU);
TF_CALL_double(HANDLE_TYPE_NAME_CPU);
@@ -415,16 +469,16 @@ TF_CALL_complex64(HANDLE_TYPE_NAME_CPU);
TF_CALL_complex128(HANDLE_TYPE_NAME_CPU);
#if GOOGLE_CUDA
-HANDLE_CASE_DIM(GPUDevice, DT_FLOAT);
-HANDLE_CASE_DIM(GPUDevice, DT_DOUBLE);
-HANDLE_CASE_DIM(GPUDevice, DT_INT16);
-HANDLE_CASE_DIM(GPUDevice, DT_INT32);
-HANDLE_CASE_DIM(GPUDevice, DT_INT64);
-HANDLE_CASE_DIM(GPUDevice, DT_HALF);
-
+TF_CALL_float(HANDLE_TYPE_NAME_GPU);
+TF_CALL_double(HANDLE_TYPE_NAME_GPU);
+TF_CALL_int16(HANDLE_TYPE_NAME_GPU);
+TF_CALL_int32(HANDLE_TYPE_NAME_GPU);
+TF_CALL_int64(HANDLE_TYPE_NAME_GPU);
+TF_CALL_half(HANDLE_TYPE_NAME_GPU);
#endif // GOOGLE_CUDA
#undef HANDLE_TYPE_NAME_CPU
+#undef HANDLE_TYPE_NAME_GPU
#undef HANDLE_CASE_DIM
#undef HANDLE_CASE
@@ -436,46 +490,6 @@ REGISTER_KERNEL_BUILDER(Name("TileGrad")
TileGradientOp<CPUDevice>);
#if GOOGLE_CUDA
-#define DEFINE_GPU_TYPE(T) \
- DEFINE_GPU_DIM(T, 1) \
- DEFINE_GPU_DIM(T, 2) \
- DEFINE_GPU_DIM(T, 3) \
- DEFINE_GPU_DIM(T, 4) \
- DEFINE_GPU_DIM(T, 5)
-
-#define DEFINE_GPU_DIM(T, NDIM) \
- template <> \
- void Tile<GPUDevice, T, NDIM>::operator()( \
- const GPUDevice& d, typename TTypes<T, NDIM>::Tensor out, \
- typename TTypes<T, NDIM>::ConstTensor in, \
- const Eigen::array<int32, NDIM>& broadcast_array) const; \
- extern template struct Tile<GPUDevice, T, NDIM>; \
- template <> \
- void TileGrad<GPUDevice, T, NDIM>::operator()( \
- const GPUDevice& d, typename TTypes<T, NDIM>::Tensor out, \
- typename TTypes<T, NDIM>::ConstTensor in, \
- const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices, \
- const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes, bool first) const; \
- extern template struct TileGrad<GPUDevice, T, NDIM>; \
- template <> \
- void ReduceAndReshape<GPUDevice, T, NDIM, 1>::operator()( \
- const GPUDevice& d, typename TTypes<T, NDIM>::Tensor out, \
- typename TTypes<T, NDIM>::ConstTensor in, \
- const Eigen::DSizes<Eigen::DenseIndex, 1>& reduce_dim, \
- const Eigen::DSizes<Eigen::DenseIndex, NDIM>& reshape_dim) const; \
- extern template struct ReduceAndReshape<GPUDevice, T, NDIM, 1>;
-
-namespace functor {
-DEFINE_GPU_TYPE(float);
-DEFINE_GPU_TYPE(double);
-DEFINE_GPU_TYPE(int64);
-DEFINE_GPU_TYPE(int32);
-DEFINE_GPU_TYPE(int16);
-DEFINE_GPU_TYPE(Eigen::half);
-} // end namespace functor
-
-#undef DEFINE_GPU_DIM
-#undef DEFINE_GPU_TYPE
REGISTER_KERNEL_BUILDER(Name("Tile")
.Device(DEVICE_GPU)
diff --git a/tensorflow/core/kernels/tile_ops_cpu_impl.h b/tensorflow/core/kernels/tile_ops_cpu_impl.h
new file mode 100644
index 0000000000..9cdf69ad0b
--- /dev/null
+++ b/tensorflow/core/kernels/tile_ops_cpu_impl.h
@@ -0,0 +1,68 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_TILE_OPS_CPU_IMPL_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_TILE_OPS_CPU_IMPL_H_
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/kernels/tile_ops_impl.h"
+
+namespace tensorflow {
+namespace functor {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+// Register functors used for TileOp.
+#define DEFINE_DIM(T, NDIM) template struct Tile<CPUDevice, T, NDIM>;
+#define DEFINE_TYPE(T) DEFINE_DIM(T, CPU_PROVIDED_IXDIM)
+
+TF_CALL_bool(DEFINE_TYPE);
+TF_CALL_float(DEFINE_TYPE);
+TF_CALL_double(DEFINE_TYPE);
+TF_CALL_uint8(DEFINE_TYPE);
+TF_CALL_int32(DEFINE_TYPE);
+TF_CALL_int16(DEFINE_TYPE);
+TF_CALL_int64(DEFINE_TYPE);
+TF_CALL_half(DEFINE_TYPE);
+TF_CALL_complex64(DEFINE_TYPE);
+TF_CALL_complex128(DEFINE_TYPE);
+TF_CALL_string(DEFINE_TYPE);
+
+#undef DEFINE_DIM
+#undef DEFINE_TYPE
+
+// Register functors used for TileGradientOp.
+#define DEFINE_DIM(T, NDIM) \
+ template struct TileGrad<CPUDevice, T, NDIM>; \
+ template struct ReduceAndReshape<CPUDevice, T, NDIM, 1>;
+#define DEFINE_TYPE(T) DEFINE_DIM(T, CPU_PROVIDED_IXDIM)
+
+TF_CALL_float(DEFINE_TYPE);
+TF_CALL_double(DEFINE_TYPE);
+TF_CALL_int16(DEFINE_TYPE);
+TF_CALL_int32(DEFINE_TYPE);
+TF_CALL_int64(DEFINE_TYPE);
+TF_CALL_half(DEFINE_TYPE);
+TF_CALL_complex64(DEFINE_TYPE);
+TF_CALL_complex128(DEFINE_TYPE);
+
+#undef DEFINE_DIM
+#undef DEFINE_TYPE
+
+} // end namespace functor
+} // end namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_TILE_OPS_CPU_IMPL_H_
diff --git a/tensorflow/core/kernels/tile_ops_cpu_impl_1.cc b/tensorflow/core/kernels/tile_ops_cpu_impl_1.cc
new file mode 100644
index 0000000000..4795505749
--- /dev/null
+++ b/tensorflow/core/kernels/tile_ops_cpu_impl_1.cc
@@ -0,0 +1,18 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define CPU_PROVIDED_IXDIM 1
+#include "tensorflow/core/kernels/tile_ops_cpu_impl.h"
+#undef CPU_PROVIDED_IXDIM
diff --git a/tensorflow/core/kernels/tile_ops_cpu_impl_2.cc b/tensorflow/core/kernels/tile_ops_cpu_impl_2.cc
new file mode 100644
index 0000000000..7fcd31c783
--- /dev/null
+++ b/tensorflow/core/kernels/tile_ops_cpu_impl_2.cc
@@ -0,0 +1,18 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define CPU_PROVIDED_IXDIM 2
+#include "tensorflow/core/kernels/tile_ops_cpu_impl.h"
+#undef CPU_PROVIDED_IXDIM
diff --git a/tensorflow/core/kernels/tile_ops_cpu_impl_3.cc b/tensorflow/core/kernels/tile_ops_cpu_impl_3.cc
new file mode 100644
index 0000000000..3e835b43d2
--- /dev/null
+++ b/tensorflow/core/kernels/tile_ops_cpu_impl_3.cc
@@ -0,0 +1,18 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define CPU_PROVIDED_IXDIM 3
+#include "tensorflow/core/kernels/tile_ops_cpu_impl.h"
+#undef CPU_PROVIDED_IXDIM
diff --git a/tensorflow/core/kernels/tile_ops_cpu_impl_4.cc b/tensorflow/core/kernels/tile_ops_cpu_impl_4.cc
new file mode 100644
index 0000000000..872f654cb9
--- /dev/null
+++ b/tensorflow/core/kernels/tile_ops_cpu_impl_4.cc
@@ -0,0 +1,18 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define CPU_PROVIDED_IXDIM 4
+#include "tensorflow/core/kernels/tile_ops_cpu_impl.h"
+#undef CPU_PROVIDED_IXDIM
diff --git a/tensorflow/core/kernels/tile_ops_cpu_impl_5.cc b/tensorflow/core/kernels/tile_ops_cpu_impl_5.cc
new file mode 100644
index 0000000000..91e332e53a
--- /dev/null
+++ b/tensorflow/core/kernels/tile_ops_cpu_impl_5.cc
@@ -0,0 +1,18 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define CPU_PROVIDED_IXDIM 5
+#include "tensorflow/core/kernels/tile_ops_cpu_impl.h"
+#undef CPU_PROVIDED_IXDIM
diff --git a/tensorflow/core/kernels/tile_ops_gpu.cu.cc b/tensorflow/core/kernels/tile_ops_gpu.cu.cc
index 3870c1a7bb..787ffb4ea7 100644
--- a/tensorflow/core/kernels/tile_ops_gpu.cu.cc
+++ b/tensorflow/core/kernels/tile_ops_gpu.cu.cc
@@ -17,8 +17,8 @@ limitations under the License.
#define EIGEN_USE_GPU
-#include "tensorflow/core/kernels/tile_ops.h"
#include <stdio.h>
+#include "tensorflow/core/kernels/tile_ops_impl.h"
namespace tensorflow {
namespace functor {
diff --git a/tensorflow/core/kernels/tile_ops.h b/tensorflow/core/kernels/tile_ops_impl.h
index b79ac4586c..c41e4bd74b 100644
--- a/tensorflow/core/kernels/tile_ops.h
+++ b/tensorflow/core/kernels/tile_ops_impl.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_TILE_OPS_H_
-#define TENSORFLOW_KERNELS_TILE_OPS_H_
+#ifndef TENSORFLOW_KERNELS_TILE_IMPL_OPS_H_
+#define TENSORFLOW_KERNELS_TILE_IMPL_OPS_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
@@ -91,4 +91,4 @@ struct ReduceAndReshape {
} // end namespace functor
} // end namespace tensorflow
-#endif // TENSORFLOW_KERNELS_TILE_OPS_H_
+#endif // TENSORFLOW_KERNELS_TILE_OPS_IMPL_H_