aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/values.py
diff options
context:
space:
mode:
authorGravatar Anjali Sridhar <anjalisridhar@google.com>2018-07-06 13:50:29 -0700
committerGravatar Yifei Feng <yifeif@google.com>2018-07-06 15:17:59 -0700
commit90fc5e3819ed62e93228a9c2c29dede0f0f8cfd6 (patch)
tree0e50e14646a382fbdf5edec988f9818bb93b12c0 /tensorflow/contrib/distribute/python/values.py
parentd64754c5c768f26b6a95b350cfd8c7ded2590dc9 (diff)
Allow is_initialized and initializer to be called on MirroredVariables and TowerLocalVariables.
PiperOrigin-RevId: 203520287
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r--tensorflow/contrib/distribute/python/values.py56
1 files changed, 56 insertions, 0 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index b36ac563d2..1b5e00bc79 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -297,6 +297,12 @@ class MirroredVariable(DistributedVariable, Mirrored,
for v in six.itervalues(index):
v._mirrored_container = weakref.ref(self) # pylint: disable=protected-access
self._primary_var = primary_var
+ # tf.keras keeps track of variables initialized using this attribute. When
+ # tf.keras gets the default session, it initializes all uninitialized vars.
+ # We need to make _keras_initialized a member of MirroredVariable because
+ # without this it will use `__getattr__` which will delegate to a component
+ # variable.
+ self._keras_initialized = False
self._aggregation = aggregation
super(MirroredVariable, self).__init__(index)
@@ -348,6 +354,28 @@ class MirroredVariable(DistributedVariable, Mirrored,
def assign(self, *args, **kwargs):
return self._assign_func(f=state_ops.assign, *args, **kwargs)
+ def is_initialized(self, name=None):
+ # We have to cast the self._index.values() to a `list` because when we
+ # use `model_to_estimator` to run tf.keras models, self._index.values() is
+ # of type `dict_values` and not `list`.
+ values_list = list(self._index.values())
+ result = values_list[0].is_initialized()
+ # We iterate through the list of values except the last one to allow us to
+ # name the final `logical_and` op the same name that is passed by the user
+ # to the `is_initialized` op. For mirrored variables, the `is_initialized`
+ # op is a `logical_and` op.
+ for v in values_list[1:-1]:
+ result = math_ops.logical_and(result, v.is_initialized())
+ result = math_ops.logical_and(result, values_list[-1].is_initialized(),
+ name=name)
+ return result
+
+ @property
+ def initializer(self):
+ # return grouped ops of all the var initializations of component values of
+ # the mirrored variable
+ return control_flow_ops.group([v.initializer for v in self._index.values()])
+
@property
def aggregation(self):
return self._aggregation
@@ -435,6 +463,12 @@ class TowerLocalVariable(DistributedVariable, PerDevice,
def __init__(self, index, primary_var, aggregation):
self._primary_var = primary_var
self._aggregation = aggregation
+ # tf.keras keeps track of variables initialized using this attribute. When
+ # tf.keras gets the default session, it initializes all uninitialized vars.
+ # We need to make _keras_initialized a member of TowerLocalVariable because
+ # without this it will use `__getattr__` which will delegate to a component
+ # variable.
+ self._keras_initialized = False
super(TowerLocalVariable, self).__init__(index)
def assign_sub(self, *args, **kwargs):
@@ -449,6 +483,28 @@ class TowerLocalVariable(DistributedVariable, PerDevice,
_assert_tower_context()
return self.get().assign(*args, **kwargs)
+ def is_initialized(self, name=None):
+ # We have to cast the self._index.values() to a `list` because when we
+ # use `model_to_estimator` to run tf.keras models, self._index.values() is
+ # of type `dict_values` and not `list`.
+ values_list = list(self._index.values())
+ result = values_list[0].is_initialized()
+ # We iterate through the list of values except the last one to allow us to
+ # name the final `logical_and` op the same name that is passed by the user
+ # to the `is_initialized` op. For tower local variables, the
+ # `is_initialized` op is a `logical_and` op.
+ for v in values_list[1:-1]:
+ result = math_ops.logical_and(result, v.is_initialized())
+ result = math_ops.logical_and(result, values_list[-1].is_initialized(),
+ name=name)
+ return result
+
+ @property
+ def initializer(self):
+ # return grouped ops of all the var initializations of component values of
+ # the tower local variable
+ return control_flow_ops.group([v.initializer for v in self._index.values()])
+
@property
def aggregation(self):
return self._aggregation