diff options
author | Scott Zhu <scottzhu@google.com> | 2018-09-11 14:10:31 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-11 14:19:30 -0700 |
commit | 72410969ca8dd7f1be48672c6cb943940edb9f31 (patch) | |
tree | 7584f3e1a00cb0e9e85945313a068234d898ccef /tensorflow/python/eager | |
parent | b40ab8d8a024bb934f25ebc3f5260b64c5816ef5 (diff) |
Update defun to support extra params as function attributes.
PiperOrigin-RevId: 212517784
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r-- | tensorflow/python/eager/function.py | 79 | ||||
-rw-r--r-- | tensorflow/python/eager/function_test.py | 61 |
2 files changed, 136 insertions, 4 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 8c30550708..348bf4650f 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -27,6 +27,7 @@ import threading import numpy as np import six +from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import function_pb2 from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import context @@ -60,6 +61,10 @@ cond_v2_impl._function = sys.modules[__name__] # pylint: disable=protected-acce gradients_impl._function = sys.modules[__name__] # pylint: disable=protected-access +# TODO(scottzhu): Update this to allow arbitrary attribute names in future. +WHITELIST_FUNCTION_ATTRIBUTE_PREFIX = "experimental_" + + def _create_substitute_placeholder(value, name, dtype=None): """Creates a placeholder for `value` and propagates shape info to it.""" # Note: setting ops.control_dependencies(None) ensures we always put @@ -100,6 +105,44 @@ def _get_device_functions(ctx, graph): return tuple(graph._device_functions_outer_to_inner) # pylint: disable=protected-access +def _parse_func_attrs(attributes): + """Convert the keyword arguments into function_def attributes. + + Currently only support primitive types: bool, int, float and string. + + Args: + attributes: the dictionary of attributes. + Returns: + A dict of attributes where the key is the name of attribute and the value + is the AttrValue proto. + Raises: + ValueError: If the kwargs contains unwhitelisted name or unsupported value + types. + """ + attrs = {} + for key, value in attributes.items(): + if not key.startswith(WHITELIST_FUNCTION_ATTRIBUTE_PREFIX): + raise ValueError("Attribute name is not whitelisted. " + "Whitelisted: prefix %s, got: %s" % + (WHITELIST_FUNCTION_ATTRIBUTE_PREFIX, key)) + + if isinstance(value, attr_value_pb2.AttrValue): + attrs[key] = value + # bool type check has to happen before int since bool is a subclass of int. + elif isinstance(value, bool): + attrs[key] = attr_value_pb2.AttrValue(b=value) + elif isinstance(value, int): + attrs[key] = attr_value_pb2.AttrValue(i=value) + elif isinstance(value, float): + attrs[key] = attr_value_pb2.AttrValue(f=value) + elif isinstance(value, str): + attrs[key] = attr_value_pb2.AttrValue(s=compat.as_bytes(value)) + else: + raise ValueError("Unsupported attribute type for %s with type %s" % + (key, type(value))) + return attrs + + class FuncGraph(ops.Graph): """Graph representing a function body. @@ -486,7 +529,7 @@ class Function(object): self._num_outputs = len(self._func_graph.outputs) self._output_shapes = tuple( output.shape for output in self._func_graph.outputs) - self._attrs = attrs or {} + self._attrs = _parse_func_attrs(attrs) self._device_functions = tuple( self._func_graph._device_functions_outer_to_inner) # pylint: disable=protected-access @@ -909,7 +952,8 @@ class PolymorphicFunction(object): def __init__(self, python_function, name, - input_signature=None): + input_signature=None, + attributes=None): """Initializes a polymorphic function. Args: @@ -918,6 +962,8 @@ class PolymorphicFunction(object): input_signature: a possibly nested sequence of `TensorSpec` objects specifying the input signature of this function. If `None`, a separate function is instantiated for each inferred input signature. + attributes: dict, extra keyword arguments that will be added as attribute + of the function. Raises: ValueError: if `input_signature` is not None and the `python_function`'s @@ -935,6 +981,7 @@ class PolymorphicFunction(object): self._name = name self._function_cache = collections.OrderedDict() self._variables = [] + self._function_attributes = attributes or {} self._lock = threading.Lock() @@ -1149,7 +1196,8 @@ class PolymorphicFunction(object): if graph_function is None: graph_function = Function( func_graph_from_py_func(self._name, self._python_function, args, - kwds, self._input_signature)) + kwds, self._input_signature), + self._function_attributes) self._variables.extend( [v for v in graph_function.variables if v not in self._variables]) self._function_cache[cache_key] = graph_function @@ -1483,7 +1531,29 @@ def defun(func=None, input_signature=None): TypeError: If `input_signature` is neither `None` nor a sequence of `tf.contrib.eager.TensorSpec` objects. """ + return defun_with_attributes(func=func, input_signature=input_signature) + + +def defun_with_attributes(func=None, input_signature=None, attributes=None): + """Compiles a Python function into a callable TensorFlow graph. + + This function supports adding extra function attributes. See detailed + documentation in defun(). Currently this is not exposed in public API since we + don't expect user to directly use attributes, and attribute won't work by + itself. This assumption might change in future. + Args: + func: function to be compiled. + input_signature: same as defun()'s input_signature. + attributes: A dictionary of arguments which will be added to function def as + attributes. Currently only support primitive types as value, and only + whitelisted attribute name is allowed. Unwhitelisted attribute name or + unsupported value will result into ValueError. + + Returns: + Same as the return value of defun, with attributes added to the function in + graph. + """ if input_signature is not None: _validate_signature(input_signature) @@ -1495,7 +1565,8 @@ def defun(func=None, input_signature=None): name = "function" return tf_decorator.make_decorator( function, - PolymorphicFunction(function, name, input_signature=input_signature)) + PolymorphicFunction(function, name, input_signature=input_signature, + attributes=attributes)) # This code path is for the `foo = tfe.defun(foo, ...)` use case if func is not None: diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 6507bc6d71..e6a49b66cf 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -1501,6 +1501,67 @@ class FunctionTest(test.TestCase): side_effecting_function.python_function() self.assertAllEqual(state, [0, 0]) + def testFunctionWithExtraAttributes(self): + @function.defun_with_attributes(attributes={'experimental_1': 'value1', + 'experimental_2': 2}) + def matmul(x, y): + return math_ops.matmul(x, y) + + def add(x, y): + return math_ops.add(x, y) + defun_add = function.defun_with_attributes( + add, attributes={'experimental_3': True, 'experimental_4': 1.0}) + + with context.graph_mode(), self.test_session(): + with ops.get_default_graph().as_default(): + t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + sq = matmul(t, t) + double = defun_add(t, t) + self.assertAllEqual(sq.eval().reshape(-1), [7, 10, 15, 22]) + self.assertAllEqual(double.eval().reshape(-1), [2, 4, 6, 8]) + + graph = ops.get_default_graph() + # pylint: disable=protected-access + self.assertEqual(len(graph._functions), 2) + functions = list(graph._functions.values()) + self.assertRegexpMatches( + functions[0].definition.signature.name, '.*matmul.*') + attrs = functions[0].definition.attr + self.assertEqual(len(attrs), 2) + self.assertEqual(attrs['experimental_1'].s, b'value1') + self.assertEqual(attrs['experimental_2'].i, 2) + + self.assertRegexpMatches( + functions[1].definition.signature.name, '.*add.*') + attrs = functions[1].definition.attr + self.assertEqual(len(attrs), 2) + self.assertEqual(attrs['experimental_3'].b, True) + self.assertEqual(attrs['experimental_4'].f, 1.0) + # pylint: enable=protected-access + + def testFunctionWithInvalidAttribute(self): + @function.defun_with_attributes(attributes={'attr1': 'value1'}) + def matmul(x, y): + return math_ops.matmul(x, y) + + with self.assertRaisesRegexp(ValueError, + '.*Attribute name is not whitelisted.*'): + with context.graph_mode(), self.test_session(): + with ops.get_default_graph().as_default(): + t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + matmul(t, t) + + @function.defun_with_attributes(attributes={'experimental_1': ['value1']}) + def add(x, y): + return math_ops.add(x, y) + + with self.assertRaisesRegexp(ValueError, + '.*Unsupported attribute type.*'): + with context.graph_mode(), self.test_session(): + with ops.get_default_graph().as_default(): + t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + add(t, t) + @test_util.with_c_shapes class AutomaticControlDependenciesTest(test.TestCase): |