diff options
Diffstat (limited to 'tensorflow/core/kernels/tile_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/tile_ops.cc | 76 |
1 files changed, 21 insertions, 55 deletions
diff --git a/tensorflow/core/kernels/tile_ops.cc b/tensorflow/core/kernels/tile_ops.cc index 7c72487d3f..f1da3c8afb 100644 --- a/tensorflow/core/kernels/tile_ops.cc +++ b/tensorflow/core/kernels/tile_ops.cc @@ -44,21 +44,12 @@ typedef Eigen::GpuDevice GPUDevice; typedef Eigen::SyclDevice SYCLDevice; #endif // TENSORFLOW_USE_SYCL -// Forward declarations of functors that will be defined in -// tile_ops_cpu_impl*.cc and tile_ops_gpu.cu.cc. +// Forward declarations of functors that will be defined in tile_ops_impl.h namespace functor { -template <typename Device, typename T, int NDIM> -struct Tile { - void operator()(const Device& d, typename TTypes<T, NDIM>::Tensor out, - typename TTypes<T, NDIM>::ConstTensor in, - const Eigen::array<int32, NDIM>& broadcast_array) const; -}; - template <typename Device, typename T> -struct Tile<Device, T, 0> { - void operator()(const Device& d, typename TTypes<T, 0>::Tensor out, - typename TTypes<T, 0>::ConstTensor in, - const Eigen::array<int32, 0>&) const; +struct Tile { + void operator()(const Device& d, Tensor* out, const Tensor& in, + const gtl::ArraySlice<int32> broadcast_array) const; }; template <typename Device, typename T, int NDIM> @@ -134,21 +125,12 @@ class TileOp : public OpKernel { // If there's no output, there's nothing to do. if (output_shape.num_elements() == 0) return; -#define HANDLE_DIM(DT, NDIM) \ - if (context->input(0).dtype() == DT && input_dims == NDIM) { \ - HandleCase<DT, NDIM>(context, multiples_array, result); \ +#define HANDLE_TYPE(DT) \ + if (context->input(0).dtype() == DT) { \ + HandleCase<DT>(context, multiples_array, result); \ return; \ } -#define HANDLE_TYPE(T) \ - HANDLE_DIM(T, 1) \ - HANDLE_DIM(T, 2) \ - HANDLE_DIM(T, 3) \ - HANDLE_DIM(T, 4) \ - HANDLE_DIM(T, 5) \ - HANDLE_DIM(T, 6) \ - HANDLE_DIM(T, 7) - #define HANDLE_TYPE_NAME(T) HANDLE_TYPE(DataTypeToEnum<T>::value) // Invoke macro using TF_CALL_* so type-filtering for platform applies. @@ -166,7 +148,6 @@ class TileOp : public OpKernel { #undef HANDLE_TYPE_NAME #undef HANDLE_TYPE -#undef HANDLE_DIM OP_REQUIRES(context, false, errors::Unimplemented( @@ -175,21 +156,17 @@ class TileOp : public OpKernel { } private: - template <DataType DT, int NDIM> + template <DataType DT> void HandleCaseImpl(OpKernelContext* context, const gtl::ArraySlice<int32>& multiples_array, Tensor* result) { typedef typename EnumToDataType<DT>::Type T; - Eigen::array<int32, NDIM> broadcast_array; - for (int i = 0; i < NDIM; ++i) { - broadcast_array[i] = multiples_array[i]; - } - functor::Tile<Device, T, NDIM>()( - context->eigen_device<Device>(), result->tensor<T, NDIM>(), - context->input(0).tensor<T, NDIM>(), broadcast_array); + functor::Tile<Device, T>() ( + context->eigen_device<Device>(), result, + context->input(0), multiples_array); } - template <DataType DT, int NDIM> + template <DataType DT> void HandleCase(OpKernelContext* context, const gtl::ArraySlice<int32>& multiples_array, Tensor* result); @@ -198,45 +175,35 @@ class TileOp : public OpKernel { }; template <typename Device> -template <DataType DT, int NDIM> +template <DataType DT> inline void TileOp<Device>::HandleCase( OpKernelContext* context, const gtl::ArraySlice<int32>& multiples_array, Tensor* result) { // TODO(vrv): print out the device name if useful. Currently disabled to avoid // having to use RTTI. - LOG(FATAL) << "TileOp: Invalid combination of Device, DT and NDIM: " + LOG(FATAL) << "TileOp: Invalid combination of Device, DT: " // << typeid(Device).name() << ", " - << DataTypeString(DT) << ", " << NDIM; + << DataTypeString(DT); } -#define HANDLE_CASE(device, T, dtype, ndim) \ +#define HANDLE_CASE(device, dtype) \ template <> \ template <> \ - void TileOp<device>::HandleCase<dtype, ndim>( \ + void TileOp<device>::HandleCase<dtype>( \ OpKernelContext * context, \ const gtl::ArraySlice<int32>& multiples_array, Tensor* result) { \ - HandleCaseImpl<dtype, ndim>(context, multiples_array, result); \ + HandleCaseImpl<dtype>(context, multiples_array, result); \ } -// 0-D handled above -#define HANDLE_CASE_DIM(device, T, dtype) \ - HANDLE_CASE(device, T, dtype, 1); \ - HANDLE_CASE(device, T, dtype, 2); \ - HANDLE_CASE(device, T, dtype, 3); \ - HANDLE_CASE(device, T, dtype, 4); \ - HANDLE_CASE(device, T, dtype, 5); \ - HANDLE_CASE(device, T, dtype, 6); \ - HANDLE_CASE(device, T, dtype, 7); - #define HANDLE_TYPE_NAME_CPU(T) \ - HANDLE_CASE_DIM(CPUDevice, T, DataTypeToEnum<T>::value); + HANDLE_CASE(CPUDevice, DataTypeToEnum<T>::value); #define HANDLE_TYPE_NAME_GPU(T) \ - HANDLE_CASE_DIM(GPUDevice, T, DataTypeToEnum<T>::value); + HANDLE_CASE(GPUDevice, DataTypeToEnum<T>::value); #ifdef TENSORFLOW_USE_SYCL #define HANDLE_TYPE_NAME_SYCL(T) \ - HANDLE_CASE_DIM(SYCLDevice, T, DataTypeToEnum<T>::value); + HANDLE_CASE(SYCLDevice, DataTypeToEnum<T>::value); #endif // TENSORFLOW_USE_SYCL TF_CALL_bool(HANDLE_TYPE_NAME_CPU); @@ -275,7 +242,6 @@ TF_CALL_int64(HANDLE_TYPE_NAME_SYCL); #ifdef TENSORFLOW_USE_SYCL #undef HANDLE_TYPE_NAME_SYCL #endif // TENSORFLOW_USE_SYCL -#undef HANDLE_CASE_DIM #undef HANDLE_CASE // -------------------------------------------------------------------------- |