aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tpu
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-25 17:27:27 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-25 17:36:50 -0700
commitc1e303ed8fa1bf11aaea16e68b14ba2f5ab5dde0 (patch)
treec35ec5ea71adf95fc180c4172f8ee4c26187c52b /tensorflow/contrib/tpu
parenta7f14807417ea78aee8ea275536902f0aaa94fd4 (diff)
Support dynamic LR for Keras optimizer by setting the global Keras session.
PiperOrigin-RevId: 214532827
Diffstat (limited to 'tensorflow/contrib/tpu')
-rw-r--r--tensorflow/contrib/tpu/python/tpu/keras_support.py564
1 files changed, 294 insertions, 270 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py
index 03e06b8142..f67e0e6aca 100644
--- a/tensorflow/contrib/tpu/python/tpu/keras_support.py
+++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py
@@ -46,12 +46,12 @@ from __future__ import print_function
import abc
import collections
-import contextlib
import re
import sys
import time
import numpy as np
+import six
from tensorflow.contrib.cluster_resolver.python.training import tpu_cluster_resolver as tpu_cluster_resolver_lib
from tensorflow.contrib.framework.python.framework import experimental
@@ -90,34 +90,34 @@ from tensorflow.python.ops import gen_linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
-_SESSIONS = {}
-
-
-def tpu_session(cluster_resolver):
+def setup_tpu_session(cluster_resolver):
"""Construct or return a `tf.Session` connected to the given cluster."""
- global _SESSIONS
master = cluster_resolver.master()
- if master not in _SESSIONS:
- cluster_spec = cluster_resolver.cluster_spec()
- config = config_pb2.ConfigProto(isolate_session_state=True)
- if cluster_spec:
- config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
- logging.info('Connecting to: %s', master)
- graph = ops.Graph()
- session = tf_session.Session(graph=graph, target=master, config=config)
- with graph.as_default():
- session.run(tpu.initialize_system())
+ # Use the existing session if we're already connected to this TPU
+ if (K.get_session()._target == master and
+ getattr(K.get_session(), '_tpu_initialized', None)):
+ return
+
+ cluster_spec = cluster_resolver.cluster_spec()
+ config = config_pb2.ConfigProto(isolate_session_state=True)
+ if cluster_spec:
+ config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
- _SESSIONS[master] = session
- return _SESSIONS[master]
+ logging.info('Initialize')
+ tpu_session = tf_session.Session(target=master, config=config)
+ tpu_session.run(tpu.initialize_system())
+ tpu_session._tpu_initialized = True
+ # N.B. We have to call `K.set_session()` AND set our session as the
+ # TF default. `K.get_session()` surprisingly does not return the value
+ # supplied by K.set_session otherwise.
+ K.set_session(tpu_session)
-def reset_tpu_sessions():
- _SESSIONS.clear()
try:
from scipy.sparse import issparse # pylint: disable=g-import-not-at-top
@@ -134,9 +134,7 @@ def get_tpu_system_metadata(tpu_cluster_resolver):
cluster_def = cluster_spec.as_cluster_def() if cluster_spec else None
tpu_system_metadata = (
tpu_system_metadata_lib._query_tpu_system_metadata(
- master,
- cluster_def=cluster_def,
- query_topology=False))
+ master, cluster_def=cluster_def, query_topology=False))
return tpu_system_metadata
@@ -157,6 +155,8 @@ class TPUDistributionStrategy(object):
replication, typically using all avaiable TPU cores. If overwrites as
`True`, force the model replication using single core, i.e., no
replication.
+ Raises:
+ Exception: No TPU Found on the given worker.
"""
if tpu_cluster_resolver is None:
@@ -172,7 +172,8 @@ class TPUDistributionStrategy(object):
for device in metadata.devices:
if 'TPU:0' in device.name:
self._worker_name = worker_re.search(device.name).group(1)
- break
+ return
+ raise Exception('No TPU found on given worker.')
def _make_assignment_for_model(self, cpu_model):
"""Makes a `TPUAssignment` for the passed in `cpu_model`."""
@@ -183,8 +184,7 @@ class TPUDistributionStrategy(object):
'Degrading to a single core.')
num_cores = 1
- return TPUAssignment(
- worker_name=self._worker_name, num_cores=num_cores)
+ return TPUAssignment(worker_name=self._worker_name, num_cores=num_cores)
class TPUAssignment(object):
@@ -280,9 +280,9 @@ class KerasCrossShardOptimizer(keras_optimizers.Optimizer):
super(KerasCrossShardOptimizer, self).__init__()
self._name = name
self._opt = opt
+ logging.info('KerasCrossShard: %s %s', self._opt, self._opt.weights)
def get_updates(self, loss, params):
- logging.info('Get updates: %s', loss)
self._opt.get_gradients = self.get_gradients
return self._opt.get_updates(loss, params)
@@ -291,17 +291,15 @@ class KerasCrossShardOptimizer(keras_optimizers.Optimizer):
grads = super(KerasCrossShardOptimizer, self).get_gradients(loss, params)
return [tpu_ops.cross_replica_sum(grad) / num_shards for grad in grads]
- def set_weights(self, weights):
- # TODO(power): Figure out whether we really need this given there is no
- # caller for this API yet.
- self._opt.set_weights()
-
def get_weights(self):
return self._opt.get_weights()
- @property
- def lr(self):
- return self._opt.lr
+ def get_config(self):
+ return self._opt.get_config()
+
+ # Defer remaining operations to the underlying optimizer
+ def __getattr__(self, key):
+ return getattr(self._opt, key)
class TPUModelOp(
@@ -327,14 +325,34 @@ def _replicated_optimizer(opt):
return KerasCrossShardOptimizer(opt)
-def clone_metrics(metrics):
+def _clone_metrics(metrics):
"""Returns a copy of metrics. A copy is created for stateful metrics."""
if metrics is None:
return None
- return [
- m.__class__.from_config(m.get_config())
- if isinstance(m, metrics_module.Metric) else m for m in metrics
- ]
+ with variable_scope.variable_scope(
+ 'metrics', reuse=variable_scope.AUTO_REUSE):
+ return [
+ m.__class__.from_config(m.get_config()) if isinstance(
+ m, metrics_module.Metric) else m for m in metrics
+ ]
+
+
+def _clone_optimizer(optimizer, config=None):
+ """Returns a cloned optimizer with the provided optimizer.config or config."""
+ if not isinstance(optimizer, keras_optimizers.Optimizer):
+ # In the first call to tpu_model(model), Keras may not have wrapped the TF
+ # optimizer in the TFOptimizer helper, e.g., the given model isn't compiled
+ # or optimizer isn't set, and later generated tpu_model compiles with a TF
+ # optimizer.
+ return optimizer
+
+ if isinstance(optimizer, keras_optimizers.TFOptimizer):
+ return keras_optimizers.TFOptimizer(optimizer.optimizer)
+
+ if config is None:
+ config = optimizer.get_config()
+ logging.info('Cloning %s %s', optimizer.__class__.__name__, config)
+ return optimizer.__class__.from_config(config)
class TPURewriteContext(object):
@@ -425,6 +443,7 @@ class TPURewriteContext(object):
return (r, q)
else:
raise ValueError('Invalid shape passed to qr: %s' % input_shape)
+
gen_linalg_ops.qr = qr
ops.name_scope = _name_scope
@@ -440,9 +459,9 @@ class TPURewriteContext(object):
gen_linalg_ops.qr = self._default_qr
-class SizedInfeed(collections.namedtuple('SizedInfeed',
- ['sharded_infeed_tensors',
- 'infeed_ops'])):
+class SizedInfeed(
+ collections.namedtuple('SizedInfeed',
+ ['sharded_infeed_tensors', 'infeed_ops'])):
"""Represents an instantiation of the infeed ops for a concrete input shape.
sharded_infeed_tensors: A data structure of Tensors used to represent the
@@ -628,12 +647,13 @@ class TPUNumpyInfeedManager(TPUInfeedManager):
infeed_tensors, [spec.shape for spec in input_specs],
name='infeed-enqueue-%s-%d' % (execution_mode, shard_id),
device_ordinal=shard_id))
- return SizedInfeed(infeed_ops=infeed_op,
- sharded_infeed_tensors=shard_infeed_tensors)
+ return SizedInfeed(
+ infeed_ops=infeed_op, sharded_infeed_tensors=shard_infeed_tensors)
class TPUDatasetInfeedManager(TPUInfeedManager):
"""Manages infeed for a `tf.data.Dataset` into a TPU computation.
+
"""
class DatasetInfeedInstance(TPUInfeedInstance):
@@ -651,16 +671,13 @@ class TPUDatasetInfeedManager(TPUInfeedManager):
return {}
# pylint: disable=redefined-outer-name
- def __init__(self, dataset, tpu_assignment, tpu_session, mode):
+ def __init__(self, dataset, tpu_assignment, mode):
"""Constructs a TPUDatasetInfeedManager.
- Must be called within a `KerasTPUModel.tpu_session` context!
-
Args:
dataset: A `tf.data.Dataset` to infeed.
tpu_assignment: The `TPUAssignment` used to configure the
Keras TPU model.
- tpu_session: The `tf.Session` object used for running the TPU model.
mode: ModeKeys enum.
"""
self._verify_dataset_shape(dataset)
@@ -672,7 +689,7 @@ class TPUDatasetInfeedManager(TPUInfeedManager):
dummy_y_shape = dataset.output_shapes[1].as_list()
dummy_y_shape[0] *= tpu_assignment.num_towers
self._iterator = dataset.make_initializable_iterator()
- tpu_session.run(self._iterator.initializer)
+ K.get_session().run(self._iterator.initializer)
self._get_next_ops = []
ctrl_deps = []
@@ -685,10 +702,10 @@ class TPUDatasetInfeedManager(TPUInfeedManager):
# Use dummy numpy inputs for the rest of Keras' shape checking. We
# intercept them when building the model.
- self._dummy_x = np.zeros(dummy_x_shape,
- dtype=dataset.output_types[0].as_numpy_dtype)
- self._dummy_y = np.zeros(dummy_y_shape,
- dtype=dataset.output_types[1].as_numpy_dtype)
+ self._dummy_x = np.zeros(
+ dummy_x_shape, dtype=dataset.output_types[0].as_numpy_dtype)
+ self._dummy_y = np.zeros(
+ dummy_y_shape, dtype=dataset.output_types[1].as_numpy_dtype)
input_specs = []
if isinstance(self._iterator.output_shapes, tuple):
@@ -719,9 +736,8 @@ class TPUDatasetInfeedManager(TPUInfeedManager):
raise ValueError('The dataset must return a tuple of tf.Tensors, '
'instead it returns: %s' % dataset.output_classes)
if len(dataset.output_classes) != 2:
- raise ValueError(
- 'The dataset must return a 2-element tuple, got '
- '%s output classes instead.' % (dataset.output_classes,))
+ raise ValueError('The dataset must return a 2-element tuple, got '
+ '%s output classes instead.' % (dataset.output_classes,))
for i, cls in enumerate(dataset.output_classes):
if cls != ops.Tensor:
raise ValueError('The dataset returned a non-Tensor type (%s) at '
@@ -730,8 +746,7 @@ class TPUDatasetInfeedManager(TPUInfeedManager):
if not shape:
raise ValueError('The dataset returns a scalar tensor in '
'tuple index %d. Did you forget to batch? '
- '(Output shapes: %s).' % (i,
- dataset.output_shapes))
+ '(Output shapes: %s).' % (i, dataset.output_shapes))
for j, dim in enumerate(shape):
if dim.value is None:
if j == 0:
@@ -771,8 +786,8 @@ class TPUDatasetInfeedManager(TPUInfeedManager):
[spec.shape for spec in input_specs],
name='infeed-enqueue-%s-%d' % (execution_mode, shard_id),
device_ordinal=shard_id))
- return SizedInfeed(infeed_ops=infeed_ops,
- sharded_infeed_tensors=shard_infeed_tensors)
+ return SizedInfeed(
+ infeed_ops=infeed_ops, sharded_infeed_tensors=shard_infeed_tensors)
def _inject_tpu_inputs_for_dataset(tpu_assignment, mode,
@@ -858,12 +873,7 @@ class TPUFunction(object):
self._tpu_assignment = tpu_assignment
self._compilation_cache = {}
self._cloned_model = None
-
- # Copy optimizer configuration. This is done prior to `_specialize_model`
- # as the configuration may require evaluating variables in the CPU session.
- self._optimizer_config = None
- if not isinstance(self.model.optimizer, keras_optimizers.TFOptimizer):
- self._optimizer_config = self.model.optimizer.get_config()
+ self._cloned_optimizer = None
def _specialize_model(self, input_specs, infeed_manager):
"""Specialize `self.model` (a Keras model) for the given input shapes."""
@@ -909,53 +919,51 @@ class TPUFunction(object):
tpu_targets.append(tensor)
# Clone our CPU model, running within the TPU device context.
+ #
+ # We use the id of the original model as a key to avoid weight collisions
+ # (if a user re-runs the same model multiple times, in e.g. Colab).
with TPURewriteContext(tpu_input_map):
- with variable_scope.variable_scope('tpu_model_%s' % id(self.model)):
+ with variable_scope.variable_scope('tpu_%s' % id(self.model)):
with keras_tpu_variables.replicated_scope(
self._tpu_assignment.num_towers):
- self._cloned_model = models.clone_model(self.model)
+ if not self._cloned_optimizer:
+ self._cloned_optimizer = _clone_optimizer(
+ self.model.cpu_optimizer)
- # When running on more than one core, concatenate outputs at the end of
- # processing. In backprop stage, the gradients will be calculdated
- # according to the local inputs as gradient of cross-replica-concat being
- # zero for any outputs other than those from mlocal core so the loss
- # calculation is identical.
- num_towers = self.model._tpu_assignment.num_towers
- if num_towers > 1 and (is_training or is_test):
- new_outputs = [
- _cross_replica_concat(
- o, core_id, num_towers, name='model output ({})'.format(o.name))
- for o in self._cloned_model.outputs
- ]
- self._cloned_model.outputs = new_outputs
- tpu_targets = [
- _cross_replica_concat(
- tensor,
- core_id,
- num_towers,
- name='model target ({})'.format(tensor.name))
- for tensor in tpu_targets
- ]
-
- # Create a copy of the optimizer for this graph.
- if isinstance(self.model.optimizer, keras_optimizers.TFOptimizer):
- cloned_optimizer = keras_optimizers.TFOptimizer(
- self.model.optimizer.optimizer)
- else:
- logging.info('Cloning %s %s', self.model.optimizer.__class__.__name__,
- self._optimizer_config)
- cloned_optimizer = self.model.optimizer.__class__.from_config(
- self._optimizer_config)
+ self._cloned_model = models.clone_model(self.model)
- if is_training or is_test:
- self._cloned_model.compile(
- optimizer=_replicated_optimizer(cloned_optimizer),
- loss=self.model.loss,
- loss_weights=self.model.loss_weights,
- metrics=clone_metrics(self.model.metrics),
- weighted_metrics=clone_metrics(self.model.weighted_metrics),
- target_tensors=tpu_targets,
- )
+ # When running on more than one core, concatenate outputs at the end
+ # of processing. In backprop stage, the gradients will be
+ # calculdated according to the local inputs as gradient of
+ # cross-replica-concat being zero for any outputs other than those
+ # from mlocal core so the loss calculation is identical.
+ num_towers = self.model._tpu_assignment.num_towers
+ if num_towers > 1 and (is_training or is_test):
+ new_outputs = [
+ _cross_replica_concat(
+ o, core_id, num_towers,
+ name='model output ({})'.format(o.name))
+ for o in self._cloned_model.outputs
+ ]
+ self._cloned_model.outputs = new_outputs
+ tpu_targets = [
+ _cross_replica_concat(
+ tensor,
+ core_id,
+ num_towers,
+ name='model target ({})'.format(tensor.name))
+ for tensor in tpu_targets
+ ]
+
+ if is_training or is_test:
+ self._cloned_model.compile(
+ optimizer=_replicated_optimizer(self._cloned_optimizer),
+ loss=self.model.loss,
+ loss_weights=self.model.loss_weights,
+ metrics=_clone_metrics(self.model.metrics),
+ weighted_metrics=_clone_metrics(self.model.weighted_metrics),
+ target_tensors=tpu_targets,
+ )
# Compute our outfeed depending on the execution mode
if is_training:
@@ -1089,15 +1097,14 @@ class TPUFunction(object):
# unique input shape.
shape_key = tuple([tuple(spec.shape.as_list()) for spec in input_specs])
if shape_key not in self._compilation_cache:
- with self.model.tpu_session():
- logging.info(
- 'New input shapes; (re-)compiling: mode=%s '
- '(# of cores %d), %s', self.execution_mode,
- self._tpu_assignment.num_towers, input_specs)
- new_tpu_model_ops = self._specialize_model(input_specs,
- infeed_manager)
- self._compilation_cache[shape_key] = new_tpu_model_ops
- self._test_model_compiles(new_tpu_model_ops)
+ logging.info(
+ 'New input shapes; (re-)compiling: mode=%s '
+ '(# of cores %d), %s', self.execution_mode,
+ self._tpu_assignment.num_towers, input_specs)
+ new_tpu_model_ops = self._specialize_model(input_specs,
+ infeed_manager)
+ self._compilation_cache[shape_key] = new_tpu_model_ops
+ self._test_model_compiles(new_tpu_model_ops)
return self._compilation_cache[shape_key]
@@ -1195,11 +1202,10 @@ class TPUFunction(object):
# Initialize our TPU weights on the first compile.
self.model._initialize_weights(self._cloned_model)
- with self.model.tpu_session() as session:
- _, _, outfeed_outputs = session.run([
- tpu_model_ops.infeed_op, tpu_model_ops.execute_op,
- tpu_model_ops.outfeed_op
- ], infeed_dict)
+ _, _, outfeed_outputs = K.get_session().run([
+ tpu_model_ops.infeed_op, tpu_model_ops.execute_op,
+ tpu_model_ops.outfeed_op
+ ], infeed_dict)
return self._process_outputs(outfeed_outputs)
def pipeline_run(self, cur_step_inputs, next_step_inputs):
@@ -1231,8 +1237,8 @@ class TPUFunction(object):
next_step_infeed_manager = self._lookup_infeed_manager(next_step_inputs)
cur_step_infeed_manager = self._lookup_infeed_manager(cur_step_inputs)
- if (next_step_infeed_manager is not None
- and cur_step_infeed_manager is not None):
+ if (next_step_infeed_manager is not None and
+ cur_step_infeed_manager is not None):
assert type(next_step_infeed_manager) is type(cur_step_infeed_manager)
next_input_tensors, next_step_inputs = (
@@ -1257,14 +1263,12 @@ class TPUFunction(object):
infeed_dict = None
if cur_infeed_instance and cur_input_tensors and cur_step_infeed_manager:
- cur_input_specs = cur_infeed_instance.make_input_specs(
- cur_input_tensors)
+ cur_input_specs = cur_infeed_instance.make_input_specs(cur_input_tensors)
cur_tpu_model_ops = self._tpu_model_ops_for_input_specs(
cur_input_specs, cur_step_infeed_manager)
- if (next_infeed_instance
- and next_input_tensors
- and next_step_infeed_manager):
+ if (next_infeed_instance and next_input_tensors and
+ next_step_infeed_manager):
next_input_specs = next_infeed_instance.make_input_specs(
next_input_tensors)
next_tpu_model_ops = self._tpu_model_ops_for_input_specs(
@@ -1275,26 +1279,24 @@ class TPUFunction(object):
self.model._initialize_weights(self._cloned_model)
if next_tpu_model_ops and cur_tpu_model_ops:
- with self.model.tpu_session() as session:
- _, _, outfeed_outputs = session.run([
- next_tpu_model_ops.infeed_op, cur_tpu_model_ops.execute_op,
- cur_tpu_model_ops.outfeed_op
- ], infeed_dict)
+ _, _, outfeed_outputs = K.get_session().run([
+ next_tpu_model_ops.infeed_op, cur_tpu_model_ops.execute_op,
+ cur_tpu_model_ops.outfeed_op
+ ], infeed_dict)
return self._process_outputs(outfeed_outputs)
+
if cur_tpu_model_ops:
- with self.model.tpu_session() as session:
- _, outfeed_outputs = session.run([
- cur_tpu_model_ops.execute_op, cur_tpu_model_ops.outfeed_op])
+ _, outfeed_outputs = K.get_session().run(
+ [cur_tpu_model_ops.execute_op, cur_tpu_model_ops.outfeed_op])
return self._process_outputs(outfeed_outputs)
+
if next_tpu_model_ops:
- with self.model.tpu_session() as session:
- session.run(next_tpu_model_ops.infeed_op, infeed_dict)
+ K.get_session().run(next_tpu_model_ops.infeed_op, infeed_dict)
return None
raise RuntimeError('Internal error: both current & next tpu_model_ops '
'were None')
-
class KerasTPUModel(models.Model):
"""TPU compatible Keras model wrapper."""
@@ -1321,8 +1323,6 @@ class KerasTPUModel(models.Model):
self._tpu_model = None
self._tpu_weights_initialized = False
- self._session = tpu_session(cluster_resolver)
-
# If the input CPU model has already been compiled, compile our TPU model
# immediately.
if self._cpu_model.optimizer:
@@ -1359,15 +1359,20 @@ class KerasTPUModel(models.Model):
if target_tensors:
raise ValueError('target_tensors is not supported for TPU execution.')
+ self._cpu_model.compile(
+ _clone_optimizer(optimizer),
+ loss,
+ _clone_metrics(metrics),
+ loss_weights,
+ sample_weight_mode,
+ _clone_metrics(weighted_metrics),
+ target_tensors,
+ **kwargs)
+
super(KerasTPUModel, self).compile(optimizer, loss, metrics, loss_weights,
sample_weight_mode, weighted_metrics,
target_tensors, **kwargs)
- if not self._cpu_model.optimizer:
- self._cpu_model.compile(optimizer, loss, metrics, loss_weights,
- sample_weight_mode, weighted_metrics,
- target_tensors, **kwargs)
-
def fit(self,
x=None,
y=None,
@@ -1400,8 +1405,8 @@ class KerasTPUModel(models.Model):
'https://github.com/tensorflow/tpu/tree/master/models/experimental'
'/keras')
if callable(x):
- with self.tpu_session() as sess,\
- ops.device('/job:%s/device:CPU:0' % self._tpu_assignment.worker_name):
+ with ops.device('/job:%s/device:CPU:0' %
+ self._tpu_assignment.worker_name):
dataset = x()
if steps_per_epoch is None:
raise ValueError('When using tf.data as input to a model, you '
@@ -1410,7 +1415,7 @@ class KerasTPUModel(models.Model):
raise ValueError('When using tf.data as input to a model, y must be '
'None')
infeed_manager = TPUDatasetInfeedManager(
- dataset, self._tpu_assignment, sess, model_fn_lib.ModeKeys.TRAIN)
+ dataset, self._tpu_assignment, model_fn_lib.ModeKeys.TRAIN)
# Use dummy numpy inputs for the rest of Keras' shape checking. We
# intercept them when building the model.
x = infeed_manager.dummy_x
@@ -1426,26 +1431,24 @@ class KerasTPUModel(models.Model):
'https://github.com/tensorflow/tpu/tree/master/models/experimental'
'/keras')
if callable(validation_data):
- with self.tpu_session() as sess:
- dataset = validation_data()
- if validation_steps is None:
- raise ValueError('When using tf.data as validation for a model, you '
- 'should specify the validation_steps argument.')
- infeed_manager = TPUDatasetInfeedManager(
- dataset, self._tpu_assignment, sess, model_fn_lib.ModeKeys.EVAL)
- # Use dummy numpy inputs for the rest of Keras' shape checking. We
- # intercept them when building the model.
- val_x = infeed_manager.dummy_x
- val_y = infeed_manager.dummy_y
- infeed_managers.append((val_x, infeed_manager))
- validation_data = (val_x, val_y)
+ dataset = validation_data()
+ if validation_steps is None:
+ raise ValueError('When using tf.data as validation for a model, you '
+ 'should specify the validation_steps argument.')
+ infeed_manager = TPUDatasetInfeedManager(
+ dataset, self._tpu_assignment, model_fn_lib.ModeKeys.EVAL)
+ # Use dummy numpy inputs for the rest of Keras' shape checking. We
+ # intercept them when building the model.
+ val_x = infeed_manager.dummy_x
+ val_y = infeed_manager.dummy_y
+ infeed_managers.append((val_x, infeed_manager))
+ validation_data = (val_x, val_y)
self._numpy_to_infeed_manager_list = infeed_managers
try:
if not kwargs.get('_pipeline', True):
- logging.info(
- 'Running non-pipelined training loop (`_pipeline=%s`).',
- kwargs['_pipeline'])
+ logging.info('Running non-pipelined training loop (`_pipeline=%s`).',
+ kwargs['_pipeline'])
kwargs.pop('_pipeline')
return super(KerasTPUModel, self).fit(
x,
@@ -1501,50 +1504,32 @@ class KerasTPUModel(models.Model):
'https://github.com/tensorflow/tpu/tree/master/models/experimental'
'/keras')
if callable(x):
- with self.tpu_session() as sess:
- dataset = x()
- if steps is None:
- raise ValueError('When using tf.data as input to a model, you '
- 'should specify the steps argument.')
- if y is not None:
- raise ValueError('When using tf.data as input to a model, y must be '
- 'None')
- infeed_manager = TPUDatasetInfeedManager(
- dataset, self._tpu_assignment, sess, model_fn_lib.ModeKeys.EVAL)
- # Use dummy numpy inputs for the rest of Keras' shape checking. We
- # intercept them when building the model.
- x = infeed_manager.dummy_x
- y = infeed_manager.dummy_y
- infeed_managers.append((x, infeed_manager))
+ dataset = x()
+ if steps is None:
+ raise ValueError('When using tf.data as input to a model, you '
+ 'should specify the steps argument.')
+ if y is not None:
+ raise ValueError('When using tf.data as input to a model, y must be '
+ 'None')
+ infeed_manager = TPUDatasetInfeedManager(
+ dataset, self._tpu_assignment, model_fn_lib.ModeKeys.EVAL)
+ # Use dummy numpy inputs for the rest of Keras' shape checking. We
+ # intercept them when building the model.
+ x = infeed_manager.dummy_x
+ y = infeed_manager.dummy_y
+ infeed_managers.append((x, infeed_manager))
self._numpy_to_infeed_manager_list = infeed_managers
try:
- return super(KerasTPUModel, self).evaluate(
- x,
- y,
- batch_size,
- verbose,
- sample_weight,
- steps)
+ return super(KerasTPUModel, self).evaluate(x, y, batch_size, verbose,
+ sample_weight, steps)
finally:
self._numpy_to_infeed_manager_list = []
- def _pipeline_fit(self,
- x,
- y,
- batch_size,
- epochs,
- verbose,
- callbacks,
- validation_split,
- validation_data,
- shuffle,
- class_weight,
- sample_weight,
- initial_epoch,
- steps_per_epoch,
- validation_steps,
- **kwargs):
+ def _pipeline_fit(self, x, y, batch_size, epochs, verbose, callbacks,
+ validation_split, validation_data, shuffle, class_weight,
+ sample_weight, initial_epoch, steps_per_epoch,
+ validation_steps, **kwargs):
# Similar to super.fit(...), but modified to support software pipelining.
# Backwards compatibility
@@ -1572,13 +1557,8 @@ class KerasTPUModel(models.Model):
# Prepare validation data
val_x, val_y, val_sample_weights = self._prepare_validation_data(
- validation_data,
- validation_split,
- validation_steps,
- x,
- y,
- sample_weights,
- batch_size)
+ validation_data, validation_split, validation_steps, x, y,
+ sample_weights, batch_size)
return self._pipeline_fit_loop(
x,
y,
@@ -1751,8 +1731,8 @@ class KerasTPUModel(models.Model):
for i in indices_for_conversion_to_dense:
ins_batch[i] = ins_batch[i].toarray()
- outs = f.pipeline_run(cur_step_inputs=ins_last_batch,
- next_step_inputs=ins_batch)
+ outs = f.pipeline_run(
+ cur_step_inputs=ins_last_batch, next_step_inputs=ins_batch)
ins_last_batch = ins_batch
if batch_index == 0:
@@ -1824,8 +1804,8 @@ class KerasTPUModel(models.Model):
next_step_inputs = ins
else:
next_step_inputs = None
- outs = f.pipeline_run(cur_step_inputs=ins,
- next_step_inputs=next_step_inputs)
+ outs = f.pipeline_run(
+ cur_step_inputs=ins, next_step_inputs=next_step_inputs)
except errors.OutOfRangeError:
logging.warning('Your dataset iterator ran out of data; '
'interrupting training. Make sure that your '
@@ -1845,25 +1825,21 @@ class KerasTPUModel(models.Model):
break
if do_validation:
- val_outs = training_arrays.test_loop(self,
- val_inputs,
- val_targets,
- sample_weights=val_sample_weights,
- steps=validation_steps,
- verbose=0)
+ val_outs = training_arrays.test_loop(
+ self,
+ val_inputs,
+ val_targets,
+ sample_weights=val_sample_weights,
+ steps=validation_steps,
+ verbose=0)
if not isinstance(val_outs, list):
val_outs = [val_outs]
# Same labels assumed.
for l, o in zip(self.metrics_names, val_outs):
epoch_logs['val_' + l] = o
- def _prepare_validation_data(self,
- validation_data,
- validation_split,
- validation_steps,
- x,
- y,
- sample_weights,
+ def _prepare_validation_data(self, validation_data, validation_split,
+ validation_steps, x, y, sample_weights,
batch_size):
"""Prepares the validation dataset.
@@ -1921,8 +1897,10 @@ class KerasTPUModel(models.Model):
x, val_x = (slice_arrays(x, 0, split_at), slice_arrays(x, split_at))
y, val_y = (slice_arrays(y, 0, split_at), slice_arrays(y, split_at))
- sample_weights, val_sample_weights = (slice_arrays(
- sample_weights, 0, split_at), slice_arrays(sample_weights, split_at))
+ sample_weights, val_sample_weights = (
+ slice_arrays(sample_weights, 0, split_at),
+ slice_arrays(sample_weights, split_at)
+ )
elif validation_steps:
val_x = []
val_y = []
@@ -1934,11 +1912,20 @@ class KerasTPUModel(models.Model):
return val_x, val_y, val_sample_weights
+ @property
+ def optimizer(self):
+ if self._tpu_model:
+ return self._tpu_model.optimizer
+ return self._cpu_model.optimizer
+
+ @optimizer.setter
+ def optimizer(self, optimizer):
+ self._optimizer = optimizer
+
def _make_train_function(self):
if not self.train_function:
self.train_function = TPUFunction(
- self,
- model_fn_lib.ModeKeys.TRAIN,
+ self, model_fn_lib.ModeKeys.TRAIN,
tpu_assignment=self._tpu_assignment)
return self.train_function
@@ -1973,18 +1960,48 @@ class KerasTPUModel(models.Model):
self._tpu_weights_initialized = True
weights = self._cpu_model.get_weights()
- with self.tpu_session():
- logging.info('Setting weights on TPU model.')
- cloned_model.set_weights(weights)
+
+ if isinstance(self.cpu_optimizer, keras_optimizers.TFOptimizer):
+ cpu_optimizer_config = {}
+ else:
+ cpu_optimizer_config = self.cpu_optimizer.get_config()
+
+ logging.info('Setting weights on TPU model.')
+ cloned_model.set_weights(weights)
+ for k, v in six.iteritems(cpu_optimizer_config):
+ opt_var = getattr(self._tpu_model.optimizer, k)
+ if isinstance(opt_var, variables.Variable):
+ logging.info('CPU -> TPU %s: %s {%s}', k, v, K.get_value(opt_var))
+ K.get_session().run(opt_var.assign(v))
+ else:
+ logging.warning('Cannot update non-variable config: %s', k)
+
+ @property
+ def cpu_optimizer(self):
+ return self._cpu_model.optimizer
def sync_to_cpu(self):
"""Copy weights from the CPU, returning a synchronized CPU model."""
- if self._tpu_weights_initialized:
- with self.tpu_session():
- logging.info('Copying TPU weights to the CPU')
- tpu_weights = self._tpu_model.get_weights()
+ if not self._tpu_weights_initialized:
+ return self._cpu_model
- self._cpu_model.set_weights(tpu_weights)
+ logging.info('Copying TPU weights to the CPU')
+ tpu_weights = self._tpu_model.get_weights()
+
+ # TFOptimizers have no configurable options
+ if isinstance(self.cpu_optimizer, keras_optimizers.TFOptimizer):
+ tpu_optimizer_config = {}
+ else:
+ tpu_optimizer_config = self._tpu_model.optimizer.get_config()
+
+ self._cpu_model.set_weights(tpu_weights)
+ for k, v in six.iteritems(tpu_optimizer_config):
+ logging.info('TPU -> CPU %s: %s', k, v)
+ opt_var = getattr(self.cpu_optimizer, k)
+ if isinstance(opt_var, variables.Variable):
+ K.get_session().run(opt_var.assign(v))
+ else:
+ logging.warning('Cannot update non-variable config: %s', k)
return self._cpu_model
@@ -2005,26 +2022,6 @@ class KerasTPUModel(models.Model):
self._cpu_model.set_weights(weights)
self._tpu_weights_initialized = False
- @contextlib.contextmanager
- def tpu_session(self):
- """Yields a TPU session and sets it as the default Keras session."""
- with self._session.graph.as_default():
- default_session = K.get_session()
- # N.B. We have to call `K.set_session()` AND set our session as the
- # TF default. `K.get_session()` surprisingly does not return the value
- # supplied by K.set_session otherwise.
- K.set_session(self._session)
- with self._session.as_default():
- yield self._session
- K.set_session(default_session)
-
- def shutdown(self):
- # TODO(b/111364423): Actually shut down the system.
- logging.info('Skipping shutting down TPU system.')
- # with self.tpu_session() as session:
- # session.run(tpu.shutdown_system())
- self._session.close()
-
# pylint: disable=bad-continuation
def _validate_shapes(model):
@@ -2065,7 +2062,9 @@ Output shape: %(output_shape)s
@experimental
def tpu_model(model, strategy=None):
- """Copy `model` along with weights to the TPU. Returns a TPU model.
+ """Copy `model` along with weights to the TPU.
+
+ Returns a TPU model.
Usage:
```
@@ -2080,21 +2079,16 @@ def tpu_model(model, strategy=None):
model.compile(
optimizer=tf.train.GradientDescentOptimizer(learning_rate=1.0),
...)
- model.shutdown()
```
Args:
- model: A `KerasTPUModel`.
+ model: A `tf.keras.Model` instance.
strategy: `TPUDistributionStrategy`. The strategy to use for replicating
- model across multiple TPU cores.
+ model across multiple TPU cores.
Returns:
A new `KerasTPUModel` instance.
"""
- # Force initialization of the CPU model.
- model.get_weights()
- model.reset_states()
-
_validate_shapes(model)
# TODO(xiejw): Validate TPU model. TPUModel only?
# TODO(xiejw): Validate replicas. Full or 1. Shall we allow subset?
@@ -2108,4 +2102,34 @@ def tpu_model(model, strategy=None):
'`strategy` must have type `tf.contrib.tpu.TPUDistributionStrategy`. '
'Got: {}'.format(type(strategy)))
- return KerasTPUModel(cpu_model=model, strategy=strategy)
+ # If the model has already been initialized, grab the optimizer configuration
+ # and model weights before entering the TPU session.
+ if model.optimizer:
+ if (isinstance(model.optimizer, keras_optimizers.Optimizer) and not
+ isinstance(model.optimizer, keras_optimizers.TFOptimizer)):
+ optimizer_config = model.optimizer.get_config()
+ else:
+ optimizer_config = None
+ model_weights = model.get_weights()
+ else:
+ model_weights = None
+
+ setup_tpu_session(strategy._tpu_cluster_resolver)
+
+ # Force initialization of the CPU model in the TPU session.
+ cpu_model = models.clone_model(model)
+ if model.optimizer:
+ cpu_model.compile(
+ _clone_optimizer(model.optimizer, optimizer_config),
+ model.loss,
+ _clone_metrics(model.metrics),
+ model.loss_weights,
+ model.sample_weight_mode,
+ _clone_metrics(model.weighted_metrics),
+ )
+
+ if model_weights:
+ cpu_model.set_weights(model_weights)
+ cpu_model.reset_states()
+
+ return KerasTPUModel(cpu_model=cpu_model, strategy=strategy)