diff options
author | Peter Hawkins <phawkins@google.com> | 2017-12-23 08:59:26 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-12-23 09:03:02 -0800 |
commit | c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0 (patch) | |
tree | 55db4d5d76c5acaa565feef0fe510644b162760e /tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc | |
parent | c0c2775ce3de682f7913d1aeaf50bbc4d1521934 (diff) |
[TF:XLA] Refactor large list of hard-coded compile-time constant arguments to operators. Add a new .CompileTimeConstInput() annotation on kernel registrations instead.
PiperOrigin-RevId: 180008567
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc | 10 |
1 files changed, 6 insertions, 4 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 8a742ff11c..9224072a3c 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -192,7 +192,8 @@ class TensorArrayOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayOp); }; -REGISTER_XLA_OP(Name("TensorArrayV3"), TensorArrayOp); +REGISTER_XLA_OP(Name("TensorArrayV3").CompileTimeConstInput("size"), + TensorArrayOp); class TensorArrayWriteOp : public XlaOpKernel { public: @@ -414,8 +415,8 @@ class TensorArrayScatterOp : public XlaOpKernel { // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. auto index = b->Slice(indices, {i}, {i + 1}, {1}); auto start_indices = - b->Pad(b->Reshape(index, {1}), b->ConstantR0<int32>(0), - xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); + b->Pad(b->Reshape(index, {1}), b->ConstantR0<int32>(0), + xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices); } } @@ -537,7 +538,8 @@ class TensorArraySplitOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(TensorArraySplitOp); }; -REGISTER_XLA_OP(Name("TensorArraySplitV3"), TensorArraySplitOp); +REGISTER_XLA_OP(Name("TensorArraySplitV3").CompileTimeConstInput("lengths"), + TensorArraySplitOp); class TensorArraySizeOp : public XlaOpKernel { public: |