aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc26
1 files changed, 18 insertions, 8 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
index 94108b764f..6cdfaf4d97 100644
--- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
@@ -123,9 +123,10 @@ Status GetTensorArrayShape(const XlaResource* resource,
xla::XlaOp DynamicAddSlice(xla::XlaBuilder* builder, const xla::XlaOp& operand,
const xla::XlaOp& update,
absl::Span<const int64> update_dims,
- const xla::XlaOp& start_indices) {
+ const xla::XlaOp& start_indices, DataType dtype) {
xla::XlaOp current = xla::DynamicSlice(operand, start_indices, update_dims);
- xla::XlaOp sum = xla::Add(current, update);
+ xla::XlaOp sum =
+ dtype == DT_BOOL ? xla::Or(current, update) : xla::Add(current, update);
return xla::DynamicUpdateSlice(operand, sum, start_indices);
}
@@ -222,8 +223,8 @@ class TensorArrayWriteOp : public XlaOpKernel {
slice_shape.InsertDim(0, 1LL);
auto update = xla::Reshape(value, slice_shape.dim_sizes());
- xla::XlaOp written =
- DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices);
+ xla::XlaOp written = DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(),
+ start_indices, dtype_);
OP_REQUIRES_OK(ctx, resource->SetValue(written));
ctx->SetOutput(0, flow);
@@ -391,7 +392,11 @@ class TensorArrayScatterOp : public XlaOpKernel {
}
if (scatter_all_elements_in_order) {
- ta = xla::Add(ta, value);
+ if (dtype_ == DT_BOOL) {
+ ta = xla::Or(ta, value);
+ } else {
+ ta = xla::Add(ta, value);
+ }
} else {
auto slice_dims = value_shape.dim_sizes();
slice_dims[0] = 1LL;
@@ -414,7 +419,7 @@ class TensorArrayScatterOp : public XlaOpKernel {
auto start_indices =
xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0<int32>(b, 0),
xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}}));
- ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices);
+ ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices, dtype_);
}
}
@@ -522,8 +527,13 @@ class TensorArraySplitOp : public XlaOpKernel {
value_shape.DebugString(), " vs. ",
ta_shape.DebugString()));
- OP_REQUIRES_OK(ctx, resource->SetValue(xla::Add(
- ta, xla::Reshape(value, ta_shape.dim_sizes()))));
+ const xla::XlaOp reshape = xla::Reshape(value, ta_shape.dim_sizes());
+ if (dtype_ == DT_BOOL) {
+ ta = xla::Or(ta, reshape);
+ } else {
+ ta = xla::Add(ta, reshape);
+ }
+ OP_REQUIRES_OK(ctx, resource->SetValue(ta));
ctx->SetOutput(0, flow);
}