aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/framework
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2018-04-25 03:07:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-25 03:09:54 -0700
commit3cd2d2a66bb296cbc97be9d7cea6c9bdded60a8c (patch)
tree282d7dabfbf1d3487fe9c16490db005121551dc6 /tensorflow/contrib/framework
parentffd3499094b6201169113eb4db6ae7409a9f0e2e (diff)
Make CriticalSection work inside a Dataset with eager execution enabled.
tf.colocate_with() might be provided with eager tensors when constructing TensorFlow functions (like the subgraph for map() inside a tf.data.Dataset). Prior to this change, the added test would fail with: "Tensor.op is meaningless when eager execution is enabled." PiperOrigin-RevId: 194217166
Diffstat (limited to 'tensorflow/contrib/framework')
-rw-r--r--tensorflow/contrib/framework/BUILD2
-rw-r--r--tensorflow/contrib/framework/python/ops/critical_section_test.py21
2 files changed, 23 insertions, 0 deletions
diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD
index f675cc0cf0..249debbdf6 100644
--- a/tensorflow/contrib/framework/BUILD
+++ b/tensorflow/contrib/framework/BUILD
@@ -178,6 +178,8 @@ cuda_py_test(
"//tensorflow/python:platform_test",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:tensor_array_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/eager:context",
],
)
diff --git a/tensorflow/contrib/framework/python/ops/critical_section_test.py b/tensorflow/contrib/framework/python/ops/critical_section_test.py
index ba660295cb..df7d7e9dae 100644
--- a/tensorflow/contrib/framework/python/ops/critical_section_test.py
+++ b/tensorflow/contrib/framework/python/ops/critical_section_test.py
@@ -19,6 +19,8 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.framework.python.ops import critical_section_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
@@ -330,6 +332,25 @@ class CriticalSectionTest(test.TestCase):
self.evaluate(v.initializer)
self.assertEqual(10, self.evaluate(out))
+ @test_util.run_in_graph_and_eager_modes()
+ def testInsideFunction(self):
+ cs = critical_section_ops.CriticalSection()
+ v = resource_variable_ops.ResourceVariable(1)
+ def fn():
+ return v.read_value()
+
+ # map() creates a TensorFlow function.
+ ds = dataset_ops.Dataset.range(1).map(lambda _: cs.execute(fn))
+
+ def get_first():
+ if context.executing_eagerly():
+ return self.evaluate(ds.make_one_shot_iterator().get_next())
+ itr = ds.make_initializable_iterator()
+ self.evaluate([v.initializer, itr.initializer])
+ return self.evaluate(itr.get_next())
+
+ self.assertEqual(1, get_first())
+
# TODO(ebrevdo): Re-enable once CriticalSection is in core.
#
# def testCriticalSectionAndExecuteOpSaverRoundTrip(self):