diff options
author | Katherine Wu <kathywu@google.com> | 2018-09-24 15:47:17 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-24 15:51:08 -0700 |
commit | 084f84f2ce44b8a1909b59bcc940652a95cd6fc9 (patch) | |
tree | e730e3d4c9ed85e242e7a892694d07364f707e87 /tensorflow/python/eager | |
parent | 6995db405617abc90da3331094aa8af5e6b57fd1 (diff) |
PolymorphicFunction cache key is changed to use the init graph instead of the default graph in the scope.
PiperOrigin-RevId: 214345046
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r-- | tensorflow/python/eager/function.py | 23 | ||||
-rw-r--r-- | tensorflow/python/eager/function_test.py | 30 |
2 files changed, 43 insertions, 10 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 1f5d479882..b28befeb62 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -1157,7 +1157,7 @@ class PolymorphicFunction(object): # then `instance` will be `foo` (and `owner` will be `Foo`). return functools.partial(self.__call__, instance) - def _cache_key(self, args, kwargs, ctx, graph): + def _cache_key(self, args, kwargs): """Computes the cache key given inputs and execution context.""" if self._input_signature is None: inputs = (args, kwargs) if kwargs else args @@ -1166,19 +1166,23 @@ class PolymorphicFunction(object): del args, kwargs cache_key = self._flat_input_signature - # The graph, or whether we're executing eagerly, should be a part of the - # cache key so we don't improperly capture tensors such as variables. - executing_eagerly = ctx.executing_eagerly() - execution_context = executing_eagerly or graph + with ops.init_scope(): + init_graph = ops.get_default_graph() + + # The graph, or whether we're executing eagerly, should be a part of the + # cache key so we don't improperly capture tensors such as variables. + executing_eagerly = context.executing_eagerly() + execution_context = executing_eagerly or init_graph + default_graph = ops.get_default_graph() # Putting the device in the cache key ensures that call-site device # annotations are respected. - device_functions = _get_device_functions(ctx, graph) + device_functions = _get_device_functions(context.context(), default_graph) # `ops.colocate_with` directives translate into `ops.device` directives when # eager execution is enabled. - colocation_stack = (None if executing_eagerly else - tuple(graph._colocation_stack.peek_objs())) # pylint: disable=protected-access + colocation_stack = (() if executing_eagerly else + tuple(default_graph._colocation_stack.peek_objs())) # pylint: disable=protected-access return cache_key + (execution_context, device_functions, colocation_stack) @@ -1281,8 +1285,7 @@ class PolymorphicFunction(object): """ if self._input_signature is None or args is not None or kwargs is not None: args, kwargs = self._canonicalize_function_inputs(*args, **kwargs) - cache_key = self._cache_key(args, kwargs, context.context(), - ops.get_default_graph()) + cache_key = self._cache_key(args, kwargs) with self._lock: try: graph_function = self._function_cache.get(cache_key, None) diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 04f42f63d4..59faf967c5 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -1812,6 +1812,36 @@ class FunctionTest(test.TestCase): # Grappler fallback to use the CPU impl even called with GPU function. self.assertEquals(y_value, 3.0) + def testDefunFunctionSeparateGraphs(self): + with context.graph_mode(): + + @function.defun + def add(x): + return x + 5 + + @function.defun + def maybe_add(x, should_add): + if should_add: + return add(x) + else: + return x + + with ops.Graph().as_default(): + x = constant_op.constant(11) + maybe_add(x, True) + self.assertEqual(len(maybe_add._function_cache), 1) + self.assertEqual(len(add._function_cache), 1) + + maybe_add(x, False) + self.assertEqual(len(maybe_add._function_cache), 2) + self.assertEqual(len(add._function_cache), 1) + + with ops.Graph().as_default(): + x = constant_op.constant(11) + maybe_add(x, True) + self.assertEqual(len(maybe_add._function_cache), 3) + self.assertEqual(len(add._function_cache), 2) + @test_util.with_c_shapes class AutomaticControlDependenciesTest(test.TestCase): |