aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-10-08 10:20:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-08 10:31:16 -0700
commita04cd08ee7a8c5245d76a59849e1f7e8ba8a3f52 (patch)
treef814d8d2e5e85bc6c67f55afcc22d6818fc5ebba /tensorflow/python
parent0e42fd6d0a88b30ab57959f38c79bea19d745ec3 (diff)
Allow TensorSpec objects as arguments to defun's get_concrete_function
Will be helpful for specifying serving signatures when exporting SavedModels PiperOrigin-RevId: 216207284
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/eager/function.py24
-rw-r--r--tensorflow/python/eager/function_test.py37
2 files changed, 45 insertions, 16 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index bafe07de2b..93168826b1 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -855,20 +855,12 @@ class Function(object):
return ret
-def _get_defun_inputs_from_signature(signature):
- """Maps a signature to graph-construction inputs."""
- function_inputs = [
- graph_placeholder(spec.dtype, spec.shape)
- for spec in nest.flatten(signature)
- ]
- return nest.pack_sequence_as(signature, function_inputs)
-
-
def _get_defun_inputs_from_args(args):
"""Maps python function args to graph-construction inputs."""
function_inputs = [
graph_placeholder(arg.dtype, arg.shape)
- if isinstance(arg, ops.Tensor) else arg for arg in nest.flatten(args)
+ if isinstance(arg, (ops.Tensor, tensor_spec.TensorSpec))
+ else arg for arg in nest.flatten(args)
]
return nest.pack_sequence_as(args, function_inputs)
@@ -912,12 +904,12 @@ def func_graph_from_py_func(name,
with func_graph.as_default(), AutomaticControlDependencies() as a:
variable_scope.get_variable_scope().set_use_resource(True)
- if signature is None:
- func_args = _get_defun_inputs_from_args(args)
- func_kwargs = _get_defun_inputs_from_args(kwargs)
- else:
- func_args = _get_defun_inputs_from_signature(signature)
- func_kwargs = {}
+ if signature is not None:
+ args = signature
+ kwargs = {}
+
+ func_args = _get_defun_inputs_from_args(args)
+ func_kwargs = _get_defun_inputs_from_args(kwargs)
# Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`.
# Variables to help check whether mutation happens in calling the function
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index a2cfb4b476..57e545be69 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -172,6 +172,43 @@ class FunctionTest(test.TestCase):
out = sq_op(t)
self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
+ def testInputSpecGraphFunction(self):
+ matmul = function.defun(math_ops.matmul)
+
+ @function.defun
+ def sq(a):
+ return matmul(a, a)
+
+ sq_op = sq.get_concrete_function(
+ tensor_spec.TensorSpec((None, None), dtypes.float32))
+ self.assertEqual([None, None], sq_op.output_shapes.as_list())
+
+ t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+ out1 = sq_op(t1)
+ self.assertAllEqual(out1, math_ops.matmul(t1, t1).numpy())
+
+ t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+ out2 = sq_op(t2)
+ self.assertAllEqual(out2, math_ops.matmul(t2, t2).numpy())
+
+ def testNestedInputSpecGraphFunction(self):
+ matmul = function.defun(math_ops.matmul)
+
+ @function.defun
+ def sq(mats):
+ ((a, b),) = mats
+ return matmul(a, b)
+
+ sq_op = sq.get_concrete_function(
+ [(tensor_spec.TensorSpec((None, None), dtypes.float32),
+ tensor_spec.TensorSpec((None, None), dtypes.float32))])
+ self.assertEqual([None, None], sq_op.output_shapes.as_list())
+
+ t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+ t2 = constant_op.constant([[1.4, 2.4], [3.4, 4.4]])
+ out = sq_op(t1, t2) # Flattened structure for inputs to the graph function
+ self.assertAllEqual(out, math_ops.matmul(t1, t2).numpy())
+
def testExecutingStatelessDefunConcurrently(self):
@function.defun