aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-01 07:39:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-01 07:39:41 -0700
commiteef9a3e816932b7c34f426d00d7df53c91130402 (patch)
tree5a8b508b46f13df0d19b10b453d09a63b48b7323 /tensorflow/python/ops
parent7c5eb354a6b5b2d5a2e27d8ce3dc4861cb51153c (diff)
parent8eb27871583d9fc61e046493acaa0df2839bc1c7 (diff)
Merge pull request #22473 from wangsiyu:assign_in_part_vars
PiperOrigin-RevId: 215211485
Diffstat (limited to 'tensorflow/python/ops')
-rw-r--r--tensorflow/python/ops/variables.py48
1 files changed, 44 insertions, 4 deletions
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index 8da1e9fe56..45c8618610 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -2620,10 +2620,50 @@ class PartitionedVariable(object):
def _get_partitions(self):
return self._partitions
- def assign(self, value, use_locking=False):
- _ = value, use_locking
- raise NotImplementedError(
- "assign() has not been implemented for PartitionedVariable.")
+ def _apply_assign_fn(self, assign_fn, value):
+ partition_axes = self._partition_axes()
+ if len(partition_axes) > 1:
+ raise NotImplementedError(
+ "Cannot do assign action along more than one dimension: %s. "
+ "Multi-axis partition assign action is not supported " %
+ str(partition_axes))
+ partition_ix = partition_axes[0]
+ size_splits_list = [
+ var.shape[partition_ix].value for var in self._variable_list
+ ]
+ value_list = array_ops.split(value, size_splits_list, axis=partition_ix)
+ op_list = [
+ assign_fn(var, value_list[idx], idx)
+ for idx, var in enumerate(self._variable_list)
+ ]
+ return op_list
+
+ def assign(self, value, use_locking=False, name=None, read_value=True):
+ assign_fn = lambda var, r_value, idx: var.assign(
+ r_value, use_locking=use_locking,
+ name="%s_%d" % (name, idx), read_value=read_value)
+ assign_list = self._apply_assign_fn(assign_fn, value)
+ if read_value:
+ return assign_list
+ return [assign.op for assign in assign_list]
+
+ def assign_add(self, value, use_locking=False, name=None, read_value=True):
+ assign_fn = lambda var, r_value, idx: var.assign_add(
+ r_value, use_locking=use_locking,
+ name="%s_%d" % (name, idx), read_value=read_value)
+ assign_list = self._apply_assign_fn(assign_fn, value)
+ if read_value:
+ return assign_list
+ return [assign.op for assign in assign_list]
+
+ def assign_sub(self, value, use_locking=False, name=None, read_value=True):
+ assign_fn = lambda var, r_value, idx: var.assign_sub(
+ r_value, use_locking=use_locking,
+ name="%s_%d" % (name, idx), read_value=read_value)
+ assign_list = self._apply_assign_fn(assign_fn, value)
+ if read_value:
+ return assign_list
+ return [assign.op for assign in assign_list]
@tf_export(v1=["global_variables"])