aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/tile_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/tile_ops.cc')
-rw-r--r--tensorflow/core/kernels/tile_ops.cc76
1 files changed, 21 insertions, 55 deletions
diff --git a/tensorflow/core/kernels/tile_ops.cc b/tensorflow/core/kernels/tile_ops.cc
index 7c72487d3f..f1da3c8afb 100644
--- a/tensorflow/core/kernels/tile_ops.cc
+++ b/tensorflow/core/kernels/tile_ops.cc
@@ -44,21 +44,12 @@ typedef Eigen::GpuDevice GPUDevice;
typedef Eigen::SyclDevice SYCLDevice;
#endif // TENSORFLOW_USE_SYCL
-// Forward declarations of functors that will be defined in
-// tile_ops_cpu_impl*.cc and tile_ops_gpu.cu.cc.
+// Forward declarations of functors that will be defined in tile_ops_impl.h
namespace functor {
-template <typename Device, typename T, int NDIM>
-struct Tile {
- void operator()(const Device& d, typename TTypes<T, NDIM>::Tensor out,
- typename TTypes<T, NDIM>::ConstTensor in,
- const Eigen::array<int32, NDIM>& broadcast_array) const;
-};
-
template <typename Device, typename T>
-struct Tile<Device, T, 0> {
- void operator()(const Device& d, typename TTypes<T, 0>::Tensor out,
- typename TTypes<T, 0>::ConstTensor in,
- const Eigen::array<int32, 0>&) const;
+struct Tile {
+ void operator()(const Device& d, Tensor* out, const Tensor& in,
+ const gtl::ArraySlice<int32> broadcast_array) const;
};
template <typename Device, typename T, int NDIM>
@@ -134,21 +125,12 @@ class TileOp : public OpKernel {
// If there's no output, there's nothing to do.
if (output_shape.num_elements() == 0) return;
-#define HANDLE_DIM(DT, NDIM) \
- if (context->input(0).dtype() == DT && input_dims == NDIM) { \
- HandleCase<DT, NDIM>(context, multiples_array, result); \
+#define HANDLE_TYPE(DT) \
+ if (context->input(0).dtype() == DT) { \
+ HandleCase<DT>(context, multiples_array, result); \
return; \
}
-#define HANDLE_TYPE(T) \
- HANDLE_DIM(T, 1) \
- HANDLE_DIM(T, 2) \
- HANDLE_DIM(T, 3) \
- HANDLE_DIM(T, 4) \
- HANDLE_DIM(T, 5) \
- HANDLE_DIM(T, 6) \
- HANDLE_DIM(T, 7)
-
#define HANDLE_TYPE_NAME(T) HANDLE_TYPE(DataTypeToEnum<T>::value)
// Invoke macro using TF_CALL_* so type-filtering for platform applies.
@@ -166,7 +148,6 @@ class TileOp : public OpKernel {
#undef HANDLE_TYPE_NAME
#undef HANDLE_TYPE
-#undef HANDLE_DIM
OP_REQUIRES(context, false,
errors::Unimplemented(
@@ -175,21 +156,17 @@ class TileOp : public OpKernel {
}
private:
- template <DataType DT, int NDIM>
+ template <DataType DT>
void HandleCaseImpl(OpKernelContext* context,
const gtl::ArraySlice<int32>& multiples_array,
Tensor* result) {
typedef typename EnumToDataType<DT>::Type T;
- Eigen::array<int32, NDIM> broadcast_array;
- for (int i = 0; i < NDIM; ++i) {
- broadcast_array[i] = multiples_array[i];
- }
- functor::Tile<Device, T, NDIM>()(
- context->eigen_device<Device>(), result->tensor<T, NDIM>(),
- context->input(0).tensor<T, NDIM>(), broadcast_array);
+ functor::Tile<Device, T>() (
+ context->eigen_device<Device>(), result,
+ context->input(0), multiples_array);
}
- template <DataType DT, int NDIM>
+ template <DataType DT>
void HandleCase(OpKernelContext* context,
const gtl::ArraySlice<int32>& multiples_array,
Tensor* result);
@@ -198,45 +175,35 @@ class TileOp : public OpKernel {
};
template <typename Device>
-template <DataType DT, int NDIM>
+template <DataType DT>
inline void TileOp<Device>::HandleCase(
OpKernelContext* context, const gtl::ArraySlice<int32>& multiples_array,
Tensor* result) {
// TODO(vrv): print out the device name if useful. Currently disabled to avoid
// having to use RTTI.
- LOG(FATAL) << "TileOp: Invalid combination of Device, DT and NDIM: "
+ LOG(FATAL) << "TileOp: Invalid combination of Device, DT: "
// << typeid(Device).name() << ", "
- << DataTypeString(DT) << ", " << NDIM;
+ << DataTypeString(DT);
}
-#define HANDLE_CASE(device, T, dtype, ndim) \
+#define HANDLE_CASE(device, dtype) \
template <> \
template <> \
- void TileOp<device>::HandleCase<dtype, ndim>( \
+ void TileOp<device>::HandleCase<dtype>( \
OpKernelContext * context, \
const gtl::ArraySlice<int32>& multiples_array, Tensor* result) { \
- HandleCaseImpl<dtype, ndim>(context, multiples_array, result); \
+ HandleCaseImpl<dtype>(context, multiples_array, result); \
}
-// 0-D handled above
-#define HANDLE_CASE_DIM(device, T, dtype) \
- HANDLE_CASE(device, T, dtype, 1); \
- HANDLE_CASE(device, T, dtype, 2); \
- HANDLE_CASE(device, T, dtype, 3); \
- HANDLE_CASE(device, T, dtype, 4); \
- HANDLE_CASE(device, T, dtype, 5); \
- HANDLE_CASE(device, T, dtype, 6); \
- HANDLE_CASE(device, T, dtype, 7);
-
#define HANDLE_TYPE_NAME_CPU(T) \
- HANDLE_CASE_DIM(CPUDevice, T, DataTypeToEnum<T>::value);
+ HANDLE_CASE(CPUDevice, DataTypeToEnum<T>::value);
#define HANDLE_TYPE_NAME_GPU(T) \
- HANDLE_CASE_DIM(GPUDevice, T, DataTypeToEnum<T>::value);
+ HANDLE_CASE(GPUDevice, DataTypeToEnum<T>::value);
#ifdef TENSORFLOW_USE_SYCL
#define HANDLE_TYPE_NAME_SYCL(T) \
- HANDLE_CASE_DIM(SYCLDevice, T, DataTypeToEnum<T>::value);
+ HANDLE_CASE(SYCLDevice, DataTypeToEnum<T>::value);
#endif // TENSORFLOW_USE_SYCL
TF_CALL_bool(HANDLE_TYPE_NAME_CPU);
@@ -275,7 +242,6 @@ TF_CALL_int64(HANDLE_TYPE_NAME_SYCL);
#ifdef TENSORFLOW_USE_SYCL
#undef HANDLE_TYPE_NAME_SYCL
#endif // TENSORFLOW_USE_SYCL
-#undef HANDLE_CASE_DIM
#undef HANDLE_CASE
// --------------------------------------------------------------------------