diff options
author | Peter Hawkins <phawkins@google.com> | 2017-06-15 19:04:59 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-06-15 19:07:58 -0700 |
commit | 56b36d88fce295c151458f42a120fcdcac7a3ca2 (patch) | |
tree | af93a6a2d52115ae23a40b86716ed0333686c2a5 | |
parent | d134d210f3dec36fe45c6cf718f538da9c8f005b (diff) |
[TF:XLA] Add no-op implementation of TensorArrayCloseV3 to the XLA bridge.
PiperOrigin-RevId: 159185414
-rw-r--r-- | tensorflow/compiler/tests/tensor_array_ops_test.py | 34 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc | 14 |
2 files changed, 30 insertions, 18 deletions
diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py index 27a2977305..00a7358130 100644 --- a/tensorflow/compiler/tests/tensor_array_ops_test.py +++ b/tensorflow/compiler/tests/tensor_array_ops_test.py @@ -573,13 +573,12 @@ class TensorArrayTest(xla_test.XLATestCase): [2000.0, -2000.0]], grad_vals[0]) - # TODO(phawkins): implement TensorArrayClose - # def testCloseTensorArray(self): - # with self.test_session() as session, self.test_scope(): - # ta = tensor_array_ops.TensorArray( - # dtype=dtypes.float32, tensor_array_name="foo", size=3) - # c1 = ta.close() - # session.run(c1) + def testCloseTensorArray(self): + with self.test_session() as session, self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, tensor_array_name="foo", size=3) + c1 = ta.close() + session.run(c1) def testSizeTensorArray(self): with self.test_session(), self.test_scope(): @@ -588,17 +587,16 @@ class TensorArrayTest(xla_test.XLATestCase): s = ta.size() self.assertAllEqual(3, s.eval()) - # TODO(phawkins): implement TensorArrayClose - # def testWriteCloseTensorArray(self): - # with self.test_session(), self.test_scope(): - # ta = tensor_array_ops.TensorArray( - # dtype=dtypes.float32, - # tensor_array_name="foo", - # size=3, - # infer_shape=False) - # w0 = ta.write(0, [[4.0, 5.0]]) - # w1 = w0.write(1, [3.0]) - # w1.close().run() # Expected to run without problems + def testWriteCloseTensorArray(self): + with self.test_session(), self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, + tensor_array_name="foo", + size=3, + infer_shape=False) + w0 = ta.write(0, [[4.0, 5.0]]) + w1 = w0.write(1, [3.0]) + w1.close().run() # Expected to run without problems # TODO(phawkins): implement while loops. # def _testWhileLoopWritePackGradients(self, dynamic_size, dtype): diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index de542d55e8..c7510bf3d2 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -534,5 +534,19 @@ class TensorArrayGradOp : public XlaOpKernel { REGISTER_XLA_OP(Name("TensorArrayGradV3"), TensorArrayGradOp); +class TensorArrayCloseOp : public XlaOpKernel { + public: + explicit TensorArrayCloseOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + // Do nothing; XLA handles resource management. + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayCloseOp); +}; + +REGISTER_XLA_OP(Name("TensorArrayCloseV3"), TensorArrayCloseOp); + } // anonymous namespace } // namespace tensorflow |