aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/values.py
diff options
context:
space:
mode:
authorGravatar Anjali Sridhar <anjalisridhar@google.com>2018-07-20 11:11:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-20 11:14:44 -0700
commitda929851e2b5446a5aaee29a869428037a72f2b7 (patch)
tree550f7673fc8664be74a9db635cd5bd5cd5b88fc8 /tensorflow/contrib/distribute/python/values.py
parentdfcf85601a0372f3130baf9fe36605fe528144cc (diff)
Refactor properties and functions common to Mirrored and TowerLocal Variables.
PiperOrigin-RevId: 205424692
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r--tensorflow/contrib/distribute/python/values.py89
1 files changed, 33 insertions, 56 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index 1761a43251..3162aebf5b 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -196,10 +196,43 @@ class DistributedVariable(DistributedDelegate):
# to the container without introducing a reference cycle.
for v in six.itervalues(index):
v._distributed_container = weakref.ref(self) # pylint: disable=protected-access
+ # tf.keras keeps track of variables initialized using this attribute. When
+ # tf.keras gets the default session, it initializes all uninitialized vars.
+ # We need to make _keras_initialized a member of DistributedVariable because
+ # without this it will use `__getattr__` which will delegate to a component
+ # variable.
+ self._keras_initialized = False
super(DistributedVariable, self).__init__(index)
+ def is_initialized(self, name=None):
+ """Identifies if all the component variables are initialized.
+
+ Args:
+ name: Name of the final `logical_and` op.
+
+ Returns:
+ The op that evaluates to True or False depending on if all the
+ component variables are initialized.
+ """
+ # We have to cast the self._index.values() to a `list` because when we
+ # use `model_to_estimator` to run tf.keras models, self._index.values() is
+ # of type `dict_values` and not `list`.
+ values_list = list(self._index.values())
+ result = values_list[0].is_initialized()
+ # We iterate through the list of values except the last one to allow us to
+ # name the final `logical_and` op the same name that is passed by the user
+ # to the `is_initialized` op. For distributed variables, the
+ # `is_initialized` op is a `logical_and` op.
+ for v in values_list[1:-1]:
+ result = math_ops.logical_and(result, v.is_initialized())
+ result = math_ops.logical_and(result, values_list[-1].is_initialized(),
+ name=name)
+ return result
+
@property
def initializer(self):
+ # return grouped ops of all the var initializations of component values of
+ # the mirrored variable
return control_flow_ops.group([v.initializer for v in self._index.values()])
@property
@@ -296,12 +329,6 @@ class MirroredVariable(DistributedVariable, Mirrored,
for v in six.itervalues(index):
v._mirrored_container = weakref.ref(self) # pylint: disable=protected-access
self._primary_var = primary_var
- # tf.keras keeps track of variables initialized using this attribute. When
- # tf.keras gets the default session, it initializes all uninitialized vars.
- # We need to make _keras_initialized a member of MirroredVariable because
- # without this it will use `__getattr__` which will delegate to a component
- # variable.
- self._keras_initialized = False
self._aggregation = aggregation
super(MirroredVariable, self).__init__(index)
@@ -357,28 +384,6 @@ class MirroredVariable(DistributedVariable, Mirrored,
assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
return self._assign_func(f=assign_fn, *args, **kwargs)
- def is_initialized(self, name=None):
- # We have to cast the self._index.values() to a `list` because when we
- # use `model_to_estimator` to run tf.keras models, self._index.values() is
- # of type `dict_values` and not `list`.
- values_list = list(self._index.values())
- result = values_list[0].is_initialized()
- # We iterate through the list of values except the last one to allow us to
- # name the final `logical_and` op the same name that is passed by the user
- # to the `is_initialized` op. For mirrored variables, the `is_initialized`
- # op is a `logical_and` op.
- for v in values_list[1:-1]:
- result = math_ops.logical_and(result, v.is_initialized())
- result = math_ops.logical_and(result, values_list[-1].is_initialized(),
- name=name)
- return result
-
- @property
- def initializer(self):
- # return grouped ops of all the var initializations of component values of
- # the mirrored variable
- return control_flow_ops.group([v.initializer for v in self._index.values()])
-
@property
def aggregation(self):
return self._aggregation
@@ -466,12 +471,6 @@ class TowerLocalVariable(DistributedVariable, PerDevice,
def __init__(self, index, primary_var, aggregation):
self._primary_var = primary_var
self._aggregation = aggregation
- # tf.keras keeps track of variables initialized using this attribute. When
- # tf.keras gets the default session, it initializes all uninitialized vars.
- # We need to make _keras_initialized a member of TowerLocalVariable because
- # without this it will use `__getattr__` which will delegate to a component
- # variable.
- self._keras_initialized = False
super(TowerLocalVariable, self).__init__(index)
def assign_sub(self, *args, **kwargs):
@@ -486,28 +485,6 @@ class TowerLocalVariable(DistributedVariable, PerDevice,
_assert_tower_context()
return self.get().assign(*args, **kwargs)
- def is_initialized(self, name=None):
- # We have to cast the self._index.values() to a `list` because when we
- # use `model_to_estimator` to run tf.keras models, self._index.values() is
- # of type `dict_values` and not `list`.
- values_list = list(self._index.values())
- result = values_list[0].is_initialized()
- # We iterate through the list of values except the last one to allow us to
- # name the final `logical_and` op the same name that is passed by the user
- # to the `is_initialized` op. For tower local variables, the
- # `is_initialized` op is a `logical_and` op.
- for v in values_list[1:-1]:
- result = math_ops.logical_and(result, v.is_initialized())
- result = math_ops.logical_and(result, values_list[-1].is_initialized(),
- name=name)
- return result
-
- @property
- def initializer(self):
- # return grouped ops of all the var initializations of component values of
- # the tower local variable
- return control_flow_ops.group([v.initializer for v in self._index.values()])
-
@property
def aggregation(self):
return self._aggregation