aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-01 07:39:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-01 07:39:41 -0700
commiteef9a3e816932b7c34f426d00d7df53c91130402 (patch)
tree5a8b508b46f13df0d19b10b453d09a63b48b7323
parent7c5eb354a6b5b2d5a2e27d8ce3dc4861cb51153c (diff)
parent8eb27871583d9fc61e046493acaa0df2839bc1c7 (diff)
Merge pull request #22473 from wangsiyu:assign_in_part_vars
PiperOrigin-RevId: 215211485
-rw-r--r--tensorflow/python/kernel_tests/variables_test.py42
-rw-r--r--tensorflow/python/ops/variables.py48
2 files changed, 86 insertions, 4 deletions
diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py
index 942ceedc8b..c2b86089f4 100644
--- a/tensorflow/python/kernel_tests/variables_test.py
+++ b/tensorflow/python/kernel_tests/variables_test.py
@@ -696,6 +696,48 @@ class PartitionedVariableTest(test.TestCase):
variable_list=[v0],
partitions=partitions)
+ def testPartitionedVariableAssignments(self):
+ with ops.Graph().as_default(), self.cached_session() as sess:
+ v0 = variables.Variable(initial_value=[0.0])
+ v1 = variables.Variable(initial_value=[1.0])
+ v0._set_save_slice_info(
+ variables.Variable.SaveSliceInfo(v0.name, [2], [0], [1]))
+ v1._set_save_slice_info(
+ variables.Variable.SaveSliceInfo(v0.name, [2], [1], [1]))
+ partitions = [2]
+
+ # Pass variable_list as [v1, v0] to ensure they are properly
+ # re-sorted to [v0, v1] based on their slice info offsets.
+ partitioned_variable = variables.PartitionedVariable(
+ name="two_vars",
+ shape=[2],
+ dtype=v0.dtype,
+ variable_list=[v0, v1],
+ partitions=partitions)
+
+ deltas_a = constant_op.constant([1.0, 2.0])
+ deltas_b = constant_op.constant([3.0, 4.0])
+ ones = array_ops.ones([2])
+ plus_delta = partitioned_variable.assign_add(deltas_a)
+ minus_delta = partitioned_variable.assign_sub(deltas_b)
+ assign_ones = partitioned_variable.assign(ones)
+ variables.global_variables_initializer().run()
+
+ self.assertEqual([1.0], plus_delta[0].eval())
+ self.assertEqual([1.0], v0.eval())
+ self.assertEqual([3.0], plus_delta[1].eval())
+ self.assertEqual([3.0], v1.eval())
+
+ self.assertEqual([-2.0], minus_delta[0].eval())
+ self.assertEqual([-2.0], v0.eval())
+ self.assertEqual([-1.0], minus_delta[1].eval())
+ self.assertEqual([-1.0], v1.eval())
+
+ self.assertEqual([1.0], assign_ones[0].eval())
+ self.assertEqual([1.0], v0.eval())
+ self.assertEqual([1.0], assign_ones[1].eval())
+ self.assertEqual([1.0], v1.eval())
+
class VariableContainerTest(test.TestCase):
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index 8da1e9fe56..45c8618610 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -2620,10 +2620,50 @@ class PartitionedVariable(object):
def _get_partitions(self):
return self._partitions
- def assign(self, value, use_locking=False):
- _ = value, use_locking
- raise NotImplementedError(
- "assign() has not been implemented for PartitionedVariable.")
+ def _apply_assign_fn(self, assign_fn, value):
+ partition_axes = self._partition_axes()
+ if len(partition_axes) > 1:
+ raise NotImplementedError(
+ "Cannot do assign action along more than one dimension: %s. "
+ "Multi-axis partition assign action is not supported " %
+ str(partition_axes))
+ partition_ix = partition_axes[0]
+ size_splits_list = [
+ var.shape[partition_ix].value for var in self._variable_list
+ ]
+ value_list = array_ops.split(value, size_splits_list, axis=partition_ix)
+ op_list = [
+ assign_fn(var, value_list[idx], idx)
+ for idx, var in enumerate(self._variable_list)
+ ]
+ return op_list
+
+ def assign(self, value, use_locking=False, name=None, read_value=True):
+ assign_fn = lambda var, r_value, idx: var.assign(
+ r_value, use_locking=use_locking,
+ name="%s_%d" % (name, idx), read_value=read_value)
+ assign_list = self._apply_assign_fn(assign_fn, value)
+ if read_value:
+ return assign_list
+ return [assign.op for assign in assign_list]
+
+ def assign_add(self, value, use_locking=False, name=None, read_value=True):
+ assign_fn = lambda var, r_value, idx: var.assign_add(
+ r_value, use_locking=use_locking,
+ name="%s_%d" % (name, idx), read_value=read_value)
+ assign_list = self._apply_assign_fn(assign_fn, value)
+ if read_value:
+ return assign_list
+ return [assign.op for assign in assign_list]
+
+ def assign_sub(self, value, use_locking=False, name=None, read_value=True):
+ assign_fn = lambda var, r_value, idx: var.assign_sub(
+ r_value, use_locking=use_locking,
+ name="%s_%d" % (name, idx), read_value=read_value)
+ assign_list = self._apply_assign_fn(assign_fn, value)
+ if read_value:
+ return assign_list
+ return [assign.op for assign in assign_list]
@tf_export(v1=["global_variables"])