aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar wangsiyu <siyu.wsy@gmail.com>2018-09-23 20:33:19 +0800
committerGravatar wangsiyu <siyu.wsy@gmail.com>2018-09-23 20:37:29 +0800
commit6dd7a09211cc74d11ff1554624b527c432020cbc (patch)
treed58876ca39b268b9a1615cf80fa5d8cf5726e636
parent646b3c237deaddddd087d39ab57130b08375c4c7 (diff)
Enable partitioned variable assignments
-rw-r--r--tensorflow/python/kernel_tests/variables_test.py43
-rw-r--r--tensorflow/python/ops/variables.py47
2 files changed, 85 insertions, 5 deletions
diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py
index 2e7975667c..687784c8b7 100644
--- a/tensorflow/python/kernel_tests/variables_test.py
+++ b/tensorflow/python/kernel_tests/variables_test.py
@@ -673,7 +673,7 @@ class PartitionedVariableTest(test.TestCase):
v0._set_save_slice_info(
variables.Variable.SaveSliceInfo(v0.name, [2], [0], [1]))
v1._set_save_slice_info(
- variables.Variable.SaveSliceInfo(v0.name, [2], [1], [1]))
+ variables.Variable.SaveSliceInfo(v1.name, [2], [1], [1]))
partitions = [2]
variables.PartitionedVariable(
@@ -696,6 +696,47 @@ class PartitionedVariableTest(test.TestCase):
variable_list=[v0],
partitions=partitions)
+ def testPartitionedVariableAssignments(self):
+ with ops.Graph().as_default(), self.cached_session() as sess:
+ v0 = variables.Variable(initial_value=[0.0])
+ v1 = variables.Variable(initial_value=[1.0])
+ v0._set_save_slice_info(
+ variables.Variable.SaveSliceInfo(v0.name, [2], [0], [1]))
+ v1._set_save_slice_info(
+ variables.Variable.SaveSliceInfo(v0.name, [2], [1], [1]))
+ partitions = [2]
+
+ # Pass variable_list as [v1, v0] to ensure they are properly
+ # re-sorted to [v0, v1] based on their slice info offsets.
+ partitioned_variable = variables.PartitionedVariable(
+ name="two_vars",
+ shape=[2],
+ dtype=v0.dtype,
+ variable_list=[v0, v1],
+ partitions=partitions)
+
+ deltas_a = constant_op.constant([1.0, 2.0])
+ deltas_b = constant_op.constant([3.0, 4.0])
+ ones = array_ops.ones([2])
+ plus_delta = partitioned_variable.assign_add(deltas_a)
+ minus_delta = partitioned_variable.assign_sub(deltas_b)
+ assign_ones = partitioned_variable.assign(ones)
+ variables.global_variables_initializer().run()
+
+ self.assertEqual([1.0], plus_delta[0].eval())
+ self.assertEqual([1.0], v0.eval())
+ self.assertEqual([3.0], plus_delta[1].eval())
+ self.assertEqual([3.0], v1.eval())
+
+ self.assertEqual([-2.0], minus_delta[0].eval())
+ self.assertEqual([-2.0], v0.eval())
+ self.assertEqual([-1.0], minus_delta[1].eval())
+ self.assertEqual([-1.0], v1.eval())
+
+ self.assertEqual([1.0], assign_ones[0].eval())
+ self.assertEqual([1.0], v0.eval())
+ self.assertEqual([1.0], assign_ones[1].eval())
+ self.assertEqual([1.0], v1.eval())
class VariableContainerTest(test.TestCase):
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index 7a46157739..2d6a767fed 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -2395,11 +2395,50 @@ class PartitionedVariable(object):
def _get_partitions(self):
return self._partitions
- def assign(self, value, use_locking=False):
- _ = value, use_locking
- raise NotImplementedError(
- "assign() has not been implemented for PartitionedVariable.")
+ def _apply_assign_fn(self,
+ assign_fn,
+ value):
+ partition_axes = self._partition_axes()
+ if len(partition_axes) > 1:
+ raise NotImplementedError(
+ "Cannot concatenate along more than one dimension: %s. "
+ "Multi-axis partition assign_fn is not supported" % str(partition_axes))
+ partition_ix = partition_axes[0]
+ size_splits_list = [
+ var.shape[partition_ix].value for var in self._variable_list]
+ value_list = array_ops.split(
+ value, size_splits_list, axis=partition_ix)
+ op_list = [
+ assign_fn(var, value_list[idx], idx) \
+ for idx, var in enumerate(self._variable_list)]
+ return op_list
+ def assign(self, value, use_locking=False, name=None, read_value=True):
+ assign_fn = lambda var, r_value, idx: var.assign(
+ r_value, use_locking=use_locking,
+ name="%s_%d" % (name, idx), read_value=read_value)
+ assign_list = self._apply_assign_fn(assign_fn, value)
+ if read_value:
+ return assign_list
+ return [assign.op for assign in assign_list]
+
+ def assign_add(self, value, use_locking=False, name=None, read_value=True):
+ assign_fn = lambda var, r_value, idx: var.assign_add(
+ r_value, use_locking=use_locking,
+ name="%s_%d" % (name, idx), read_value=read_value)
+ assign_list = self._apply_assign_fn(assign_fn, value)
+ if read_value:
+ return assign_list
+ return [assign.op for assign in assign_list]
+
+ def assign_sub(self, value, use_locking=False, name=None, read_value=True):
+ assign_fn = lambda var, r_value, idx: var.assign_sub(
+ r_value, use_locking=use_locking,
+ name="%s_%d" % (name, idx), read_value=read_value)
+ assign_list = self._apply_assign_fn(assign_fn, value)
+ if read_value:
+ return assign_list
+ return [assign.op for assign in assign_list]
@tf_export("global_variables")
def global_variables(scope=None):