aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-01 07:39:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-01 07:39:41 -0700
commiteef9a3e816932b7c34f426d00d7df53c91130402 (patch)
tree5a8b508b46f13df0d19b10b453d09a63b48b7323 /tensorflow/python/kernel_tests
parent7c5eb354a6b5b2d5a2e27d8ce3dc4861cb51153c (diff)
parent8eb27871583d9fc61e046493acaa0df2839bc1c7 (diff)
Merge pull request #22473 from wangsiyu:assign_in_part_vars
PiperOrigin-RevId: 215211485
Diffstat (limited to 'tensorflow/python/kernel_tests')
-rw-r--r--tensorflow/python/kernel_tests/variables_test.py42
1 files changed, 42 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py
index 942ceedc8b..c2b86089f4 100644
--- a/tensorflow/python/kernel_tests/variables_test.py
+++ b/tensorflow/python/kernel_tests/variables_test.py
@@ -696,6 +696,48 @@ class PartitionedVariableTest(test.TestCase):
variable_list=[v0],
partitions=partitions)
+ def testPartitionedVariableAssignments(self):
+ with ops.Graph().as_default(), self.cached_session() as sess:
+ v0 = variables.Variable(initial_value=[0.0])
+ v1 = variables.Variable(initial_value=[1.0])
+ v0._set_save_slice_info(
+ variables.Variable.SaveSliceInfo(v0.name, [2], [0], [1]))
+ v1._set_save_slice_info(
+ variables.Variable.SaveSliceInfo(v0.name, [2], [1], [1]))
+ partitions = [2]
+
+ # Pass variable_list as [v1, v0] to ensure they are properly
+ # re-sorted to [v0, v1] based on their slice info offsets.
+ partitioned_variable = variables.PartitionedVariable(
+ name="two_vars",
+ shape=[2],
+ dtype=v0.dtype,
+ variable_list=[v0, v1],
+ partitions=partitions)
+
+ deltas_a = constant_op.constant([1.0, 2.0])
+ deltas_b = constant_op.constant([3.0, 4.0])
+ ones = array_ops.ones([2])
+ plus_delta = partitioned_variable.assign_add(deltas_a)
+ minus_delta = partitioned_variable.assign_sub(deltas_b)
+ assign_ones = partitioned_variable.assign(ones)
+ variables.global_variables_initializer().run()
+
+ self.assertEqual([1.0], plus_delta[0].eval())
+ self.assertEqual([1.0], v0.eval())
+ self.assertEqual([3.0], plus_delta[1].eval())
+ self.assertEqual([3.0], v1.eval())
+
+ self.assertEqual([-2.0], minus_delta[0].eval())
+ self.assertEqual([-2.0], v0.eval())
+ self.assertEqual([-1.0], minus_delta[1].eval())
+ self.assertEqual([-1.0], v1.eval())
+
+ self.assertEqual([1.0], assign_ones[0].eval())
+ self.assertEqual([1.0], v0.eval())
+ self.assertEqual([1.0], assign_ones[1].eval())
+ self.assertEqual([1.0], v1.eval())
+
class VariableContainerTest(test.TestCase):