From a04cd08ee7a8c5245d76a59849e1f7e8ba8a3f52 Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Mon, 8 Oct 2018 10:20:52 -0700 Subject: Allow TensorSpec objects as arguments to defun's get_concrete_function Will be helpful for specifying serving signatures when exporting SavedModels PiperOrigin-RevId: 216207284 --- tensorflow/python/eager/function.py | 24 +++++++-------------- tensorflow/python/eager/function_test.py | 37 ++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 16 deletions(-) (limited to 'tensorflow/python') diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index bafe07de2b..93168826b1 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -855,20 +855,12 @@ class Function(object): return ret -def _get_defun_inputs_from_signature(signature): - """Maps a signature to graph-construction inputs.""" - function_inputs = [ - graph_placeholder(spec.dtype, spec.shape) - for spec in nest.flatten(signature) - ] - return nest.pack_sequence_as(signature, function_inputs) - - def _get_defun_inputs_from_args(args): """Maps python function args to graph-construction inputs.""" function_inputs = [ graph_placeholder(arg.dtype, arg.shape) - if isinstance(arg, ops.Tensor) else arg for arg in nest.flatten(args) + if isinstance(arg, (ops.Tensor, tensor_spec.TensorSpec)) + else arg for arg in nest.flatten(args) ] return nest.pack_sequence_as(args, function_inputs) @@ -912,12 +904,12 @@ def func_graph_from_py_func(name, with func_graph.as_default(), AutomaticControlDependencies() as a: variable_scope.get_variable_scope().set_use_resource(True) - if signature is None: - func_args = _get_defun_inputs_from_args(args) - func_kwargs = _get_defun_inputs_from_args(kwargs) - else: - func_args = _get_defun_inputs_from_signature(signature) - func_kwargs = {} + if signature is not None: + args = signature + kwargs = {} + + func_args = _get_defun_inputs_from_args(args) + func_kwargs = _get_defun_inputs_from_args(kwargs) # Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`. # Variables to help check whether mutation happens in calling the function diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index a2cfb4b476..57e545be69 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -172,6 +172,43 @@ class FunctionTest(test.TestCase): out = sq_op(t) self.assertAllEqual(out, math_ops.matmul(t, t).numpy()) + def testInputSpecGraphFunction(self): + matmul = function.defun(math_ops.matmul) + + @function.defun + def sq(a): + return matmul(a, a) + + sq_op = sq.get_concrete_function( + tensor_spec.TensorSpec((None, None), dtypes.float32)) + self.assertEqual([None, None], sq_op.output_shapes.as_list()) + + t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + out1 = sq_op(t1) + self.assertAllEqual(out1, math_ops.matmul(t1, t1).numpy()) + + t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + out2 = sq_op(t2) + self.assertAllEqual(out2, math_ops.matmul(t2, t2).numpy()) + + def testNestedInputSpecGraphFunction(self): + matmul = function.defun(math_ops.matmul) + + @function.defun + def sq(mats): + ((a, b),) = mats + return matmul(a, b) + + sq_op = sq.get_concrete_function( + [(tensor_spec.TensorSpec((None, None), dtypes.float32), + tensor_spec.TensorSpec((None, None), dtypes.float32))]) + self.assertEqual([None, None], sq_op.output_shapes.as_list()) + + t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + t2 = constant_op.constant([[1.4, 2.4], [3.4, 4.4]]) + out = sq_op(t1, t2) # Flattened structure for inputs to the graph function + self.assertAllEqual(out, math_ops.matmul(t1, t2).numpy()) + def testExecutingStatelessDefunConcurrently(self): @function.defun -- cgit v1.2.3