aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/partitioned_variables_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/partitioned_variables_test.py')
-rw-r--r--tensorflow/python/kernel_tests/partitioned_variables_test.py80
1 files changed, 54 insertions, 26 deletions
diff --git a/tensorflow/python/kernel_tests/partitioned_variables_test.py b/tensorflow/python/kernel_tests/partitioned_variables_test.py
index f5c6255c34..ba9359d923 100644
--- a/tensorflow/python/kernel_tests/partitioned_variables_test.py
+++ b/tensorflow/python/kernel_tests/partitioned_variables_test.py
@@ -25,12 +25,15 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
+from tensorflow.python.training import gradient_descent
class PartitionerCreatorsTest(test.TestCase):
@@ -543,32 +546,6 @@ class PartitionedVariablesTestCase(test.TestCase):
partitioned_variables.create_partitioned_variables(
[10, 43], [1, 50], rnd.initialized_value())
- def testControlDepsNone(self):
- with self.test_session() as session:
- c = constant_op.constant(1.0)
- with ops.control_dependencies([c]):
- # d get the control dependency.
- d = constant_op.constant(2.0)
- # Partitioned variables do not.
- var_x = variable_scope.get_variable(
- "x",
- shape=[2],
- initializer=init_ops.ones_initializer(),
- partitioner=partitioned_variables.variable_axis_size_partitioner(4))
-
- ops_before_read = session.graph.get_operations()
- var_x.as_tensor() # Caches the ops for subsequent reads.
- reading_ops = [
- op for op in session.graph.get_operations()
- if op not in ops_before_read
- ]
-
- self.assertEqual([c.op], d.op.control_inputs)
- # Tests that no control dependencies are added to reading a partitioned
- # variable which is similar to reading a variable.
- for op in reading_ops:
- self.assertEqual([], op.control_inputs)
-
def testConcat(self):
with self.test_session() as session:
var_x = variable_scope.get_variable(
@@ -594,6 +571,57 @@ class PartitionedVariablesTestCase(test.TestCase):
variables.global_variables_initializer().run()
self.assertAllClose(value.eval(), var_x.as_tensor().eval())
+ def testVariableCreationInALoop(self):
+ """Tests the variable created inside a loop can be used outside the loop."""
+ with self.test_session():
+ with variable_scope.variable_scope("ascope") as scope:
+ def Body(i, _):
+ var_x = variable_scope.get_variable(
+ "x",
+ shape=[2],
+ initializer=init_ops.ones_initializer(),
+ partitioner=partitioned_variables.variable_axis_size_partitioner(
+ 4))
+ return (i + 1, var_x.as_tensor())
+
+ cond = lambda i, _: i < 2
+ _, x = control_flow_ops.while_loop(
+ cond, Body, (0, constant_op.constant([7, 8], dtypes.float32)))
+ variables.global_variables_initializer().run()
+ self.assertAllClose([1.0, 1.0], x.eval())
+
+ scope.reuse_variables()
+ var_x = variable_scope.get_variable(
+ "x",
+ shape=[2],
+ initializer=init_ops.ones_initializer(),
+ partitioner=partitioned_variables.variable_axis_size_partitioner(4))
+
+ self.assertAllClose([1.0, 1.0], var_x.as_tensor().eval())
+
+ def testReadInWhileLoop(self):
+ """Tests the value is current (not cached) when read within a loop."""
+ with self.test_session():
+ var_x = variable_scope.get_variable(
+ "x",
+ shape=[2],
+ initializer=init_ops.ones_initializer(),
+ partitioner=partitioned_variables.variable_axis_size_partitioner(4))
+
+ def Body(i, _):
+ # Use a SGD step to update the variable's value.
+ loss = math_ops.reduce_sum(var_x)
+ optimizer = gradient_descent.GradientDescentOptimizer(1.0)
+ minimize = optimizer.minimize(loss * 0.7)
+ with ops.control_dependencies([minimize]):
+ return (i + 1, var_x.as_tensor())
+
+ cond = lambda i, _: i < 2
+ _, x = control_flow_ops.while_loop(
+ cond, Body, (0, constant_op.constant([7, 8], dtypes.float32)))
+ variables.global_variables_initializer().run()
+ self.assertAllClose([-0.4, -0.4], x.eval())
+
if __name__ == "__main__":
test.main()