aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-06-15 19:04:59 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-15 19:07:58 -0700
commit56b36d88fce295c151458f42a120fcdcac7a3ca2 (patch)
treeaf93a6a2d52115ae23a40b86716ed0333686c2a5
parentd134d210f3dec36fe45c6cf718f538da9c8f005b (diff)
[TF:XLA] Add no-op implementation of TensorArrayCloseV3 to the XLA bridge.
PiperOrigin-RevId: 159185414
-rw-r--r--tensorflow/compiler/tests/tensor_array_ops_test.py34
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc14
2 files changed, 30 insertions, 18 deletions
diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py
index 27a2977305..00a7358130 100644
--- a/tensorflow/compiler/tests/tensor_array_ops_test.py
+++ b/tensorflow/compiler/tests/tensor_array_ops_test.py
@@ -573,13 +573,12 @@ class TensorArrayTest(xla_test.XLATestCase):
[2000.0, -2000.0]],
grad_vals[0])
- # TODO(phawkins): implement TensorArrayClose
- # def testCloseTensorArray(self):
- # with self.test_session() as session, self.test_scope():
- # ta = tensor_array_ops.TensorArray(
- # dtype=dtypes.float32, tensor_array_name="foo", size=3)
- # c1 = ta.close()
- # session.run(c1)
+ def testCloseTensorArray(self):
+ with self.test_session() as session, self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32, tensor_array_name="foo", size=3)
+ c1 = ta.close()
+ session.run(c1)
def testSizeTensorArray(self):
with self.test_session(), self.test_scope():
@@ -588,17 +587,16 @@ class TensorArrayTest(xla_test.XLATestCase):
s = ta.size()
self.assertAllEqual(3, s.eval())
- # TODO(phawkins): implement TensorArrayClose
- # def testWriteCloseTensorArray(self):
- # with self.test_session(), self.test_scope():
- # ta = tensor_array_ops.TensorArray(
- # dtype=dtypes.float32,
- # tensor_array_name="foo",
- # size=3,
- # infer_shape=False)
- # w0 = ta.write(0, [[4.0, 5.0]])
- # w1 = w0.write(1, [3.0])
- # w1.close().run() # Expected to run without problems
+ def testWriteCloseTensorArray(self):
+ with self.test_session(), self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32,
+ tensor_array_name="foo",
+ size=3,
+ infer_shape=False)
+ w0 = ta.write(0, [[4.0, 5.0]])
+ w1 = w0.write(1, [3.0])
+ w1.close().run() # Expected to run without problems
# TODO(phawkins): implement while loops.
# def _testWhileLoopWritePackGradients(self, dynamic_size, dtype):
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
index de542d55e8..c7510bf3d2 100644
--- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
@@ -534,5 +534,19 @@ class TensorArrayGradOp : public XlaOpKernel {
REGISTER_XLA_OP(Name("TensorArrayGradV3"), TensorArrayGradOp);
+class TensorArrayCloseOp : public XlaOpKernel {
+ public:
+ explicit TensorArrayCloseOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ // Do nothing; XLA handles resource management.
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayCloseOp);
+};
+
+REGISTER_XLA_OP(Name("TensorArrayCloseV3"), TensorArrayCloseOp);
+
} // anonymous namespace
} // namespace tensorflow