diff options
author | Asim Shankar <ashankar@google.com> | 2018-04-25 03:07:04 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-25 03:09:54 -0700 |
commit | 3cd2d2a66bb296cbc97be9d7cea6c9bdded60a8c (patch) | |
tree | 282d7dabfbf1d3487fe9c16490db005121551dc6 /tensorflow/contrib/framework | |
parent | ffd3499094b6201169113eb4db6ae7409a9f0e2e (diff) |
Make CriticalSection work inside a Dataset with eager execution enabled.
tf.colocate_with() might be provided with eager tensors when
constructing TensorFlow functions (like the subgraph for map()
inside a tf.data.Dataset).
Prior to this change, the added test would fail with:
"Tensor.op is meaningless when eager execution is enabled."
PiperOrigin-RevId: 194217166
Diffstat (limited to 'tensorflow/contrib/framework')
-rw-r--r-- | tensorflow/contrib/framework/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/contrib/framework/python/ops/critical_section_test.py | 21 |
2 files changed, 23 insertions, 0 deletions
diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD index f675cc0cf0..249debbdf6 100644 --- a/tensorflow/contrib/framework/BUILD +++ b/tensorflow/contrib/framework/BUILD @@ -178,6 +178,8 @@ cuda_py_test( "//tensorflow/python:platform_test", "//tensorflow/python:resource_variable_ops", "//tensorflow/python:tensor_array_ops", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/eager:context", ], ) diff --git a/tensorflow/contrib/framework/python/ops/critical_section_test.py b/tensorflow/contrib/framework/python/ops/critical_section_test.py index ba660295cb..df7d7e9dae 100644 --- a/tensorflow/contrib/framework/python/ops/critical_section_test.py +++ b/tensorflow/contrib/framework/python/ops/critical_section_test.py @@ -19,6 +19,8 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.framework.python.ops import critical_section_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops @@ -330,6 +332,25 @@ class CriticalSectionTest(test.TestCase): self.evaluate(v.initializer) self.assertEqual(10, self.evaluate(out)) + @test_util.run_in_graph_and_eager_modes() + def testInsideFunction(self): + cs = critical_section_ops.CriticalSection() + v = resource_variable_ops.ResourceVariable(1) + def fn(): + return v.read_value() + + # map() creates a TensorFlow function. + ds = dataset_ops.Dataset.range(1).map(lambda _: cs.execute(fn)) + + def get_first(): + if context.executing_eagerly(): + return self.evaluate(ds.make_one_shot_iterator().get_next()) + itr = ds.make_initializable_iterator() + self.evaluate([v.initializer, itr.initializer]) + return self.evaluate(itr.get_next()) + + self.assertEqual(1, get_first()) + # TODO(ebrevdo): Re-enable once CriticalSection is in core. # # def testCriticalSectionAndExecuteOpSaverRoundTrip(self): |