aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager
diff options
context:
space:
mode:
authorGravatar Katherine Wu <kathywu@google.com>2018-09-24 15:47:17 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-24 15:51:08 -0700
commit084f84f2ce44b8a1909b59bcc940652a95cd6fc9 (patch)
treee730e3d4c9ed85e242e7a892694d07364f707e87 /tensorflow/python/eager
parent6995db405617abc90da3331094aa8af5e6b57fd1 (diff)
PolymorphicFunction cache key is changed to use the init graph instead of the default graph in the scope.
PiperOrigin-RevId: 214345046
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r--tensorflow/python/eager/function.py23
-rw-r--r--tensorflow/python/eager/function_test.py30
2 files changed, 43 insertions, 10 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 1f5d479882..b28befeb62 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -1157,7 +1157,7 @@ class PolymorphicFunction(object):
# then `instance` will be `foo` (and `owner` will be `Foo`).
return functools.partial(self.__call__, instance)
- def _cache_key(self, args, kwargs, ctx, graph):
+ def _cache_key(self, args, kwargs):
"""Computes the cache key given inputs and execution context."""
if self._input_signature is None:
inputs = (args, kwargs) if kwargs else args
@@ -1166,19 +1166,23 @@ class PolymorphicFunction(object):
del args, kwargs
cache_key = self._flat_input_signature
- # The graph, or whether we're executing eagerly, should be a part of the
- # cache key so we don't improperly capture tensors such as variables.
- executing_eagerly = ctx.executing_eagerly()
- execution_context = executing_eagerly or graph
+ with ops.init_scope():
+ init_graph = ops.get_default_graph()
+
+ # The graph, or whether we're executing eagerly, should be a part of the
+ # cache key so we don't improperly capture tensors such as variables.
+ executing_eagerly = context.executing_eagerly()
+ execution_context = executing_eagerly or init_graph
+ default_graph = ops.get_default_graph()
# Putting the device in the cache key ensures that call-site device
# annotations are respected.
- device_functions = _get_device_functions(ctx, graph)
+ device_functions = _get_device_functions(context.context(), default_graph)
# `ops.colocate_with` directives translate into `ops.device` directives when
# eager execution is enabled.
- colocation_stack = (None if executing_eagerly else
- tuple(graph._colocation_stack.peek_objs())) # pylint: disable=protected-access
+ colocation_stack = (() if executing_eagerly else
+ tuple(default_graph._colocation_stack.peek_objs())) # pylint: disable=protected-access
return cache_key + (execution_context, device_functions, colocation_stack)
@@ -1281,8 +1285,7 @@ class PolymorphicFunction(object):
"""
if self._input_signature is None or args is not None or kwargs is not None:
args, kwargs = self._canonicalize_function_inputs(*args, **kwargs)
- cache_key = self._cache_key(args, kwargs, context.context(),
- ops.get_default_graph())
+ cache_key = self._cache_key(args, kwargs)
with self._lock:
try:
graph_function = self._function_cache.get(cache_key, None)
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 04f42f63d4..59faf967c5 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -1812,6 +1812,36 @@ class FunctionTest(test.TestCase):
# Grappler fallback to use the CPU impl even called with GPU function.
self.assertEquals(y_value, 3.0)
+ def testDefunFunctionSeparateGraphs(self):
+ with context.graph_mode():
+
+ @function.defun
+ def add(x):
+ return x + 5
+
+ @function.defun
+ def maybe_add(x, should_add):
+ if should_add:
+ return add(x)
+ else:
+ return x
+
+ with ops.Graph().as_default():
+ x = constant_op.constant(11)
+ maybe_add(x, True)
+ self.assertEqual(len(maybe_add._function_cache), 1)
+ self.assertEqual(len(add._function_cache), 1)
+
+ maybe_add(x, False)
+ self.assertEqual(len(maybe_add._function_cache), 2)
+ self.assertEqual(len(add._function_cache), 1)
+
+ with ops.Graph().as_default():
+ x = constant_op.constant(11)
+ maybe_add(x, True)
+ self.assertEqual(len(maybe_add._function_cache), 3)
+ self.assertEqual(len(add._function_cache), 2)
+
@test_util.with_c_shapes
class AutomaticControlDependenciesTest(test.TestCase):