aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-12-23 08:59:26 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-23 09:03:02 -0800
commitc975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0 (patch)
tree55db4d5d76c5acaa565feef0fe510644b162760e /tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
parentc0c2775ce3de682f7913d1aeaf50bbc4d1521934 (diff)
[TF:XLA] Refactor large list of hard-coded compile-time constant arguments to operators. Add a new .CompileTimeConstInput() annotation on kernel registrations instead.
PiperOrigin-RevId: 180008567
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc10
1 files changed, 6 insertions, 4 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
index 8a742ff11c..9224072a3c 100644
--- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
@@ -192,7 +192,8 @@ class TensorArrayOp : public XlaOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayOp);
};
-REGISTER_XLA_OP(Name("TensorArrayV3"), TensorArrayOp);
+REGISTER_XLA_OP(Name("TensorArrayV3").CompileTimeConstInput("size"),
+ TensorArrayOp);
class TensorArrayWriteOp : public XlaOpKernel {
public:
@@ -414,8 +415,8 @@ class TensorArrayScatterOp : public XlaOpKernel {
// start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
auto index = b->Slice(indices, {i}, {i + 1}, {1});
auto start_indices =
- b->Pad(b->Reshape(index, {1}), b->ConstantR0<int32>(0),
- xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}}));
+ b->Pad(b->Reshape(index, {1}), b->ConstantR0<int32>(0),
+ xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}}));
ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices);
}
}
@@ -537,7 +538,8 @@ class TensorArraySplitOp : public XlaOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(TensorArraySplitOp);
};
-REGISTER_XLA_OP(Name("TensorArraySplitV3"), TensorArraySplitOp);
+REGISTER_XLA_OP(Name("TensorArraySplitV3").CompileTimeConstInput("lengths"),
+ TensorArraySplitOp);
class TensorArraySizeOp : public XlaOpKernel {
public: