aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Akshay Agrawal <akshayka@google.com>2018-08-06 15:17:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-06 15:21:01 -0700
commit080e15b9287968d8ae6f0e7aa041593bd4b82c8d (patch)
treec9dc23b2da3cfe762e0d339ceabc388788a48f41
parentcb29a6b2217a140c248188015da424669fd08e54 (diff)
Add an optional input signature for functions generated with defun.
An input signature is a possibly nested collection of `TensorSpec` objects declaring the shapes and dtypes of a function's arguments. tfe.defun propagates these shapes and dtypes to the graph function it generates for the traced Python function. Since the shapes may be partially specified, this makes it possible to generate functions with partial shape information. When an input signature is specified, every argument to the `defun`-wrapped function *must* be a Tensor. Input signatures cannot be specified for functions with keyword arguments, but positional arguments that are called in a keyword-like fashion are acceptable. For example, the following code snippet is valid @tfe.defun(input_signature=(TensorSpec([], tf.float32), TensorSpec([], tf.int32)) def foo(b, a): ... foo(b=tf.constant(1.0), a=tf.constant(2)) foo(a=tf.constant(2), b=tf.constant(1.0)) while the next code snippet is invalid @tfe.defun(input_signature=(TensorSpec([], tf.float32), TensorSpec([], tf.int32)) def foo(b, **kwargs): ... # This will fail --- arbitrary kwargs are not allowed. foo(b=tf.constant(1.0), a=tf.constant(2)) This change also adds benchmarks that approximately measure the time taken to compute the cache key and verify the signature (and execute an empty function). Signatures introduce an overhead of ~100 us. The benchmarks for the non-signature path are the same as they were before this change. entry { name: "MicroBenchmarks.benchmark_defun_with_signature" iters: 30000 wall_time: 349.911403656 extras { key: "examples_per_sec" value { double_value: 2857.86627572 } } } entry { name: "MicroBenchmarks.benchmark_defun_with_signature_and_kwargs" iters: 30000 wall_time: 360.46500206 extras { key: "examples_per_sec" value { double_value: 2774.19442743 } } } entry { name: "MicroBenchmarks.benchmark_defun_without_signature" iters: 30000 wall_time: 259.087236722 extras { key: "examples_per_sec" value { double_value: 3859.70383046 } } } entry { name: "MicroBenchmarks.benchmark_defun_without_signature_and_with_kwargs" iters: 30000 wall_time: 272.486400604 extras { key: "examples_per_sec" value { double_value: 3669.90792121 } } } PiperOrigin-RevId: 207617442
-rw-r--r--tensorflow/contrib/eager/python/tfe.py3
-rw-r--r--tensorflow/python/eager/BUILD1
-rw-r--r--tensorflow/python/eager/benchmarks_test.py50
-rw-r--r--tensorflow/python/eager/function.py332
-rw-r--r--tensorflow/python/eager/function_test.py247
-rw-r--r--tensorflow/python/framework/tensor_spec.py9
6 files changed, 562 insertions, 80 deletions
diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py
index 2f0ab616e4..de11d00a1a 100644
--- a/tensorflow/contrib/eager/python/tfe.py
+++ b/tensorflow/contrib/eager/python/tfe.py
@@ -71,6 +71,8 @@ To use, at program startup, call `tfe.enable_eager_execution()`.
@@run_test_in_graph_and_eager_modes
@@run_all_tests_in_graph_and_eager_modes
+@@TensorSpec
+
@@DEVICE_PLACEMENT_EXPLICIT
@@DEVICE_PLACEMENT_WARN
@@DEVICE_PLACEMENT_SILENT
@@ -114,6 +116,7 @@ from tensorflow.python.eager.execution_callbacks import inf_callback
from tensorflow.python.eager.execution_callbacks import inf_nan_callback
from tensorflow.python.eager.execution_callbacks import nan_callback
from tensorflow.python.eager.execution_callbacks import seterr
+from tensorflow.python.framework.tensor_spec import TensorSpec
from tensorflow.python.framework.ops import enable_eager_execution
from tensorflow.python.framework.ops import enable_eager_execution_internal as enable_remote_eager_execution
from tensorflow.python.framework.ops import eager_run as run
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index 32a8452f62..de93b1e2e1 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -249,6 +249,7 @@ py_library(
"//tensorflow/python/eager:execute",
"//tensorflow/python/eager:tape",
"//third_party/py/numpy",
+ "@six_archive//:six",
],
)
diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py
index afc4bf0066..1a78559ac0 100644
--- a/tensorflow/python/eager/benchmarks_test.py
+++ b/tensorflow/python/eager/benchmarks_test.py
@@ -38,8 +38,10 @@ from tensorflow.python.eager import context
from tensorflow.python.eager import core
from tensorflow.python.eager import function
from tensorflow.python.eager import test
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops
@@ -527,6 +529,54 @@ class MicroBenchmarks(test.Benchmark):
self._benchmark_defun_matmul(
m, transpose_b=True, num_iters=self._num_iters_100_by_784)
+ def benchmark_defun_without_signature(self):
+
+ def func(t1, t2, t3, t4, t5, t6, t7, t8):
+ del t1, t2, t3, t4, t5, t6, t7, t8
+ return None
+
+ defined = function.defun(func)
+ t = constant_op.constant(0.0)
+ cache_computation = lambda: defined(t, t, t, t, t, t, t, t)
+ self._run(cache_computation, 30000)
+
+ def benchmark_defun_without_signature_and_with_kwargs(self):
+
+ def func(t1, t2, t3, t4, t5, t6, t7, t8):
+ del t1, t2, t3, t4, t5, t6, t7, t8
+ return None
+
+ defined = function.defun(func)
+ t = constant_op.constant(0.0)
+ def cache_computation():
+ return defined(t1=t, t2=t, t3=t, t4=t, t5=t, t6=t, t7=t, t8=t)
+ self._run(cache_computation, 30000)
+
+ def benchmark_defun_with_signature(self):
+
+ def func(t1, t2, t3, t4, t5, t6, t7, t8):
+ del t1, t2, t3, t4, t5, t6, t7, t8
+ return None
+
+ defined = function.defun(
+ func, input_signature=[tensor_spec.TensorSpec([], dtypes.float32)] * 8)
+ t = constant_op.constant(0.0)
+ signature_computation = lambda: defined(t, t, t, t, t, t, t, t)
+ self._run(signature_computation, 30000)
+
+ def benchmark_defun_with_signature_and_kwargs(self):
+
+ def func(t1, t2, t3, t4, t5, t6, t7, t8):
+ del t1, t2, t3, t4, t5, t6, t7, t8
+ return None
+
+ defined = function.defun(
+ func, input_signature=[tensor_spec.TensorSpec([], dtypes.float32)] * 8)
+ t = constant_op.constant(0.0)
+ def signature_computation():
+ return defined(t1=t, t2=t, t3=t, t4=t, t5=t, t6=t, t7=t, t8=t)
+ self._run(signature_computation, 30000)
+
def benchmark_matmul_read_variable_op_2_by_2_CPU(self):
with context.device(CPU):
m = resource_variable_ops.ResourceVariable(self._m_2_by_2)
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 29e234efd8..ca0c6b18eb 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -24,6 +24,7 @@ import functools
import threading
import numpy as np
+import six
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import function_pb2
@@ -35,6 +36,7 @@ from tensorflow.python.eager.graph_only_ops import graph_placeholder
from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import dtypes as dtypes_module
from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import functional_ops
@@ -43,6 +45,7 @@ from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.util import compat
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
+from tensorflow.python.util import tf_inspect
def capture_value(tensor_map, value, dtype, name):
@@ -659,43 +662,68 @@ class GraphModeFunction(object):
return ret
-def _get_defun_inputs(args):
- """Maps the inputs args to graph inputs."""
- ret = []
- flat_args = nest.flatten(args)
- for a in flat_args:
- if isinstance(a, ops.Tensor):
- ret.append(graph_placeholder(a.dtype, a.shape))
- else:
- ret.append(a)
- return nest.pack_sequence_as(args, ret)
+def _get_defun_inputs_from_signature(signature):
+ """Maps a signature to graph-construction inputs."""
+ function_inputs = [
+ graph_placeholder(spec.dtype, spec.shape)
+ for spec in nest.flatten(signature)
+ ]
+ return nest.pack_sequence_as(signature, function_inputs)
+
+
+def _get_defun_inputs_from_args(args):
+ """Maps python function args to graph-construction inputs."""
+ function_inputs = [
+ graph_placeholder(arg.dtype, arg.shape) if isinstance(arg, ops.Tensor)
+ else arg for arg in nest.flatten(args)
+ ]
+ return nest.pack_sequence_as(args, function_inputs)
-def _deterministic_dict_values(kwds):
- return tuple(kwds[key] for key in sorted(kwds))
+def _trace_and_define_function(name, python_func, compiled, args, kwds,
+ signature=None):
+ """Defines and returns graph-mode version of `python_func`.
+ Args:
+ name: an identifier for the function.
+ python_func: the Python function to trace.
+ compiled: whether the graph function should be compiled through XLA.
+ args: the positional args with which the Python function should be called;
+ ignored if a signature is provided.
+ kwds: the keyword args with which the Python function should be called;
+ ignored if a signature is provided.
+ signature: a possibly nested sequence of `TensorSpecs` specifying the shapes
+ and dtypes of the arguments. When a signature is provided, `args` and
+ `kwds` are ignored, and `python_func` is traced with Tensors conforming
+ to `signature`. If `None`, the shapes and dtypes are inferred from the
+ inputs.
-def _trace_and_define_function(name, func, compiled, args, kwds):
- """Defines and returns graph-mode version of func."""
+ Returns:
+ A GraphModeFunction.
+ """
graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access
- tmp_graph = CapturingGraph()
+ func_graph = CapturingGraph()
# Inherit the graph key, since this is used for matching variables in
# optimizers.
- tmp_graph._graph_key = graph_key # pylint: disable=protected-access
+ func_graph._graph_key = graph_key # pylint: disable=protected-access
# Copy the graph collections to ensure summaries and other things work. This
# lets the function access (but not mutate) collections of the containing
# graph, such as the global step and the summary writer collections.
curr_graph = ops.get_default_graph()
for collection in curr_graph.collections:
- tmp_graph.get_collection_ref(collection)[:] = curr_graph.get_collection(
+ func_graph.get_collection_ref(collection)[:] = curr_graph.get_collection(
collection)
if context.executing_eagerly():
- tmp_graph.seed = context.global_seed()
+ func_graph.seed = context.global_seed()
else:
- tmp_graph.seed = curr_graph.seed
- with tmp_graph.as_default(), AutomaticControlDependencies() as a:
- func_args = _get_defun_inputs(args)
- func_kwds = _get_defun_inputs(kwds)
+ func_graph.seed = curr_graph.seed
+ with func_graph.as_default(), AutomaticControlDependencies() as a:
+ if signature is None:
+ func_args = _get_defun_inputs_from_args(args)
+ func_kwds = _get_defun_inputs_from_args(kwds)
+ else:
+ func_args = _get_defun_inputs_from_signature(signature)
+ func_kwds = {}
# Variables to help check whether mutation happens in calling the function
# Copy the recursive list, tuple and map structure, but not base objects
@@ -711,7 +739,7 @@ def _trace_and_define_function(name, func, compiled, args, kwds):
this_tape = tape.push_new_tape()
try:
- func_outputs = func(*func_args, **func_kwds)
+ func_outputs = python_func(*func_args, **func_kwds)
func_outputs = nest.map_structure(convert, func_outputs)
def check_mutation(n1, n2):
@@ -740,11 +768,11 @@ def _trace_and_define_function(name, func, compiled, args, kwds):
# call to convert_to_tensor, so we manually capture all such tensors.
outputs_list = _flatten(func_outputs)
func_def_outputs = [
- tmp_graph.capture(x) for x in outputs_list
+ func_graph.capture(x) for x in outputs_list
if x is not None
]
- captures = tmp_graph.captures
+ captures = func_graph.captures
ids = list(sorted(captures.keys()))
if ids:
extra_inputs, extra_placeholders = zip(* [captures[x] for x in ids])
@@ -755,20 +783,20 @@ def _trace_and_define_function(name, func, compiled, args, kwds):
x.shape if isinstance(x, ops.Tensor) else None
for x in func_def_outputs)
- func_kwds_values = _deterministic_dict_values(func_kwds)
+ # Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`.
flat_inputs = [
- x for x in nest.flatten(func_args) + nest.flatten(func_kwds_values)
+ x for x in nest.flatten(func_args) + nest.flatten(func_kwds)
if isinstance(x, ops.Tensor)
]
all_inputs = flat_inputs + list(extra_placeholders)
all_ignored_ops = frozenset(x.op for x in all_inputs)
fname = _inference_name(name)
- operations = tuple(x for x in tmp_graph.get_operations()
+ operations = tuple(x for x in func_graph.get_operations()
if x not in all_ignored_ops)
# Register any other functions defined in the graph
# TODO(ashankar): Oh lord, forgive me for this lint travesty.
if context.executing_eagerly():
- for f in tmp_graph._functions.values(): # pylint: disable=protected-access
+ for f in func_graph._functions.values(): # pylint: disable=protected-access
# TODO(ashankar): What about the gradient registry?
_register(f._c_func.func) # pylint: disable=protected-access
@@ -777,41 +805,55 @@ def _trace_and_define_function(name, func, compiled, args, kwds):
attrs[_xla_compile_attr] = attr_value_pb2.AttrValue(b=True)
return GraphModeFunction(
- fname, all_inputs, extra_inputs, tmp_graph, operations, func_def_outputs,
+ fname, all_inputs, extra_inputs, func_graph, operations, func_def_outputs,
func_outputs, output_shapes, variables, attrs)
-# Defun uses this instead of Tensor as a cache key. Using dtype because
-# TensorFlow graphs are not parametric wrt dtypes, and using shapes for
-# performance reasons, as much TensorFlow code specializes on known shapes to
-# produce slimmer graphs.
-_TensorDtype = collections.namedtuple("_TensorDtype", ["dtype", "shape"])
-_ZeroDtype = collections.namedtuple("_ZeroDtype", ["dtype", "shape"])
+_TensorType = collections.namedtuple("_TensorType", ["dtype", "shape"])
+
+def _encode_arg(arg):
+ """A canonical representation for this argument, for use in a cache key."""
-def _cache_key(x):
- """Cache key for tfe functions."""
- if isinstance(x, ops.Tensor):
- return _TensorDtype(x.dtype, x._shape_tuple()) # pylint: disable=protected-access
- if isinstance(x, ops.IndexedSlices):
- if x.dense_shape is not None:
+ # `defun` uses dtypes and shapes instead of `Tensors` as cache keys. Dtypes
+ # are used because TensorFlow graphs are not parametric w.r.t. dtypes. Shapes
+ # are used for both performance reasons, as much TensorFlow code specializes
+ # on known shapes to produce slimmer graphs, and correctness, as some
+ # high-level APIs require shapes to be fully-known.
+ #
+ # TODO(akshayka): Add support for sparse tensors.
+ #
+ # pylint: disable=protected-access
+ if isinstance(arg, ops.Tensor):
+ return _TensorType(arg.dtype, arg._shape_tuple())
+ elif isinstance(arg, ops.IndexedSlices):
+ if arg.dense_shape is not None:
return tuple([
- _TensorDtype(x.values.dtype, x.values._shape_tuple()), # pylint: disable=protected-access
- _TensorDtype(x.indices.dtype, x.indices._shape_tuple()), # pylint: disable=protected-access
- _TensorDtype(x.dense_shape.dtype, x.dense_shape._shape_tuple()) # pylint: disable=protected-access
+ _TensorType(arg.values.dtype, arg.values._shape_tuple()),
+ _TensorType(arg.indices.dtype, arg.indices._shape_tuple()),
+ _TensorType(arg.dense_shape.dtype, arg.dense_shape._shape_tuple()),
])
else:
return tuple([
- _TensorDtype(x.values.dtype, x.values._shape_tuple()), # pylint: disable=protected-access
- _TensorDtype(x.indices.dtype, x.indices._shape_tuple()) # pylint: disable=protected-access
+ _TensorType(arg.values.dtype, arg.values._shape_tuple()),
+ _TensorType(arg.indices.dtype, arg.indices._shape_tuple()),
])
- if isinstance(x, np.ndarray):
- return ("array", x.shape, tuple(x.reshape(-1)))
- if isinstance(x, (list, tuple)):
- return tuple([_cache_key(a) for a in x])
- if isinstance(x, dict):
- return tuple(tuple([_cache_key(k), _cache_key(v)]) for k, v in x.items())
- return x
+ # pylint: enable=protected-access
+ elif isinstance(arg, np.ndarray):
+ # TODO(akshayka): Consider instead converting this NumPy array to a Tensor
+ # and encoding it with a _TensorType.
+ return ("array", arg.shape, tuple(arg.reshape(-1)))
+ elif isinstance(arg, (list, tuple)):
+ return tuple([_encode_arg(elem) for elem in arg])
+ elif isinstance(arg, dict):
+ return tuple(
+ (_encode_arg(key), _encode_arg(arg[key])) for key in sorted(arg))
+ else:
+ return arg
+
+
+def _deterministic_dict_values(dictionary):
+ return tuple(dictionary[key] for key in sorted(dictionary))
class _PolymorphicFunction(object):
@@ -826,16 +868,37 @@ class _PolymorphicFunction(object):
synchronization is necessary.
"""
- def __init__(self, python_function, name, compiled=False):
+ def __init__(self,
+ python_function,
+ name,
+ input_signature=None,
+ compiled=False):
"""Initializes a polymorphic function.
Args:
python_function: the function to be wrapped.
name: the name given to it.
+ input_signature: a possibly nested sequence of `TensorSpec` objects
+ specifying the input signature of this function. If `None`, a separate
+ function is instantiated for each inferred input signature.
compiled: if True, the framework will attempt to compile func with XLA.
+
+ Raises:
+ ValueError: if `input_signature` is not None and the `python_function`'s
+ argspec has keyword arguments.
+ TypeError: if `input_signature` contains anything other than
+ `TensorSpec` objects, or (if not None) is anything other than a tuple or
+ list.
"""
- self._python_function = python_function
+ if isinstance(python_function, functools.partial):
+ self._python_function = python_function.func
+ self._args_to_prepend = python_function.args or tuple()
+ self._kwds_to_include = python_function.keywords or {}
+ else:
+ self._python_function = python_function
+ self._args_to_prepend = tuple()
+ self._kwds_to_include = {}
self._name = name
self._compiled = compiled
self._arguments_to_functions = {}
@@ -843,6 +906,35 @@ class _PolymorphicFunction(object):
self._lock = threading.Lock()
+ fullargspec = tf_inspect.getfullargspec(self._python_function)
+ # A cache mapping from argument name to index, for canonicalizing
+ # arguments that are called in a keyword-like fashion.
+ self._args_to_indices = {arg: i for i, arg in enumerate(fullargspec.args)}
+ # A cache mapping from arg index to default value, for canonicalization.
+ offset = len(fullargspec.args) - len(fullargspec.defaults or [])
+ self._arg_indices_to_default_values = {
+ offset + index: default
+ for index, default in enumerate(fullargspec.defaults or [])
+ }
+ if input_signature is None:
+ self._input_signature = None
+ else:
+ if fullargspec.varkw is not None or fullargspec.kwonlyargs:
+ raise ValueError("Cannot define a TensorFlow function from a Python "
+ "function with keyword arguments when "
+ "input_signature is provided.")
+
+ if not isinstance(input_signature, (tuple, list)):
+ raise TypeError("input_signature must be either a tuple or a "
+ "list, received " + str(type(input_signature)))
+
+ self._input_signature = tuple(input_signature)
+ self._flat_input_signature = tuple(nest.flatten(input_signature))
+ if any(not isinstance(arg, tensor_spec.TensorSpec)
+ for arg in self._flat_input_signature):
+ raise TypeError("Invalid input_signature %s; input_signature must be "
+ "a possibly nested sequence of TensorSpec objects.")
+
def __get__(self, instance, owner):
"""Makes it possible to defun instance methods."""
del owner
@@ -861,36 +953,119 @@ class _PolymorphicFunction(object):
# then `instance` will be `foo` (and `owner` will be `Foo`).
return functools.partial(self.__call__, instance)
+ def _cache_key(self, args, kwds):
+ """Computes the cache key given inputs."""
+ if self._input_signature is None:
+ inputs = (args, kwds) if kwds else args
+ cache_key = tuple(_encode_arg(arg) for arg in inputs)
+ else:
+ del args, kwds
+ cache_key = self._flat_input_signature
+ # The graph, or whether we're executing eagerly, should be a part of the
+ # cache key so we don't improperly capture tensors such as variables.
+ return cache_key + (context.executing_eagerly() or ops.get_default_graph(),)
+
+ def _canonicalize_function_inputs(self, *args, **kwds):
+ """Canonicalizes `args` and `kwds`.
+
+ Canonicalize the inputs to the Python function using its fullargspec. In
+ particular, we parse the varags and kwargs that this
+ `_PolymorphicFunction` was called with into a tuple corresponding to the
+ Python function's positional (named) arguments and a dictionary
+ corresponding to its kwargs.
+
+ Args:
+ *args: The varargs this object was called with.
+ **kwds: The keyword args this function was called with.
+
+ Returns:
+ A canonicalized ordering of the inputs.
+
+ Raises:
+ ValueError: If a keyword in `kwds` cannot be matched with a positional
+ argument when an input signature is specified, or when the inputs
+ do not conform to the input signature.
+ """
+ args = self._args_to_prepend + args
+ kwds = dict(kwds, **self._kwds_to_include)
+ # Maps from index of arg to its corresponding value, according to `args`
+ # and `kwds`; seeded with the default values for the named args that aren't
+ # in `args`.
+ arg_indices_to_values = {
+ index: default
+ for index, default in six.iteritems(self._arg_indices_to_default_values)
+ if index >= len(args)
+ }
+ consumed_args = []
+ for arg, value in six.iteritems(kwds):
+ index = self._args_to_indices.get(arg, None)
+ if index is not None:
+ arg_indices_to_values[index] = value
+ consumed_args.append(arg)
+ elif self._input_signature is not None:
+ raise ValueError("Cannot define a TensorFlow function from a Python "
+ "function with keyword arguments when "
+ "input_signature is provided.")
+ for arg in consumed_args:
+ # After this loop, `kwds` will only contain true keyword arguments, as
+ # opposed to named arguments called in a keyword-like fashion.
+ kwds.pop(arg)
+ inputs = args + _deterministic_dict_values(arg_indices_to_values)
+ if self._input_signature is None:
+ return inputs, kwds
+ else:
+ assert not kwds
+ try:
+ nest.assert_same_structure(self._input_signature, inputs)
+ except (ValueError, TypeError):
+ raise ValueError("Structure of Python function inputs does not match "
+ "input_signature.")
+ flat_inputs = nest.flatten(inputs)
+ if any(not isinstance(arg, ops.Tensor) for arg in flat_inputs):
+ raise ValueError("When input_signature is provided, all inputs to "
+ "the Python function must be Tensors.")
+ tensor_specs = [tensor_spec.TensorSpec.from_tensor(tensor)
+ for tensor in flat_inputs]
+ if any(not spec.is_compatible_with(other)
+ for spec, other in zip(self._flat_input_signature, tensor_specs)):
+ raise ValueError("Python inputs incompatible with input_signature: "
+ "inputs (%s), input_signature (%s)" %
+ (str(inputs), str(self._input_signature)))
+ return inputs, {}
+
def _maybe_define_function(self, *args, **kwds):
"""Gets a function for these inputs, defining it if necessary.
Args:
- *args: args for the Python function; used to compute the signature
- **kwds: kwds for the Python function; used to compute the signature
+ *args: args for the Python function.
+ **kwds: keywords for the Python function.
Returns:
A graph function corresponding to the input signature implied by args and
kwds, as well as the inputs that the object should be called with.
- """
- # TODO(apassos): Better error messages for non-hashable arguments.
- kwd_values = _deterministic_dict_values(kwds)
- inputs = args + kwd_values
- signature = tuple(_cache_key(x) for x in inputs)
- # The graph, or whether we're executing eagerly, should be a part of the
- # signature so we don't improperly capture tensors such as variables.
- signature += tuple([context.executing_eagerly() or ops.get_default_graph()])
+ Raises:
+ ValueError: If inputs are incompatible with the input signature.
+ TypeError: If the function inputs include non-hashable objects
+ """
+ args, kwds = self._canonicalize_function_inputs(*args, **kwds)
+ cache_key = self._cache_key(args, kwds)
with self._lock:
- if signature not in self._arguments_to_functions:
+ try:
+ graph_function = self._arguments_to_functions.get(cache_key, None)
+ except TypeError:
+ raise TypeError("Arguments supplied to `defun`-generated functions "
+ "must be hashable.")
+
+ if graph_function is None:
graph_function = _trace_and_define_function(
- self._name, self._python_function, self._compiled, args, kwds)
- self._arguments_to_functions[signature] = graph_function
+ self._name, self._python_function, self._compiled, args, kwds,
+ self._input_signature)
self._variables.extend(
[v for v in graph_function.variables if v not in self._variables])
- return graph_function, inputs
- else:
- return self._arguments_to_functions[signature], inputs
+ self._arguments_to_functions[cache_key] = graph_function
+ return graph_function, (args, kwds)
def __call__(self, *args, **kwds):
"""Calls a graph function specialized for this input signature."""
@@ -910,7 +1085,7 @@ class _PolymorphicFunction(object):
# TODO(akshayka): Remove the `compiled` flag and create a separate
# API for xla compilation (`defun` is already complicated enough
# as it is, and the keyword argument makes 'compiled' an overloaded concept)
-def defun(func=None, compiled=False):
+def defun(func=None, input_signature=None, compiled=False):
"""Compiles a Python function into a callable TensorFlow graph.
`defun` (short for "define function") trace-compiles a Python function
@@ -1165,6 +1340,13 @@ def defun(func=None, compiled=False):
def foo(...):
...
+ input_signature: A possibly nested sequence of
+ `tf.contrib.eager.TensorSpec` objects specifying the shapes and dtypes of
+ the Tensors that will be supplied to this function. If `None`, a separate
+ function is instantiated for each inferred input signature. If a
+ signature is specified, every input to `func` must be a `Tensor`, and
+ `func` cannot accept `**kwargs`.
+
compiled: If True, an attempt to compile `func` with XLA will be made.
If it fails, function will be run normally. Experimental. Currently
supported only for execution on TPUs. For the vast majority of users,
@@ -1183,7 +1365,9 @@ def defun(func=None, compiled=False):
except AttributeError:
name = "function"
return tf_decorator.make_decorator(
- function, _PolymorphicFunction(function, name, compiled=compiled))
+ function,
+ _PolymorphicFunction(
+ function, name, input_signature=input_signature, compiled=compiled))
# This code path is for the `foo = tfe.defun(foo, ...)` use case
if func is not None:
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 5efdecdbc6..b568af9bce 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
import collections
+import functools
import sys
from tensorflow.core.protobuf import config_pb2
@@ -33,6 +34,7 @@ from tensorflow.python.framework import function as tf_function
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.layers import convolutional
from tensorflow.python.ops import array_ops
@@ -50,6 +52,7 @@ from tensorflow.python.training import adam
from tensorflow.python.training import momentum
from tensorflow.python.training import training_ops
from tensorflow.python.util import compat
+from tensorflow.python.util import nest
@test_util.with_c_shapes
@@ -897,6 +900,237 @@ class FunctionTest(test.TestCase):
_ = defined(x) # ensure the variables list remains the same
self.assertAllEqual(defined.variables, [v])
+ def testPythonFunctionWithDefaultArgs(self):
+
+ def func(foo, bar=1, baz=2):
+ del foo
+ del bar
+ del baz
+ return
+
+ defined = function.defun(func)
+ defined(0, baz=20)
+ # `True` corresponds to the fact that we're executing eagerly
+ self.assertIn((0, 1, 20, True), defined._arguments_to_functions)
+
+ defined(1) # bar=1, baz=2
+ self.assertIn((1, 1, 2, True), defined._arguments_to_functions)
+
+ # This matches the previous call.
+ defined(foo=1)
+ self.assertEqual(len(defined._arguments_to_functions), 2)
+
+ defined(1, 2, 3)
+ self.assertIn((1, 2, 3, True), defined._arguments_to_functions)
+
+ # This matches the previous call.
+ defined(1, bar=2, baz=3)
+ self.assertEqual(len(defined._arguments_to_functions), 3)
+
+ # This matches the previous call.
+ defined(1, baz=3, bar=2)
+ self.assertEqual(len(defined._arguments_to_functions), 3)
+
+ def testFunctoolsPartialUnwrappedCorrectly(self):
+
+ def full_function(a, b, c=3):
+ return a, b, c
+
+ partial = functools.partial(full_function, 1, c=3)
+ a, b, c = partial(2)
+
+ defined = function.defun(partial)
+ func_a, func_b, func_c = defined(2)
+ self.assertEqual(func_a.numpy(), a)
+ self.assertEqual(func_b.numpy(), b)
+ self.assertEqual(func_c.numpy(), c)
+
+ def testInputSignatureWithCompatibleInputs(self):
+
+ def foo(a):
+ self.assertEqual(a.shape, (2,))
+ return a
+
+ signature = [tensor_spec.TensorSpec(shape=(2,), dtype=dtypes.float32)]
+ defined = function.defun(foo, input_signature=signature)
+ a = array_ops.ones([2])
+ out = defined(a)
+ self.assertEqual(len(defined._arguments_to_functions), 1)
+ self.assertAllEqual(out, a)
+
+ def bar(a):
+ self.assertEqual(a._shape_tuple(), (2, None))
+ return a
+
+ signature = [tensor_spec.TensorSpec((2, None), dtypes.float32)]
+ defined = function.defun(bar, input_signature=signature)
+ a = array_ops.ones([2, 1])
+ out = defined(a)
+ self.assertEqual(len(defined._arguments_to_functions), 1)
+ self.assertAllEqual(out, a)
+
+ # Changing the second dimension shouldn't create a new function.
+ b = array_ops.ones([2, 3])
+ out = defined(b)
+ self.assertEqual(len(defined._arguments_to_functions), 1)
+ self.assertAllEqual(out, b)
+
+ def testNestedInputSignatures(self):
+
+ def foo(a, b):
+ self.assertEqual(a[0]._shape_tuple(), (2, None))
+ self.assertEqual(a[1]._shape_tuple(), (2, None))
+ self.assertEqual(b._shape_tuple(), (1,))
+ return [a, b]
+
+ signature = [[tensor_spec.TensorSpec((2, None), dtypes.float32)] * 2,
+ tensor_spec.TensorSpec((1,), dtypes.float32)]
+ defined = function.defun(foo, input_signature=signature)
+ a = array_ops.ones([2, 1])
+ b = array_ops.ones([1])
+ out = defined([a, a], b)
+ self.assertEqual(len(defined._arguments_to_functions), 1)
+ nest.assert_same_structure(out, [[a, a], b])
+ self.assertAllEqual(out[0][0], a)
+ self.assertAllEqual(out[0][1], a)
+ self.assertAllEqual(out[1], b)
+
+ # Changing the unspecified dimensions shouldn't create a new function.
+ a = array_ops.ones([2, 3])
+ b = array_ops.ones([2, 5])
+ c = array_ops.ones([1])
+ out = defined([a, b], c)
+ self.assertEqual(len(defined._arguments_to_functions), 1)
+ nest.assert_same_structure(out, [[a, b], c])
+ self.assertAllEqual(out[0][0], a)
+ self.assertAllEqual(out[0][1], b)
+ self.assertAllEqual(out[1], c)
+
+ def bar(a):
+ self.assertEqual(a['a']._shape_tuple(), (2, None))
+ self.assertEqual(a['b']._shape_tuple(), (2, None))
+ self.assertEqual(a['c']._shape_tuple(), (1,))
+ return a
+
+ signature = [{
+ 'a': tensor_spec.TensorSpec((2, None), dtypes.float32),
+ 'b': tensor_spec.TensorSpec((2, None), dtypes.float32),
+ 'c': tensor_spec.TensorSpec((1,), dtypes.float32)
+ }]
+ a = array_ops.ones([2, 3])
+ b = array_ops.ones([1])
+ inputs = {'a': a, 'b': a, 'c': b}
+ defined = function.defun(bar, input_signature=signature)
+ out = defined(inputs)
+ nest.assert_same_structure(out, inputs)
+ self.assertAllEqual(out['a'], inputs['a'])
+ self.assertAllEqual(out['b'], inputs['b'])
+ self.assertAllEqual(out['c'], inputs['c'])
+
+ def testInputSignatureMustBeSequenceOfTensorSpecs(self):
+
+ def foo(a, b):
+ del a
+ del b
+
+ # Signatures must consist exclusively of `TensorSpec` objects.
+ signature = [(2, 3), tensor_spec.TensorSpec([2, 3], dtypes.float32)]
+ with self.assertRaisesRegexp(TypeError, 'Invalid input_signature.*'):
+ function.defun(foo, input_signature=signature)(1, 2)
+
+ # Signatures must be either lists or tuples on their outermost levels.
+ signature = {'t1': tensor_spec.TensorSpec([], dtypes.float32)}
+ with self.assertRaisesRegexp(TypeError, 'input_signature must be either a '
+ 'tuple or a list.*'):
+ function.defun(foo, input_signature=signature)(1, 2)
+
+ def testInputsIncompatibleWithSignatureRaisesError(self):
+
+ def foo(a):
+ return a
+
+ signature = [tensor_spec.TensorSpec(shape=(2,), dtype=dtypes.float32)]
+ defined = function.defun(foo, input_signature=signature)
+
+ # Invalid shapes.
+ with self.assertRaisesRegexp(ValueError, 'Python inputs incompatible.*'):
+ defined(array_ops.ones([3]))
+
+ with self.assertRaisesRegexp(ValueError, 'Python inputs incompatible.*'):
+ defined(array_ops.ones([2, 1]))
+
+ # Wrong number of arguments.
+ with self.assertRaisesRegexp(ValueError,
+ 'Structure of Python function inputs.*'):
+ defined(array_ops.ones([2]), array_ops.ones([2]))
+ with self.assertRaisesRegexp(ValueError,
+ 'Structure of Python function inputs.*'):
+ defined()
+
+ def testInputSignatureForFunctionWithNonTensorInputsNotAllowed(self):
+
+ def foo(a, training=True):
+ if training:
+ return a
+ else:
+ return -1.0 * a
+
+ signature = [tensor_spec.TensorSpec([], dtypes.float32)] * 2
+ defined = function.defun(foo, input_signature=signature)
+ a = constant_op.constant(1.0)
+ with self.assertRaisesRegexp(
+ ValueError, 'When input_signature is provided, '
+ 'all inputs to the Python function must be Tensors.'):
+ defined(a, training=True)
+
+ def testInputSignatureWithKeywordPositionalArgs(self):
+
+ @function.defun(input_signature=[
+ tensor_spec.TensorSpec([], dtypes.float32),
+ tensor_spec.TensorSpec([], dtypes.int64)
+ ])
+ def foo(flt, integer):
+ return flt, integer
+
+ flt = constant_op.constant(1.0)
+ integer = constant_op.constant(2, dtypes.int64)
+
+ out1, out2 = foo(flt, integer)
+ self.assertEqual(len(foo._arguments_to_functions), 1)
+ self.assertEqual(out1.numpy(), 1.0)
+ self.assertEqual(out2.numpy(), 2)
+
+ out1, out2 = foo(flt=flt, integer=integer)
+ self.assertEqual(len(foo._arguments_to_functions), 1)
+ self.assertEqual(out1.numpy(), 1.0)
+ self.assertEqual(out2.numpy(), 2)
+
+ out1, out2 = foo(integer=integer, flt=flt)
+ self.assertEqual(len(foo._arguments_to_functions), 1)
+ self.assertEqual(out1.numpy(), 1.0)
+ self.assertEqual(out2.numpy(), 2)
+
+ out1, out2 = foo(flt, integer=integer)
+ self.assertEqual(len(foo._arguments_to_functions), 1)
+ self.assertEqual(out1.numpy(), 1.0)
+ self.assertEqual(out2.numpy(), 2)
+
+ def testInputSignatureWithKeywordArgsFails(self):
+
+ def foo(a, **kwargs):
+ del a
+ del kwargs
+
+ with self.assertRaisesRegexp(
+ ValueError, 'Cannot define a TensorFlow function from a Python '
+ 'function with keyword arguments when input_signature.*'):
+ function.defun(
+ foo,
+ input_signature=[
+ tensor_spec.TensorSpec([], dtypes.float32),
+ tensor_spec.TensorSpec([], dtypes.int64)
+ ])
+
def testTensorKeywordArguments(self):
def foo(a, b):
@@ -964,7 +1198,9 @@ class FunctionTest(test.TestCase):
self.assertAllEqual(f(x=constant_op.constant(1.0)), 2.0)
- def testDecoratingInstanceMethod(self):
+ def testDefuningInstanceMethod(self):
+
+ integer = constant_op.constant(2, dtypes.int64)
class Foo(object):
@@ -972,13 +1208,14 @@ class FunctionTest(test.TestCase):
return tensor
@function.defun
- def two(self, tensor):
- return self.one(tensor)
+ def two(self, tensor, other=integer):
+ return self.one(tensor), other
foo = Foo()
t = constant_op.constant(1.0)
- out = foo.two(t)
- self.assertEqual(float(out), 1.0)
+ one, two = foo.two(t)
+ self.assertEqual(one.numpy(), 1.0)
+ self.assertEqual(two.numpy(), 2)
def testPythonCallWithSideEffects(self):
state = []
diff --git a/tensorflow/python/framework/tensor_spec.py b/tensorflow/python/framework/tensor_spec.py
index 6676cfcaa3..fbea930fe0 100644
--- a/tensorflow/python/framework/tensor_spec.py
+++ b/tensorflow/python/framework/tensor_spec.py
@@ -34,7 +34,7 @@ class TensorSpec(object):
construction and configuration.
"""
- __slots__ = ["_shape", "_dtype", "_name"]
+ __slots__ = ["_shape", "_shape_tuple", "_dtype", "_name"]
def __init__(self, shape, dtype, name=None):
"""Creates a TensorSpec.
@@ -49,6 +49,10 @@ class TensorSpec(object):
not convertible to a `tf.DType`.
"""
self._shape = tensor_shape.TensorShape(shape)
+ try:
+ self._shape_tuple = tuple(self.shape.as_list())
+ except ValueError:
+ self._shape_tuple = None
self._dtype = dtypes.as_dtype(dtype)
self._name = name
@@ -104,6 +108,9 @@ class TensorSpec(object):
return "TensorSpec(shape={}, dtype={}, name={})".format(
self.shape, repr(self.dtype), repr(self.name))
+ def __hash__(self):
+ return hash((self._shape_tuple, self.dtype))
+
def __eq__(self, other):
return self.shape == other.shape and self.dtype == other.dtype