diff options
Diffstat (limited to 'tensorflow/core/kernels/pad_op.cc')
-rw-r--r-- | tensorflow/core/kernels/pad_op.cc | 91 |
1 files changed, 70 insertions, 21 deletions
diff --git a/tensorflow/core/kernels/pad_op.cc b/tensorflow/core/kernels/pad_op.cc index 4c43193579..6e8b09d050 100644 --- a/tensorflow/core/kernels/pad_op.cc +++ b/tensorflow/core/kernels/pad_op.cc @@ -70,6 +70,16 @@ class PadOp : public OpKernel { "The first dimension of paddings must be the rank of inputs", in1.shape().DebugString(), " ", in0.shape().DebugString())); + T pad_value(0); + if (context->num_inputs() == 3) { + const Tensor& constant_values = context->input(2); + OP_REQUIRES( + context, TensorShapeUtils::IsScalar(constant_values.shape()), + errors::InvalidArgument("constant_values must be a scalar. Found: ", + constant_values.shape().DebugString())); + pad_value = context->input(2).scalar<T>()(); + } + // Compute the shape of the output tensor, and allocate it. TensorShape output_shape; TTypes<int32>::ConstMatrix paddings = in1.matrix<int32>(); @@ -99,27 +109,27 @@ class PadOp : public OpKernel { // Invoke the dims-specific implementation. switch (fixed_dims) { case 0: - Operate<0>(context, in0.tensor<T, 0>(), paddings, output); + Operate<0>(context, in0.tensor<T, 0>(), paddings, pad_value, output); break; case 1: // TODO(irving): Once Pad doesn't need a scalar special case, // change flat to tensor. That is, once !allow_legacy_scalars(). - Operate<1>(context, in0.flat<T>(), paddings, output); + Operate<1>(context, in0.flat<T>(), paddings, pad_value, output); break; case 2: - Operate<2>(context, in0.tensor<T, 2>(), paddings, output); + Operate<2>(context, in0.tensor<T, 2>(), paddings, pad_value, output); break; case 3: - Operate<3>(context, in0.tensor<T, 3>(), paddings, output); + Operate<3>(context, in0.tensor<T, 3>(), paddings, pad_value, output); break; case 4: - Operate<4>(context, in0.tensor<T, 4>(), paddings, output); + Operate<4>(context, in0.tensor<T, 4>(), paddings, pad_value, output); break; case 5: - Operate<5>(context, in0.tensor<T, 5>(), paddings, output); + Operate<5>(context, in0.tensor<T, 5>(), paddings, pad_value, output); break; case 6: - Operate<6>(context, in0.tensor<T, 6>(), paddings, output); + Operate<6>(context, in0.tensor<T, 6>(), paddings, pad_value, output); break; default: OP_REQUIRES(context, false, @@ -132,7 +142,8 @@ class PadOp : public OpKernel { template <int Dims> void Operate(OpKernelContext* context, typename TTypes<T, Dims>::ConstTensor input, - TTypes<int32>::ConstMatrix paddings, Tensor* output) { + TTypes<int32>::ConstMatrix paddings, T pad_value, + Tensor* output) { CHECK_EQ(Dims, paddings.dimension(0)); CHECK_EQ(2, paddings.dimension(1)); Eigen::array<std::pair<int32, int32>, Dims> paddings_array; @@ -141,16 +152,22 @@ class PadOp : public OpKernel { } functor::Pad<Device, T, Dims> functor; functor(context->eigen_device<Device>(), output->tensor<T, Dims>(), input, - paddings_array); + paddings_array, pad_value); } }; -#define REGISTER_KERNEL(type) \ - REGISTER_KERNEL_BUILDER(Name("Pad") \ - .Device(DEVICE_CPU) \ - .TypeConstraint<type>("T") \ - .HostMemory("paddings"), \ - PadOp<CPUDevice, type>) +#define REGISTER_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("Pad") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<type>("T") \ + .HostMemory("paddings"), \ + PadOp<CPUDevice, type>); \ + REGISTER_KERNEL_BUILDER(Name("PadV2") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<type>("T") \ + .HostMemory("paddings") \ + .HostMemory("constant_values"), \ + PadOp<CPUDevice, type>); TF_CALL_POD_TYPES(REGISTER_KERNEL); #undef REGISTER_KERNEL @@ -158,12 +175,12 @@ TF_CALL_POD_TYPES(REGISTER_KERNEL); #if GOOGLE_CUDA // Forward declarations of the functor specializations for GPU. namespace functor { -#define DECLARE_GPU_SPEC(T, Dims) \ - template <> \ - void Pad<GPUDevice, T, Dims>::operator()( \ - const GPUDevice& d, typename TTypes<T, Dims>::Tensor output, \ - typename TTypes<T, Dims>::ConstTensor input, \ - Eigen::array<std::pair<int32, int32>, Dims> paddings); \ +#define DECLARE_GPU_SPEC(T, Dims) \ + template <> \ + void Pad<GPUDevice, T, Dims>::operator()( \ + const GPUDevice& d, typename TTypes<T, Dims>::Tensor output, \ + typename TTypes<T, Dims>::ConstTensor input, \ + Eigen::array<std::pair<int32, int32>, Dims> paddings, T pad_value); \ extern template struct Pad<GPUDevice, T, Dims>; #define DECLARE_GPU_SPECS(T) \ @@ -185,6 +202,13 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS); .TypeConstraint<T>("T") \ .TypeConstraint<int32>("Tpaddings") \ .HostMemory("paddings"), \ + PadOp<GPUDevice, T>); \ + REGISTER_KERNEL_BUILDER(Name("PadV2") \ + .Device(DEVICE_GPU) \ + .TypeConstraint<T>("T") \ + .TypeConstraint<int32>("Tpaddings") \ + .HostMemory("paddings") \ + .HostMemory("constant_values"), \ PadOp<GPUDevice, T>) TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL); @@ -200,6 +224,15 @@ REGISTER_KERNEL_BUILDER(Name("Pad") .HostMemory("paddings") .HostMemory("output"), PadOp<CPUDevice, int32>); +REGISTER_KERNEL_BUILDER(Name("PadV2") + .Device(DEVICE_GPU) + .TypeConstraint<int32>("T") + .TypeConstraint<int32>("Tpaddings") + .HostMemory("input") + .HostMemory("paddings") + .HostMemory("constant_values") + .HostMemory("output"), + PadOp<CPUDevice, int32>); #endif #ifdef TENSORFLOW_USE_SYCL @@ -210,6 +243,13 @@ REGISTER_KERNEL_BUILDER(Name("Pad") .TypeConstraint<T>("T") \ .TypeConstraint<int32>("Tpaddings") \ .HostMemory("paddings"), \ + PadOp<SYCLDevice, T>); \ + REGISTER_KERNEL_BUILDER(Name("PadV2") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint<T>("T") \ + .TypeConstraint<int32>("Tpaddings") \ + .HostMemory("paddings") \ + .HostMemory("constant_values"), \ PadOp<SYCLDevice, T>) TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNEL); @@ -221,6 +261,15 @@ REGISTER_KERNEL_BUILDER(Name("Pad") .HostMemory("paddings") .HostMemory("output"), PadOp<CPUDevice, int32>); +REGISTER_KERNEL_BUILDER(Name("PadV2") + .Device(DEVICE_SYCL) + .TypeConstraint<int32>("T") + .TypeConstraint<int32>("Tpaddings") + .HostMemory("input") + .HostMemory("paddings") + .HostMemory("constant_values") + .HostMemory("output"), + PadOp<CPUDevice, int32>); #undef REGISTER_SYCL_KERNEL #endif // TENSORFLOW_USE_SYCL |