diff options
author | Russell Power <power@google.com> | 2018-10-01 17:45:22 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-01 17:51:41 -0700 |
commit | beede8525be5386451bf0098992c37416d1864db (patch) | |
tree | f44f87836607482f89e779be5bd6f730d0605947 /tensorflow/contrib/tpu | |
parent | 16d0079efd10eb1c5ac09522e01cff7ecdcbdfd5 (diff) |
Make Keras/TPU more robust to closed TF sessions.
PiperOrigin-RevId: 215313156
Diffstat (limited to 'tensorflow/contrib/tpu')
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/keras_support.py | 278 |
1 files changed, 155 insertions, 123 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py index 696656e840..a3a7fd8bb0 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_support.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py @@ -46,6 +46,7 @@ from __future__ import print_function import abc import collections +import contextlib import re import sys import time @@ -94,21 +95,56 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging +# TODO(b/114775106): temporary shim to optionally initialize the TPU +# This increases the odds our session is initialized, but shouldn't be needed. +def _maybe_initialize_tpu(session): + """Initialize the TPU if it has not already been initialized.""" + try: + + def test_op(): + return constant_op.constant(1) + constant_op.constant(1) + + session.run(tpu.rewrite(test_op)) + except errors.FailedPreconditionError as _: + session.run(tpu.initialize_system()) + + +@contextlib.contextmanager +def _tpu_session_context(): + """Initialize the TPU and cleans cache entries for bad sessions.""" + try: + _maybe_initialize_tpu(K.get_session()) + yield + except (errors.FailedPreconditionError, errors.AbortedError) as e: + K.clear_session() + raise Exception(""" +An error occurred connecting or initializing your TPU. + +The session has been reset. re-run keras_to_tpu_model to create a new session. +""" + e) + + def setup_tpu_session(cluster_resolver): """Construct or return a `tf.Session` connected to the given cluster.""" master = cluster_resolver.master() # Use the existing session if we're already connected to this TPU - if (K.get_session()._target == master and - getattr(K.get_session(), '_tpu_initialized', None)): - return + # N.B K.get_session() is a non-trivial operation, and may fail if the remote + # session has been reset. + try: + default_session = K.get_session() + if (default_session._target == master and + getattr(default_session, '_tpu_initialized', None)): + return + except errors.AbortedError as _: + # We lost the remote session and need to re-initialize. + logging.warning('Lost remote session: creating a new session.') cluster_spec = cluster_resolver.cluster_spec() config = config_pb2.ConfigProto(isolate_session_state=True) if cluster_spec: config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) - logging.info('Initialize') tpu_session = tf_session.Session(target=master, config=config) tpu_session.run(tpu.initialize_system()) tpu_session._tpu_initialized = True @@ -1391,97 +1427,74 @@ class KerasTPUModel(models.Model): raise EnvironmentError('KerasTPUModel currently does not support eager ' 'mode.') - assert not self._numpy_to_infeed_manager_list # Ensure empty. - - infeed_managers = [] # Managers to clean up at the end of the fit call. - if isinstance(x, dataset_ops.Dataset): - # TODO(b/111413240): Support taking a tf.data.Dataset directly. - raise ValueError( - 'Taking a Dataset directly is not yet supported. Please ' - 'wrap your dataset construction code in a function and ' - 'pass that to fit instead. For examples, see: ' - 'https://github.com/tensorflow/tpu/tree/master/models/experimental' - '/keras') - if callable(x): - with ops.device('/job:%s/device:CPU:0' % - self._tpu_assignment.worker_name): - dataset = x() - if steps_per_epoch is None: - raise ValueError('When using tf.data as input to a model, you ' - 'should specify the steps_per_epoch argument.') - if y is not None: - raise ValueError('When using tf.data as input to a model, y must be ' - 'None') - infeed_manager = TPUDatasetInfeedManager( - dataset, self._tpu_assignment, model_fn_lib.ModeKeys.TRAIN) + with _tpu_session_context(): + assert not self._numpy_to_infeed_manager_list # Ensure empty. + + infeed_managers = [] # Managers to clean up at the end of the fit call. + if isinstance(x, dataset_ops.Dataset): + # TODO(b/111413240): Support taking a tf.data.Dataset directly. + raise ValueError( + 'Taking a Dataset directly is not yet supported. Please ' + 'wrap your dataset construction code in a function and ' + 'pass that to fit instead. For examples, see: ' + 'https://github.com/tensorflow/tpu/tree/master/models/experimental' + '/keras') + if callable(x): + with ops.device( + '/job:%s/device:CPU:0' % self._tpu_assignment.worker_name): + dataset = x() + if steps_per_epoch is None: + raise ValueError('When using tf.data as input to a model, you ' + 'should specify the steps_per_epoch argument.') + if y is not None: + raise ValueError('When using tf.data as input to a model, y must ' + 'be None') + infeed_manager = TPUDatasetInfeedManager( + dataset, self._tpu_assignment, model_fn_lib.ModeKeys.TRAIN) + # Use dummy numpy inputs for the rest of Keras' shape checking. We + # intercept them when building the model. + x = infeed_manager.dummy_x + y = infeed_manager.dummy_y + infeed_managers.append((x, infeed_manager)) + + if isinstance(validation_data, dataset_ops.Dataset): + # TODO(b/111413240): Support taking a tf.data.Dataset directly. + raise ValueError( + 'Taking a Dataset directly is not yet supported. Please ' + 'wrap your dataset construction code in a function and ' + 'pass that to fit instead. For examples, see: ' + 'https://github.com/tensorflow/tpu/tree/master/models/experimental' + '/keras') + if callable(validation_data): + dataset = validation_data() + if validation_steps is None: + raise ValueError('When using tf.data as validation for a model, you ' + 'should specify the validation_steps argument.') + infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment, + model_fn_lib.ModeKeys.EVAL) # Use dummy numpy inputs for the rest of Keras' shape checking. We # intercept them when building the model. - x = infeed_manager.dummy_x - y = infeed_manager.dummy_y - infeed_managers.append((x, infeed_manager)) + val_x = infeed_manager.dummy_x + val_y = infeed_manager.dummy_y + infeed_managers.append((val_x, infeed_manager)) + validation_data = (val_x, val_y) - if isinstance(validation_data, dataset_ops.Dataset): - # TODO(b/111413240): Support taking a tf.data.Dataset directly. - raise ValueError( - 'Taking a Dataset directly is not yet supported. Please ' - 'wrap your dataset construction code in a function and ' - 'pass that to fit instead. For examples, see: ' - 'https://github.com/tensorflow/tpu/tree/master/models/experimental' - '/keras') - if callable(validation_data): - dataset = validation_data() - if validation_steps is None: - raise ValueError('When using tf.data as validation for a model, you ' - 'should specify the validation_steps argument.') - infeed_manager = TPUDatasetInfeedManager( - dataset, self._tpu_assignment, model_fn_lib.ModeKeys.EVAL) - # Use dummy numpy inputs for the rest of Keras' shape checking. We - # intercept them when building the model. - val_x = infeed_manager.dummy_x - val_y = infeed_manager.dummy_y - infeed_managers.append((val_x, infeed_manager)) - validation_data = (val_x, val_y) - - self._numpy_to_infeed_manager_list = infeed_managers - try: - if not kwargs.get('_pipeline', True): - logging.info('Running non-pipelined training loop (`_pipeline=%s`).', - kwargs['_pipeline']) - kwargs.pop('_pipeline') - return super(KerasTPUModel, self).fit( - x, - y, - batch_size, - epochs, - verbose, - callbacks, - validation_split, - validation_data, - shuffle, - class_weight, - sample_weight, - initial_epoch, - steps_per_epoch, - validation_steps, - **kwargs) - return self._pipeline_fit( - x, - y, - batch_size, - epochs, - verbose, - callbacks, - validation_split, - validation_data, - shuffle, - class_weight, - sample_weight, - initial_epoch, - steps_per_epoch, - validation_steps, - **kwargs) - finally: - self._numpy_to_infeed_manager_list = [] + self._numpy_to_infeed_manager_list = infeed_managers + try: + if not kwargs.get('_pipeline', True): + logging.info('Running non-pipelined training loop (`_pipeline=%s`).', + kwargs['_pipeline']) + kwargs.pop('_pipeline') + return super(KerasTPUModel, self).fit( + x, y, batch_size, epochs, verbose, callbacks, validation_split, + validation_data, shuffle, class_weight, sample_weight, + initial_epoch, steps_per_epoch, validation_steps, **kwargs) + return self._pipeline_fit(x, y, batch_size, epochs, verbose, callbacks, + validation_split, validation_data, shuffle, + class_weight, sample_weight, initial_epoch, + steps_per_epoch, validation_steps, **kwargs) + finally: + self._numpy_to_infeed_manager_list = [] def evaluate(self, x=None, @@ -1492,37 +1505,38 @@ class KerasTPUModel(models.Model): steps=None): assert not self._numpy_to_infeed_manager_list # Ensure empty. - infeed_managers = [] # Managers to clean up at the end of the fit call. - if isinstance(x, dataset_ops.Dataset): - # TODO(b/111413240): Support taking a tf.data.Dataset directly. - raise ValueError( - 'Taking a Dataset directly is not yet supported. Please ' - 'wrap your dataset construction code in a function and ' - 'pass that to fit instead. For examples, see: ' - 'https://github.com/tensorflow/tpu/tree/master/models/experimental' - '/keras') - if callable(x): - dataset = x() - if steps is None: - raise ValueError('When using tf.data as input to a model, you ' - 'should specify the steps argument.') - if y is not None: - raise ValueError('When using tf.data as input to a model, y must be ' - 'None') - infeed_manager = TPUDatasetInfeedManager( - dataset, self._tpu_assignment, model_fn_lib.ModeKeys.EVAL) - # Use dummy numpy inputs for the rest of Keras' shape checking. We - # intercept them when building the model. - x = infeed_manager.dummy_x - y = infeed_manager.dummy_y - infeed_managers.append((x, infeed_manager)) - - self._numpy_to_infeed_manager_list = infeed_managers - try: - return super(KerasTPUModel, self).evaluate(x, y, batch_size, verbose, - sample_weight, steps) - finally: - self._numpy_to_infeed_manager_list = [] + with _tpu_session_context(): + infeed_managers = [] # Managers to clean up at the end of the fit call. + if isinstance(x, dataset_ops.Dataset): + # TODO(b/111413240): Support taking a tf.data.Dataset directly. + raise ValueError( + 'Taking a Dataset directly is not yet supported. Please ' + 'wrap your dataset construction code in a function and ' + 'pass that to fit instead. For examples, see: ' + 'https://github.com/tensorflow/tpu/tree/master/models/experimental' + '/keras') + if callable(x): + dataset = x() + if steps is None: + raise ValueError('When using tf.data as input to a model, you ' + 'should specify the steps argument.') + if y is not None: + raise ValueError('When using tf.data as input to a model, y must be ' + 'None') + infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment, + model_fn_lib.ModeKeys.EVAL) + # Use dummy numpy inputs for the rest of Keras' shape checking. We + # intercept them when building the model. + x = infeed_manager.dummy_x + y = infeed_manager.dummy_y + infeed_managers.append((x, infeed_manager)) + + self._numpy_to_infeed_manager_list = infeed_managers + try: + return super(KerasTPUModel, self).evaluate(x, y, batch_size, verbose, + sample_weight, steps) + finally: + self._numpy_to_infeed_manager_list = [] def _pipeline_fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, @@ -1910,6 +1924,24 @@ class KerasTPUModel(models.Model): return val_x, val_y, val_sample_weights + def predict(self, + x, + batch_size=None, + verbose=0, + steps=None, + max_queue_size=10, + workers=1, + use_multiprocessing=False): + with _tpu_session_context(): + return super(KerasTPUModel, self).predict( + x, + batch_size=batch_size, + verbose=verbose, + steps=steps, + max_queue_size=max_queue_size, + workers=workers, + use_multiprocessing=use_multiprocessing) + @property def optimizer(self): if self._tpu_model: |