aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tpu
diff options
context:
space:
mode:
authorGravatar Russell Power <power@google.com>2018-10-01 17:45:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-01 17:51:41 -0700
commitbeede8525be5386451bf0098992c37416d1864db (patch)
treef44f87836607482f89e779be5bd6f730d0605947 /tensorflow/contrib/tpu
parent16d0079efd10eb1c5ac09522e01cff7ecdcbdfd5 (diff)
Make Keras/TPU more robust to closed TF sessions.
PiperOrigin-RevId: 215313156
Diffstat (limited to 'tensorflow/contrib/tpu')
-rw-r--r--tensorflow/contrib/tpu/python/tpu/keras_support.py278
1 files changed, 155 insertions, 123 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py
index 696656e840..a3a7fd8bb0 100644
--- a/tensorflow/contrib/tpu/python/tpu/keras_support.py
+++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py
@@ -46,6 +46,7 @@ from __future__ import print_function
import abc
import collections
+import contextlib
import re
import sys
import time
@@ -94,21 +95,56 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
+# TODO(b/114775106): temporary shim to optionally initialize the TPU
+# This increases the odds our session is initialized, but shouldn't be needed.
+def _maybe_initialize_tpu(session):
+ """Initialize the TPU if it has not already been initialized."""
+ try:
+
+ def test_op():
+ return constant_op.constant(1) + constant_op.constant(1)
+
+ session.run(tpu.rewrite(test_op))
+ except errors.FailedPreconditionError as _:
+ session.run(tpu.initialize_system())
+
+
+@contextlib.contextmanager
+def _tpu_session_context():
+ """Initialize the TPU and cleans cache entries for bad sessions."""
+ try:
+ _maybe_initialize_tpu(K.get_session())
+ yield
+ except (errors.FailedPreconditionError, errors.AbortedError) as e:
+ K.clear_session()
+ raise Exception("""
+An error occurred connecting or initializing your TPU.
+
+The session has been reset. re-run keras_to_tpu_model to create a new session.
+""" + e)
+
+
def setup_tpu_session(cluster_resolver):
"""Construct or return a `tf.Session` connected to the given cluster."""
master = cluster_resolver.master()
# Use the existing session if we're already connected to this TPU
- if (K.get_session()._target == master and
- getattr(K.get_session(), '_tpu_initialized', None)):
- return
+ # N.B K.get_session() is a non-trivial operation, and may fail if the remote
+ # session has been reset.
+ try:
+ default_session = K.get_session()
+ if (default_session._target == master and
+ getattr(default_session, '_tpu_initialized', None)):
+ return
+ except errors.AbortedError as _:
+ # We lost the remote session and need to re-initialize.
+ logging.warning('Lost remote session: creating a new session.')
cluster_spec = cluster_resolver.cluster_spec()
config = config_pb2.ConfigProto(isolate_session_state=True)
if cluster_spec:
config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
- logging.info('Initialize')
tpu_session = tf_session.Session(target=master, config=config)
tpu_session.run(tpu.initialize_system())
tpu_session._tpu_initialized = True
@@ -1391,97 +1427,74 @@ class KerasTPUModel(models.Model):
raise EnvironmentError('KerasTPUModel currently does not support eager '
'mode.')
- assert not self._numpy_to_infeed_manager_list # Ensure empty.
-
- infeed_managers = [] # Managers to clean up at the end of the fit call.
- if isinstance(x, dataset_ops.Dataset):
- # TODO(b/111413240): Support taking a tf.data.Dataset directly.
- raise ValueError(
- 'Taking a Dataset directly is not yet supported. Please '
- 'wrap your dataset construction code in a function and '
- 'pass that to fit instead. For examples, see: '
- 'https://github.com/tensorflow/tpu/tree/master/models/experimental'
- '/keras')
- if callable(x):
- with ops.device('/job:%s/device:CPU:0' %
- self._tpu_assignment.worker_name):
- dataset = x()
- if steps_per_epoch is None:
- raise ValueError('When using tf.data as input to a model, you '
- 'should specify the steps_per_epoch argument.')
- if y is not None:
- raise ValueError('When using tf.data as input to a model, y must be '
- 'None')
- infeed_manager = TPUDatasetInfeedManager(
- dataset, self._tpu_assignment, model_fn_lib.ModeKeys.TRAIN)
+ with _tpu_session_context():
+ assert not self._numpy_to_infeed_manager_list # Ensure empty.
+
+ infeed_managers = [] # Managers to clean up at the end of the fit call.
+ if isinstance(x, dataset_ops.Dataset):
+ # TODO(b/111413240): Support taking a tf.data.Dataset directly.
+ raise ValueError(
+ 'Taking a Dataset directly is not yet supported. Please '
+ 'wrap your dataset construction code in a function and '
+ 'pass that to fit instead. For examples, see: '
+ 'https://github.com/tensorflow/tpu/tree/master/models/experimental'
+ '/keras')
+ if callable(x):
+ with ops.device(
+ '/job:%s/device:CPU:0' % self._tpu_assignment.worker_name):
+ dataset = x()
+ if steps_per_epoch is None:
+ raise ValueError('When using tf.data as input to a model, you '
+ 'should specify the steps_per_epoch argument.')
+ if y is not None:
+ raise ValueError('When using tf.data as input to a model, y must '
+ 'be None')
+ infeed_manager = TPUDatasetInfeedManager(
+ dataset, self._tpu_assignment, model_fn_lib.ModeKeys.TRAIN)
+ # Use dummy numpy inputs for the rest of Keras' shape checking. We
+ # intercept them when building the model.
+ x = infeed_manager.dummy_x
+ y = infeed_manager.dummy_y
+ infeed_managers.append((x, infeed_manager))
+
+ if isinstance(validation_data, dataset_ops.Dataset):
+ # TODO(b/111413240): Support taking a tf.data.Dataset directly.
+ raise ValueError(
+ 'Taking a Dataset directly is not yet supported. Please '
+ 'wrap your dataset construction code in a function and '
+ 'pass that to fit instead. For examples, see: '
+ 'https://github.com/tensorflow/tpu/tree/master/models/experimental'
+ '/keras')
+ if callable(validation_data):
+ dataset = validation_data()
+ if validation_steps is None:
+ raise ValueError('When using tf.data as validation for a model, you '
+ 'should specify the validation_steps argument.')
+ infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment,
+ model_fn_lib.ModeKeys.EVAL)
# Use dummy numpy inputs for the rest of Keras' shape checking. We
# intercept them when building the model.
- x = infeed_manager.dummy_x
- y = infeed_manager.dummy_y
- infeed_managers.append((x, infeed_manager))
+ val_x = infeed_manager.dummy_x
+ val_y = infeed_manager.dummy_y
+ infeed_managers.append((val_x, infeed_manager))
+ validation_data = (val_x, val_y)
- if isinstance(validation_data, dataset_ops.Dataset):
- # TODO(b/111413240): Support taking a tf.data.Dataset directly.
- raise ValueError(
- 'Taking a Dataset directly is not yet supported. Please '
- 'wrap your dataset construction code in a function and '
- 'pass that to fit instead. For examples, see: '
- 'https://github.com/tensorflow/tpu/tree/master/models/experimental'
- '/keras')
- if callable(validation_data):
- dataset = validation_data()
- if validation_steps is None:
- raise ValueError('When using tf.data as validation for a model, you '
- 'should specify the validation_steps argument.')
- infeed_manager = TPUDatasetInfeedManager(
- dataset, self._tpu_assignment, model_fn_lib.ModeKeys.EVAL)
- # Use dummy numpy inputs for the rest of Keras' shape checking. We
- # intercept them when building the model.
- val_x = infeed_manager.dummy_x
- val_y = infeed_manager.dummy_y
- infeed_managers.append((val_x, infeed_manager))
- validation_data = (val_x, val_y)
-
- self._numpy_to_infeed_manager_list = infeed_managers
- try:
- if not kwargs.get('_pipeline', True):
- logging.info('Running non-pipelined training loop (`_pipeline=%s`).',
- kwargs['_pipeline'])
- kwargs.pop('_pipeline')
- return super(KerasTPUModel, self).fit(
- x,
- y,
- batch_size,
- epochs,
- verbose,
- callbacks,
- validation_split,
- validation_data,
- shuffle,
- class_weight,
- sample_weight,
- initial_epoch,
- steps_per_epoch,
- validation_steps,
- **kwargs)
- return self._pipeline_fit(
- x,
- y,
- batch_size,
- epochs,
- verbose,
- callbacks,
- validation_split,
- validation_data,
- shuffle,
- class_weight,
- sample_weight,
- initial_epoch,
- steps_per_epoch,
- validation_steps,
- **kwargs)
- finally:
- self._numpy_to_infeed_manager_list = []
+ self._numpy_to_infeed_manager_list = infeed_managers
+ try:
+ if not kwargs.get('_pipeline', True):
+ logging.info('Running non-pipelined training loop (`_pipeline=%s`).',
+ kwargs['_pipeline'])
+ kwargs.pop('_pipeline')
+ return super(KerasTPUModel, self).fit(
+ x, y, batch_size, epochs, verbose, callbacks, validation_split,
+ validation_data, shuffle, class_weight, sample_weight,
+ initial_epoch, steps_per_epoch, validation_steps, **kwargs)
+ return self._pipeline_fit(x, y, batch_size, epochs, verbose, callbacks,
+ validation_split, validation_data, shuffle,
+ class_weight, sample_weight, initial_epoch,
+ steps_per_epoch, validation_steps, **kwargs)
+ finally:
+ self._numpy_to_infeed_manager_list = []
def evaluate(self,
x=None,
@@ -1492,37 +1505,38 @@ class KerasTPUModel(models.Model):
steps=None):
assert not self._numpy_to_infeed_manager_list # Ensure empty.
- infeed_managers = [] # Managers to clean up at the end of the fit call.
- if isinstance(x, dataset_ops.Dataset):
- # TODO(b/111413240): Support taking a tf.data.Dataset directly.
- raise ValueError(
- 'Taking a Dataset directly is not yet supported. Please '
- 'wrap your dataset construction code in a function and '
- 'pass that to fit instead. For examples, see: '
- 'https://github.com/tensorflow/tpu/tree/master/models/experimental'
- '/keras')
- if callable(x):
- dataset = x()
- if steps is None:
- raise ValueError('When using tf.data as input to a model, you '
- 'should specify the steps argument.')
- if y is not None:
- raise ValueError('When using tf.data as input to a model, y must be '
- 'None')
- infeed_manager = TPUDatasetInfeedManager(
- dataset, self._tpu_assignment, model_fn_lib.ModeKeys.EVAL)
- # Use dummy numpy inputs for the rest of Keras' shape checking. We
- # intercept them when building the model.
- x = infeed_manager.dummy_x
- y = infeed_manager.dummy_y
- infeed_managers.append((x, infeed_manager))
-
- self._numpy_to_infeed_manager_list = infeed_managers
- try:
- return super(KerasTPUModel, self).evaluate(x, y, batch_size, verbose,
- sample_weight, steps)
- finally:
- self._numpy_to_infeed_manager_list = []
+ with _tpu_session_context():
+ infeed_managers = [] # Managers to clean up at the end of the fit call.
+ if isinstance(x, dataset_ops.Dataset):
+ # TODO(b/111413240): Support taking a tf.data.Dataset directly.
+ raise ValueError(
+ 'Taking a Dataset directly is not yet supported. Please '
+ 'wrap your dataset construction code in a function and '
+ 'pass that to fit instead. For examples, see: '
+ 'https://github.com/tensorflow/tpu/tree/master/models/experimental'
+ '/keras')
+ if callable(x):
+ dataset = x()
+ if steps is None:
+ raise ValueError('When using tf.data as input to a model, you '
+ 'should specify the steps argument.')
+ if y is not None:
+ raise ValueError('When using tf.data as input to a model, y must be '
+ 'None')
+ infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment,
+ model_fn_lib.ModeKeys.EVAL)
+ # Use dummy numpy inputs for the rest of Keras' shape checking. We
+ # intercept them when building the model.
+ x = infeed_manager.dummy_x
+ y = infeed_manager.dummy_y
+ infeed_managers.append((x, infeed_manager))
+
+ self._numpy_to_infeed_manager_list = infeed_managers
+ try:
+ return super(KerasTPUModel, self).evaluate(x, y, batch_size, verbose,
+ sample_weight, steps)
+ finally:
+ self._numpy_to_infeed_manager_list = []
def _pipeline_fit(self, x, y, batch_size, epochs, verbose, callbacks,
validation_split, validation_data, shuffle, class_weight,
@@ -1910,6 +1924,24 @@ class KerasTPUModel(models.Model):
return val_x, val_y, val_sample_weights
+ def predict(self,
+ x,
+ batch_size=None,
+ verbose=0,
+ steps=None,
+ max_queue_size=10,
+ workers=1,
+ use_multiprocessing=False):
+ with _tpu_session_context():
+ return super(KerasTPUModel, self).predict(
+ x,
+ batch_size=batch_size,
+ verbose=verbose,
+ steps=steps,
+ max_queue_size=max_queue_size,
+ workers=workers,
+ use_multiprocessing=use_multiprocessing)
+
@property
def optimizer(self):
if self._tpu_model: