diff options
author | 2017-12-13 08:56:20 -0800 | |
---|---|---|
committer | 2017-12-13 09:00:51 -0800 | |
commit | 185c593cb71cb6d8116ba05c97e9385642648f1b (patch) | |
tree | 8853277b53c58d69e93735be50e75f4f5afa9516 /tensorflow/python/framework/function_test.py | |
parent | 2b1b7dffcd2c76876efdbcfc431424e259da3bf4 (diff) |
Automated g4 rollback of changelist 178759398
PiperOrigin-RevId: 178909147
Diffstat (limited to 'tensorflow/python/framework/function_test.py')
-rw-r--r-- | tensorflow/python/framework/function_test.py | 42 |
1 files changed, 42 insertions, 0 deletions
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py index 886c6f04b9..f5a97eb197 100644 --- a/tensorflow/python/framework/function_test.py +++ b/tensorflow/python/framework/function_test.py @@ -941,6 +941,48 @@ class FunctionTest(test.TestCase): self.assertEqual(100, sess.run(result_2)) self.assertEqual((4.0, 100), sess.run((result_1, result_2))) + def testStatefulFunction(self): + + @function.Defun() + def FunctionWithStatelessOp(): + return constant_op.constant(42.0) + + @function.Defun() + def FunctionWithStatefulOp(): + return random_ops.random_uniform([100], maxval=10, dtype=dtypes.int32) + + @function.Defun() + def FunctionWithStatelessFunctionCall(): + return FunctionWithStatelessOp() + + @function.Defun() + def FunctionWithStatefulFunctionCall(): + return FunctionWithStatefulOp() + + # Test that the `is_stateful` bit is propagated. + self.assertFalse(FunctionWithStatelessOp.definition.signature.is_stateful) + self.assertTrue(FunctionWithStatefulOp.definition.signature.is_stateful) + self.assertFalse( + FunctionWithStatelessFunctionCall.definition.signature.is_stateful) + self.assertTrue( + FunctionWithStatefulFunctionCall.definition.signature.is_stateful) + + # Ensure that two invocations of the same random-number-generating + # function produce different results. + result1 = FunctionWithStatefulFunctionCall() + result2 = FunctionWithStatefulFunctionCall() + + # Statefulness affects how the function is treated by the various + # optimization passes, so run the test in each optimizer + # configuration. + for config in _OptimizerOptions(): + with session.Session(config=config) as sess: + val1, val2 = sess.run((result1, result2)) + self.assertFalse(all(val1 == val2)) + val3, val4 = sess.run((result1, result2)) + self.assertFalse(all(val3 == val1)) + self.assertFalse(all(val4 == val2)) + @test_util.with_c_api class FunctionsFromProtos(test.TestCase): |