diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-25 17:27:27 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-25 17:36:50 -0700 |
commit | c1e303ed8fa1bf11aaea16e68b14ba2f5ab5dde0 (patch) | |
tree | c35ec5ea71adf95fc180c4172f8ee4c26187c52b /tensorflow/contrib/tpu | |
parent | a7f14807417ea78aee8ea275536902f0aaa94fd4 (diff) |
Support dynamic LR for Keras optimizer by setting the global Keras session.
PiperOrigin-RevId: 214532827
Diffstat (limited to 'tensorflow/contrib/tpu')
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/keras_support.py | 564 |
1 files changed, 294 insertions, 270 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py index 03e06b8142..f67e0e6aca 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_support.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py @@ -46,12 +46,12 @@ from __future__ import print_function import abc import collections -import contextlib import re import sys import time import numpy as np +import six from tensorflow.contrib.cluster_resolver.python.training import tpu_cluster_resolver as tpu_cluster_resolver_lib from tensorflow.contrib.framework.python.framework import experimental @@ -90,34 +90,34 @@ from tensorflow.python.ops import gen_linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging -_SESSIONS = {} - - -def tpu_session(cluster_resolver): +def setup_tpu_session(cluster_resolver): """Construct or return a `tf.Session` connected to the given cluster.""" - global _SESSIONS master = cluster_resolver.master() - if master not in _SESSIONS: - cluster_spec = cluster_resolver.cluster_spec() - config = config_pb2.ConfigProto(isolate_session_state=True) - if cluster_spec: - config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) - logging.info('Connecting to: %s', master) - graph = ops.Graph() - session = tf_session.Session(graph=graph, target=master, config=config) - with graph.as_default(): - session.run(tpu.initialize_system()) + # Use the existing session if we're already connected to this TPU + if (K.get_session()._target == master and + getattr(K.get_session(), '_tpu_initialized', None)): + return + + cluster_spec = cluster_resolver.cluster_spec() + config = config_pb2.ConfigProto(isolate_session_state=True) + if cluster_spec: + config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) - _SESSIONS[master] = session - return _SESSIONS[master] + logging.info('Initialize') + tpu_session = tf_session.Session(target=master, config=config) + tpu_session.run(tpu.initialize_system()) + tpu_session._tpu_initialized = True + # N.B. We have to call `K.set_session()` AND set our session as the + # TF default. `K.get_session()` surprisingly does not return the value + # supplied by K.set_session otherwise. + K.set_session(tpu_session) -def reset_tpu_sessions(): - _SESSIONS.clear() try: from scipy.sparse import issparse # pylint: disable=g-import-not-at-top @@ -134,9 +134,7 @@ def get_tpu_system_metadata(tpu_cluster_resolver): cluster_def = cluster_spec.as_cluster_def() if cluster_spec else None tpu_system_metadata = ( tpu_system_metadata_lib._query_tpu_system_metadata( - master, - cluster_def=cluster_def, - query_topology=False)) + master, cluster_def=cluster_def, query_topology=False)) return tpu_system_metadata @@ -157,6 +155,8 @@ class TPUDistributionStrategy(object): replication, typically using all avaiable TPU cores. If overwrites as `True`, force the model replication using single core, i.e., no replication. + Raises: + Exception: No TPU Found on the given worker. """ if tpu_cluster_resolver is None: @@ -172,7 +172,8 @@ class TPUDistributionStrategy(object): for device in metadata.devices: if 'TPU:0' in device.name: self._worker_name = worker_re.search(device.name).group(1) - break + return + raise Exception('No TPU found on given worker.') def _make_assignment_for_model(self, cpu_model): """Makes a `TPUAssignment` for the passed in `cpu_model`.""" @@ -183,8 +184,7 @@ class TPUDistributionStrategy(object): 'Degrading to a single core.') num_cores = 1 - return TPUAssignment( - worker_name=self._worker_name, num_cores=num_cores) + return TPUAssignment(worker_name=self._worker_name, num_cores=num_cores) class TPUAssignment(object): @@ -280,9 +280,9 @@ class KerasCrossShardOptimizer(keras_optimizers.Optimizer): super(KerasCrossShardOptimizer, self).__init__() self._name = name self._opt = opt + logging.info('KerasCrossShard: %s %s', self._opt, self._opt.weights) def get_updates(self, loss, params): - logging.info('Get updates: %s', loss) self._opt.get_gradients = self.get_gradients return self._opt.get_updates(loss, params) @@ -291,17 +291,15 @@ class KerasCrossShardOptimizer(keras_optimizers.Optimizer): grads = super(KerasCrossShardOptimizer, self).get_gradients(loss, params) return [tpu_ops.cross_replica_sum(grad) / num_shards for grad in grads] - def set_weights(self, weights): - # TODO(power): Figure out whether we really need this given there is no - # caller for this API yet. - self._opt.set_weights() - def get_weights(self): return self._opt.get_weights() - @property - def lr(self): - return self._opt.lr + def get_config(self): + return self._opt.get_config() + + # Defer remaining operations to the underlying optimizer + def __getattr__(self, key): + return getattr(self._opt, key) class TPUModelOp( @@ -327,14 +325,34 @@ def _replicated_optimizer(opt): return KerasCrossShardOptimizer(opt) -def clone_metrics(metrics): +def _clone_metrics(metrics): """Returns a copy of metrics. A copy is created for stateful metrics.""" if metrics is None: return None - return [ - m.__class__.from_config(m.get_config()) - if isinstance(m, metrics_module.Metric) else m for m in metrics - ] + with variable_scope.variable_scope( + 'metrics', reuse=variable_scope.AUTO_REUSE): + return [ + m.__class__.from_config(m.get_config()) if isinstance( + m, metrics_module.Metric) else m for m in metrics + ] + + +def _clone_optimizer(optimizer, config=None): + """Returns a cloned optimizer with the provided optimizer.config or config.""" + if not isinstance(optimizer, keras_optimizers.Optimizer): + # In the first call to tpu_model(model), Keras may not have wrapped the TF + # optimizer in the TFOptimizer helper, e.g., the given model isn't compiled + # or optimizer isn't set, and later generated tpu_model compiles with a TF + # optimizer. + return optimizer + + if isinstance(optimizer, keras_optimizers.TFOptimizer): + return keras_optimizers.TFOptimizer(optimizer.optimizer) + + if config is None: + config = optimizer.get_config() + logging.info('Cloning %s %s', optimizer.__class__.__name__, config) + return optimizer.__class__.from_config(config) class TPURewriteContext(object): @@ -425,6 +443,7 @@ class TPURewriteContext(object): return (r, q) else: raise ValueError('Invalid shape passed to qr: %s' % input_shape) + gen_linalg_ops.qr = qr ops.name_scope = _name_scope @@ -440,9 +459,9 @@ class TPURewriteContext(object): gen_linalg_ops.qr = self._default_qr -class SizedInfeed(collections.namedtuple('SizedInfeed', - ['sharded_infeed_tensors', - 'infeed_ops'])): +class SizedInfeed( + collections.namedtuple('SizedInfeed', + ['sharded_infeed_tensors', 'infeed_ops'])): """Represents an instantiation of the infeed ops for a concrete input shape. sharded_infeed_tensors: A data structure of Tensors used to represent the @@ -628,12 +647,13 @@ class TPUNumpyInfeedManager(TPUInfeedManager): infeed_tensors, [spec.shape for spec in input_specs], name='infeed-enqueue-%s-%d' % (execution_mode, shard_id), device_ordinal=shard_id)) - return SizedInfeed(infeed_ops=infeed_op, - sharded_infeed_tensors=shard_infeed_tensors) + return SizedInfeed( + infeed_ops=infeed_op, sharded_infeed_tensors=shard_infeed_tensors) class TPUDatasetInfeedManager(TPUInfeedManager): """Manages infeed for a `tf.data.Dataset` into a TPU computation. + """ class DatasetInfeedInstance(TPUInfeedInstance): @@ -651,16 +671,13 @@ class TPUDatasetInfeedManager(TPUInfeedManager): return {} # pylint: disable=redefined-outer-name - def __init__(self, dataset, tpu_assignment, tpu_session, mode): + def __init__(self, dataset, tpu_assignment, mode): """Constructs a TPUDatasetInfeedManager. - Must be called within a `KerasTPUModel.tpu_session` context! - Args: dataset: A `tf.data.Dataset` to infeed. tpu_assignment: The `TPUAssignment` used to configure the Keras TPU model. - tpu_session: The `tf.Session` object used for running the TPU model. mode: ModeKeys enum. """ self._verify_dataset_shape(dataset) @@ -672,7 +689,7 @@ class TPUDatasetInfeedManager(TPUInfeedManager): dummy_y_shape = dataset.output_shapes[1].as_list() dummy_y_shape[0] *= tpu_assignment.num_towers self._iterator = dataset.make_initializable_iterator() - tpu_session.run(self._iterator.initializer) + K.get_session().run(self._iterator.initializer) self._get_next_ops = [] ctrl_deps = [] @@ -685,10 +702,10 @@ class TPUDatasetInfeedManager(TPUInfeedManager): # Use dummy numpy inputs for the rest of Keras' shape checking. We # intercept them when building the model. - self._dummy_x = np.zeros(dummy_x_shape, - dtype=dataset.output_types[0].as_numpy_dtype) - self._dummy_y = np.zeros(dummy_y_shape, - dtype=dataset.output_types[1].as_numpy_dtype) + self._dummy_x = np.zeros( + dummy_x_shape, dtype=dataset.output_types[0].as_numpy_dtype) + self._dummy_y = np.zeros( + dummy_y_shape, dtype=dataset.output_types[1].as_numpy_dtype) input_specs = [] if isinstance(self._iterator.output_shapes, tuple): @@ -719,9 +736,8 @@ class TPUDatasetInfeedManager(TPUInfeedManager): raise ValueError('The dataset must return a tuple of tf.Tensors, ' 'instead it returns: %s' % dataset.output_classes) if len(dataset.output_classes) != 2: - raise ValueError( - 'The dataset must return a 2-element tuple, got ' - '%s output classes instead.' % (dataset.output_classes,)) + raise ValueError('The dataset must return a 2-element tuple, got ' + '%s output classes instead.' % (dataset.output_classes,)) for i, cls in enumerate(dataset.output_classes): if cls != ops.Tensor: raise ValueError('The dataset returned a non-Tensor type (%s) at ' @@ -730,8 +746,7 @@ class TPUDatasetInfeedManager(TPUInfeedManager): if not shape: raise ValueError('The dataset returns a scalar tensor in ' 'tuple index %d. Did you forget to batch? ' - '(Output shapes: %s).' % (i, - dataset.output_shapes)) + '(Output shapes: %s).' % (i, dataset.output_shapes)) for j, dim in enumerate(shape): if dim.value is None: if j == 0: @@ -771,8 +786,8 @@ class TPUDatasetInfeedManager(TPUInfeedManager): [spec.shape for spec in input_specs], name='infeed-enqueue-%s-%d' % (execution_mode, shard_id), device_ordinal=shard_id)) - return SizedInfeed(infeed_ops=infeed_ops, - sharded_infeed_tensors=shard_infeed_tensors) + return SizedInfeed( + infeed_ops=infeed_ops, sharded_infeed_tensors=shard_infeed_tensors) def _inject_tpu_inputs_for_dataset(tpu_assignment, mode, @@ -858,12 +873,7 @@ class TPUFunction(object): self._tpu_assignment = tpu_assignment self._compilation_cache = {} self._cloned_model = None - - # Copy optimizer configuration. This is done prior to `_specialize_model` - # as the configuration may require evaluating variables in the CPU session. - self._optimizer_config = None - if not isinstance(self.model.optimizer, keras_optimizers.TFOptimizer): - self._optimizer_config = self.model.optimizer.get_config() + self._cloned_optimizer = None def _specialize_model(self, input_specs, infeed_manager): """Specialize `self.model` (a Keras model) for the given input shapes.""" @@ -909,53 +919,51 @@ class TPUFunction(object): tpu_targets.append(tensor) # Clone our CPU model, running within the TPU device context. + # + # We use the id of the original model as a key to avoid weight collisions + # (if a user re-runs the same model multiple times, in e.g. Colab). with TPURewriteContext(tpu_input_map): - with variable_scope.variable_scope('tpu_model_%s' % id(self.model)): + with variable_scope.variable_scope('tpu_%s' % id(self.model)): with keras_tpu_variables.replicated_scope( self._tpu_assignment.num_towers): - self._cloned_model = models.clone_model(self.model) + if not self._cloned_optimizer: + self._cloned_optimizer = _clone_optimizer( + self.model.cpu_optimizer) - # When running on more than one core, concatenate outputs at the end of - # processing. In backprop stage, the gradients will be calculdated - # according to the local inputs as gradient of cross-replica-concat being - # zero for any outputs other than those from mlocal core so the loss - # calculation is identical. - num_towers = self.model._tpu_assignment.num_towers - if num_towers > 1 and (is_training or is_test): - new_outputs = [ - _cross_replica_concat( - o, core_id, num_towers, name='model output ({})'.format(o.name)) - for o in self._cloned_model.outputs - ] - self._cloned_model.outputs = new_outputs - tpu_targets = [ - _cross_replica_concat( - tensor, - core_id, - num_towers, - name='model target ({})'.format(tensor.name)) - for tensor in tpu_targets - ] - - # Create a copy of the optimizer for this graph. - if isinstance(self.model.optimizer, keras_optimizers.TFOptimizer): - cloned_optimizer = keras_optimizers.TFOptimizer( - self.model.optimizer.optimizer) - else: - logging.info('Cloning %s %s', self.model.optimizer.__class__.__name__, - self._optimizer_config) - cloned_optimizer = self.model.optimizer.__class__.from_config( - self._optimizer_config) + self._cloned_model = models.clone_model(self.model) - if is_training or is_test: - self._cloned_model.compile( - optimizer=_replicated_optimizer(cloned_optimizer), - loss=self.model.loss, - loss_weights=self.model.loss_weights, - metrics=clone_metrics(self.model.metrics), - weighted_metrics=clone_metrics(self.model.weighted_metrics), - target_tensors=tpu_targets, - ) + # When running on more than one core, concatenate outputs at the end + # of processing. In backprop stage, the gradients will be + # calculdated according to the local inputs as gradient of + # cross-replica-concat being zero for any outputs other than those + # from mlocal core so the loss calculation is identical. + num_towers = self.model._tpu_assignment.num_towers + if num_towers > 1 and (is_training or is_test): + new_outputs = [ + _cross_replica_concat( + o, core_id, num_towers, + name='model output ({})'.format(o.name)) + for o in self._cloned_model.outputs + ] + self._cloned_model.outputs = new_outputs + tpu_targets = [ + _cross_replica_concat( + tensor, + core_id, + num_towers, + name='model target ({})'.format(tensor.name)) + for tensor in tpu_targets + ] + + if is_training or is_test: + self._cloned_model.compile( + optimizer=_replicated_optimizer(self._cloned_optimizer), + loss=self.model.loss, + loss_weights=self.model.loss_weights, + metrics=_clone_metrics(self.model.metrics), + weighted_metrics=_clone_metrics(self.model.weighted_metrics), + target_tensors=tpu_targets, + ) # Compute our outfeed depending on the execution mode if is_training: @@ -1089,15 +1097,14 @@ class TPUFunction(object): # unique input shape. shape_key = tuple([tuple(spec.shape.as_list()) for spec in input_specs]) if shape_key not in self._compilation_cache: - with self.model.tpu_session(): - logging.info( - 'New input shapes; (re-)compiling: mode=%s ' - '(# of cores %d), %s', self.execution_mode, - self._tpu_assignment.num_towers, input_specs) - new_tpu_model_ops = self._specialize_model(input_specs, - infeed_manager) - self._compilation_cache[shape_key] = new_tpu_model_ops - self._test_model_compiles(new_tpu_model_ops) + logging.info( + 'New input shapes; (re-)compiling: mode=%s ' + '(# of cores %d), %s', self.execution_mode, + self._tpu_assignment.num_towers, input_specs) + new_tpu_model_ops = self._specialize_model(input_specs, + infeed_manager) + self._compilation_cache[shape_key] = new_tpu_model_ops + self._test_model_compiles(new_tpu_model_ops) return self._compilation_cache[shape_key] @@ -1195,11 +1202,10 @@ class TPUFunction(object): # Initialize our TPU weights on the first compile. self.model._initialize_weights(self._cloned_model) - with self.model.tpu_session() as session: - _, _, outfeed_outputs = session.run([ - tpu_model_ops.infeed_op, tpu_model_ops.execute_op, - tpu_model_ops.outfeed_op - ], infeed_dict) + _, _, outfeed_outputs = K.get_session().run([ + tpu_model_ops.infeed_op, tpu_model_ops.execute_op, + tpu_model_ops.outfeed_op + ], infeed_dict) return self._process_outputs(outfeed_outputs) def pipeline_run(self, cur_step_inputs, next_step_inputs): @@ -1231,8 +1237,8 @@ class TPUFunction(object): next_step_infeed_manager = self._lookup_infeed_manager(next_step_inputs) cur_step_infeed_manager = self._lookup_infeed_manager(cur_step_inputs) - if (next_step_infeed_manager is not None - and cur_step_infeed_manager is not None): + if (next_step_infeed_manager is not None and + cur_step_infeed_manager is not None): assert type(next_step_infeed_manager) is type(cur_step_infeed_manager) next_input_tensors, next_step_inputs = ( @@ -1257,14 +1263,12 @@ class TPUFunction(object): infeed_dict = None if cur_infeed_instance and cur_input_tensors and cur_step_infeed_manager: - cur_input_specs = cur_infeed_instance.make_input_specs( - cur_input_tensors) + cur_input_specs = cur_infeed_instance.make_input_specs(cur_input_tensors) cur_tpu_model_ops = self._tpu_model_ops_for_input_specs( cur_input_specs, cur_step_infeed_manager) - if (next_infeed_instance - and next_input_tensors - and next_step_infeed_manager): + if (next_infeed_instance and next_input_tensors and + next_step_infeed_manager): next_input_specs = next_infeed_instance.make_input_specs( next_input_tensors) next_tpu_model_ops = self._tpu_model_ops_for_input_specs( @@ -1275,26 +1279,24 @@ class TPUFunction(object): self.model._initialize_weights(self._cloned_model) if next_tpu_model_ops and cur_tpu_model_ops: - with self.model.tpu_session() as session: - _, _, outfeed_outputs = session.run([ - next_tpu_model_ops.infeed_op, cur_tpu_model_ops.execute_op, - cur_tpu_model_ops.outfeed_op - ], infeed_dict) + _, _, outfeed_outputs = K.get_session().run([ + next_tpu_model_ops.infeed_op, cur_tpu_model_ops.execute_op, + cur_tpu_model_ops.outfeed_op + ], infeed_dict) return self._process_outputs(outfeed_outputs) + if cur_tpu_model_ops: - with self.model.tpu_session() as session: - _, outfeed_outputs = session.run([ - cur_tpu_model_ops.execute_op, cur_tpu_model_ops.outfeed_op]) + _, outfeed_outputs = K.get_session().run( + [cur_tpu_model_ops.execute_op, cur_tpu_model_ops.outfeed_op]) return self._process_outputs(outfeed_outputs) + if next_tpu_model_ops: - with self.model.tpu_session() as session: - session.run(next_tpu_model_ops.infeed_op, infeed_dict) + K.get_session().run(next_tpu_model_ops.infeed_op, infeed_dict) return None raise RuntimeError('Internal error: both current & next tpu_model_ops ' 'were None') - class KerasTPUModel(models.Model): """TPU compatible Keras model wrapper.""" @@ -1321,8 +1323,6 @@ class KerasTPUModel(models.Model): self._tpu_model = None self._tpu_weights_initialized = False - self._session = tpu_session(cluster_resolver) - # If the input CPU model has already been compiled, compile our TPU model # immediately. if self._cpu_model.optimizer: @@ -1359,15 +1359,20 @@ class KerasTPUModel(models.Model): if target_tensors: raise ValueError('target_tensors is not supported for TPU execution.') + self._cpu_model.compile( + _clone_optimizer(optimizer), + loss, + _clone_metrics(metrics), + loss_weights, + sample_weight_mode, + _clone_metrics(weighted_metrics), + target_tensors, + **kwargs) + super(KerasTPUModel, self).compile(optimizer, loss, metrics, loss_weights, sample_weight_mode, weighted_metrics, target_tensors, **kwargs) - if not self._cpu_model.optimizer: - self._cpu_model.compile(optimizer, loss, metrics, loss_weights, - sample_weight_mode, weighted_metrics, - target_tensors, **kwargs) - def fit(self, x=None, y=None, @@ -1400,8 +1405,8 @@ class KerasTPUModel(models.Model): 'https://github.com/tensorflow/tpu/tree/master/models/experimental' '/keras') if callable(x): - with self.tpu_session() as sess,\ - ops.device('/job:%s/device:CPU:0' % self._tpu_assignment.worker_name): + with ops.device('/job:%s/device:CPU:0' % + self._tpu_assignment.worker_name): dataset = x() if steps_per_epoch is None: raise ValueError('When using tf.data as input to a model, you ' @@ -1410,7 +1415,7 @@ class KerasTPUModel(models.Model): raise ValueError('When using tf.data as input to a model, y must be ' 'None') infeed_manager = TPUDatasetInfeedManager( - dataset, self._tpu_assignment, sess, model_fn_lib.ModeKeys.TRAIN) + dataset, self._tpu_assignment, model_fn_lib.ModeKeys.TRAIN) # Use dummy numpy inputs for the rest of Keras' shape checking. We # intercept them when building the model. x = infeed_manager.dummy_x @@ -1426,26 +1431,24 @@ class KerasTPUModel(models.Model): 'https://github.com/tensorflow/tpu/tree/master/models/experimental' '/keras') if callable(validation_data): - with self.tpu_session() as sess: - dataset = validation_data() - if validation_steps is None: - raise ValueError('When using tf.data as validation for a model, you ' - 'should specify the validation_steps argument.') - infeed_manager = TPUDatasetInfeedManager( - dataset, self._tpu_assignment, sess, model_fn_lib.ModeKeys.EVAL) - # Use dummy numpy inputs for the rest of Keras' shape checking. We - # intercept them when building the model. - val_x = infeed_manager.dummy_x - val_y = infeed_manager.dummy_y - infeed_managers.append((val_x, infeed_manager)) - validation_data = (val_x, val_y) + dataset = validation_data() + if validation_steps is None: + raise ValueError('When using tf.data as validation for a model, you ' + 'should specify the validation_steps argument.') + infeed_manager = TPUDatasetInfeedManager( + dataset, self._tpu_assignment, model_fn_lib.ModeKeys.EVAL) + # Use dummy numpy inputs for the rest of Keras' shape checking. We + # intercept them when building the model. + val_x = infeed_manager.dummy_x + val_y = infeed_manager.dummy_y + infeed_managers.append((val_x, infeed_manager)) + validation_data = (val_x, val_y) self._numpy_to_infeed_manager_list = infeed_managers try: if not kwargs.get('_pipeline', True): - logging.info( - 'Running non-pipelined training loop (`_pipeline=%s`).', - kwargs['_pipeline']) + logging.info('Running non-pipelined training loop (`_pipeline=%s`).', + kwargs['_pipeline']) kwargs.pop('_pipeline') return super(KerasTPUModel, self).fit( x, @@ -1501,50 +1504,32 @@ class KerasTPUModel(models.Model): 'https://github.com/tensorflow/tpu/tree/master/models/experimental' '/keras') if callable(x): - with self.tpu_session() as sess: - dataset = x() - if steps is None: - raise ValueError('When using tf.data as input to a model, you ' - 'should specify the steps argument.') - if y is not None: - raise ValueError('When using tf.data as input to a model, y must be ' - 'None') - infeed_manager = TPUDatasetInfeedManager( - dataset, self._tpu_assignment, sess, model_fn_lib.ModeKeys.EVAL) - # Use dummy numpy inputs for the rest of Keras' shape checking. We - # intercept them when building the model. - x = infeed_manager.dummy_x - y = infeed_manager.dummy_y - infeed_managers.append((x, infeed_manager)) + dataset = x() + if steps is None: + raise ValueError('When using tf.data as input to a model, you ' + 'should specify the steps argument.') + if y is not None: + raise ValueError('When using tf.data as input to a model, y must be ' + 'None') + infeed_manager = TPUDatasetInfeedManager( + dataset, self._tpu_assignment, model_fn_lib.ModeKeys.EVAL) + # Use dummy numpy inputs for the rest of Keras' shape checking. We + # intercept them when building the model. + x = infeed_manager.dummy_x + y = infeed_manager.dummy_y + infeed_managers.append((x, infeed_manager)) self._numpy_to_infeed_manager_list = infeed_managers try: - return super(KerasTPUModel, self).evaluate( - x, - y, - batch_size, - verbose, - sample_weight, - steps) + return super(KerasTPUModel, self).evaluate(x, y, batch_size, verbose, + sample_weight, steps) finally: self._numpy_to_infeed_manager_list = [] - def _pipeline_fit(self, - x, - y, - batch_size, - epochs, - verbose, - callbacks, - validation_split, - validation_data, - shuffle, - class_weight, - sample_weight, - initial_epoch, - steps_per_epoch, - validation_steps, - **kwargs): + def _pipeline_fit(self, x, y, batch_size, epochs, verbose, callbacks, + validation_split, validation_data, shuffle, class_weight, + sample_weight, initial_epoch, steps_per_epoch, + validation_steps, **kwargs): # Similar to super.fit(...), but modified to support software pipelining. # Backwards compatibility @@ -1572,13 +1557,8 @@ class KerasTPUModel(models.Model): # Prepare validation data val_x, val_y, val_sample_weights = self._prepare_validation_data( - validation_data, - validation_split, - validation_steps, - x, - y, - sample_weights, - batch_size) + validation_data, validation_split, validation_steps, x, y, + sample_weights, batch_size) return self._pipeline_fit_loop( x, y, @@ -1751,8 +1731,8 @@ class KerasTPUModel(models.Model): for i in indices_for_conversion_to_dense: ins_batch[i] = ins_batch[i].toarray() - outs = f.pipeline_run(cur_step_inputs=ins_last_batch, - next_step_inputs=ins_batch) + outs = f.pipeline_run( + cur_step_inputs=ins_last_batch, next_step_inputs=ins_batch) ins_last_batch = ins_batch if batch_index == 0: @@ -1824,8 +1804,8 @@ class KerasTPUModel(models.Model): next_step_inputs = ins else: next_step_inputs = None - outs = f.pipeline_run(cur_step_inputs=ins, - next_step_inputs=next_step_inputs) + outs = f.pipeline_run( + cur_step_inputs=ins, next_step_inputs=next_step_inputs) except errors.OutOfRangeError: logging.warning('Your dataset iterator ran out of data; ' 'interrupting training. Make sure that your ' @@ -1845,25 +1825,21 @@ class KerasTPUModel(models.Model): break if do_validation: - val_outs = training_arrays.test_loop(self, - val_inputs, - val_targets, - sample_weights=val_sample_weights, - steps=validation_steps, - verbose=0) + val_outs = training_arrays.test_loop( + self, + val_inputs, + val_targets, + sample_weights=val_sample_weights, + steps=validation_steps, + verbose=0) if not isinstance(val_outs, list): val_outs = [val_outs] # Same labels assumed. for l, o in zip(self.metrics_names, val_outs): epoch_logs['val_' + l] = o - def _prepare_validation_data(self, - validation_data, - validation_split, - validation_steps, - x, - y, - sample_weights, + def _prepare_validation_data(self, validation_data, validation_split, + validation_steps, x, y, sample_weights, batch_size): """Prepares the validation dataset. @@ -1921,8 +1897,10 @@ class KerasTPUModel(models.Model): x, val_x = (slice_arrays(x, 0, split_at), slice_arrays(x, split_at)) y, val_y = (slice_arrays(y, 0, split_at), slice_arrays(y, split_at)) - sample_weights, val_sample_weights = (slice_arrays( - sample_weights, 0, split_at), slice_arrays(sample_weights, split_at)) + sample_weights, val_sample_weights = ( + slice_arrays(sample_weights, 0, split_at), + slice_arrays(sample_weights, split_at) + ) elif validation_steps: val_x = [] val_y = [] @@ -1934,11 +1912,20 @@ class KerasTPUModel(models.Model): return val_x, val_y, val_sample_weights + @property + def optimizer(self): + if self._tpu_model: + return self._tpu_model.optimizer + return self._cpu_model.optimizer + + @optimizer.setter + def optimizer(self, optimizer): + self._optimizer = optimizer + def _make_train_function(self): if not self.train_function: self.train_function = TPUFunction( - self, - model_fn_lib.ModeKeys.TRAIN, + self, model_fn_lib.ModeKeys.TRAIN, tpu_assignment=self._tpu_assignment) return self.train_function @@ -1973,18 +1960,48 @@ class KerasTPUModel(models.Model): self._tpu_weights_initialized = True weights = self._cpu_model.get_weights() - with self.tpu_session(): - logging.info('Setting weights on TPU model.') - cloned_model.set_weights(weights) + + if isinstance(self.cpu_optimizer, keras_optimizers.TFOptimizer): + cpu_optimizer_config = {} + else: + cpu_optimizer_config = self.cpu_optimizer.get_config() + + logging.info('Setting weights on TPU model.') + cloned_model.set_weights(weights) + for k, v in six.iteritems(cpu_optimizer_config): + opt_var = getattr(self._tpu_model.optimizer, k) + if isinstance(opt_var, variables.Variable): + logging.info('CPU -> TPU %s: %s {%s}', k, v, K.get_value(opt_var)) + K.get_session().run(opt_var.assign(v)) + else: + logging.warning('Cannot update non-variable config: %s', k) + + @property + def cpu_optimizer(self): + return self._cpu_model.optimizer def sync_to_cpu(self): """Copy weights from the CPU, returning a synchronized CPU model.""" - if self._tpu_weights_initialized: - with self.tpu_session(): - logging.info('Copying TPU weights to the CPU') - tpu_weights = self._tpu_model.get_weights() + if not self._tpu_weights_initialized: + return self._cpu_model - self._cpu_model.set_weights(tpu_weights) + logging.info('Copying TPU weights to the CPU') + tpu_weights = self._tpu_model.get_weights() + + # TFOptimizers have no configurable options + if isinstance(self.cpu_optimizer, keras_optimizers.TFOptimizer): + tpu_optimizer_config = {} + else: + tpu_optimizer_config = self._tpu_model.optimizer.get_config() + + self._cpu_model.set_weights(tpu_weights) + for k, v in six.iteritems(tpu_optimizer_config): + logging.info('TPU -> CPU %s: %s', k, v) + opt_var = getattr(self.cpu_optimizer, k) + if isinstance(opt_var, variables.Variable): + K.get_session().run(opt_var.assign(v)) + else: + logging.warning('Cannot update non-variable config: %s', k) return self._cpu_model @@ -2005,26 +2022,6 @@ class KerasTPUModel(models.Model): self._cpu_model.set_weights(weights) self._tpu_weights_initialized = False - @contextlib.contextmanager - def tpu_session(self): - """Yields a TPU session and sets it as the default Keras session.""" - with self._session.graph.as_default(): - default_session = K.get_session() - # N.B. We have to call `K.set_session()` AND set our session as the - # TF default. `K.get_session()` surprisingly does not return the value - # supplied by K.set_session otherwise. - K.set_session(self._session) - with self._session.as_default(): - yield self._session - K.set_session(default_session) - - def shutdown(self): - # TODO(b/111364423): Actually shut down the system. - logging.info('Skipping shutting down TPU system.') - # with self.tpu_session() as session: - # session.run(tpu.shutdown_system()) - self._session.close() - # pylint: disable=bad-continuation def _validate_shapes(model): @@ -2065,7 +2062,9 @@ Output shape: %(output_shape)s @experimental def tpu_model(model, strategy=None): - """Copy `model` along with weights to the TPU. Returns a TPU model. + """Copy `model` along with weights to the TPU. + + Returns a TPU model. Usage: ``` @@ -2080,21 +2079,16 @@ def tpu_model(model, strategy=None): model.compile( optimizer=tf.train.GradientDescentOptimizer(learning_rate=1.0), ...) - model.shutdown() ``` Args: - model: A `KerasTPUModel`. + model: A `tf.keras.Model` instance. strategy: `TPUDistributionStrategy`. The strategy to use for replicating - model across multiple TPU cores. + model across multiple TPU cores. Returns: A new `KerasTPUModel` instance. """ - # Force initialization of the CPU model. - model.get_weights() - model.reset_states() - _validate_shapes(model) # TODO(xiejw): Validate TPU model. TPUModel only? # TODO(xiejw): Validate replicas. Full or 1. Shall we allow subset? @@ -2108,4 +2102,34 @@ def tpu_model(model, strategy=None): '`strategy` must have type `tf.contrib.tpu.TPUDistributionStrategy`. ' 'Got: {}'.format(type(strategy))) - return KerasTPUModel(cpu_model=model, strategy=strategy) + # If the model has already been initialized, grab the optimizer configuration + # and model weights before entering the TPU session. + if model.optimizer: + if (isinstance(model.optimizer, keras_optimizers.Optimizer) and not + isinstance(model.optimizer, keras_optimizers.TFOptimizer)): + optimizer_config = model.optimizer.get_config() + else: + optimizer_config = None + model_weights = model.get_weights() + else: + model_weights = None + + setup_tpu_session(strategy._tpu_cluster_resolver) + + # Force initialization of the CPU model in the TPU session. + cpu_model = models.clone_model(model) + if model.optimizer: + cpu_model.compile( + _clone_optimizer(model.optimizer, optimizer_config), + model.loss, + _clone_metrics(model.metrics), + model.loss_weights, + model.sample_weight_mode, + _clone_metrics(model.weighted_metrics), + ) + + if model_weights: + cpu_model.set_weights(model_weights) + cpu_model.reset_states() + + return KerasTPUModel(cpu_model=cpu_model, strategy=strategy) |