aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/function_test.py
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2017-12-13 08:56:20 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-13 09:00:51 -0800
commit185c593cb71cb6d8116ba05c97e9385642648f1b (patch)
tree8853277b53c58d69e93735be50e75f4f5afa9516 /tensorflow/python/framework/function_test.py
parent2b1b7dffcd2c76876efdbcfc431424e259da3bf4 (diff)
Automated g4 rollback of changelist 178759398
PiperOrigin-RevId: 178909147
Diffstat (limited to 'tensorflow/python/framework/function_test.py')
-rw-r--r--tensorflow/python/framework/function_test.py42
1 files changed, 42 insertions, 0 deletions
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index 886c6f04b9..f5a97eb197 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -941,6 +941,48 @@ class FunctionTest(test.TestCase):
self.assertEqual(100, sess.run(result_2))
self.assertEqual((4.0, 100), sess.run((result_1, result_2)))
+ def testStatefulFunction(self):
+
+ @function.Defun()
+ def FunctionWithStatelessOp():
+ return constant_op.constant(42.0)
+
+ @function.Defun()
+ def FunctionWithStatefulOp():
+ return random_ops.random_uniform([100], maxval=10, dtype=dtypes.int32)
+
+ @function.Defun()
+ def FunctionWithStatelessFunctionCall():
+ return FunctionWithStatelessOp()
+
+ @function.Defun()
+ def FunctionWithStatefulFunctionCall():
+ return FunctionWithStatefulOp()
+
+ # Test that the `is_stateful` bit is propagated.
+ self.assertFalse(FunctionWithStatelessOp.definition.signature.is_stateful)
+ self.assertTrue(FunctionWithStatefulOp.definition.signature.is_stateful)
+ self.assertFalse(
+ FunctionWithStatelessFunctionCall.definition.signature.is_stateful)
+ self.assertTrue(
+ FunctionWithStatefulFunctionCall.definition.signature.is_stateful)
+
+ # Ensure that two invocations of the same random-number-generating
+ # function produce different results.
+ result1 = FunctionWithStatefulFunctionCall()
+ result2 = FunctionWithStatefulFunctionCall()
+
+ # Statefulness affects how the function is treated by the various
+ # optimization passes, so run the test in each optimizer
+ # configuration.
+ for config in _OptimizerOptions():
+ with session.Session(config=config) as sess:
+ val1, val2 = sess.run((result1, result2))
+ self.assertFalse(all(val1 == val2))
+ val3, val4 = sess.run((result1, result2))
+ self.assertFalse(all(val3 == val1))
+ self.assertFalse(all(val4 == val2))
+
@test_util.with_c_api
class FunctionsFromProtos(test.TestCase):