aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-09-14 18:36:15 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-14 18:39:54 -0700
commit7097a5d46ecdbc498418dfe46bccca31632a9718 (patch)
tree7e96529dc51ee4b231109c473bad0455d9210976 /tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
parentf7d8b836f461eadbfc888630c64288f6091426b0 (diff)
[XLA] Add helper method xla::MakeEdgePaddingConfig for creating an xla::PaddingConfig with only edge padding.
[TF:XLA] Remove helper method XlaHelpers::PadWithZeros, replace callers with direct calls to Pad(). PiperOrigin-RevId: 168779358
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc10
1 files changed, 7 insertions, 3 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
index bdd7e73302..7f1597e9ad 100644
--- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
@@ -223,7 +223,9 @@ class TensorArrayWriteOp : public XlaOpKernel {
xla::ComputationDataHandle flow = ctx->Input(3);
// start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
- auto start_indices = XlaHelpers::PadWithZeros(b, index, elem_shape.dims());
+ auto start_indices =
+ b->Pad(b->Reshape(index, {1}), b->ConstantR0<int32>(0),
+ xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}}));
TensorShape slice_shape = elem_shape;
slice_shape.InsertDim(0, 1LL);
@@ -266,7 +268,8 @@ class TensorArrayReadOp : public XlaOpKernel {
// start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
auto start_indices =
- XlaHelpers::PadWithZeros(b, index, ta_shape.dims() - 1);
+ b->Pad(b->Reshape(index, {1}), b->ConstantR0<int32>(0),
+ xla::MakeEdgePaddingConfig({{0, ta_shape.dims() - 1}}));
auto slice_shape = ta_shape.dim_sizes();
slice_shape[0] = 1LL;
@@ -371,7 +374,8 @@ class TensorArrayScatterOp : public XlaOpKernel {
// start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
auto index = b->Slice(indices, {i}, {i + 1}, {1});
auto start_indices =
- XlaHelpers::PadWithZeros(b, index, elem_shape.dims());
+ b->Pad(b->Reshape(index, {1}), b->ConstantR0<int32>(0),
+ xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}}));
ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices);
}