aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager/function_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/eager/function_test.py')
-rw-r--r--tensorflow/python/eager/function_test.py28
1 files changed, 28 insertions, 0 deletions
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 0488dc9752..380bcf763f 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -397,6 +397,18 @@ class FunctionTest(test.TestCase):
compiled = function.defun(f)
compiled()
+ @test_util.run_in_graph_and_eager_modes
+ def testDefunForcesResourceVariables(self):
+
+ def variable_creator():
+ return variables.Variable(0.0).read_value()
+
+ defined = function.defun(variable_creator)
+ defined() # Create the variable.
+ self.assertEqual(len(defined.variables), 1)
+ self.assertIsInstance(
+ defined.variables[0], resource_variable_ops.ResourceVariable)
+
def testDefunDifferentiable(self):
v = resource_variable_ops.ResourceVariable(1.0)
@@ -434,6 +446,22 @@ class FunctionTest(test.TestCase):
op = call()
self.assertAllEqual(sess.run(op), 2.0)
+ def testSymbolicGradientVariableZerosLike(self):
+ with ops.Graph().as_default():
+ v = resource_variable_ops.ResourceVariable(1.0)
+
+ @function.defun
+ def f(x, v):
+ v.read_value()
+ return x * x
+
+ x = constant_op.constant(1.0)
+ l = f(x, v)
+ _, dv = gradients_impl.gradients(l, [x, v])
+ with self.test_session():
+ v.initializer.run()
+ self.assertAllEqual(dv.eval(), 0.0)
+
def testGraphModeManyFunctions(self):
with context.graph_mode(), self.test_session():