aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/variables_test.py
diff options
context:
space:
mode:
authorGravatar wangsiyu <siyu.wsy@gmail.com>2018-09-23 20:33:19 +0800
committerGravatar wangsiyu <siyu.wsy@gmail.com>2018-09-23 20:37:29 +0800
commit6dd7a09211cc74d11ff1554624b527c432020cbc (patch)
treed58876ca39b268b9a1615cf80fa5d8cf5726e636 /tensorflow/python/kernel_tests/variables_test.py
parent646b3c237deaddddd087d39ab57130b08375c4c7 (diff)
Enable partitioned variable assignments
Diffstat (limited to 'tensorflow/python/kernel_tests/variables_test.py')
-rw-r--r--tensorflow/python/kernel_tests/variables_test.py43
1 files changed, 42 insertions, 1 deletions
diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py
index 2e7975667c..687784c8b7 100644
--- a/tensorflow/python/kernel_tests/variables_test.py
+++ b/tensorflow/python/kernel_tests/variables_test.py
@@ -673,7 +673,7 @@ class PartitionedVariableTest(test.TestCase):
v0._set_save_slice_info(
variables.Variable.SaveSliceInfo(v0.name, [2], [0], [1]))
v1._set_save_slice_info(
- variables.Variable.SaveSliceInfo(v0.name, [2], [1], [1]))
+ variables.Variable.SaveSliceInfo(v1.name, [2], [1], [1]))
partitions = [2]
variables.PartitionedVariable(
@@ -696,6 +696,47 @@ class PartitionedVariableTest(test.TestCase):
variable_list=[v0],
partitions=partitions)
+ def testPartitionedVariableAssignments(self):
+ with ops.Graph().as_default(), self.cached_session() as sess:
+ v0 = variables.Variable(initial_value=[0.0])
+ v1 = variables.Variable(initial_value=[1.0])
+ v0._set_save_slice_info(
+ variables.Variable.SaveSliceInfo(v0.name, [2], [0], [1]))
+ v1._set_save_slice_info(
+ variables.Variable.SaveSliceInfo(v0.name, [2], [1], [1]))
+ partitions = [2]
+
+ # Pass variable_list as [v1, v0] to ensure they are properly
+ # re-sorted to [v0, v1] based on their slice info offsets.
+ partitioned_variable = variables.PartitionedVariable(
+ name="two_vars",
+ shape=[2],
+ dtype=v0.dtype,
+ variable_list=[v0, v1],
+ partitions=partitions)
+
+ deltas_a = constant_op.constant([1.0, 2.0])
+ deltas_b = constant_op.constant([3.0, 4.0])
+ ones = array_ops.ones([2])
+ plus_delta = partitioned_variable.assign_add(deltas_a)
+ minus_delta = partitioned_variable.assign_sub(deltas_b)
+ assign_ones = partitioned_variable.assign(ones)
+ variables.global_variables_initializer().run()
+
+ self.assertEqual([1.0], plus_delta[0].eval())
+ self.assertEqual([1.0], v0.eval())
+ self.assertEqual([3.0], plus_delta[1].eval())
+ self.assertEqual([3.0], v1.eval())
+
+ self.assertEqual([-2.0], minus_delta[0].eval())
+ self.assertEqual([-2.0], v0.eval())
+ self.assertEqual([-1.0], minus_delta[1].eval())
+ self.assertEqual([-1.0], v1.eval())
+
+ self.assertEqual([1.0], assign_ones[0].eval())
+ self.assertEqual([1.0], v0.eval())
+ self.assertEqual([1.0], assign_ones[1].eval())
+ self.assertEqual([1.0], v1.eval())
class VariableContainerTest(test.TestCase):