diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-10-10 02:22:32 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-10 02:26:33 -0700 |
commit | ee7c9597f4ab8e586e921f9fe3e3c1383417169c (patch) | |
tree | 7fe6502ffe520a045d517cbf43df767ffd86242a | |
parent | 7575e0949703a4dd0ec19e51e568e9abba037728 (diff) |
Emit xla::Or in TensorArrayScatterV3 for PRED types instead of xla::Add
Previosuly we emitted xla::Add what isn't supported by some XLA backend
on PRED types.
PiperOrigin-RevId: 216497939
-rw-r--r-- | tensorflow/compiler/tests/tensor_array_ops_test.py | 37 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc | 26 |
2 files changed, 52 insertions, 11 deletions
diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py index 78244d0b36..46ca371c8a 100644 --- a/tensorflow/compiler/tests/tensor_array_ops_test.py +++ b/tensorflow/compiler/tests/tensor_array_ops_test.py @@ -920,6 +920,34 @@ class TensorArrayTest(xla_test.XLATestCase): def testTensorArrayEvalEmptyWithDefault(self): self._testTensorArrayEvalEmptyWithDefault() + def _testTensorArrayScatterRead(self, tf_dtype): + with self.cached_session() as session, self.test_scope(): + convert = _make_converter(tf_dtype) + + ta = tensor_array_ops.TensorArray( + dtype=tf_dtype, + tensor_array_name="foo", + size=10) + + indices = constant_op.constant([1, 8]) + value = constant_op.constant(convert([[1.0, -1.0], [10.0, -10.0]])) + id0 = array_ops.placeholder(dtypes.int32) + id1 = array_ops.placeholder(dtypes.int32) + + w = ta.scatter(indices, value) + r0 = w.read(id0) + r1 = w.read(id1) + + # Test aggregation of read + read_vals = session.run([r0, r1], feed_dict={id0: 1, id1: 8}) + self.assertAllEqual(convert([1.0, -1.0]), read_vals[0]) + self.assertAllEqual(convert([10.0, -10.0]), read_vals[1]) + + def testTensorArrayScatterRead(self): + for dtype in self.numeric_tf_types: + self._testTensorArrayScatterRead(dtype) + self._testTensorArrayScatterRead(dtypes.bool) + def testTensorArrayScatterReadAndGradients(self): with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( @@ -929,15 +957,18 @@ class TensorArrayTest(xla_test.XLATestCase): indices = constant_op.constant([1, 8]) value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]]) + id0 = array_ops.placeholder(dtypes.int32) + id1 = array_ops.placeholder(dtypes.int32) w = ta.scatter(indices, value) - r0 = w.read(1) - r1 = w.read(8) + r0 = w.read(id0) + r1 = w.read(id1) # Test combined gradients + aggregation of read(0). grad = gradients_impl.gradients( ys=[r0, r1], xs=[value], grad_ys=[[2.0, 3.0], [4.0, 5.0]]) - read_vals, grad_vals = session.run([[r0, r1], grad]) + read_vals, grad_vals = session.run([[r0, r1], grad], + feed_dict={id0: 1, id1: 8}) self.assertEqual(len(read_vals), 2) self.assertEqual(len(grad_vals), 1) diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 94108b764f..6cdfaf4d97 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -123,9 +123,10 @@ Status GetTensorArrayShape(const XlaResource* resource, xla::XlaOp DynamicAddSlice(xla::XlaBuilder* builder, const xla::XlaOp& operand, const xla::XlaOp& update, absl::Span<const int64> update_dims, - const xla::XlaOp& start_indices) { + const xla::XlaOp& start_indices, DataType dtype) { xla::XlaOp current = xla::DynamicSlice(operand, start_indices, update_dims); - xla::XlaOp sum = xla::Add(current, update); + xla::XlaOp sum = + dtype == DT_BOOL ? xla::Or(current, update) : xla::Add(current, update); return xla::DynamicUpdateSlice(operand, sum, start_indices); } @@ -222,8 +223,8 @@ class TensorArrayWriteOp : public XlaOpKernel { slice_shape.InsertDim(0, 1LL); auto update = xla::Reshape(value, slice_shape.dim_sizes()); - xla::XlaOp written = - DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices); + xla::XlaOp written = DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), + start_indices, dtype_); OP_REQUIRES_OK(ctx, resource->SetValue(written)); ctx->SetOutput(0, flow); @@ -391,7 +392,11 @@ class TensorArrayScatterOp : public XlaOpKernel { } if (scatter_all_elements_in_order) { - ta = xla::Add(ta, value); + if (dtype_ == DT_BOOL) { + ta = xla::Or(ta, value); + } else { + ta = xla::Add(ta, value); + } } else { auto slice_dims = value_shape.dim_sizes(); slice_dims[0] = 1LL; @@ -414,7 +419,7 @@ class TensorArrayScatterOp : public XlaOpKernel { auto start_indices = xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0<int32>(b, 0), xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); - ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices); + ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices, dtype_); } } @@ -522,8 +527,13 @@ class TensorArraySplitOp : public XlaOpKernel { value_shape.DebugString(), " vs. ", ta_shape.DebugString())); - OP_REQUIRES_OK(ctx, resource->SetValue(xla::Add( - ta, xla::Reshape(value, ta_shape.dim_sizes())))); + const xla::XlaOp reshape = xla::Reshape(value, ta_shape.dim_sizes()); + if (dtype_ == DT_BOOL) { + ta = xla::Or(ta, reshape); + } else { + ta = xla::Add(ta, reshape); + } + OP_REQUIRES_OK(ctx, resource->SetValue(ta)); ctx->SetOutput(0, flow); } |