aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager/function.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/eager/function.py')
-rw-r--r--tensorflow/python/eager/function.py24
1 files changed, 8 insertions, 16 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index bafe07de2b..93168826b1 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -855,20 +855,12 @@ class Function(object):
return ret
-def _get_defun_inputs_from_signature(signature):
- """Maps a signature to graph-construction inputs."""
- function_inputs = [
- graph_placeholder(spec.dtype, spec.shape)
- for spec in nest.flatten(signature)
- ]
- return nest.pack_sequence_as(signature, function_inputs)
-
-
def _get_defun_inputs_from_args(args):
"""Maps python function args to graph-construction inputs."""
function_inputs = [
graph_placeholder(arg.dtype, arg.shape)
- if isinstance(arg, ops.Tensor) else arg for arg in nest.flatten(args)
+ if isinstance(arg, (ops.Tensor, tensor_spec.TensorSpec))
+ else arg for arg in nest.flatten(args)
]
return nest.pack_sequence_as(args, function_inputs)
@@ -912,12 +904,12 @@ def func_graph_from_py_func(name,
with func_graph.as_default(), AutomaticControlDependencies() as a:
variable_scope.get_variable_scope().set_use_resource(True)
- if signature is None:
- func_args = _get_defun_inputs_from_args(args)
- func_kwargs = _get_defun_inputs_from_args(kwargs)
- else:
- func_args = _get_defun_inputs_from_signature(signature)
- func_kwargs = {}
+ if signature is not None:
+ args = signature
+ kwargs = {}
+
+ func_args = _get_defun_inputs_from_args(args)
+ func_kwargs = _get_defun_inputs_from_args(kwargs)
# Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`.
# Variables to help check whether mutation happens in calling the function