diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc | 26 |
1 files changed, 18 insertions, 8 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 94108b764f..6cdfaf4d97 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -123,9 +123,10 @@ Status GetTensorArrayShape(const XlaResource* resource, xla::XlaOp DynamicAddSlice(xla::XlaBuilder* builder, const xla::XlaOp& operand, const xla::XlaOp& update, absl::Span<const int64> update_dims, - const xla::XlaOp& start_indices) { + const xla::XlaOp& start_indices, DataType dtype) { xla::XlaOp current = xla::DynamicSlice(operand, start_indices, update_dims); - xla::XlaOp sum = xla::Add(current, update); + xla::XlaOp sum = + dtype == DT_BOOL ? xla::Or(current, update) : xla::Add(current, update); return xla::DynamicUpdateSlice(operand, sum, start_indices); } @@ -222,8 +223,8 @@ class TensorArrayWriteOp : public XlaOpKernel { slice_shape.InsertDim(0, 1LL); auto update = xla::Reshape(value, slice_shape.dim_sizes()); - xla::XlaOp written = - DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices); + xla::XlaOp written = DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), + start_indices, dtype_); OP_REQUIRES_OK(ctx, resource->SetValue(written)); ctx->SetOutput(0, flow); @@ -391,7 +392,11 @@ class TensorArrayScatterOp : public XlaOpKernel { } if (scatter_all_elements_in_order) { - ta = xla::Add(ta, value); + if (dtype_ == DT_BOOL) { + ta = xla::Or(ta, value); + } else { + ta = xla::Add(ta, value); + } } else { auto slice_dims = value_shape.dim_sizes(); slice_dims[0] = 1LL; @@ -414,7 +419,7 @@ class TensorArrayScatterOp : public XlaOpKernel { auto start_indices = xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0<int32>(b, 0), xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); - ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices); + ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices, dtype_); } } @@ -522,8 +527,13 @@ class TensorArraySplitOp : public XlaOpKernel { value_shape.DebugString(), " vs. ", ta_shape.DebugString())); - OP_REQUIRES_OK(ctx, resource->SetValue(xla::Add( - ta, xla::Reshape(value, ta_shape.dim_sizes())))); + const xla::XlaOp reshape = xla::Reshape(value, ta_shape.dim_sizes()); + if (dtype_ == DT_BOOL) { + ta = xla::Or(ta, reshape); + } else { + ta = xla::Add(ta, reshape); + } + OP_REQUIRES_OK(ctx, resource->SetValue(ta)); ctx->SetOutput(0, flow); } |