diff options
Diffstat (limited to 'tensorflow/python/eager/function_test.py')
-rw-r--r-- | tensorflow/python/eager/function_test.py | 28 |
1 files changed, 28 insertions, 0 deletions
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 0488dc9752..380bcf763f 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -397,6 +397,18 @@ class FunctionTest(test.TestCase): compiled = function.defun(f) compiled() + @test_util.run_in_graph_and_eager_modes + def testDefunForcesResourceVariables(self): + + def variable_creator(): + return variables.Variable(0.0).read_value() + + defined = function.defun(variable_creator) + defined() # Create the variable. + self.assertEqual(len(defined.variables), 1) + self.assertIsInstance( + defined.variables[0], resource_variable_ops.ResourceVariable) + def testDefunDifferentiable(self): v = resource_variable_ops.ResourceVariable(1.0) @@ -434,6 +446,22 @@ class FunctionTest(test.TestCase): op = call() self.assertAllEqual(sess.run(op), 2.0) + def testSymbolicGradientVariableZerosLike(self): + with ops.Graph().as_default(): + v = resource_variable_ops.ResourceVariable(1.0) + + @function.defun + def f(x, v): + v.read_value() + return x * x + + x = constant_op.constant(1.0) + l = f(x, v) + _, dv = gradients_impl.gradients(l, [x, v]) + with self.test_session(): + v.initializer.run() + self.assertAllEqual(dv.eval(), 0.0) + def testGraphModeManyFunctions(self): with context.graph_mode(), self.test_session(): |