diff options
author | Anjali Sridhar <anjalisridhar@google.com> | 2018-07-20 11:11:42 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-20 11:14:44 -0700 |
commit | da929851e2b5446a5aaee29a869428037a72f2b7 (patch) | |
tree | 550f7673fc8664be74a9db635cd5bd5cd5b88fc8 /tensorflow/contrib/distribute/python/values.py | |
parent | dfcf85601a0372f3130baf9fe36605fe528144cc (diff) |
Refactor properties and functions common to Mirrored and TowerLocal Variables.
PiperOrigin-RevId: 205424692
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/values.py | 89 |
1 files changed, 33 insertions, 56 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index 1761a43251..3162aebf5b 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -196,10 +196,43 @@ class DistributedVariable(DistributedDelegate): # to the container without introducing a reference cycle. for v in six.itervalues(index): v._distributed_container = weakref.ref(self) # pylint: disable=protected-access + # tf.keras keeps track of variables initialized using this attribute. When + # tf.keras gets the default session, it initializes all uninitialized vars. + # We need to make _keras_initialized a member of DistributedVariable because + # without this it will use `__getattr__` which will delegate to a component + # variable. + self._keras_initialized = False super(DistributedVariable, self).__init__(index) + def is_initialized(self, name=None): + """Identifies if all the component variables are initialized. + + Args: + name: Name of the final `logical_and` op. + + Returns: + The op that evaluates to True or False depending on if all the + component variables are initialized. + """ + # We have to cast the self._index.values() to a `list` because when we + # use `model_to_estimator` to run tf.keras models, self._index.values() is + # of type `dict_values` and not `list`. + values_list = list(self._index.values()) + result = values_list[0].is_initialized() + # We iterate through the list of values except the last one to allow us to + # name the final `logical_and` op the same name that is passed by the user + # to the `is_initialized` op. For distributed variables, the + # `is_initialized` op is a `logical_and` op. + for v in values_list[1:-1]: + result = math_ops.logical_and(result, v.is_initialized()) + result = math_ops.logical_and(result, values_list[-1].is_initialized(), + name=name) + return result + @property def initializer(self): + # return grouped ops of all the var initializations of component values of + # the mirrored variable return control_flow_ops.group([v.initializer for v in self._index.values()]) @property @@ -296,12 +329,6 @@ class MirroredVariable(DistributedVariable, Mirrored, for v in six.itervalues(index): v._mirrored_container = weakref.ref(self) # pylint: disable=protected-access self._primary_var = primary_var - # tf.keras keeps track of variables initialized using this attribute. When - # tf.keras gets the default session, it initializes all uninitialized vars. - # We need to make _keras_initialized a member of MirroredVariable because - # without this it will use `__getattr__` which will delegate to a component - # variable. - self._keras_initialized = False self._aggregation = aggregation super(MirroredVariable, self).__init__(index) @@ -357,28 +384,6 @@ class MirroredVariable(DistributedVariable, Mirrored, assign_fn = lambda var, *a, **kw: var.assign(*a, **kw) return self._assign_func(f=assign_fn, *args, **kwargs) - def is_initialized(self, name=None): - # We have to cast the self._index.values() to a `list` because when we - # use `model_to_estimator` to run tf.keras models, self._index.values() is - # of type `dict_values` and not `list`. - values_list = list(self._index.values()) - result = values_list[0].is_initialized() - # We iterate through the list of values except the last one to allow us to - # name the final `logical_and` op the same name that is passed by the user - # to the `is_initialized` op. For mirrored variables, the `is_initialized` - # op is a `logical_and` op. - for v in values_list[1:-1]: - result = math_ops.logical_and(result, v.is_initialized()) - result = math_ops.logical_and(result, values_list[-1].is_initialized(), - name=name) - return result - - @property - def initializer(self): - # return grouped ops of all the var initializations of component values of - # the mirrored variable - return control_flow_ops.group([v.initializer for v in self._index.values()]) - @property def aggregation(self): return self._aggregation @@ -466,12 +471,6 @@ class TowerLocalVariable(DistributedVariable, PerDevice, def __init__(self, index, primary_var, aggregation): self._primary_var = primary_var self._aggregation = aggregation - # tf.keras keeps track of variables initialized using this attribute. When - # tf.keras gets the default session, it initializes all uninitialized vars. - # We need to make _keras_initialized a member of TowerLocalVariable because - # without this it will use `__getattr__` which will delegate to a component - # variable. - self._keras_initialized = False super(TowerLocalVariable, self).__init__(index) def assign_sub(self, *args, **kwargs): @@ -486,28 +485,6 @@ class TowerLocalVariable(DistributedVariable, PerDevice, _assert_tower_context() return self.get().assign(*args, **kwargs) - def is_initialized(self, name=None): - # We have to cast the self._index.values() to a `list` because when we - # use `model_to_estimator` to run tf.keras models, self._index.values() is - # of type `dict_values` and not `list`. - values_list = list(self._index.values()) - result = values_list[0].is_initialized() - # We iterate through the list of values except the last one to allow us to - # name the final `logical_and` op the same name that is passed by the user - # to the `is_initialized` op. For tower local variables, the - # `is_initialized` op is a `logical_and` op. - for v in values_list[1:-1]: - result = math_ops.logical_and(result, v.is_initialized()) - result = math_ops.logical_and(result, values_list[-1].is_initialized(), - name=name) - return result - - @property - def initializer(self): - # return grouped ops of all the var initializations of component values of - # the tower local variable - return control_flow_ops.group([v.initializer for v in self._index.values()]) - @property def aggregation(self): return self._aggregation |