aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Lasse Espeholt <lespeholt@google.com>2018-09-24 02:17:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-24 02:21:47 -0700
commitb57bdf414edb27b82a95c5f4e2729fafd4cf2dc7 (patch)
tree8a811a6bea40e1f0b388910a669687962f31b604 /tensorflow
parentcdcc7d31cce91169dc686387522d7015ac57db0e (diff)
Clean-up of function.py.
PiperOrigin-RevId: 214232622
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/c/eager/BUILD5
-rw-r--r--tensorflow/python/eager/BUILD5
-rw-r--r--tensorflow/python/eager/function.py86
3 files changed, 53 insertions, 43 deletions
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
index 37be52f57d..3ee31a6a7a 100644
--- a/tensorflow/c/eager/BUILD
+++ b/tensorflow/c/eager/BUILD
@@ -68,7 +68,10 @@ tf_cuda_library(
tf_cuda_library(
name = "c_api_internal",
hdrs = ["c_api_internal.h"],
- visibility = ["//tensorflow:internal"],
+ visibility = [
+ "//learning/deepmind/courier:__pkg__",
+ "//tensorflow:internal",
+ ],
deps = [
":c_api",
"//tensorflow/c:c_api",
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index f571da308e..d3d997e6df 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -17,7 +17,10 @@ cc_library(
"pywrap_tensor.h",
"pywrap_tfe.h",
],
- visibility = ["//tensorflow:internal"],
+ visibility = [
+ "//learning/deepmind/courier:__pkg__",
+ "//tensorflow:internal",
+ ],
deps = [
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_internal",
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index bcb1881264..1f5d479882 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -850,7 +850,7 @@ def _get_defun_inputs_from_args(args):
def func_graph_from_py_func(name,
python_func,
args,
- kwds,
+ kwargs,
signature=None,
func_graph=None):
"""Returns a `FuncGraph` generated from `python_func`.
@@ -860,11 +860,11 @@ def func_graph_from_py_func(name,
python_func: the Python function to trace.
args: the positional args with which the Python function should be called;
ignored if a signature is provided.
- kwds: the keyword args with which the Python function should be called;
+ kwargs: the keyword args with which the Python function should be called;
ignored if a signature is provided.
signature: a possibly nested sequence of `TensorSpecs` specifying the shapes
and dtypes of the arguments. When a signature is provided, `args` and
- `kwds` are ignored, and `python_func` is traced with Tensors conforming
+ `kwargs` are ignored, and `python_func` is traced with Tensors conforming
to `signature`. If `None`, the shapes and dtypes are inferred from the
inputs.
func_graph: Optional. An instance of FuncGraph. If provided, we will use
@@ -885,16 +885,17 @@ def func_graph_from_py_func(name,
if signature is None:
func_args = _get_defun_inputs_from_args(args)
- func_kwds = _get_defun_inputs_from_args(kwds)
+ func_kwargs = _get_defun_inputs_from_args(kwargs)
else:
func_args = _get_defun_inputs_from_signature(signature)
- func_kwds = {}
+ func_kwargs = {}
# Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`.
# Variables to help check whether mutation happens in calling the function
# Copy the recursive list, tuple and map structure, but not base objects
func_args_before = nest.pack_sequence_as(func_args, nest.flatten(func_args))
- func_kwds_before = nest.pack_sequence_as(func_kwds, nest.flatten(func_kwds))
+ func_kwargs_before = nest.pack_sequence_as(
+ func_kwargs, nest.flatten(func_kwargs))
def convert(x):
"""Converts an argument to a Tensor."""
@@ -913,7 +914,7 @@ def func_graph_from_py_func(name,
this_tape = tape.push_new_tape()
try:
- func_outputs = python_func(*func_args, **func_kwds)
+ func_outputs = python_func(*func_args, **func_kwargs)
# invariant: `func_outputs` contains only Tensors and `None`s.
func_outputs = nest.map_structure(convert, func_outputs)
@@ -933,16 +934,16 @@ def func_graph_from_py_func(name,
raise ValueError(errmsg)
check_mutation(func_args_before, func_args)
- check_mutation(func_kwds_before, func_kwds)
+ check_mutation(func_kwargs_before, func_kwargs)
finally:
tape.pop_tape(this_tape)
- # Variables in `func_args`, `func_kwds` should be explicit inputs
+ # Variables in `func_args`, `func_kwargs` should be explicit inputs
# to the function, not captured inputs.
tape_variables = this_tape.watched_variables()
arg_variables = set()
inputs = []
- for arg in nest.flatten(func_args) + nest.flatten(func_kwds):
+ for arg in nest.flatten(func_args) + nest.flatten(func_kwargs):
if isinstance(arg, resource_variable_ops.ResourceVariable):
try:
resource_placeholder = func_graph.captures.pop(arg.handle)
@@ -1073,11 +1074,11 @@ class PolymorphicFunction(object):
if isinstance(python_function, functools.partial):
self._python_function = python_function.func
self._args_to_prepend = python_function.args or tuple()
- self._kwds_to_include = python_function.keywords or {}
+ self._kwargs_to_include = python_function.keywords or {}
else:
self._python_function = python_function
self._args_to_prepend = tuple()
- self._kwds_to_include = {}
+ self._kwargs_to_include = {}
self._name = name
self._function_cache = collections.OrderedDict()
self._function_attributes = attributes or {}
@@ -1115,9 +1116,9 @@ class PolymorphicFunction(object):
self._input_signature = tuple(input_signature)
self._flat_input_signature = tuple(nest.flatten(input_signature))
- def __call__(self, *args, **kwds):
+ def __call__(self, *args, **kwargs):
"""Calls a graph function specialized to the inputs."""
- graph_function, inputs = self._maybe_define_function(*args, **kwds)
+ graph_function, inputs = self._maybe_define_function(args, kwargs)
return graph_function(*inputs)
@property
@@ -1135,7 +1136,7 @@ class PolymorphicFunction(object):
*args: inputs to specialize on.
**kwargs: inputs to specialize on.
"""
- graph_function, _ = self._maybe_define_function(*args, **kwargs)
+ graph_function, _ = self._maybe_define_function(args, kwargs)
return graph_function
def __get__(self, instance, owner):
@@ -1156,13 +1157,13 @@ class PolymorphicFunction(object):
# then `instance` will be `foo` (and `owner` will be `Foo`).
return functools.partial(self.__call__, instance)
- def _cache_key(self, args, kwds, ctx, graph):
+ def _cache_key(self, args, kwargs, ctx, graph):
"""Computes the cache key given inputs and execution context."""
if self._input_signature is None:
- inputs = (args, kwds) if kwds else args
+ inputs = (args, kwargs) if kwargs else args
cache_key = tuple(_encode_arg(arg) for arg in inputs)
else:
- del args, kwds
+ del args, kwargs
cache_key = self._flat_input_signature
# The graph, or whether we're executing eagerly, should be a part of the
@@ -1181,8 +1182,8 @@ class PolymorphicFunction(object):
return cache_key + (execution_context, device_functions, colocation_stack)
- def _canonicalize_function_inputs(self, *args, **kwds):
- """Canonicalizes `args` and `kwds`.
+ def _canonicalize_function_inputs(self, *args, **kwargs):
+ """Canonicalizes `args` and `kwargs`.
Canonicalize the inputs to the Python function using its fullargspec. In
particular, we parse the varags and kwargs that this
@@ -1192,28 +1193,28 @@ class PolymorphicFunction(object):
Args:
*args: The varargs this object was called with.
- **kwds: The keyword args this function was called with.
+ **kwargs: The keyword args this function was called with.
Returns:
A canonicalized ordering of the inputs.
Raises:
- ValueError: If a keyword in `kwds` cannot be matched with a positional
+ ValueError: If a keyword in `kwargs` cannot be matched with a positional
argument when an input signature is specified, or when the inputs
do not conform to the input signature.
"""
args = self._args_to_prepend + args
- kwds = dict(kwds, **self._kwds_to_include)
+ kwargs = dict(kwargs, **self._kwargs_to_include)
# Maps from index of arg to its corresponding value, according to `args`
- # and `kwds`; seeded with the default values for the named args that aren't
- # in `args`.
+ # and `kwargs`; seeded with the default values for the named args that
+ # aren't in `args`.
arg_indices_to_values = {
index: default
for index, default in six.iteritems(self._arg_indices_to_default_values)
if index >= len(args)
}
consumed_args = []
- for arg, value in six.iteritems(kwds):
+ for arg, value in six.iteritems(kwargs):
index = self._args_to_indices.get(arg, None)
if index is not None:
arg_indices_to_values[index] = value
@@ -1223,9 +1224,9 @@ class PolymorphicFunction(object):
"function with keyword arguments when "
"input_signature is provided.")
for arg in consumed_args:
- # After this loop, `kwds` will only contain true keyword arguments, as
+ # After this loop, `kwargs` will only contain true keyword arguments, as
# opposed to named arguments called in a keyword-like fashion.
- kwds.pop(arg)
+ kwargs.pop(arg)
inputs = args + _deterministic_dict_values(arg_indices_to_values)
flat_inputs = nest.flatten(inputs)
@@ -1239,9 +1240,9 @@ class PolymorphicFunction(object):
inputs = nest.pack_sequence_as(structure=inputs,
flat_sequence=flat_inputs)
if self._input_signature is None:
- return inputs, kwds
+ return inputs, kwargs
else:
- assert not kwds
+ assert not kwargs
try:
nest.assert_same_structure(self._input_signature, inputs)
except (ValueError, TypeError):
@@ -1260,24 +1261,27 @@ class PolymorphicFunction(object):
(str(inputs), str(self._input_signature)))
return inputs, {}
- def _maybe_define_function(self, *args, **kwds):
+ def _maybe_define_function(self, args, kwargs):
"""Gets a function for these inputs, defining it if necessary.
+ `args` and `kwargs` can be None if this `PolymorphicFunction` was created
+ with an `input_signature`.
+
Args:
- *args: args for the Python function.
- **kwds: keywords for the Python function.
+ args: The varargs for the Python function.
+ kwargs: The keyword args for the Python function.
Returns:
A graph function corresponding to the input signature implied by args and
- kwds, as well as the inputs that the object should be called with.
+ kwargs, as well as the inputs that the object should be called with.
Raises:
ValueError: If inputs are incompatible with the input signature.
TypeError: If the function inputs include non-hashable objects
"""
-
- args, kwds = self._canonicalize_function_inputs(*args, **kwds)
- cache_key = self._cache_key(args, kwds, context.context(),
+ if self._input_signature is None or args is not None or kwargs is not None:
+ args, kwargs = self._canonicalize_function_inputs(*args, **kwargs)
+ cache_key = self._cache_key(args, kwargs, context.context(),
ops.get_default_graph())
with self._lock:
try:
@@ -1289,11 +1293,11 @@ class PolymorphicFunction(object):
if graph_function is None:
graph_function = Function(
func_graph_from_py_func(self._name, self._python_function, args,
- kwds, self._input_signature),
+ kwargs, self._input_signature),
self._function_attributes)
self._function_cache[cache_key] = graph_function
return graph_function, [
- t for t in nest.flatten((args, kwds))
+ t for t in nest.flatten((args, kwargs))
if isinstance(t, (ops.Tensor, resource_variable_ops.ResourceVariable))
]
@@ -1933,9 +1937,9 @@ def automatic_control_dependencies(f):
The wrapped function.
"""
- def wrapper(*args, **kwds):
+ def wrapper(*args, **kwargs):
with AutomaticControlDependencies() as a:
- result = f(*args, **kwds)
+ result = f(*args, **kwargs)
result_flat = [a.mark_as_return(t) for t in nest.flatten(result)]
return nest.pack_sequence_as(result, result_flat)