diff options
Diffstat (limited to 'tensorflow/python/keras/engine/training.py')
-rw-r--r-- | tensorflow/python/keras/engine/training.py | 315 |
1 files changed, 230 insertions, 85 deletions
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 2cdd00a48d..f71388cadb 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -29,6 +29,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.keras import backend as K from tensorflow.python.keras import losses +from tensorflow.python.keras import metrics as metrics_module from tensorflow.python.keras import optimizers from tensorflow.python.keras.engine import base_layer from tensorflow.python.keras.engine import distributed_training_utils @@ -39,6 +40,8 @@ from tensorflow.python.keras.engine import training_generator from tensorflow.python.keras.engine import training_utils from tensorflow.python.keras.engine.network import Network from tensorflow.python.keras.utils.generic_utils import slice_arrays +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import weights_broadcast_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import optimizer as tf_optimizer_module from tensorflow.python.training.checkpointable import base as checkpointable @@ -74,6 +77,7 @@ class Model(Network): class MyModel(tf.keras.Model): def __init__(self): + super(MyModel, self).__init__() self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu) self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax) @@ -94,6 +98,7 @@ class Model(Network): class MyModel(tf.keras.Model): def __init__(self): + super(MyModel, self).__init__() self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu) self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax) self.dropout = tf.keras.layers.Dropout(0.5) @@ -136,6 +141,167 @@ class Model(Network): if i not in skip_target_weighing_indices ] + def _get_metric_name(self, metric, output_index, weighted=False): + """Returns the metric name corresponding to the given metric input. + + Arguments: + metric: Metric function name or reference. + output_index: Index of the current output. + weighted: Boolean indicating if the given metric is weighted. + + Returns: + A metric name. + """ + metric_name_prefix = 'weighted_' if weighted else '' + if metric in ('accuracy', 'acc', 'crossentropy', 'ce'): + if metric in ('accuracy', 'acc'): + suffix = 'acc' + elif metric in ('crossentropy', 'ce'): + suffix = 'ce' + else: + metric_fn = metrics_module.get(metric) + # Get metric name as string + if hasattr(metric_fn, 'name'): + suffix = metric_fn.name + else: + suffix = metric_fn.__name__ + metric_name = metric_name_prefix + suffix + + if len(self.output_names) > 1: + metric_name = '%s_%s' % (self.output_names[output_index], metric_name) + j = 1 + base_metric_name = metric_name + while metric_name in self.metrics_names: + metric_name = '%s_%d' % (base_metric_name, j) + j += 1 + + return metric_name + + def _handle_per_output_metrics(self, + metrics, + y_true, + y_pred, + output_index, + output_shape, + loss_fn, + mask, + weights=None): + """Calls metric functions and sets metric attributes for a single output. + + Arguments: + metrics: List of metrics. + y_true: Target output. + y_pred: Predicted output. + output_index: Index of the current output. + output_shape: Shape of the current output. + loss_fn: Loss function corresponding to the current output. + mask: Computed mask value for the current output. + weights: Weights to be applied on the current output. + + Returns: + A list of metric result tensors. + """ + metric_results = [] + for metric in metrics: + metric_fn = training_utils.get_metric_function( + metric, output_shape=output_shape, loss_fn=loss_fn) + metric_name = self._get_metric_name( + metric, output_index, weighted=weights is not None) + + with K.name_scope(metric_name): + # If both outputs and targets are available, call the metric function. + if y_true is not None and y_pred is not None: + if isinstance(metric_fn, metrics_module.Metric): + # Call the stateful metric function. + if mask is not None: + mask = math_ops.cast(mask, y_pred.dtype) + # Update weights with mask. + if weights is None: + weights = mask + else: + # Update shape of weights if possible before adding mask. + # Update dimensions of weights to match with mask if possible. + mask, _, weights = metrics_module.squeeze_or_expand_dimensions( + mask, None, weights) + try: + # Broadcast weights if possible. + weights = weights_broadcast_ops.broadcast_weights( + weights, mask) + except ValueError: + pass + # TODO(psv): Handle case when mask and weight shapes are not + # compatible. + weights *= mask + + metric_result = metric_fn(y_true, y_pred, weights) + else: + # Call the stateless metric function. + weighted_metric_fn = training_utils.weighted_masked_objective( + metric_fn) + metric_result = weighted_metric_fn( + y_true, y_pred, weights=weights, mask=mask) + + if not context.executing_eagerly(): + # Keep track of metric result tensor. + self.metrics_tensors.append(metric_result) + metric_results.append(metric_result) + + # Keep track of metric name. + self.metrics_names.append(metric_name) + + # Keep track of stateful metric attributes (name and metric function). + if isinstance(metric_fn, base_layer.Layer) and metric_fn.stateful: + self.stateful_metric_names.append(metric_name) + self.stateful_metric_functions.append(metric_fn) + if not context.executing_eagerly(): + # Keep track of updates created by stateful metrics. + self.metrics_updates += metric_fn.updates + return metric_results + + def _handle_metrics(self, + outputs, + skip_target_indices=None, + targets=None, + sample_weights=None, + masks=None): + """Handles calling metric functions and setting model metric attributes. + + Arguments: + outputs: List of outputs (predictions). + skip_target_indices: Optional. List of target ids to skip. + targets: List of targets. + sample_weights: Optional list of sample weight arrays. + masks: List of computed output mask values. + + Returns: + A list of metric result tensors. + """ + skip_target_indices = skip_target_indices or [] + metric_results = [] + with K.name_scope('metrics'): + for i in range(len(outputs)): + if i in skip_target_indices: + continue + output = outputs[i] if outputs else None + target = targets[i] if targets else None + output_shape = None if output is None else output.get_shape().as_list() + output_mask = masks[i] if masks else None + metric_results.extend( + self._handle_per_output_metrics( + self.nested_metrics[i], target, output, i, output_shape, + self.loss_functions[i], output_mask)) + metric_results.extend( + self._handle_per_output_metrics( + self.nested_weighted_metrics[i], + target, + output, + i, + output_shape, + self.loss_functions[i], + output_mask, + weights=sample_weights[i])) + return metric_results + @checkpointable.no_automatic_dependency_tracking def compile(self, optimizer, @@ -151,9 +317,9 @@ class Model(Network): Arguments: optimizer: String (name of optimizer) or optimizer instance. - See [optimizers](/optimizers). + See [optimizers](/api_docs/python/tf/keras/optimizers). loss: String (name of objective function) or objective function. - See [losses](/losses). + See [losses](/api_docs/python/tf/losses). If the model has multiple outputs, you can use a different loss on each output by passing a dictionary or a list of losses. The loss value that will be minimized by the model @@ -231,8 +397,6 @@ class Model(Network): self.metrics = metrics or [] self.loss_weights = loss_weights self.sample_weight_mode = sample_weight_mode - if context.executing_eagerly() and weighted_metrics is not None: - raise ValueError('weighted_metrics is not supported in Eager mode.') self.weighted_metrics = weighted_metrics if context.executing_eagerly() and target_tensors is not None: raise ValueError('target_tensors is not supported in Eager mode.') @@ -335,6 +499,20 @@ class Model(Network): str(loss_weights) + ' - expected a list of dicts.') self.loss_weights_list = loss_weights_list + # Initialize model metric attributes. + self.metrics_names = ['loss'] + self.metrics_tensors = [] + self.metrics_updates = [] + self.stateful_metric_names = [] + self.stateful_metric_functions = [] + + # Nested metrics is a list of list of metrics. + # One list per output of the model. + self.nested_metrics = training_utils.collect_metrics( + metrics, self.output_names) + self.nested_weighted_metrics = training_utils.collect_metrics( + weighted_metrics, self.output_names) + # Initialization for Eager mode execution. if context.executing_eagerly(): # Prepare sample weights. @@ -345,19 +523,16 @@ class Model(Network): raise ValueError('target_tensors are not currently supported in Eager ' 'mode.') self.total_loss = None - self.metrics_tensors = [] - self.metrics_names = ['loss'] for i in range(len(self.outputs)): if len(self.outputs) > 1: self.metrics_names.append(self.output_names[i] + '_loss') - self.nested_metrics = training_utils.collect_metrics(metrics, - self.output_names) - # TODO(fchollet): support stateful metrics in eager execution. - self.stateful_metric_functions = [] - self.stateful_metric_names = [] - - with K.name_scope('metrics'): - training_utils.populate_metric_names(self) + + # Set metric attributes on model. + self._handle_metrics( + self.outputs, + skip_target_indices=skip_target_indices, + sample_weights=self.sample_weights) + self.targets = [] for i in range(len(self.outputs)): self._feed_output_names.append(self.output_names[i]) @@ -420,11 +595,6 @@ class Model(Network): self._set_sample_weight_attributes(sample_weight_mode, skip_target_weighing_indices) - # Prepare metrics. - self.weighted_metrics = weighted_metrics - self.metrics_names = ['loss'] - self.metrics_tensors = [] - # Compute total loss. total_loss = None with K.name_scope('loss'): @@ -458,55 +628,13 @@ class Model(Network): for loss_tensor in self.losses: total_loss += loss_tensor - # List of same size as output_names. - # contains tuples (metrics for output, names of metrics). - nested_metrics = training_utils.collect_metrics(metrics, self.output_names) - nested_weighted_metrics = training_utils.collect_metrics(weighted_metrics, - self.output_names) - self.metrics_updates = [] - self.stateful_metric_names = [] - self.stateful_metric_functions = [] - with K.name_scope('metrics'): - for i in range(len(self.outputs)): - if i in skip_target_indices: - continue - - y_true = self.targets[i] - y_pred = self.outputs[i] - weights = self.sample_weights[i] - output_metrics = nested_metrics[i] - output_weighted_metrics = nested_weighted_metrics[i] - output_shape = self.outputs[i].get_shape().as_list() - loss_fn = self.loss_functions[i] - - def handle_metrics(metrics, output_shape, loss_fn, weights=None): - """Invokes metric functions for the output.""" - - for metric in metrics: - metric_fn = training_utils.get_metric_function( - metric, output_shape=output_shape, loss_fn=loss_fn) - metric_name = training_utils.get_metric_name( - metric, weighted=weights is not None) - - with K.name_scope(metric_name): - weighted_metric_fn = training_utils.weighted_masked_objective( - metric_fn) - metric_result = weighted_metric_fn( - y_true, y_pred, weights=weights, mask=masks[i]) # pylint: disable=undefined-loop-variable - - metric_name = training_utils.add_metric_name(self, metric_name, i) # pylint: disable=undefined-loop-variable - self.metrics_tensors.append(metric_result) - - # Keep track of state updates created by - # stateful metrics (i.e. metrics layers). - if isinstance(metric_fn, base_layer.Layer) and metric_fn.stateful: - self.stateful_metric_names.append(metric_name) - self.stateful_metric_functions.append(metric_fn) - self.metrics_updates += metric_fn.updates - - handle_metrics(output_metrics, output_shape, loss_fn) - handle_metrics( - output_weighted_metrics, output_shape, loss_fn, weights=weights) + # Invoke metric functions for all the outputs. + self._handle_metrics( + self.outputs, + masks=masks, + targets=self.targets, + skip_target_indices=skip_target_indices, + sample_weights=self.sample_weights) # Prepare gradient updates and state updates. self.total_loss = total_loss @@ -717,8 +845,8 @@ class Model(Network): x_values, y_values = distributed_training_utils.\ validate_distributed_dataset_inputs(self._distribution_strategy, x, y) - _, _, sample_weights = self._standardize_weights(x_values[0], - y_values[0], + _, _, sample_weights = self._standardize_weights(x_values, + y_values, sample_weight, class_weight, batch_size) @@ -856,7 +984,7 @@ class Model(Network): all_inputs = [] is_build_called = False is_compile_called = False - if not self.built: + if not self.inputs: # We need to use `x` to set the model inputs. # We type-check that `x` and `y` are either single arrays # or lists of arrays. @@ -1067,22 +1195,13 @@ class Model(Network): 'in their call() signatures do not yet support shape inference. File ' 'a feature request if this limitation bothers you.') if self.__class__.__name__ == 'Sequential': - # Note: we can't test whether the model is `Sequential` via `isinstance` - # since `Sequential` depends on `Model`. - if isinstance(inputs, list): - assert len(inputs) == 1 - inputs = inputs[0] - if tensor_util.is_tensor(inputs): - if context.executing_eagerly(): - input_shape = (None,) + tuple(inputs.get_shape().as_list()[1:]) - self.build(input_shape=input_shape) - else: - self.symbolic_set_inputs(inputs) + input_shape = (None,) + tuple(inputs.get_shape().as_list()[1:]) + self.build(input_shape=input_shape) else: input_shape = (None,) + inputs.shape[1:] self.build(input_shape=input_shape) - elif context.executing_eagerly(): + if context.executing_eagerly(): self._eager_set_inputs(inputs) else: self._symbolic_set_inputs(inputs, training=training) @@ -1273,7 +1392,7 @@ class Model(Network): 0 = silent, 1 = progress bar, 2 = one line per epoch. callbacks: List of `keras.callbacks.Callback` instances. List of callbacks to apply during training. - See [callbacks](/callbacks). + See [callbacks](/api_docs/python/tf/keras/callbacks). validation_split: Float between 0 and 1. Fraction of the training data to be used as validation data. The model will set apart this fraction of the training data, @@ -1891,6 +2010,10 @@ class Model(Network): Raises: ValueError: In case the generator yields data in an invalid format. """ + if self._distribution_strategy: + raise NotImplementedError('`fit_generator` is not supported for ' + 'models compiled with DistributionStrategy.') + if not self.built and not self._is_graph_network: raise NotImplementedError( '`fit_generator` is not yet enabled for unbuilt Model subclasses') @@ -1958,6 +2081,10 @@ class Model(Network): Raises: ValueError: In case the generator yields data in an invalid format. """ + if self._distribution_strategy: + raise NotImplementedError('`evaluate_generator` is not supported for ' + 'models compiled with DistributionStrategy.') + if not self.built and not self._is_graph_network: raise NotImplementedError( '`evaluate_generator` is not yet enabled for ' @@ -2012,6 +2139,10 @@ class Model(Network): Raises: ValueError: In case the generator yields data in an invalid format. """ + if self._distribution_strategy: + raise NotImplementedError('`predict_generator` is not supported for ' + 'models compiled with DistributionStrategy.') + if not self.built and not self._is_graph_network: raise NotImplementedError( '`predict_generator` is not yet enabled for unbuilt Model subclasses') @@ -2025,6 +2156,21 @@ class Model(Network): use_multiprocessing=use_multiprocessing, verbose=verbose) + def _get_callback_model(self): + """Returns the Callback Model for this Model.""" + + if hasattr(self, '_replicated_model') and self._replicated_model: + # When using training_distributed, we set the callback model + # to an instance of the `DistributedModel` that we create in + # the `compile` call. The `DistributedModel` is initialized + # with the first replicated model. We need to set the callback + # model to a DistributedModel to allow us to override saving + # and loading weights when we checkpoint the model during training. + return self._replicated_model + if hasattr(self, 'callback_model') and self.callback_model: + return self.callback_model + return self + class DistributedCallbackModel(Model): """Model that is used for callbacks with DistributionStrategy.""" @@ -2065,4 +2211,3 @@ class DistributedCallbackModel(Model): logging.warning('You are accessing attribute ' + item + 'of the' 'DistributedCallbackModel that may not have been set' 'correctly.') - |