aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/values.py
diff options
context:
space:
mode:
authorGravatar Priya Gupta <priyag@google.com>2018-07-17 23:08:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-17 23:12:29 -0700
commitf1de0ddd55dcae6237ea7d21ccddcc6467a6cf8b (patch)
tree4238193e76288333027c86aeac2dc86f165af641 /tensorflow/contrib/distribute/python/values.py
parentaa15692e54390cf3967d51bc60acf5f783df9c08 (diff)
Add support for MirroredVariables in init_from_checkpoint and warm_start in estimator.
PiperOrigin-RevId: 205030626
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r--tensorflow/contrib/distribute/python/values.py15
1 files changed, 9 insertions, 6 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index 1b5e00bc79..1761a43251 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -33,7 +33,6 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.training import device_util
from tensorflow.python.training import distribute as distribute_lib
@@ -336,23 +335,27 @@ class MirroredVariable(DistributedVariable, Mirrored,
raise ValueError("You must specify an aggregation method to update a "
"MirroredVariable in Tower Context.")
- def merge_fn(strategy, value):
+ def merge_fn(strategy, value, *other_args, **other_kwargs):
return strategy.update(
self, f,
strategy.reduce(
- aggregation=self._aggregation, value=value, destinations=self))
+ aggregation=self._aggregation, value=value, destinations=self),
+ *other_args, **other_kwargs)
return distribute_lib.get_tower_context().merge_call(merge_fn, *args,
**kwargs)
def assign_sub(self, *args, **kwargs):
- return self._assign_func(f=state_ops.assign_sub, *args, **kwargs)
+ assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
+ return self._assign_func(f=assign_sub_fn, *args, **kwargs)
def assign_add(self, *args, **kwargs):
- return self._assign_func(f=state_ops.assign_add, *args, **kwargs)
+ assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
+ return self._assign_func(f=assign_add_fn, *args, **kwargs)
def assign(self, *args, **kwargs):
- return self._assign_func(f=state_ops.assign, *args, **kwargs)
+ assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
+ return self._assign_func(f=assign_fn, *args, **kwargs)
def is_initialized(self, name=None):
# We have to cast the self._index.values() to a `list` because when we