diff options
author | Lasse Espeholt <lespeholt@google.com> | 2018-09-24 02:17:52 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-24 02:21:47 -0700 |
commit | b57bdf414edb27b82a95c5f4e2729fafd4cf2dc7 (patch) | |
tree | 8a811a6bea40e1f0b388910a669687962f31b604 /tensorflow | |
parent | cdcc7d31cce91169dc686387522d7015ac57db0e (diff) |
Clean-up of function.py.
PiperOrigin-RevId: 214232622
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/c/eager/BUILD | 5 | ||||
-rw-r--r-- | tensorflow/python/eager/BUILD | 5 | ||||
-rw-r--r-- | tensorflow/python/eager/function.py | 86 |
3 files changed, 53 insertions, 43 deletions
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 37be52f57d..3ee31a6a7a 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -68,7 +68,10 @@ tf_cuda_library( tf_cuda_library( name = "c_api_internal", hdrs = ["c_api_internal.h"], - visibility = ["//tensorflow:internal"], + visibility = [ + "//learning/deepmind/courier:__pkg__", + "//tensorflow:internal", + ], deps = [ ":c_api", "//tensorflow/c:c_api", diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index f571da308e..d3d997e6df 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -17,7 +17,10 @@ cc_library( "pywrap_tensor.h", "pywrap_tfe.h", ], - visibility = ["//tensorflow:internal"], + visibility = [ + "//learning/deepmind/courier:__pkg__", + "//tensorflow:internal", + ], deps = [ "//tensorflow/c:c_api", "//tensorflow/c:c_api_internal", diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index bcb1881264..1f5d479882 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -850,7 +850,7 @@ def _get_defun_inputs_from_args(args): def func_graph_from_py_func(name, python_func, args, - kwds, + kwargs, signature=None, func_graph=None): """Returns a `FuncGraph` generated from `python_func`. @@ -860,11 +860,11 @@ def func_graph_from_py_func(name, python_func: the Python function to trace. args: the positional args with which the Python function should be called; ignored if a signature is provided. - kwds: the keyword args with which the Python function should be called; + kwargs: the keyword args with which the Python function should be called; ignored if a signature is provided. signature: a possibly nested sequence of `TensorSpecs` specifying the shapes and dtypes of the arguments. When a signature is provided, `args` and - `kwds` are ignored, and `python_func` is traced with Tensors conforming + `kwargs` are ignored, and `python_func` is traced with Tensors conforming to `signature`. If `None`, the shapes and dtypes are inferred from the inputs. func_graph: Optional. An instance of FuncGraph. If provided, we will use @@ -885,16 +885,17 @@ def func_graph_from_py_func(name, if signature is None: func_args = _get_defun_inputs_from_args(args) - func_kwds = _get_defun_inputs_from_args(kwds) + func_kwargs = _get_defun_inputs_from_args(kwargs) else: func_args = _get_defun_inputs_from_signature(signature) - func_kwds = {} + func_kwargs = {} # Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`. # Variables to help check whether mutation happens in calling the function # Copy the recursive list, tuple and map structure, but not base objects func_args_before = nest.pack_sequence_as(func_args, nest.flatten(func_args)) - func_kwds_before = nest.pack_sequence_as(func_kwds, nest.flatten(func_kwds)) + func_kwargs_before = nest.pack_sequence_as( + func_kwargs, nest.flatten(func_kwargs)) def convert(x): """Converts an argument to a Tensor.""" @@ -913,7 +914,7 @@ def func_graph_from_py_func(name, this_tape = tape.push_new_tape() try: - func_outputs = python_func(*func_args, **func_kwds) + func_outputs = python_func(*func_args, **func_kwargs) # invariant: `func_outputs` contains only Tensors and `None`s. func_outputs = nest.map_structure(convert, func_outputs) @@ -933,16 +934,16 @@ def func_graph_from_py_func(name, raise ValueError(errmsg) check_mutation(func_args_before, func_args) - check_mutation(func_kwds_before, func_kwds) + check_mutation(func_kwargs_before, func_kwargs) finally: tape.pop_tape(this_tape) - # Variables in `func_args`, `func_kwds` should be explicit inputs + # Variables in `func_args`, `func_kwargs` should be explicit inputs # to the function, not captured inputs. tape_variables = this_tape.watched_variables() arg_variables = set() inputs = [] - for arg in nest.flatten(func_args) + nest.flatten(func_kwds): + for arg in nest.flatten(func_args) + nest.flatten(func_kwargs): if isinstance(arg, resource_variable_ops.ResourceVariable): try: resource_placeholder = func_graph.captures.pop(arg.handle) @@ -1073,11 +1074,11 @@ class PolymorphicFunction(object): if isinstance(python_function, functools.partial): self._python_function = python_function.func self._args_to_prepend = python_function.args or tuple() - self._kwds_to_include = python_function.keywords or {} + self._kwargs_to_include = python_function.keywords or {} else: self._python_function = python_function self._args_to_prepend = tuple() - self._kwds_to_include = {} + self._kwargs_to_include = {} self._name = name self._function_cache = collections.OrderedDict() self._function_attributes = attributes or {} @@ -1115,9 +1116,9 @@ class PolymorphicFunction(object): self._input_signature = tuple(input_signature) self._flat_input_signature = tuple(nest.flatten(input_signature)) - def __call__(self, *args, **kwds): + def __call__(self, *args, **kwargs): """Calls a graph function specialized to the inputs.""" - graph_function, inputs = self._maybe_define_function(*args, **kwds) + graph_function, inputs = self._maybe_define_function(args, kwargs) return graph_function(*inputs) @property @@ -1135,7 +1136,7 @@ class PolymorphicFunction(object): *args: inputs to specialize on. **kwargs: inputs to specialize on. """ - graph_function, _ = self._maybe_define_function(*args, **kwargs) + graph_function, _ = self._maybe_define_function(args, kwargs) return graph_function def __get__(self, instance, owner): @@ -1156,13 +1157,13 @@ class PolymorphicFunction(object): # then `instance` will be `foo` (and `owner` will be `Foo`). return functools.partial(self.__call__, instance) - def _cache_key(self, args, kwds, ctx, graph): + def _cache_key(self, args, kwargs, ctx, graph): """Computes the cache key given inputs and execution context.""" if self._input_signature is None: - inputs = (args, kwds) if kwds else args + inputs = (args, kwargs) if kwargs else args cache_key = tuple(_encode_arg(arg) for arg in inputs) else: - del args, kwds + del args, kwargs cache_key = self._flat_input_signature # The graph, or whether we're executing eagerly, should be a part of the @@ -1181,8 +1182,8 @@ class PolymorphicFunction(object): return cache_key + (execution_context, device_functions, colocation_stack) - def _canonicalize_function_inputs(self, *args, **kwds): - """Canonicalizes `args` and `kwds`. + def _canonicalize_function_inputs(self, *args, **kwargs): + """Canonicalizes `args` and `kwargs`. Canonicalize the inputs to the Python function using its fullargspec. In particular, we parse the varags and kwargs that this @@ -1192,28 +1193,28 @@ class PolymorphicFunction(object): Args: *args: The varargs this object was called with. - **kwds: The keyword args this function was called with. + **kwargs: The keyword args this function was called with. Returns: A canonicalized ordering of the inputs. Raises: - ValueError: If a keyword in `kwds` cannot be matched with a positional + ValueError: If a keyword in `kwargs` cannot be matched with a positional argument when an input signature is specified, or when the inputs do not conform to the input signature. """ args = self._args_to_prepend + args - kwds = dict(kwds, **self._kwds_to_include) + kwargs = dict(kwargs, **self._kwargs_to_include) # Maps from index of arg to its corresponding value, according to `args` - # and `kwds`; seeded with the default values for the named args that aren't - # in `args`. + # and `kwargs`; seeded with the default values for the named args that + # aren't in `args`. arg_indices_to_values = { index: default for index, default in six.iteritems(self._arg_indices_to_default_values) if index >= len(args) } consumed_args = [] - for arg, value in six.iteritems(kwds): + for arg, value in six.iteritems(kwargs): index = self._args_to_indices.get(arg, None) if index is not None: arg_indices_to_values[index] = value @@ -1223,9 +1224,9 @@ class PolymorphicFunction(object): "function with keyword arguments when " "input_signature is provided.") for arg in consumed_args: - # After this loop, `kwds` will only contain true keyword arguments, as + # After this loop, `kwargs` will only contain true keyword arguments, as # opposed to named arguments called in a keyword-like fashion. - kwds.pop(arg) + kwargs.pop(arg) inputs = args + _deterministic_dict_values(arg_indices_to_values) flat_inputs = nest.flatten(inputs) @@ -1239,9 +1240,9 @@ class PolymorphicFunction(object): inputs = nest.pack_sequence_as(structure=inputs, flat_sequence=flat_inputs) if self._input_signature is None: - return inputs, kwds + return inputs, kwargs else: - assert not kwds + assert not kwargs try: nest.assert_same_structure(self._input_signature, inputs) except (ValueError, TypeError): @@ -1260,24 +1261,27 @@ class PolymorphicFunction(object): (str(inputs), str(self._input_signature))) return inputs, {} - def _maybe_define_function(self, *args, **kwds): + def _maybe_define_function(self, args, kwargs): """Gets a function for these inputs, defining it if necessary. + `args` and `kwargs` can be None if this `PolymorphicFunction` was created + with an `input_signature`. + Args: - *args: args for the Python function. - **kwds: keywords for the Python function. + args: The varargs for the Python function. + kwargs: The keyword args for the Python function. Returns: A graph function corresponding to the input signature implied by args and - kwds, as well as the inputs that the object should be called with. + kwargs, as well as the inputs that the object should be called with. Raises: ValueError: If inputs are incompatible with the input signature. TypeError: If the function inputs include non-hashable objects """ - - args, kwds = self._canonicalize_function_inputs(*args, **kwds) - cache_key = self._cache_key(args, kwds, context.context(), + if self._input_signature is None or args is not None or kwargs is not None: + args, kwargs = self._canonicalize_function_inputs(*args, **kwargs) + cache_key = self._cache_key(args, kwargs, context.context(), ops.get_default_graph()) with self._lock: try: @@ -1289,11 +1293,11 @@ class PolymorphicFunction(object): if graph_function is None: graph_function = Function( func_graph_from_py_func(self._name, self._python_function, args, - kwds, self._input_signature), + kwargs, self._input_signature), self._function_attributes) self._function_cache[cache_key] = graph_function return graph_function, [ - t for t in nest.flatten((args, kwds)) + t for t in nest.flatten((args, kwargs)) if isinstance(t, (ops.Tensor, resource_variable_ops.ResourceVariable)) ] @@ -1933,9 +1937,9 @@ def automatic_control_dependencies(f): The wrapped function. """ - def wrapper(*args, **kwds): + def wrapper(*args, **kwargs): with AutomaticControlDependencies() as a: - result = f(*args, **kwds) + result = f(*args, **kwargs) result_flat = [a.mark_as_return(t) for t in nest.flatten(result)] return nest.pack_sequence_as(result, result_flat) |