diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/variables_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/variables_test.py | 42 |
1 files changed, 42 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py index 942ceedc8b..c2b86089f4 100644 --- a/tensorflow/python/kernel_tests/variables_test.py +++ b/tensorflow/python/kernel_tests/variables_test.py @@ -696,6 +696,48 @@ class PartitionedVariableTest(test.TestCase): variable_list=[v0], partitions=partitions) + def testPartitionedVariableAssignments(self): + with ops.Graph().as_default(), self.cached_session() as sess: + v0 = variables.Variable(initial_value=[0.0]) + v1 = variables.Variable(initial_value=[1.0]) + v0._set_save_slice_info( + variables.Variable.SaveSliceInfo(v0.name, [2], [0], [1])) + v1._set_save_slice_info( + variables.Variable.SaveSliceInfo(v0.name, [2], [1], [1])) + partitions = [2] + + # Pass variable_list as [v1, v0] to ensure they are properly + # re-sorted to [v0, v1] based on their slice info offsets. + partitioned_variable = variables.PartitionedVariable( + name="two_vars", + shape=[2], + dtype=v0.dtype, + variable_list=[v0, v1], + partitions=partitions) + + deltas_a = constant_op.constant([1.0, 2.0]) + deltas_b = constant_op.constant([3.0, 4.0]) + ones = array_ops.ones([2]) + plus_delta = partitioned_variable.assign_add(deltas_a) + minus_delta = partitioned_variable.assign_sub(deltas_b) + assign_ones = partitioned_variable.assign(ones) + variables.global_variables_initializer().run() + + self.assertEqual([1.0], plus_delta[0].eval()) + self.assertEqual([1.0], v0.eval()) + self.assertEqual([3.0], plus_delta[1].eval()) + self.assertEqual([3.0], v1.eval()) + + self.assertEqual([-2.0], minus_delta[0].eval()) + self.assertEqual([-2.0], v0.eval()) + self.assertEqual([-1.0], minus_delta[1].eval()) + self.assertEqual([-1.0], v1.eval()) + + self.assertEqual([1.0], assign_ones[0].eval()) + self.assertEqual([1.0], v0.eval()) + self.assertEqual([1.0], assign_ones[1].eval()) + self.assertEqual([1.0], v1.eval()) + class VariableContainerTest(test.TestCase): |