diff options
author | 2018-09-23 20:33:19 +0800 | |
---|---|---|
committer | 2018-09-23 20:37:29 +0800 | |
commit | 6dd7a09211cc74d11ff1554624b527c432020cbc (patch) | |
tree | d58876ca39b268b9a1615cf80fa5d8cf5726e636 | |
parent | 646b3c237deaddddd087d39ab57130b08375c4c7 (diff) |
Enable partitioned variable assignments
-rw-r--r-- | tensorflow/python/kernel_tests/variables_test.py | 43 | ||||
-rw-r--r-- | tensorflow/python/ops/variables.py | 47 |
2 files changed, 85 insertions, 5 deletions
diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py index 2e7975667c..687784c8b7 100644 --- a/tensorflow/python/kernel_tests/variables_test.py +++ b/tensorflow/python/kernel_tests/variables_test.py @@ -673,7 +673,7 @@ class PartitionedVariableTest(test.TestCase): v0._set_save_slice_info( variables.Variable.SaveSliceInfo(v0.name, [2], [0], [1])) v1._set_save_slice_info( - variables.Variable.SaveSliceInfo(v0.name, [2], [1], [1])) + variables.Variable.SaveSliceInfo(v1.name, [2], [1], [1])) partitions = [2] variables.PartitionedVariable( @@ -696,6 +696,47 @@ class PartitionedVariableTest(test.TestCase): variable_list=[v0], partitions=partitions) + def testPartitionedVariableAssignments(self): + with ops.Graph().as_default(), self.cached_session() as sess: + v0 = variables.Variable(initial_value=[0.0]) + v1 = variables.Variable(initial_value=[1.0]) + v0._set_save_slice_info( + variables.Variable.SaveSliceInfo(v0.name, [2], [0], [1])) + v1._set_save_slice_info( + variables.Variable.SaveSliceInfo(v0.name, [2], [1], [1])) + partitions = [2] + + # Pass variable_list as [v1, v0] to ensure they are properly + # re-sorted to [v0, v1] based on their slice info offsets. + partitioned_variable = variables.PartitionedVariable( + name="two_vars", + shape=[2], + dtype=v0.dtype, + variable_list=[v0, v1], + partitions=partitions) + + deltas_a = constant_op.constant([1.0, 2.0]) + deltas_b = constant_op.constant([3.0, 4.0]) + ones = array_ops.ones([2]) + plus_delta = partitioned_variable.assign_add(deltas_a) + minus_delta = partitioned_variable.assign_sub(deltas_b) + assign_ones = partitioned_variable.assign(ones) + variables.global_variables_initializer().run() + + self.assertEqual([1.0], plus_delta[0].eval()) + self.assertEqual([1.0], v0.eval()) + self.assertEqual([3.0], plus_delta[1].eval()) + self.assertEqual([3.0], v1.eval()) + + self.assertEqual([-2.0], minus_delta[0].eval()) + self.assertEqual([-2.0], v0.eval()) + self.assertEqual([-1.0], minus_delta[1].eval()) + self.assertEqual([-1.0], v1.eval()) + + self.assertEqual([1.0], assign_ones[0].eval()) + self.assertEqual([1.0], v0.eval()) + self.assertEqual([1.0], assign_ones[1].eval()) + self.assertEqual([1.0], v1.eval()) class VariableContainerTest(test.TestCase): diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index 7a46157739..2d6a767fed 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -2395,11 +2395,50 @@ class PartitionedVariable(object): def _get_partitions(self): return self._partitions - def assign(self, value, use_locking=False): - _ = value, use_locking - raise NotImplementedError( - "assign() has not been implemented for PartitionedVariable.") + def _apply_assign_fn(self, + assign_fn, + value): + partition_axes = self._partition_axes() + if len(partition_axes) > 1: + raise NotImplementedError( + "Cannot concatenate along more than one dimension: %s. " + "Multi-axis partition assign_fn is not supported" % str(partition_axes)) + partition_ix = partition_axes[0] + size_splits_list = [ + var.shape[partition_ix].value for var in self._variable_list] + value_list = array_ops.split( + value, size_splits_list, axis=partition_ix) + op_list = [ + assign_fn(var, value_list[idx], idx) \ + for idx, var in enumerate(self._variable_list)] + return op_list + def assign(self, value, use_locking=False, name=None, read_value=True): + assign_fn = lambda var, r_value, idx: var.assign( + r_value, use_locking=use_locking, + name="%s_%d" % (name, idx), read_value=read_value) + assign_list = self._apply_assign_fn(assign_fn, value) + if read_value: + return assign_list + return [assign.op for assign in assign_list] + + def assign_add(self, value, use_locking=False, name=None, read_value=True): + assign_fn = lambda var, r_value, idx: var.assign_add( + r_value, use_locking=use_locking, + name="%s_%d" % (name, idx), read_value=read_value) + assign_list = self._apply_assign_fn(assign_fn, value) + if read_value: + return assign_list + return [assign.op for assign in assign_list] + + def assign_sub(self, value, use_locking=False, name=None, read_value=True): + assign_fn = lambda var, r_value, idx: var.assign_sub( + r_value, use_locking=use_locking, + name="%s_%d" % (name, idx), read_value=read_value) + assign_list = self._apply_assign_fn(assign_fn, value) + if read_value: + return assign_list + return [assign.op for assign in assign_list] @tf_export("global_variables") def global_variables(scope=None): |