diff options
author | Anjali Sridhar <anjalisridhar@google.com> | 2018-07-06 13:50:29 -0700 |
---|---|---|
committer | Yifei Feng <yifeif@google.com> | 2018-07-06 15:17:59 -0700 |
commit | 90fc5e3819ed62e93228a9c2c29dede0f0f8cfd6 (patch) | |
tree | 0e50e14646a382fbdf5edec988f9818bb93b12c0 /tensorflow/contrib/distribute/python/values.py | |
parent | d64754c5c768f26b6a95b350cfd8c7ded2590dc9 (diff) |
Allow is_initialized and initializer to be called on MirroredVariables and TowerLocalVariables.
PiperOrigin-RevId: 203520287
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/values.py | 56 |
1 files changed, 56 insertions, 0 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index b36ac563d2..1b5e00bc79 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -297,6 +297,12 @@ class MirroredVariable(DistributedVariable, Mirrored, for v in six.itervalues(index): v._mirrored_container = weakref.ref(self) # pylint: disable=protected-access self._primary_var = primary_var + # tf.keras keeps track of variables initialized using this attribute. When + # tf.keras gets the default session, it initializes all uninitialized vars. + # We need to make _keras_initialized a member of MirroredVariable because + # without this it will use `__getattr__` which will delegate to a component + # variable. + self._keras_initialized = False self._aggregation = aggregation super(MirroredVariable, self).__init__(index) @@ -348,6 +354,28 @@ class MirroredVariable(DistributedVariable, Mirrored, def assign(self, *args, **kwargs): return self._assign_func(f=state_ops.assign, *args, **kwargs) + def is_initialized(self, name=None): + # We have to cast the self._index.values() to a `list` because when we + # use `model_to_estimator` to run tf.keras models, self._index.values() is + # of type `dict_values` and not `list`. + values_list = list(self._index.values()) + result = values_list[0].is_initialized() + # We iterate through the list of values except the last one to allow us to + # name the final `logical_and` op the same name that is passed by the user + # to the `is_initialized` op. For mirrored variables, the `is_initialized` + # op is a `logical_and` op. + for v in values_list[1:-1]: + result = math_ops.logical_and(result, v.is_initialized()) + result = math_ops.logical_and(result, values_list[-1].is_initialized(), + name=name) + return result + + @property + def initializer(self): + # return grouped ops of all the var initializations of component values of + # the mirrored variable + return control_flow_ops.group([v.initializer for v in self._index.values()]) + @property def aggregation(self): return self._aggregation @@ -435,6 +463,12 @@ class TowerLocalVariable(DistributedVariable, PerDevice, def __init__(self, index, primary_var, aggregation): self._primary_var = primary_var self._aggregation = aggregation + # tf.keras keeps track of variables initialized using this attribute. When + # tf.keras gets the default session, it initializes all uninitialized vars. + # We need to make _keras_initialized a member of TowerLocalVariable because + # without this it will use `__getattr__` which will delegate to a component + # variable. + self._keras_initialized = False super(TowerLocalVariable, self).__init__(index) def assign_sub(self, *args, **kwargs): @@ -449,6 +483,28 @@ class TowerLocalVariable(DistributedVariable, PerDevice, _assert_tower_context() return self.get().assign(*args, **kwargs) + def is_initialized(self, name=None): + # We have to cast the self._index.values() to a `list` because when we + # use `model_to_estimator` to run tf.keras models, self._index.values() is + # of type `dict_values` and not `list`. + values_list = list(self._index.values()) + result = values_list[0].is_initialized() + # We iterate through the list of values except the last one to allow us to + # name the final `logical_and` op the same name that is passed by the user + # to the `is_initialized` op. For tower local variables, the + # `is_initialized` op is a `logical_and` op. + for v in values_list[1:-1]: + result = math_ops.logical_and(result, v.is_initialized()) + result = math_ops.logical_and(result, values_list[-1].is_initialized(), + name=name) + return result + + @property + def initializer(self): + # return grouped ops of all the var initializations of component values of + # the tower local variable + return control_flow_ops.group([v.initializer for v in self._index.values()]) + @property def aggregation(self): return self._aggregation |