diff options
-rw-r--r-- | tensorflow/python/kernel_tests/variables_test.py | 42 | ||||
-rw-r--r-- | tensorflow/python/ops/variables.py | 48 |
2 files changed, 86 insertions, 4 deletions
diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py index 942ceedc8b..c2b86089f4 100644 --- a/tensorflow/python/kernel_tests/variables_test.py +++ b/tensorflow/python/kernel_tests/variables_test.py @@ -696,6 +696,48 @@ class PartitionedVariableTest(test.TestCase): variable_list=[v0], partitions=partitions) + def testPartitionedVariableAssignments(self): + with ops.Graph().as_default(), self.cached_session() as sess: + v0 = variables.Variable(initial_value=[0.0]) + v1 = variables.Variable(initial_value=[1.0]) + v0._set_save_slice_info( + variables.Variable.SaveSliceInfo(v0.name, [2], [0], [1])) + v1._set_save_slice_info( + variables.Variable.SaveSliceInfo(v0.name, [2], [1], [1])) + partitions = [2] + + # Pass variable_list as [v1, v0] to ensure they are properly + # re-sorted to [v0, v1] based on their slice info offsets. + partitioned_variable = variables.PartitionedVariable( + name="two_vars", + shape=[2], + dtype=v0.dtype, + variable_list=[v0, v1], + partitions=partitions) + + deltas_a = constant_op.constant([1.0, 2.0]) + deltas_b = constant_op.constant([3.0, 4.0]) + ones = array_ops.ones([2]) + plus_delta = partitioned_variable.assign_add(deltas_a) + minus_delta = partitioned_variable.assign_sub(deltas_b) + assign_ones = partitioned_variable.assign(ones) + variables.global_variables_initializer().run() + + self.assertEqual([1.0], plus_delta[0].eval()) + self.assertEqual([1.0], v0.eval()) + self.assertEqual([3.0], plus_delta[1].eval()) + self.assertEqual([3.0], v1.eval()) + + self.assertEqual([-2.0], minus_delta[0].eval()) + self.assertEqual([-2.0], v0.eval()) + self.assertEqual([-1.0], minus_delta[1].eval()) + self.assertEqual([-1.0], v1.eval()) + + self.assertEqual([1.0], assign_ones[0].eval()) + self.assertEqual([1.0], v0.eval()) + self.assertEqual([1.0], assign_ones[1].eval()) + self.assertEqual([1.0], v1.eval()) + class VariableContainerTest(test.TestCase): diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index 8da1e9fe56..45c8618610 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -2620,10 +2620,50 @@ class PartitionedVariable(object): def _get_partitions(self): return self._partitions - def assign(self, value, use_locking=False): - _ = value, use_locking - raise NotImplementedError( - "assign() has not been implemented for PartitionedVariable.") + def _apply_assign_fn(self, assign_fn, value): + partition_axes = self._partition_axes() + if len(partition_axes) > 1: + raise NotImplementedError( + "Cannot do assign action along more than one dimension: %s. " + "Multi-axis partition assign action is not supported " % + str(partition_axes)) + partition_ix = partition_axes[0] + size_splits_list = [ + var.shape[partition_ix].value for var in self._variable_list + ] + value_list = array_ops.split(value, size_splits_list, axis=partition_ix) + op_list = [ + assign_fn(var, value_list[idx], idx) + for idx, var in enumerate(self._variable_list) + ] + return op_list + + def assign(self, value, use_locking=False, name=None, read_value=True): + assign_fn = lambda var, r_value, idx: var.assign( + r_value, use_locking=use_locking, + name="%s_%d" % (name, idx), read_value=read_value) + assign_list = self._apply_assign_fn(assign_fn, value) + if read_value: + return assign_list + return [assign.op for assign in assign_list] + + def assign_add(self, value, use_locking=False, name=None, read_value=True): + assign_fn = lambda var, r_value, idx: var.assign_add( + r_value, use_locking=use_locking, + name="%s_%d" % (name, idx), read_value=read_value) + assign_list = self._apply_assign_fn(assign_fn, value) + if read_value: + return assign_list + return [assign.op for assign in assign_list] + + def assign_sub(self, value, use_locking=False, name=None, read_value=True): + assign_fn = lambda var, r_value, idx: var.assign_sub( + r_value, use_locking=use_locking, + name="%s_%d" % (name, idx), read_value=read_value) + assign_list = self._apply_assign_fn(assign_fn, value) + if read_value: + return assign_list + return [assign.op for assign in assign_list] @tf_export(v1=["global_variables"]) |