diff options
author | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-01 07:39:41 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-01 07:39:41 -0700 |
commit | eef9a3e816932b7c34f426d00d7df53c91130402 (patch) | |
tree | 5a8b508b46f13df0d19b10b453d09a63b48b7323 /tensorflow/python/ops | |
parent | 7c5eb354a6b5b2d5a2e27d8ce3dc4861cb51153c (diff) | |
parent | 8eb27871583d9fc61e046493acaa0df2839bc1c7 (diff) |
Merge pull request #22473 from wangsiyu:assign_in_part_vars
PiperOrigin-RevId: 215211485
Diffstat (limited to 'tensorflow/python/ops')
-rw-r--r-- | tensorflow/python/ops/variables.py | 48 |
1 files changed, 44 insertions, 4 deletions
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index 8da1e9fe56..45c8618610 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -2620,10 +2620,50 @@ class PartitionedVariable(object): def _get_partitions(self): return self._partitions - def assign(self, value, use_locking=False): - _ = value, use_locking - raise NotImplementedError( - "assign() has not been implemented for PartitionedVariable.") + def _apply_assign_fn(self, assign_fn, value): + partition_axes = self._partition_axes() + if len(partition_axes) > 1: + raise NotImplementedError( + "Cannot do assign action along more than one dimension: %s. " + "Multi-axis partition assign action is not supported " % + str(partition_axes)) + partition_ix = partition_axes[0] + size_splits_list = [ + var.shape[partition_ix].value for var in self._variable_list + ] + value_list = array_ops.split(value, size_splits_list, axis=partition_ix) + op_list = [ + assign_fn(var, value_list[idx], idx) + for idx, var in enumerate(self._variable_list) + ] + return op_list + + def assign(self, value, use_locking=False, name=None, read_value=True): + assign_fn = lambda var, r_value, idx: var.assign( + r_value, use_locking=use_locking, + name="%s_%d" % (name, idx), read_value=read_value) + assign_list = self._apply_assign_fn(assign_fn, value) + if read_value: + return assign_list + return [assign.op for assign in assign_list] + + def assign_add(self, value, use_locking=False, name=None, read_value=True): + assign_fn = lambda var, r_value, idx: var.assign_add( + r_value, use_locking=use_locking, + name="%s_%d" % (name, idx), read_value=read_value) + assign_list = self._apply_assign_fn(assign_fn, value) + if read_value: + return assign_list + return [assign.op for assign in assign_list] + + def assign_sub(self, value, use_locking=False, name=None, read_value=True): + assign_fn = lambda var, r_value, idx: var.assign_sub( + r_value, use_locking=use_locking, + name="%s_%d" % (name, idx), read_value=read_value) + assign_list = self._apply_assign_fn(assign_fn, value) + if read_value: + return assign_list + return [assign.op for assign in assign_list] @tf_export(v1=["global_variables"]) |