aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar RJ Ryan <rjryan@google.com>2017-07-10 16:41:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-10 16:51:17 -0700
commita6773e98e97956b7adf3aa51eb3548261f51d6f7 (patch)
treea4fa423385edabe441d7644c6df7a62803e7e2a3
parent285f9766471e10fa9fee4299940225a33515c010 (diff)
Add a PadV2 op with support for specifying a pad value.
Added a `constant_values` keyword argument to the tf.pad Python API for compatibility with numpy.pad. For now, only scalar values are supported. To efficiently support specifying a `[D, 2]` tensor for `constant_values` to pick per-dimension pre/post constant values will require adding Eigen and XLA support first. PiperOrigin-RevId: 161460091
-rw-r--r--tensorflow/cc/gradients/array_grad.cc8
-rw-r--r--tensorflow/compiler/tf2xla/kernels/pad_op.cc14
-rw-r--r--tensorflow/core/kernels/pad_op.cc91
-rw-r--r--tensorflow/core/kernels/pad_op.h13
-rw-r--r--tensorflow/core/ops/array_ops.cc39
-rw-r--r--tensorflow/core/ops/array_ops_test.cc30
-rw-r--r--tensorflow/python/kernel_tests/pad_op_test.py56
-rw-r--r--tensorflow/python/ops/array_grad.py10
-rw-r--r--tensorflow/python/ops/array_ops.py12
-rw-r--r--tensorflow/python/ops/hidden_ops.txt1
-rw-r--r--tensorflow/tools/api/golden/tensorflow.pbtxt2
11 files changed, 224 insertions, 52 deletions
diff --git a/tensorflow/cc/gradients/array_grad.cc b/tensorflow/cc/gradients/array_grad.cc
index 48185db3cb..6545e4ee3e 100644
--- a/tensorflow/cc/gradients/array_grad.cc
+++ b/tensorflow/cc/gradients/array_grad.cc
@@ -269,6 +269,7 @@ Status ScatterNdNonAliasingAddGrad(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("ScatterNdNonAliasingAdd", ScatterNdNonAliasingAddGrad);
+template <bool IsPadV2>
Status PadGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
@@ -281,9 +282,14 @@ Status PadGrad(const Scope& scope, const Operation& op,
auto begin = Reshape(scope, pad_before, {-1});
grad_outputs->push_back(Slice(scope, grad_inputs[0], begin, Shape(scope, x)));
grad_outputs->push_back(NoGradient());
+ // PadV2 adds a "constant_values" input.
+ if (IsPadV2) {
+ grad_outputs->push_back(NoGradient());
+ }
return scope.status();
}
-REGISTER_GRADIENT_OP("Pad", PadGrad);
+REGISTER_GRADIENT_OP("Pad", PadGrad<false>);
+REGISTER_GRADIENT_OP("PadV2", PadGrad<true>);
Status SpaceToBatchGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
diff --git a/tensorflow/compiler/tf2xla/kernels/pad_op.cc b/tensorflow/compiler/tf2xla/kernels/pad_op.cc
index cc13ab0203..d841bd37b3 100644
--- a/tensorflow/compiler/tf2xla/kernels/pad_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/pad_op.cc
@@ -69,12 +69,22 @@ class PadOp : public XlaOpKernel {
dim->set_edge_padding_high(after);
}
- auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0));
- ctx->SetOutput(0, ctx->builder()->Pad(ctx->Input(0), zero, config));
+ // PadV2 added a "constant_values" input that indicates the pad value.
+ xla::ComputationDataHandle constant_values;
+ if (ctx->num_inputs() == 3) {
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->InputShape(2)),
+ errors::InvalidArgument("constant_values must be a scalar."));
+ ctx->SetOutput(0,
+ ctx->builder()->Pad(ctx->Input(0), ctx->Input(2), config));
+ } else {
+ auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0));
+ ctx->SetOutput(0, ctx->builder()->Pad(ctx->Input(0), zero, config));
+ }
}
};
REGISTER_XLA_OP(Name("Pad"), PadOp);
+REGISTER_XLA_OP(Name("PadV2"), PadOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/pad_op.cc b/tensorflow/core/kernels/pad_op.cc
index 4c43193579..6e8b09d050 100644
--- a/tensorflow/core/kernels/pad_op.cc
+++ b/tensorflow/core/kernels/pad_op.cc
@@ -70,6 +70,16 @@ class PadOp : public OpKernel {
"The first dimension of paddings must be the rank of inputs",
in1.shape().DebugString(), " ", in0.shape().DebugString()));
+ T pad_value(0);
+ if (context->num_inputs() == 3) {
+ const Tensor& constant_values = context->input(2);
+ OP_REQUIRES(
+ context, TensorShapeUtils::IsScalar(constant_values.shape()),
+ errors::InvalidArgument("constant_values must be a scalar. Found: ",
+ constant_values.shape().DebugString()));
+ pad_value = context->input(2).scalar<T>()();
+ }
+
// Compute the shape of the output tensor, and allocate it.
TensorShape output_shape;
TTypes<int32>::ConstMatrix paddings = in1.matrix<int32>();
@@ -99,27 +109,27 @@ class PadOp : public OpKernel {
// Invoke the dims-specific implementation.
switch (fixed_dims) {
case 0:
- Operate<0>(context, in0.tensor<T, 0>(), paddings, output);
+ Operate<0>(context, in0.tensor<T, 0>(), paddings, pad_value, output);
break;
case 1:
// TODO(irving): Once Pad doesn't need a scalar special case,
// change flat to tensor. That is, once !allow_legacy_scalars().
- Operate<1>(context, in0.flat<T>(), paddings, output);
+ Operate<1>(context, in0.flat<T>(), paddings, pad_value, output);
break;
case 2:
- Operate<2>(context, in0.tensor<T, 2>(), paddings, output);
+ Operate<2>(context, in0.tensor<T, 2>(), paddings, pad_value, output);
break;
case 3:
- Operate<3>(context, in0.tensor<T, 3>(), paddings, output);
+ Operate<3>(context, in0.tensor<T, 3>(), paddings, pad_value, output);
break;
case 4:
- Operate<4>(context, in0.tensor<T, 4>(), paddings, output);
+ Operate<4>(context, in0.tensor<T, 4>(), paddings, pad_value, output);
break;
case 5:
- Operate<5>(context, in0.tensor<T, 5>(), paddings, output);
+ Operate<5>(context, in0.tensor<T, 5>(), paddings, pad_value, output);
break;
case 6:
- Operate<6>(context, in0.tensor<T, 6>(), paddings, output);
+ Operate<6>(context, in0.tensor<T, 6>(), paddings, pad_value, output);
break;
default:
OP_REQUIRES(context, false,
@@ -132,7 +142,8 @@ class PadOp : public OpKernel {
template <int Dims>
void Operate(OpKernelContext* context,
typename TTypes<T, Dims>::ConstTensor input,
- TTypes<int32>::ConstMatrix paddings, Tensor* output) {
+ TTypes<int32>::ConstMatrix paddings, T pad_value,
+ Tensor* output) {
CHECK_EQ(Dims, paddings.dimension(0));
CHECK_EQ(2, paddings.dimension(1));
Eigen::array<std::pair<int32, int32>, Dims> paddings_array;
@@ -141,16 +152,22 @@ class PadOp : public OpKernel {
}
functor::Pad<Device, T, Dims> functor;
functor(context->eigen_device<Device>(), output->tensor<T, Dims>(), input,
- paddings_array);
+ paddings_array, pad_value);
}
};
-#define REGISTER_KERNEL(type) \
- REGISTER_KERNEL_BUILDER(Name("Pad") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<type>("T") \
- .HostMemory("paddings"), \
- PadOp<CPUDevice, type>)
+#define REGISTER_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("Pad") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("paddings"), \
+ PadOp<CPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER(Name("PadV2") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("paddings") \
+ .HostMemory("constant_values"), \
+ PadOp<CPUDevice, type>);
TF_CALL_POD_TYPES(REGISTER_KERNEL);
#undef REGISTER_KERNEL
@@ -158,12 +175,12 @@ TF_CALL_POD_TYPES(REGISTER_KERNEL);
#if GOOGLE_CUDA
// Forward declarations of the functor specializations for GPU.
namespace functor {
-#define DECLARE_GPU_SPEC(T, Dims) \
- template <> \
- void Pad<GPUDevice, T, Dims>::operator()( \
- const GPUDevice& d, typename TTypes<T, Dims>::Tensor output, \
- typename TTypes<T, Dims>::ConstTensor input, \
- Eigen::array<std::pair<int32, int32>, Dims> paddings); \
+#define DECLARE_GPU_SPEC(T, Dims) \
+ template <> \
+ void Pad<GPUDevice, T, Dims>::operator()( \
+ const GPUDevice& d, typename TTypes<T, Dims>::Tensor output, \
+ typename TTypes<T, Dims>::ConstTensor input, \
+ Eigen::array<std::pair<int32, int32>, Dims> paddings, T pad_value); \
extern template struct Pad<GPUDevice, T, Dims>;
#define DECLARE_GPU_SPECS(T) \
@@ -185,6 +202,13 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
.TypeConstraint<T>("T") \
.TypeConstraint<int32>("Tpaddings") \
.HostMemory("paddings"), \
+ PadOp<GPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("PadV2") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<int32>("Tpaddings") \
+ .HostMemory("paddings") \
+ .HostMemory("constant_values"), \
PadOp<GPUDevice, T>)
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
@@ -200,6 +224,15 @@ REGISTER_KERNEL_BUILDER(Name("Pad")
.HostMemory("paddings")
.HostMemory("output"),
PadOp<CPUDevice, int32>);
+REGISTER_KERNEL_BUILDER(Name("PadV2")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<int32>("T")
+ .TypeConstraint<int32>("Tpaddings")
+ .HostMemory("input")
+ .HostMemory("paddings")
+ .HostMemory("constant_values")
+ .HostMemory("output"),
+ PadOp<CPUDevice, int32>);
#endif
#ifdef TENSORFLOW_USE_SYCL
@@ -210,6 +243,13 @@ REGISTER_KERNEL_BUILDER(Name("Pad")
.TypeConstraint<T>("T") \
.TypeConstraint<int32>("Tpaddings") \
.HostMemory("paddings"), \
+ PadOp<SYCLDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("PadV2") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<int32>("Tpaddings") \
+ .HostMemory("paddings") \
+ .HostMemory("constant_values"), \
PadOp<SYCLDevice, T>)
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNEL);
@@ -221,6 +261,15 @@ REGISTER_KERNEL_BUILDER(Name("Pad")
.HostMemory("paddings")
.HostMemory("output"),
PadOp<CPUDevice, int32>);
+REGISTER_KERNEL_BUILDER(Name("PadV2")
+ .Device(DEVICE_SYCL)
+ .TypeConstraint<int32>("T")
+ .TypeConstraint<int32>("Tpaddings")
+ .HostMemory("input")
+ .HostMemory("paddings")
+ .HostMemory("constant_values")
+ .HostMemory("output"),
+ PadOp<CPUDevice, int32>);
#undef REGISTER_SYCL_KERNEL
#endif // TENSORFLOW_USE_SYCL
diff --git a/tensorflow/core/kernels/pad_op.h b/tensorflow/core/kernels/pad_op.h
index 733e0f3083..6a973833e2 100644
--- a/tensorflow/core/kernels/pad_op.h
+++ b/tensorflow/core/kernels/pad_op.h
@@ -27,16 +27,17 @@ namespace functor {
// Functor used by PadOp to do the computations.
template <typename Device, typename T, int Dims>
struct Pad {
- // Pad "input" into "output", as specified by "paddings". See pad_op.cc for
- // details.
+ // Pad "input" into "output", as specified by "paddings" and "pad_value".
+ // See pad_op.cc for details.
void operator()(const Device& d, typename TTypes<T, Dims>::Tensor output,
typename TTypes<T, Dims>::ConstTensor input,
- Eigen::array<std::pair<int32, int32>, Dims> paddings) {
+ Eigen::array<std::pair<int32, int32>, Dims> paddings,
+ T pad_value) {
if (Eigen::internal::is_same<Device, Eigen::GpuDevice>::value &&
(output.size() <= std::numeric_limits<int32>::max())) {
- To32Bit(output).device(d) = To32Bit(input).pad(paddings);
+ To32Bit(output).device(d) = To32Bit(input).pad(paddings, pad_value);
} else {
- output.device(d) = input.pad(paddings);
+ output.device(d) = input.pad(paddings, pad_value);
}
}
};
@@ -46,7 +47,7 @@ struct Pad<Device, T, 0> {
// In the scalar case we simply copy the input.
void operator()(const Device& d, typename TTypes<T, 0>::Tensor output,
typename TTypes<T, 0>::ConstTensor input,
- Eigen::array<std::pair<int32, int32>, 0>) {
+ Eigen::array<std::pair<int32, int32>, 0>, T) {
output.device(d) = input;
}
};
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index 7dd3ad31a0..882b4b0cc4 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -2720,6 +2720,45 @@ pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0]
)doc");
// --------------------------------------------------------------------------
+REGISTER_OP("PadV2")
+ .Input("input: T")
+ .Input("paddings: Tpaddings")
+ .Input("constant_values: T")
+ .Output("output: T")
+ .Attr("T: type")
+ .Attr("Tpaddings: {int32, int64} = DT_INT32")
+ .SetShapeFn(PadShapeFn)
+ .Doc(R"doc(
+Pads a tensor.
+
+This operation pads `input` according to the `paddings` and `constant_values`
+you specify. `paddings` is an integer tensor with shape `[Dn, 2]`, where n is
+the rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates
+how many padding values to add before the contents of `input` in that dimension,
+and `paddings[D, 1]` indicates how many padding values to add after the contents
+of `input` in that dimension. `constant_values` is a scalar tensor of the same
+type as `input` that indicates the value to use for padding `input`.
+
+The padded size of each dimension D of the output is:
+
+`paddings(D, 0) + input.dim_size(D) + paddings(D, 1)`
+
+For example:
+
+```
+# 't' is [[1, 1], [2, 2]]
+# 'paddings' is [[1, 1], [2, 2]]
+# 'constant_values' is 0
+# rank of 't' is 2
+pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0]
+ [0, 0, 1, 1, 0, 0]
+ [0, 0, 2, 2, 0, 0]
+ [0, 0, 0, 0, 0, 0]]
+```
+
+)doc");
+
+// --------------------------------------------------------------------------
REGISTER_OP("MirrorPad")
.Input("input: T")
.Input("paddings: Tpaddings")
diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc
index dc5d46e6fa..1351a2f177 100644
--- a/tensorflow/core/ops/array_ops_test.cc
+++ b/tensorflow/core/ops/array_ops_test.cc
@@ -349,6 +349,36 @@ TEST(ArrayOpsTest, PadD_ShapeFn) {
}
}
+TEST(ArrayOpsTest, PadV2_ShapeFn) {
+ ShapeInferenceTestOp op("PadV2");
+ op.input_tensors.resize(3);
+
+ // Inputs are input, paddings and constant_values.
+
+ INFER_OK(op, "?;?;?", "?");
+
+ // Check shape of paddings.
+ INFER_ERROR("Shape must be rank 2 but is rank 3", op, "?;[1,2,3];?");
+ INFER_ERROR("Dimension must be 2 but is 4", op, "?;[1,4];?");
+
+ // input.rank and paddings.dim(0) are equal. This is the number of dims in
+ // output.
+ INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];[4,2];[]");
+ INFER_OK(op, "[1,2,3];?;[]", "[?,?,?]");
+ INFER_OK(op, "?;[3,2];[]", "[?,?,?]");
+
+ // Make the paddings tensor known and verify padding values get added.
+ // E.g., if padding is ((1,10),(2,20),(3,30)) then values 11,22,23 are added
+ // to input dims to get output.
+ Tensor paddings_t(DT_INT64, TensorShape{3, 2});
+ test::FillValues<int64>(&paddings_t, {1, 10, 2, 20, 3, 30});
+ op.input_tensors[1] = &paddings_t;
+ INFER_OK(op, "[100,200,300];[3,2];[]", "[111,222,333]");
+ INFER_OK(op, "[100,?,300];[3,2];[]", "[111,?,333]");
+ INFER_OK(op, "?;[3,2];[]", "[?,?,?]");
+ INFER_OK(op, "?;?;[]", "[?,?,?]");
+}
+
TEST(ArrayOpsTest, MirrorPadGrad_ShapeFn) {
ShapeInferenceTestOp op("MirrorPadGrad");
op.input_tensors.resize(2);
diff --git a/tensorflow/python/kernel_tests/pad_op_test.py b/tensorflow/python/kernel_tests/pad_op_test.py
index c709be0b5b..b774c69ceb 100644
--- a/tensorflow/python/kernel_tests/pad_op_test.py
+++ b/tensorflow/python/kernel_tests/pad_op_test.py
@@ -30,8 +30,12 @@ from tensorflow.python.platform import test
class PadOpTest(test.TestCase):
- def _npPad(self, inp, paddings, mode):
- return np.pad(inp, paddings, mode=mode.lower())
+ def _npPad(self, inp, paddings, mode, constant_values=0):
+ mode = mode.lower()
+ if mode == "constant":
+ return np.pad(inp, paddings, mode=mode, constant_values=constant_values)
+ else:
+ return np.pad(inp, paddings, mode=mode)
def testNpPad(self):
self.assertAllEqual(
@@ -47,6 +51,18 @@ class PadOpTest(test.TestCase):
mode="constant"))
self.assertAllEqual(
+ np.array([[1, 1, 1, 1, 1, 1],
+ [1, 3, 3, 1, 1, 1],
+ [1, 4, 4, 1, 1, 1],
+ [1, 5, 5, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1]]),
+ self._npPad(
+ np.array([[3, 3], [4, 4], [5, 5]]),
+ [[1, 2], [1, 3]],
+ mode="constant", constant_values=1))
+
+ self.assertAllEqual(
np.array([[4, 3, 4, 9, 4, 3],
[1, 0, 1, 2, 1, 0],
[4, 3, 4, 9, 4, 3],
@@ -66,35 +82,39 @@ class PadOpTest(test.TestCase):
[[1, 1], [1, 2]],
mode="symmetric"))
- def _testPad(self, np_inputs, paddings, mode):
- np_val = self._npPad(np_inputs, paddings, mode=mode)
+ def _testPad(self, np_inputs, paddings, mode, constant_values):
+ np_val = self._npPad(np_inputs, paddings, mode=mode,
+ constant_values=constant_values)
with self.test_session(use_gpu=True):
- tf_val = array_ops.pad(np_inputs, paddings, mode=mode)
+ tf_val = array_ops.pad(np_inputs, paddings, mode=mode,
+ constant_values=constant_values)
out = tf_val.eval()
self.assertAllEqual(np_val, out)
self.assertShapeEqual(np_val, tf_val)
- def _testGradient(self, x, a, mode):
+ def _testGradient(self, x, a, mode, constant_values):
with self.test_session(use_gpu=True):
inx = ops.convert_to_tensor(x)
xs = list(x.shape)
ina = ops.convert_to_tensor(a)
- y = array_ops.pad(inx, ina, mode=mode)
+ y = array_ops.pad(inx, ina, mode=mode, constant_values=constant_values)
# Expected y's shape to be:
ys = list(np.array(x.shape) + np.sum(np.array(a), axis=1))
jacob_t, jacob_n = gradient_checker.compute_gradient(
inx, xs, y, ys, x_init_value=x)
self.assertAllClose(jacob_t, jacob_n, rtol=1e-5, atol=1e-5)
- def _testAll(self, np_inputs, paddings):
+ def _testAll(self, np_inputs, paddings, constant_values):
for mode in ("CONSTANT", "REFLECT", "SYMMETRIC", "reflect", "symmetric",
"constant"):
# Zero-sized input is not allowed for REFLECT mode, but we still want
# zero-sized input test cases for the other modes.
if np_inputs.size or mode.upper() != "REFLECT":
- self._testPad(np_inputs, paddings, mode=mode)
+ self._testPad(np_inputs, paddings, mode=mode,
+ constant_values=constant_values)
if np_inputs.dtype == np.float32:
- self._testGradient(np_inputs, paddings, mode=mode)
+ self._testGradient(np_inputs, paddings, mode=mode,
+ constant_values=constant_values)
def testInputDims(self):
with self.test_session(use_gpu=True):
@@ -179,23 +199,25 @@ class PadOpTest(test.TestCase):
for t in [np.int32, np.int64]:
self._testAll(
np.random.randint(-100, 100, (4, 4, 3)).astype(t),
- [[1, 0], [2, 3], [0, 2]])
+ [[1, 0], [2, 3], [0, 2]], 0)
self._testAll(
np.random.randint(-100, 100, (4, 2, 1, 3)).astype(t),
- [[0, 0], [0, 0], [0, 0], [0, 0]])
+ [[0, 0], [0, 0], [0, 0], [0, 0]], -1234)
def testFloatTypes(self):
for t in [np.float32, np.float64]:
- self._testAll(np.random.rand(2, 5).astype(t), [[1, 0], [2, 0]])
- self._testAll(np.random.rand(2, 3, 4).astype(t), [[0, 0], [0, 0], [0, 0]])
- self._testAll(np.random.rand(0, 3, 4).astype(t), [[0, 0], [2, 1], [2, 3]])
+ self._testAll(np.random.rand(2, 5).astype(t), [[1, 0], [2, 0]], 0.0)
+ self._testAll(np.random.rand(2, 3, 4).astype(t),
+ [[0, 0], [0, 0], [0, 0]], -1234.0)
+ self._testAll(np.random.rand(0, 3, 4).astype(t),
+ [[0, 0], [2, 1], [2, 3]], 0.0)
def testComplexTypes(self):
for t in [np.complex64, np.complex128]:
x = np.random.rand(2, 5).astype(t)
- self._testAll(x + 1j * x, [[1, 0], [2, 0]])
+ self._testAll(x + 1j * x, [[1, 0], [2, 0]], 1234.0 - 1234.0j)
x = np.random.rand(3, 2, 1, 1).astype(t)
- self._testAll(x + 1j * x, [[0, 0], [0, 0], [0, 0], [0, 0]])
+ self._testAll(x + 1j * x, [[0, 0], [0, 0], [0, 0], [0, 0]], 0 + 0j)
def testShapeFunctionEdgeCases(self):
# Unknown paddings shape.
diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py
index 51b3b81500..73a4b7db9f 100644
--- a/tensorflow/python/ops/array_grad.py
+++ b/tensorflow/python/ops/array_grad.py
@@ -470,7 +470,6 @@ def _TileGrad(op, grad):
ops.NotDifferentiable("BroadcastGradientArgs")
-@ops.RegisterGradient("Pad")
def _PadGrad(op, grad):
"""Gradient for Pad."""
# Pad introduces values around the original tensor, so the gradient function
@@ -483,7 +482,14 @@ def _PadGrad(op, grad):
# Make it a 1-D tensor.
begin = array_ops.reshape(pad_before, [-1])
sizes = array_ops.shape(x)
- return array_ops.slice(grad, begin, sizes), None
+ x_grad = array_ops.slice(grad, begin, sizes)
+ if len(op.inputs) == 3:
+ return x_grad, None, None
+ else:
+ return x_grad, None
+
+ops.RegisterGradient("Pad")(_PadGrad)
+ops.RegisterGradient("PadV2")(_PadGrad)
# ReverseSequence is just a permutation. The gradient permutes back.
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index a963bd265c..be3cadbd21 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -1613,7 +1613,7 @@ def sparse_placeholder(dtype, shape=None, name=None):
# pylint: enable=redefined-outer-name
-def pad(tensor, paddings, mode="CONSTANT", name=None): # pylint: disable=invalid-name
+def pad(tensor, paddings, mode="CONSTANT", name=None, constant_values=0): # pylint: disable=invalid-name
"""Pads a tensor.
This operation pads a `tensor` according to the `paddings` you specify.
@@ -1635,6 +1635,7 @@ def pad(tensor, paddings, mode="CONSTANT", name=None): # pylint: disable=invali
```python
# 't' is [[1, 2, 3], [4, 5, 6]].
# 'paddings' is [[1, 1,], [2, 2]].
+ # 'constant_values' is 0.
# rank of 't' is 2.
pad(t, paddings, "CONSTANT") ==> [[0, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 2, 3, 0, 0],
@@ -1657,6 +1658,8 @@ def pad(tensor, paddings, mode="CONSTANT", name=None): # pylint: disable=invali
paddings: A `Tensor` of type `int32`.
mode: One of "CONSTANT", "REFLECT", or "SYMMETRIC" (case-insensitive)
name: A name for the operation (optional).
+ constant_values: In "CONSTANT" mode, the scalar pad value to use. Must be
+ same type as `tensor`.
Returns:
A `Tensor`. Has the same type as `tensor`.
@@ -1669,7 +1672,12 @@ def pad(tensor, paddings, mode="CONSTANT", name=None): # pylint: disable=invali
# NumPy uses all lower-case modes.
mode = mode.upper()
if mode == "CONSTANT":
- return gen_array_ops._pad(tensor, paddings, name=name)
+ # TODO(rjryan): Once the forward compatibility period (3 weeks) have passed
+ # remove the "Pad" fallback here.
+ if constant_values != 0:
+ return gen_array_ops._pad_v2(tensor, paddings, constant_values, name=name)
+ else:
+ return gen_array_ops._pad(tensor, paddings, name=name)
if mode == "REFLECT":
return gen_array_ops._mirror_pad(tensor,
paddings,
diff --git a/tensorflow/python/ops/hidden_ops.txt b/tensorflow/python/ops/hidden_ops.txt
index 6408d52d8c..bec6b82a4c 100644
--- a/tensorflow/python/ops/hidden_ops.txt
+++ b/tensorflow/python/ops/hidden_ops.txt
@@ -14,6 +14,7 @@ MirrorPadGrad
OneHot
Pack
Pad
+PadV2
ParallelConcat
Placeholder
RefIdentity
diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt
index 1857589417..cb0b383934 100644
--- a/tensorflow/tools/api/golden/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.pbtxt
@@ -1362,7 +1362,7 @@ tf_module {
}
member_method {
name: "pad"
- argspec: "args=[\'tensor\', \'paddings\', \'mode\', \'name\'], varargs=None, keywords=None, defaults=[\'CONSTANT\', \'None\'], "
+ argspec: "args=[\'tensor\', \'paddings\', \'mode\', \'name\', \'constant_values\'], varargs=None, keywords=None, defaults=[\'CONSTANT\', \'None\', \'0\'], "
}
member_method {
name: "parallel_stack"