aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-10 02:22:32 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-10 02:26:33 -0700
commitee7c9597f4ab8e586e921f9fe3e3c1383417169c (patch)
tree7fe6502ffe520a045d517cbf43df767ffd86242a
parent7575e0949703a4dd0ec19e51e568e9abba037728 (diff)
Emit xla::Or in TensorArrayScatterV3 for PRED types instead of xla::Add
Previosuly we emitted xla::Add what isn't supported by some XLA backend on PRED types. PiperOrigin-RevId: 216497939
-rw-r--r--tensorflow/compiler/tests/tensor_array_ops_test.py37
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc26
2 files changed, 52 insertions, 11 deletions
diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py
index 78244d0b36..46ca371c8a 100644
--- a/tensorflow/compiler/tests/tensor_array_ops_test.py
+++ b/tensorflow/compiler/tests/tensor_array_ops_test.py
@@ -920,6 +920,34 @@ class TensorArrayTest(xla_test.XLATestCase):
def testTensorArrayEvalEmptyWithDefault(self):
self._testTensorArrayEvalEmptyWithDefault()
+ def _testTensorArrayScatterRead(self, tf_dtype):
+ with self.cached_session() as session, self.test_scope():
+ convert = _make_converter(tf_dtype)
+
+ ta = tensor_array_ops.TensorArray(
+ dtype=tf_dtype,
+ tensor_array_name="foo",
+ size=10)
+
+ indices = constant_op.constant([1, 8])
+ value = constant_op.constant(convert([[1.0, -1.0], [10.0, -10.0]]))
+ id0 = array_ops.placeholder(dtypes.int32)
+ id1 = array_ops.placeholder(dtypes.int32)
+
+ w = ta.scatter(indices, value)
+ r0 = w.read(id0)
+ r1 = w.read(id1)
+
+ # Test aggregation of read
+ read_vals = session.run([r0, r1], feed_dict={id0: 1, id1: 8})
+ self.assertAllEqual(convert([1.0, -1.0]), read_vals[0])
+ self.assertAllEqual(convert([10.0, -10.0]), read_vals[1])
+
+ def testTensorArrayScatterRead(self):
+ for dtype in self.numeric_tf_types:
+ self._testTensorArrayScatterRead(dtype)
+ self._testTensorArrayScatterRead(dtypes.bool)
+
def testTensorArrayScatterReadAndGradients(self):
with self.cached_session() as session, self.test_scope():
ta = tensor_array_ops.TensorArray(
@@ -929,15 +957,18 @@ class TensorArrayTest(xla_test.XLATestCase):
indices = constant_op.constant([1, 8])
value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]])
+ id0 = array_ops.placeholder(dtypes.int32)
+ id1 = array_ops.placeholder(dtypes.int32)
w = ta.scatter(indices, value)
- r0 = w.read(1)
- r1 = w.read(8)
+ r0 = w.read(id0)
+ r1 = w.read(id1)
# Test combined gradients + aggregation of read(0).
grad = gradients_impl.gradients(
ys=[r0, r1], xs=[value], grad_ys=[[2.0, 3.0], [4.0, 5.0]])
- read_vals, grad_vals = session.run([[r0, r1], grad])
+ read_vals, grad_vals = session.run([[r0, r1], grad],
+ feed_dict={id0: 1, id1: 8})
self.assertEqual(len(read_vals), 2)
self.assertEqual(len(grad_vals), 1)
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
index 94108b764f..6cdfaf4d97 100644
--- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
@@ -123,9 +123,10 @@ Status GetTensorArrayShape(const XlaResource* resource,
xla::XlaOp DynamicAddSlice(xla::XlaBuilder* builder, const xla::XlaOp& operand,
const xla::XlaOp& update,
absl::Span<const int64> update_dims,
- const xla::XlaOp& start_indices) {
+ const xla::XlaOp& start_indices, DataType dtype) {
xla::XlaOp current = xla::DynamicSlice(operand, start_indices, update_dims);
- xla::XlaOp sum = xla::Add(current, update);
+ xla::XlaOp sum =
+ dtype == DT_BOOL ? xla::Or(current, update) : xla::Add(current, update);
return xla::DynamicUpdateSlice(operand, sum, start_indices);
}
@@ -222,8 +223,8 @@ class TensorArrayWriteOp : public XlaOpKernel {
slice_shape.InsertDim(0, 1LL);
auto update = xla::Reshape(value, slice_shape.dim_sizes());
- xla::XlaOp written =
- DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices);
+ xla::XlaOp written = DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(),
+ start_indices, dtype_);
OP_REQUIRES_OK(ctx, resource->SetValue(written));
ctx->SetOutput(0, flow);
@@ -391,7 +392,11 @@ class TensorArrayScatterOp : public XlaOpKernel {
}
if (scatter_all_elements_in_order) {
- ta = xla::Add(ta, value);
+ if (dtype_ == DT_BOOL) {
+ ta = xla::Or(ta, value);
+ } else {
+ ta = xla::Add(ta, value);
+ }
} else {
auto slice_dims = value_shape.dim_sizes();
slice_dims[0] = 1LL;
@@ -414,7 +419,7 @@ class TensorArrayScatterOp : public XlaOpKernel {
auto start_indices =
xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0<int32>(b, 0),
xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}}));
- ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices);
+ ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices, dtype_);
}
}
@@ -522,8 +527,13 @@ class TensorArraySplitOp : public XlaOpKernel {
value_shape.DebugString(), " vs. ",
ta_shape.DebugString()));
- OP_REQUIRES_OK(ctx, resource->SetValue(xla::Add(
- ta, xla::Reshape(value, ta_shape.dim_sizes()))));
+ const xla::XlaOp reshape = xla::Reshape(value, ta_shape.dim_sizes());
+ if (dtype_ == DT_BOOL) {
+ ta = xla::Or(ta, reshape);
+ } else {
+ ta = xla::Add(ta, reshape);
+ }
+ OP_REQUIRES_OK(ctx, resource->SetValue(ta));
ctx->SetOutput(0, flow);
}