aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/function_def_to_graph_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/framework/function_def_to_graph_test.py')
-rw-r--r--tensorflow/python/framework/function_def_to_graph_test.py34
1 files changed, 33 insertions, 1 deletions
diff --git a/tensorflow/python/framework/function_def_to_graph_test.py b/tensorflow/python/framework/function_def_to_graph_test.py
index 0f4e6ef54f..cd2a16ed5a 100644
--- a/tensorflow/python/framework/function_def_to_graph_test.py
+++ b/tensorflow/python/framework/function_def_to_graph_test.py
@@ -18,7 +18,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import function
from tensorflow.python.framework import function_def_to_graph
from tensorflow.python.framework import graph_to_function_def
from tensorflow.python.framework import ops
@@ -79,7 +81,6 @@ class FunctionDefToGraphTest(test.TestCase):
g = function_def_to_graph.function_def_to_graph(
fdef, input_shapes=[None, tensor_shape.matrix(5, 7)])
- print(g.as_graph_def())
self.assertIsNone(g.inputs[0].shape.dims)
self.assertSequenceEqual(g.inputs[1].shape.dims, [5, 7])
self.assertSequenceEqual(g.outputs[0].shape.dims, [5, 7])
@@ -179,6 +180,37 @@ class FunctionDefToGraphDefTest(test.TestCase):
self.assertEqual(g.node[0].attr["shape"].shape.unknown_rank, False)
self.assertFalse("shape" in g.node[2].attr)
+ def testFunctionCallsFromFunction(self):
+ x = constant_op.constant(5.0)
+ y = constant_op.constant(10.0)
+
+ @function.Defun()
+ def fn():
+
+ @function.Defun()
+ def inner_fn():
+ return x + y
+
+ return inner_fn()
+
+ # Instantiate the function in this graph so that
+ # `function_def_to_graph` can find it.
+ fn()
+
+ def fn2():
+ return 2 * fn()
+
+ fdef = function._DefinedFunction(fn2, [], []).definition
+ func_graph = function_def_to_graph.function_def_to_graph(fdef)
+ with func_graph.as_default():
+ x_ph, y_ph = func_graph.inputs
+ with self.test_session(graph=func_graph) as sess:
+ self.assertEqual(
+ sess.run(func_graph.outputs[0], feed_dict={
+ x_ph: 5.0,
+ y_ph: 10.0
+ }), 30.0)
+
if __name__ == "__main__":
test.main()