aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager/function_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/eager/function_test.py')
-rw-r--r--tensorflow/python/eager/function_test.py37
1 files changed, 37 insertions, 0 deletions
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index a2cfb4b476..57e545be69 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -172,6 +172,43 @@ class FunctionTest(test.TestCase):
out = sq_op(t)
self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
+ def testInputSpecGraphFunction(self):
+ matmul = function.defun(math_ops.matmul)
+
+ @function.defun
+ def sq(a):
+ return matmul(a, a)
+
+ sq_op = sq.get_concrete_function(
+ tensor_spec.TensorSpec((None, None), dtypes.float32))
+ self.assertEqual([None, None], sq_op.output_shapes.as_list())
+
+ t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+ out1 = sq_op(t1)
+ self.assertAllEqual(out1, math_ops.matmul(t1, t1).numpy())
+
+ t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+ out2 = sq_op(t2)
+ self.assertAllEqual(out2, math_ops.matmul(t2, t2).numpy())
+
+ def testNestedInputSpecGraphFunction(self):
+ matmul = function.defun(math_ops.matmul)
+
+ @function.defun
+ def sq(mats):
+ ((a, b),) = mats
+ return matmul(a, b)
+
+ sq_op = sq.get_concrete_function(
+ [(tensor_spec.TensorSpec((None, None), dtypes.float32),
+ tensor_spec.TensorSpec((None, None), dtypes.float32))])
+ self.assertEqual([None, None], sq_op.output_shapes.as_list())
+
+ t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+ t2 = constant_op.constant([[1.4, 2.4], [3.4, 4.4]])
+ out = sq_op(t1, t2) # Flattened structure for inputs to the graph function
+ self.assertAllEqual(out, math_ops.matmul(t1, t2).numpy())
+
def testExecutingStatelessDefunConcurrently(self):
@function.defun