aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager
diff options
context:
space:
mode:
authorGravatar Scott Zhu <scottzhu@google.com>2018-09-11 14:10:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-11 14:19:30 -0700
commit72410969ca8dd7f1be48672c6cb943940edb9f31 (patch)
tree7584f3e1a00cb0e9e85945313a068234d898ccef /tensorflow/python/eager
parentb40ab8d8a024bb934f25ebc3f5260b64c5816ef5 (diff)
Update defun to support extra params as function attributes.
PiperOrigin-RevId: 212517784
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r--tensorflow/python/eager/function.py79
-rw-r--r--tensorflow/python/eager/function_test.py61
2 files changed, 136 insertions, 4 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 8c30550708..348bf4650f 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -27,6 +27,7 @@ import threading
import numpy as np
import six
+from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import function_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
@@ -60,6 +61,10 @@ cond_v2_impl._function = sys.modules[__name__] # pylint: disable=protected-acce
gradients_impl._function = sys.modules[__name__] # pylint: disable=protected-access
+# TODO(scottzhu): Update this to allow arbitrary attribute names in future.
+WHITELIST_FUNCTION_ATTRIBUTE_PREFIX = "experimental_"
+
+
def _create_substitute_placeholder(value, name, dtype=None):
"""Creates a placeholder for `value` and propagates shape info to it."""
# Note: setting ops.control_dependencies(None) ensures we always put
@@ -100,6 +105,44 @@ def _get_device_functions(ctx, graph):
return tuple(graph._device_functions_outer_to_inner) # pylint: disable=protected-access
+def _parse_func_attrs(attributes):
+ """Convert the keyword arguments into function_def attributes.
+
+ Currently only support primitive types: bool, int, float and string.
+
+ Args:
+ attributes: the dictionary of attributes.
+ Returns:
+ A dict of attributes where the key is the name of attribute and the value
+ is the AttrValue proto.
+ Raises:
+ ValueError: If the kwargs contains unwhitelisted name or unsupported value
+ types.
+ """
+ attrs = {}
+ for key, value in attributes.items():
+ if not key.startswith(WHITELIST_FUNCTION_ATTRIBUTE_PREFIX):
+ raise ValueError("Attribute name is not whitelisted. "
+ "Whitelisted: prefix %s, got: %s" %
+ (WHITELIST_FUNCTION_ATTRIBUTE_PREFIX, key))
+
+ if isinstance(value, attr_value_pb2.AttrValue):
+ attrs[key] = value
+ # bool type check has to happen before int since bool is a subclass of int.
+ elif isinstance(value, bool):
+ attrs[key] = attr_value_pb2.AttrValue(b=value)
+ elif isinstance(value, int):
+ attrs[key] = attr_value_pb2.AttrValue(i=value)
+ elif isinstance(value, float):
+ attrs[key] = attr_value_pb2.AttrValue(f=value)
+ elif isinstance(value, str):
+ attrs[key] = attr_value_pb2.AttrValue(s=compat.as_bytes(value))
+ else:
+ raise ValueError("Unsupported attribute type for %s with type %s" %
+ (key, type(value)))
+ return attrs
+
+
class FuncGraph(ops.Graph):
"""Graph representing a function body.
@@ -486,7 +529,7 @@ class Function(object):
self._num_outputs = len(self._func_graph.outputs)
self._output_shapes = tuple(
output.shape for output in self._func_graph.outputs)
- self._attrs = attrs or {}
+ self._attrs = _parse_func_attrs(attrs)
self._device_functions = tuple(
self._func_graph._device_functions_outer_to_inner) # pylint: disable=protected-access
@@ -909,7 +952,8 @@ class PolymorphicFunction(object):
def __init__(self,
python_function,
name,
- input_signature=None):
+ input_signature=None,
+ attributes=None):
"""Initializes a polymorphic function.
Args:
@@ -918,6 +962,8 @@ class PolymorphicFunction(object):
input_signature: a possibly nested sequence of `TensorSpec` objects
specifying the input signature of this function. If `None`, a separate
function is instantiated for each inferred input signature.
+ attributes: dict, extra keyword arguments that will be added as attribute
+ of the function.
Raises:
ValueError: if `input_signature` is not None and the `python_function`'s
@@ -935,6 +981,7 @@ class PolymorphicFunction(object):
self._name = name
self._function_cache = collections.OrderedDict()
self._variables = []
+ self._function_attributes = attributes or {}
self._lock = threading.Lock()
@@ -1149,7 +1196,8 @@ class PolymorphicFunction(object):
if graph_function is None:
graph_function = Function(
func_graph_from_py_func(self._name, self._python_function, args,
- kwds, self._input_signature))
+ kwds, self._input_signature),
+ self._function_attributes)
self._variables.extend(
[v for v in graph_function.variables if v not in self._variables])
self._function_cache[cache_key] = graph_function
@@ -1483,7 +1531,29 @@ def defun(func=None, input_signature=None):
TypeError: If `input_signature` is neither `None` nor a sequence of
`tf.contrib.eager.TensorSpec` objects.
"""
+ return defun_with_attributes(func=func, input_signature=input_signature)
+
+
+def defun_with_attributes(func=None, input_signature=None, attributes=None):
+ """Compiles a Python function into a callable TensorFlow graph.
+
+ This function supports adding extra function attributes. See detailed
+ documentation in defun(). Currently this is not exposed in public API since we
+ don't expect user to directly use attributes, and attribute won't work by
+ itself. This assumption might change in future.
+ Args:
+ func: function to be compiled.
+ input_signature: same as defun()'s input_signature.
+ attributes: A dictionary of arguments which will be added to function def as
+ attributes. Currently only support primitive types as value, and only
+ whitelisted attribute name is allowed. Unwhitelisted attribute name or
+ unsupported value will result into ValueError.
+
+ Returns:
+ Same as the return value of defun, with attributes added to the function in
+ graph.
+ """
if input_signature is not None:
_validate_signature(input_signature)
@@ -1495,7 +1565,8 @@ def defun(func=None, input_signature=None):
name = "function"
return tf_decorator.make_decorator(
function,
- PolymorphicFunction(function, name, input_signature=input_signature))
+ PolymorphicFunction(function, name, input_signature=input_signature,
+ attributes=attributes))
# This code path is for the `foo = tfe.defun(foo, ...)` use case
if func is not None:
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 6507bc6d71..e6a49b66cf 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -1501,6 +1501,67 @@ class FunctionTest(test.TestCase):
side_effecting_function.python_function()
self.assertAllEqual(state, [0, 0])
+ def testFunctionWithExtraAttributes(self):
+ @function.defun_with_attributes(attributes={'experimental_1': 'value1',
+ 'experimental_2': 2})
+ def matmul(x, y):
+ return math_ops.matmul(x, y)
+
+ def add(x, y):
+ return math_ops.add(x, y)
+ defun_add = function.defun_with_attributes(
+ add, attributes={'experimental_3': True, 'experimental_4': 1.0})
+
+ with context.graph_mode(), self.test_session():
+ with ops.get_default_graph().as_default():
+ t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+ sq = matmul(t, t)
+ double = defun_add(t, t)
+ self.assertAllEqual(sq.eval().reshape(-1), [7, 10, 15, 22])
+ self.assertAllEqual(double.eval().reshape(-1), [2, 4, 6, 8])
+
+ graph = ops.get_default_graph()
+ # pylint: disable=protected-access
+ self.assertEqual(len(graph._functions), 2)
+ functions = list(graph._functions.values())
+ self.assertRegexpMatches(
+ functions[0].definition.signature.name, '.*matmul.*')
+ attrs = functions[0].definition.attr
+ self.assertEqual(len(attrs), 2)
+ self.assertEqual(attrs['experimental_1'].s, b'value1')
+ self.assertEqual(attrs['experimental_2'].i, 2)
+
+ self.assertRegexpMatches(
+ functions[1].definition.signature.name, '.*add.*')
+ attrs = functions[1].definition.attr
+ self.assertEqual(len(attrs), 2)
+ self.assertEqual(attrs['experimental_3'].b, True)
+ self.assertEqual(attrs['experimental_4'].f, 1.0)
+ # pylint: enable=protected-access
+
+ def testFunctionWithInvalidAttribute(self):
+ @function.defun_with_attributes(attributes={'attr1': 'value1'})
+ def matmul(x, y):
+ return math_ops.matmul(x, y)
+
+ with self.assertRaisesRegexp(ValueError,
+ '.*Attribute name is not whitelisted.*'):
+ with context.graph_mode(), self.test_session():
+ with ops.get_default_graph().as_default():
+ t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+ matmul(t, t)
+
+ @function.defun_with_attributes(attributes={'experimental_1': ['value1']})
+ def add(x, y):
+ return math_ops.add(x, y)
+
+ with self.assertRaisesRegexp(ValueError,
+ '.*Unsupported attribute type.*'):
+ with context.graph_mode(), self.test_session():
+ with ops.get_default_graph().as_default():
+ t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+ add(t, t)
+
@test_util.with_c_shapes
class AutomaticControlDependenciesTest(test.TestCase):