diff options
Diffstat (limited to 'tensorflow/python/eager/function_test.py')
-rw-r--r-- | tensorflow/python/eager/function_test.py | 37 |
1 files changed, 37 insertions, 0 deletions
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index a2cfb4b476..57e545be69 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -172,6 +172,43 @@ class FunctionTest(test.TestCase): out = sq_op(t) self.assertAllEqual(out, math_ops.matmul(t, t).numpy()) + def testInputSpecGraphFunction(self): + matmul = function.defun(math_ops.matmul) + + @function.defun + def sq(a): + return matmul(a, a) + + sq_op = sq.get_concrete_function( + tensor_spec.TensorSpec((None, None), dtypes.float32)) + self.assertEqual([None, None], sq_op.output_shapes.as_list()) + + t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + out1 = sq_op(t1) + self.assertAllEqual(out1, math_ops.matmul(t1, t1).numpy()) + + t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + out2 = sq_op(t2) + self.assertAllEqual(out2, math_ops.matmul(t2, t2).numpy()) + + def testNestedInputSpecGraphFunction(self): + matmul = function.defun(math_ops.matmul) + + @function.defun + def sq(mats): + ((a, b),) = mats + return matmul(a, b) + + sq_op = sq.get_concrete_function( + [(tensor_spec.TensorSpec((None, None), dtypes.float32), + tensor_spec.TensorSpec((None, None), dtypes.float32))]) + self.assertEqual([None, None], sq_op.output_shapes.as_list()) + + t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + t2 = constant_op.constant([[1.4, 2.4], [3.4, 4.4]]) + out = sq_op(t1, t2) # Flattened structure for inputs to the graph function + self.assertAllEqual(out, math_ops.matmul(t1, t2).numpy()) + def testExecutingStatelessDefunConcurrently(self): @function.defun |