diff options
author | 2018-04-20 16:18:29 -0700 | |
---|---|---|
committer | 2018-04-20 16:20:48 -0700 | |
commit | cd095e0c455b3df98841ca70ba24fd41935552e7 (patch) | |
tree | 81940b0fc6742af2bdcd953b4cad3ec21a97bc59 | |
parent | a0071844d0af47f22ab512363b56383acf762dff (diff) |
tf.contrib.data.scan: Support eager execution.
PiperOrigin-RevId: 193739234
-rw-r--r-- | tensorflow/contrib/data/python/kernel_tests/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py | 23 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/ops/scan_ops.py | 1 |
3 files changed, 17 insertions, 8 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 05a4f5028a..9d1e8b20c2 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -343,6 +343,7 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/eager:context", "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py index e0494736b7..1a97a84b2c 100644 --- a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py @@ -24,9 +24,11 @@ import numpy as np from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import scan_ops from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -57,19 +59,24 @@ class ScanDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) + @test_util.run_in_graph_and_eager_modes() def testFibonacci(self): iterator = dataset_ops.Dataset.from_tensors(1).repeat(None).apply( scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1])) ).make_one_shot_iterator() - next_element = iterator.get_next() - with self.test_session() as sess: - self.assertEqual(1, sess.run(next_element)) - self.assertEqual(1, sess.run(next_element)) - self.assertEqual(2, sess.run(next_element)) - self.assertEqual(3, sess.run(next_element)) - self.assertEqual(5, sess.run(next_element)) - self.assertEqual(8, sess.run(next_element)) + if context.executing_eagerly(): + next_element = iterator.get_next + else: + get_next = iterator.get_next() + next_element = lambda: get_next + + self.assertEqual(1, self.evaluate(next_element())) + self.assertEqual(1, self.evaluate(next_element())) + self.assertEqual(2, self.evaluate(next_element())) + self.assertEqual(3, self.evaluate(next_element())) + self.assertEqual(5, self.evaluate(next_element())) + self.assertEqual(8, self.evaluate(next_element())) def testChangingStateShape(self): # Test the fixed-point shape invariant calculations: start with diff --git a/tensorflow/contrib/data/python/ops/scan_ops.py b/tensorflow/contrib/data/python/ops/scan_ops.py index 1c88366273..711a538697 100644 --- a/tensorflow/contrib/data/python/ops/scan_ops.py +++ b/tensorflow/contrib/data/python/ops/scan_ops.py @@ -144,6 +144,7 @@ class _ScanDataset(dataset_ops.Dataset): weakened_state_shapes) self._scan_func = tf_scan_func + self._scan_func.add_to_graph(ops.get_default_graph()) def _as_variant_tensor(self): input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access |