aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2018-04-20 16:18:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-20 16:20:48 -0700
commitcd095e0c455b3df98841ca70ba24fd41935552e7 (patch)
tree81940b0fc6742af2bdcd953b4cad3ec21a97bc59
parenta0071844d0af47f22ab512363b56383acf762dff (diff)
tf.contrib.data.scan: Support eager execution.
PiperOrigin-RevId: 193739234
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD1
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py23
-rw-r--r--tensorflow/contrib/data/python/ops/scan_ops.py1
3 files changed, 17 insertions, 8 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index 05a4f5028a..9d1e8b20c2 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -343,6 +343,7 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/eager:context",
"//third_party/py/numpy",
],
)
diff --git a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
index e0494736b7..1a97a84b2c 100644
--- a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
@@ -24,9 +24,11 @@ import numpy as np
from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
from tensorflow.contrib.data.python.ops import scan_ops
from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
@@ -57,19 +59,24 @@ class ScanDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
+ @test_util.run_in_graph_and_eager_modes()
def testFibonacci(self):
iterator = dataset_ops.Dataset.from_tensors(1).repeat(None).apply(
scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1]))
).make_one_shot_iterator()
- next_element = iterator.get_next()
- with self.test_session() as sess:
- self.assertEqual(1, sess.run(next_element))
- self.assertEqual(1, sess.run(next_element))
- self.assertEqual(2, sess.run(next_element))
- self.assertEqual(3, sess.run(next_element))
- self.assertEqual(5, sess.run(next_element))
- self.assertEqual(8, sess.run(next_element))
+ if context.executing_eagerly():
+ next_element = iterator.get_next
+ else:
+ get_next = iterator.get_next()
+ next_element = lambda: get_next
+
+ self.assertEqual(1, self.evaluate(next_element()))
+ self.assertEqual(1, self.evaluate(next_element()))
+ self.assertEqual(2, self.evaluate(next_element()))
+ self.assertEqual(3, self.evaluate(next_element()))
+ self.assertEqual(5, self.evaluate(next_element()))
+ self.assertEqual(8, self.evaluate(next_element()))
def testChangingStateShape(self):
# Test the fixed-point shape invariant calculations: start with
diff --git a/tensorflow/contrib/data/python/ops/scan_ops.py b/tensorflow/contrib/data/python/ops/scan_ops.py
index 1c88366273..711a538697 100644
--- a/tensorflow/contrib/data/python/ops/scan_ops.py
+++ b/tensorflow/contrib/data/python/ops/scan_ops.py
@@ -144,6 +144,7 @@ class _ScanDataset(dataset_ops.Dataset):
weakened_state_shapes)
self._scan_func = tf_scan_func
+ self._scan_func.add_to_graph(ops.get_default_graph())
def _as_variant_tensor(self):
input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access