diff options
author | Priya Gupta <priyag@google.com> | 2018-07-17 23:08:47 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-17 23:12:29 -0700 |
commit | f1de0ddd55dcae6237ea7d21ccddcc6467a6cf8b (patch) | |
tree | 4238193e76288333027c86aeac2dc86f165af641 /tensorflow/contrib/distribute/python/values.py | |
parent | aa15692e54390cf3967d51bc60acf5f783df9c08 (diff) |
Add support for MirroredVariables in init_from_checkpoint and warm_start in estimator.
PiperOrigin-RevId: 205030626
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/values.py | 15 |
1 files changed, 9 insertions, 6 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index 1b5e00bc79..1761a43251 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -33,7 +33,6 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops -from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.training import device_util from tensorflow.python.training import distribute as distribute_lib @@ -336,23 +335,27 @@ class MirroredVariable(DistributedVariable, Mirrored, raise ValueError("You must specify an aggregation method to update a " "MirroredVariable in Tower Context.") - def merge_fn(strategy, value): + def merge_fn(strategy, value, *other_args, **other_kwargs): return strategy.update( self, f, strategy.reduce( - aggregation=self._aggregation, value=value, destinations=self)) + aggregation=self._aggregation, value=value, destinations=self), + *other_args, **other_kwargs) return distribute_lib.get_tower_context().merge_call(merge_fn, *args, **kwargs) def assign_sub(self, *args, **kwargs): - return self._assign_func(f=state_ops.assign_sub, *args, **kwargs) + assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw) + return self._assign_func(f=assign_sub_fn, *args, **kwargs) def assign_add(self, *args, **kwargs): - return self._assign_func(f=state_ops.assign_add, *args, **kwargs) + assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw) + return self._assign_func(f=assign_add_fn, *args, **kwargs) def assign(self, *args, **kwargs): - return self._assign_func(f=state_ops.assign, *args, **kwargs) + assign_fn = lambda var, *a, **kw: var.assign(*a, **kw) + return self._assign_func(f=assign_fn, *args, **kwargs) def is_initialized(self, name=None): # We have to cast the self._index.values() to a `list` because when we |