diff options
author | 2018-10-01 12:03:53 -0700 | |
---|---|---|
committer | 2018-10-01 12:09:58 -0700 | |
commit | c4b3ce081b8abfae5560814ec445f0169cb4c368 (patch) | |
tree | 4e6e934716d9f394ba48ab3b81cfa23a62dd3532 /tensorflow/python/eager | |
parent | 694367b574dcaf5ac90f3e42b8dee8fa51ca9f38 (diff) |
Add new attributes for the defun forward/backward functions.
PiperOrigin-RevId: 215255826
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r-- | tensorflow/python/eager/function.py | 39 | ||||
-rw-r--r-- | tensorflow/python/eager/function_test.py | 15 |
2 files changed, 44 insertions, 10 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index dd3e1a3723..60a4f018cd 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -21,6 +21,7 @@ from __future__ import print_function import collections import functools +import re import sys import threading import weakref @@ -61,9 +62,15 @@ cond_v2_impl._function = sys.modules[__name__] # pylint: disable=protected-acce # This is to avoid a circular dependency with gradients_impl gradients_impl._function = sys.modules[__name__] # pylint: disable=protected-access +FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name" +BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name" # TODO(scottzhu): Update this to allow arbitrary attribute names in future. -WHITELIST_FUNCTION_ATTRIBUTE_PREFIX = "experimental_" +WHITELIST_FUNCTION_ATTRIBUTE_REGEX = [ + "experimental_.*", + FORWARD_FUNCTION_ATTRIBUTE_NAME, + BACKWARD_FUNCTION_ATTRIBUTE_NAME +] def _create_substitute_placeholder(value, name=None, dtype=None): @@ -140,10 +147,11 @@ def _parse_func_attrs(attributes): """ attrs = {} for key, value in attributes.items(): - if not key.startswith(WHITELIST_FUNCTION_ATTRIBUTE_PREFIX): + if not any([re.match(reg, key) + for reg in WHITELIST_FUNCTION_ATTRIBUTE_REGEX]): raise ValueError("Attribute name is not whitelisted. " "Whitelisted: prefix %s, got: %s" % - (WHITELIST_FUNCTION_ATTRIBUTE_PREFIX, key)) + (WHITELIST_FUNCTION_ATTRIBUTE_REGEX, key)) if isinstance(value, attr_value_pb2.AttrValue): attrs[key] = value @@ -154,7 +162,7 @@ def _parse_func_attrs(attributes): attrs[key] = attr_value_pb2.AttrValue(i=value) elif isinstance(value, float): attrs[key] = attr_value_pb2.AttrValue(f=value) - elif isinstance(value, str): + elif isinstance(value, (str, bytes)): attrs[key] = attr_value_pb2.AttrValue(s=compat.as_bytes(value)) else: raise ValueError("Unsupported attribute type for %s with type %s" % @@ -705,6 +713,7 @@ class Function(object): def _construct_backprop_function(self): """Constructs the backprop function object for this function.""" backwards_graph = FuncGraph(_backward_name(self._func_graph.name)) + forward_function_name = _forward_name(self._func_graph.name) with backwards_graph.as_default(): gradients_wrt_outputs = [ graph_placeholder(x.dtype, x.shape) for x in self._func_graph.outputs @@ -715,11 +724,11 @@ class Function(object): grad_ys=gradients_wrt_outputs, src_graph=self._func_graph) - self._forward_function = _EagerDefinedFunction( - _forward_name( - self._func_graph.name), self._func_graph, self._func_graph.inputs, - self._func_graph.outputs + list(backwards_graph.captures.keys()), - self._attrs) + backwards_graph_captures = list(backwards_graph.captures.keys()) + + backward_function_attr = _parse_func_attrs( + {FORWARD_FUNCTION_ATTRIBUTE_NAME: forward_function_name}) + backward_function_attr.update(self._attrs) # The ordering of `backwards_graph.inputs` is important: inputs of # `self._backward_graph_function` correspond to outputs of @@ -732,7 +741,17 @@ class Function(object): grad for grad in _flatten(gradients_wrt_inputs) if grad is not None) backwards_graph.structured_outputs = gradients_wrt_inputs self._backward_graph_function = Function( - backwards_graph, attrs=self._attrs) + backwards_graph, attrs=backward_function_attr) + + forward_function_attr = _parse_func_attrs({ + BACKWARD_FUNCTION_ATTRIBUTE_NAME: + self._backward_graph_function._inference_function.name}) # pylint: disable=protected-access + forward_function_attr.update(self._attrs) + + self._forward_function = _EagerDefinedFunction( + forward_function_name, self._func_graph, self._func_graph.inputs, + self._func_graph.outputs + backwards_graph_captures, + forward_function_attr) def _backprop_call(self, args): """Calls the forward function and records the result on a tape. diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 34a2648e26..afe3ba9893 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -1687,6 +1687,21 @@ class FunctionTest(test.TestCase): self.assertRegexpMatches(captured_function_names[i], expected_func_name_regex[i]) + # Check the forward and backward function has the correct attributes. + self.assertEquals( + functions[1].definition.attr['backward_function_name'].s, + functions[2].name) + self.assertEquals( + functions[2].definition.attr['forward_function_name'].s, + functions[1].name) + + self.assertEquals( + functions[4].definition.attr['backward_function_name'].s, + functions[5].name) + self.assertEquals( + functions[5].definition.attr['forward_function_name'].s, + functions[4].name) + sq = defun_matmul(t, t) double = add(t, t) self.assertAllEqual(sq.eval().reshape(-1), [7, 10, 15, 22]) |