aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager
diff options
context:
space:
mode:
authorGravatar Scott Zhu <scottzhu@google.com>2018-10-01 12:03:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-01 12:09:58 -0700
commitc4b3ce081b8abfae5560814ec445f0169cb4c368 (patch)
tree4e6e934716d9f394ba48ab3b81cfa23a62dd3532 /tensorflow/python/eager
parent694367b574dcaf5ac90f3e42b8dee8fa51ca9f38 (diff)
Add new attributes for the defun forward/backward functions.
PiperOrigin-RevId: 215255826
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r--tensorflow/python/eager/function.py39
-rw-r--r--tensorflow/python/eager/function_test.py15
2 files changed, 44 insertions, 10 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index dd3e1a3723..60a4f018cd 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import collections
import functools
+import re
import sys
import threading
import weakref
@@ -61,9 +62,15 @@ cond_v2_impl._function = sys.modules[__name__] # pylint: disable=protected-acce
# This is to avoid a circular dependency with gradients_impl
gradients_impl._function = sys.modules[__name__] # pylint: disable=protected-access
+FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name"
+BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name"
# TODO(scottzhu): Update this to allow arbitrary attribute names in future.
-WHITELIST_FUNCTION_ATTRIBUTE_PREFIX = "experimental_"
+WHITELIST_FUNCTION_ATTRIBUTE_REGEX = [
+ "experimental_.*",
+ FORWARD_FUNCTION_ATTRIBUTE_NAME,
+ BACKWARD_FUNCTION_ATTRIBUTE_NAME
+]
def _create_substitute_placeholder(value, name=None, dtype=None):
@@ -140,10 +147,11 @@ def _parse_func_attrs(attributes):
"""
attrs = {}
for key, value in attributes.items():
- if not key.startswith(WHITELIST_FUNCTION_ATTRIBUTE_PREFIX):
+ if not any([re.match(reg, key)
+ for reg in WHITELIST_FUNCTION_ATTRIBUTE_REGEX]):
raise ValueError("Attribute name is not whitelisted. "
"Whitelisted: prefix %s, got: %s" %
- (WHITELIST_FUNCTION_ATTRIBUTE_PREFIX, key))
+ (WHITELIST_FUNCTION_ATTRIBUTE_REGEX, key))
if isinstance(value, attr_value_pb2.AttrValue):
attrs[key] = value
@@ -154,7 +162,7 @@ def _parse_func_attrs(attributes):
attrs[key] = attr_value_pb2.AttrValue(i=value)
elif isinstance(value, float):
attrs[key] = attr_value_pb2.AttrValue(f=value)
- elif isinstance(value, str):
+ elif isinstance(value, (str, bytes)):
attrs[key] = attr_value_pb2.AttrValue(s=compat.as_bytes(value))
else:
raise ValueError("Unsupported attribute type for %s with type %s" %
@@ -705,6 +713,7 @@ class Function(object):
def _construct_backprop_function(self):
"""Constructs the backprop function object for this function."""
backwards_graph = FuncGraph(_backward_name(self._func_graph.name))
+ forward_function_name = _forward_name(self._func_graph.name)
with backwards_graph.as_default():
gradients_wrt_outputs = [
graph_placeholder(x.dtype, x.shape) for x in self._func_graph.outputs
@@ -715,11 +724,11 @@ class Function(object):
grad_ys=gradients_wrt_outputs,
src_graph=self._func_graph)
- self._forward_function = _EagerDefinedFunction(
- _forward_name(
- self._func_graph.name), self._func_graph, self._func_graph.inputs,
- self._func_graph.outputs + list(backwards_graph.captures.keys()),
- self._attrs)
+ backwards_graph_captures = list(backwards_graph.captures.keys())
+
+ backward_function_attr = _parse_func_attrs(
+ {FORWARD_FUNCTION_ATTRIBUTE_NAME: forward_function_name})
+ backward_function_attr.update(self._attrs)
# The ordering of `backwards_graph.inputs` is important: inputs of
# `self._backward_graph_function` correspond to outputs of
@@ -732,7 +741,17 @@ class Function(object):
grad for grad in _flatten(gradients_wrt_inputs) if grad is not None)
backwards_graph.structured_outputs = gradients_wrt_inputs
self._backward_graph_function = Function(
- backwards_graph, attrs=self._attrs)
+ backwards_graph, attrs=backward_function_attr)
+
+ forward_function_attr = _parse_func_attrs({
+ BACKWARD_FUNCTION_ATTRIBUTE_NAME:
+ self._backward_graph_function._inference_function.name}) # pylint: disable=protected-access
+ forward_function_attr.update(self._attrs)
+
+ self._forward_function = _EagerDefinedFunction(
+ forward_function_name, self._func_graph, self._func_graph.inputs,
+ self._func_graph.outputs + backwards_graph_captures,
+ forward_function_attr)
def _backprop_call(self, args):
"""Calls the forward function and records the result on a tape.
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 34a2648e26..afe3ba9893 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -1687,6 +1687,21 @@ class FunctionTest(test.TestCase):
self.assertRegexpMatches(captured_function_names[i],
expected_func_name_regex[i])
+ # Check the forward and backward function has the correct attributes.
+ self.assertEquals(
+ functions[1].definition.attr['backward_function_name'].s,
+ functions[2].name)
+ self.assertEquals(
+ functions[2].definition.attr['forward_function_name'].s,
+ functions[1].name)
+
+ self.assertEquals(
+ functions[4].definition.attr['backward_function_name'].s,
+ functions[5].name)
+ self.assertEquals(
+ functions[5].definition.attr['forward_function_name'].s,
+ functions[4].name)
+
sq = defun_matmul(t, t)
double = add(t, t)
self.assertAllEqual(sq.eval().reshape(-1), [7, 10, 15, 22])