aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/engine/training.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/engine/training.py')
-rw-r--r--tensorflow/python/keras/engine/training.py315
1 files changed, 230 insertions, 85 deletions
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index 2cdd00a48d..f71388cadb 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -29,6 +29,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import losses
+from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.keras import optimizers
from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.engine import distributed_training_utils
@@ -39,6 +40,8 @@ from tensorflow.python.keras.engine import training_generator
from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.engine.network import Network
from tensorflow.python.keras.utils.generic_utils import slice_arrays
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import weights_broadcast_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import optimizer as tf_optimizer_module
from tensorflow.python.training.checkpointable import base as checkpointable
@@ -74,6 +77,7 @@ class Model(Network):
class MyModel(tf.keras.Model):
def __init__(self):
+ super(MyModel, self).__init__()
self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
@@ -94,6 +98,7 @@ class Model(Network):
class MyModel(tf.keras.Model):
def __init__(self):
+ super(MyModel, self).__init__()
self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
self.dropout = tf.keras.layers.Dropout(0.5)
@@ -136,6 +141,167 @@ class Model(Network):
if i not in skip_target_weighing_indices
]
+ def _get_metric_name(self, metric, output_index, weighted=False):
+ """Returns the metric name corresponding to the given metric input.
+
+ Arguments:
+ metric: Metric function name or reference.
+ output_index: Index of the current output.
+ weighted: Boolean indicating if the given metric is weighted.
+
+ Returns:
+ A metric name.
+ """
+ metric_name_prefix = 'weighted_' if weighted else ''
+ if metric in ('accuracy', 'acc', 'crossentropy', 'ce'):
+ if metric in ('accuracy', 'acc'):
+ suffix = 'acc'
+ elif metric in ('crossentropy', 'ce'):
+ suffix = 'ce'
+ else:
+ metric_fn = metrics_module.get(metric)
+ # Get metric name as string
+ if hasattr(metric_fn, 'name'):
+ suffix = metric_fn.name
+ else:
+ suffix = metric_fn.__name__
+ metric_name = metric_name_prefix + suffix
+
+ if len(self.output_names) > 1:
+ metric_name = '%s_%s' % (self.output_names[output_index], metric_name)
+ j = 1
+ base_metric_name = metric_name
+ while metric_name in self.metrics_names:
+ metric_name = '%s_%d' % (base_metric_name, j)
+ j += 1
+
+ return metric_name
+
+ def _handle_per_output_metrics(self,
+ metrics,
+ y_true,
+ y_pred,
+ output_index,
+ output_shape,
+ loss_fn,
+ mask,
+ weights=None):
+ """Calls metric functions and sets metric attributes for a single output.
+
+ Arguments:
+ metrics: List of metrics.
+ y_true: Target output.
+ y_pred: Predicted output.
+ output_index: Index of the current output.
+ output_shape: Shape of the current output.
+ loss_fn: Loss function corresponding to the current output.
+ mask: Computed mask value for the current output.
+ weights: Weights to be applied on the current output.
+
+ Returns:
+ A list of metric result tensors.
+ """
+ metric_results = []
+ for metric in metrics:
+ metric_fn = training_utils.get_metric_function(
+ metric, output_shape=output_shape, loss_fn=loss_fn)
+ metric_name = self._get_metric_name(
+ metric, output_index, weighted=weights is not None)
+
+ with K.name_scope(metric_name):
+ # If both outputs and targets are available, call the metric function.
+ if y_true is not None and y_pred is not None:
+ if isinstance(metric_fn, metrics_module.Metric):
+ # Call the stateful metric function.
+ if mask is not None:
+ mask = math_ops.cast(mask, y_pred.dtype)
+ # Update weights with mask.
+ if weights is None:
+ weights = mask
+ else:
+ # Update shape of weights if possible before adding mask.
+ # Update dimensions of weights to match with mask if possible.
+ mask, _, weights = metrics_module.squeeze_or_expand_dimensions(
+ mask, None, weights)
+ try:
+ # Broadcast weights if possible.
+ weights = weights_broadcast_ops.broadcast_weights(
+ weights, mask)
+ except ValueError:
+ pass
+ # TODO(psv): Handle case when mask and weight shapes are not
+ # compatible.
+ weights *= mask
+
+ metric_result = metric_fn(y_true, y_pred, weights)
+ else:
+ # Call the stateless metric function.
+ weighted_metric_fn = training_utils.weighted_masked_objective(
+ metric_fn)
+ metric_result = weighted_metric_fn(
+ y_true, y_pred, weights=weights, mask=mask)
+
+ if not context.executing_eagerly():
+ # Keep track of metric result tensor.
+ self.metrics_tensors.append(metric_result)
+ metric_results.append(metric_result)
+
+ # Keep track of metric name.
+ self.metrics_names.append(metric_name)
+
+ # Keep track of stateful metric attributes (name and metric function).
+ if isinstance(metric_fn, base_layer.Layer) and metric_fn.stateful:
+ self.stateful_metric_names.append(metric_name)
+ self.stateful_metric_functions.append(metric_fn)
+ if not context.executing_eagerly():
+ # Keep track of updates created by stateful metrics.
+ self.metrics_updates += metric_fn.updates
+ return metric_results
+
+ def _handle_metrics(self,
+ outputs,
+ skip_target_indices=None,
+ targets=None,
+ sample_weights=None,
+ masks=None):
+ """Handles calling metric functions and setting model metric attributes.
+
+ Arguments:
+ outputs: List of outputs (predictions).
+ skip_target_indices: Optional. List of target ids to skip.
+ targets: List of targets.
+ sample_weights: Optional list of sample weight arrays.
+ masks: List of computed output mask values.
+
+ Returns:
+ A list of metric result tensors.
+ """
+ skip_target_indices = skip_target_indices or []
+ metric_results = []
+ with K.name_scope('metrics'):
+ for i in range(len(outputs)):
+ if i in skip_target_indices:
+ continue
+ output = outputs[i] if outputs else None
+ target = targets[i] if targets else None
+ output_shape = None if output is None else output.get_shape().as_list()
+ output_mask = masks[i] if masks else None
+ metric_results.extend(
+ self._handle_per_output_metrics(
+ self.nested_metrics[i], target, output, i, output_shape,
+ self.loss_functions[i], output_mask))
+ metric_results.extend(
+ self._handle_per_output_metrics(
+ self.nested_weighted_metrics[i],
+ target,
+ output,
+ i,
+ output_shape,
+ self.loss_functions[i],
+ output_mask,
+ weights=sample_weights[i]))
+ return metric_results
+
@checkpointable.no_automatic_dependency_tracking
def compile(self,
optimizer,
@@ -151,9 +317,9 @@ class Model(Network):
Arguments:
optimizer: String (name of optimizer) or optimizer instance.
- See [optimizers](/optimizers).
+ See [optimizers](/api_docs/python/tf/keras/optimizers).
loss: String (name of objective function) or objective function.
- See [losses](/losses).
+ See [losses](/api_docs/python/tf/losses).
If the model has multiple outputs, you can use a different loss
on each output by passing a dictionary or a list of losses.
The loss value that will be minimized by the model
@@ -231,8 +397,6 @@ class Model(Network):
self.metrics = metrics or []
self.loss_weights = loss_weights
self.sample_weight_mode = sample_weight_mode
- if context.executing_eagerly() and weighted_metrics is not None:
- raise ValueError('weighted_metrics is not supported in Eager mode.')
self.weighted_metrics = weighted_metrics
if context.executing_eagerly() and target_tensors is not None:
raise ValueError('target_tensors is not supported in Eager mode.')
@@ -335,6 +499,20 @@ class Model(Network):
str(loss_weights) + ' - expected a list of dicts.')
self.loss_weights_list = loss_weights_list
+ # Initialize model metric attributes.
+ self.metrics_names = ['loss']
+ self.metrics_tensors = []
+ self.metrics_updates = []
+ self.stateful_metric_names = []
+ self.stateful_metric_functions = []
+
+ # Nested metrics is a list of list of metrics.
+ # One list per output of the model.
+ self.nested_metrics = training_utils.collect_metrics(
+ metrics, self.output_names)
+ self.nested_weighted_metrics = training_utils.collect_metrics(
+ weighted_metrics, self.output_names)
+
# Initialization for Eager mode execution.
if context.executing_eagerly():
# Prepare sample weights.
@@ -345,19 +523,16 @@ class Model(Network):
raise ValueError('target_tensors are not currently supported in Eager '
'mode.')
self.total_loss = None
- self.metrics_tensors = []
- self.metrics_names = ['loss']
for i in range(len(self.outputs)):
if len(self.outputs) > 1:
self.metrics_names.append(self.output_names[i] + '_loss')
- self.nested_metrics = training_utils.collect_metrics(metrics,
- self.output_names)
- # TODO(fchollet): support stateful metrics in eager execution.
- self.stateful_metric_functions = []
- self.stateful_metric_names = []
-
- with K.name_scope('metrics'):
- training_utils.populate_metric_names(self)
+
+ # Set metric attributes on model.
+ self._handle_metrics(
+ self.outputs,
+ skip_target_indices=skip_target_indices,
+ sample_weights=self.sample_weights)
+
self.targets = []
for i in range(len(self.outputs)):
self._feed_output_names.append(self.output_names[i])
@@ -420,11 +595,6 @@ class Model(Network):
self._set_sample_weight_attributes(sample_weight_mode,
skip_target_weighing_indices)
- # Prepare metrics.
- self.weighted_metrics = weighted_metrics
- self.metrics_names = ['loss']
- self.metrics_tensors = []
-
# Compute total loss.
total_loss = None
with K.name_scope('loss'):
@@ -458,55 +628,13 @@ class Model(Network):
for loss_tensor in self.losses:
total_loss += loss_tensor
- # List of same size as output_names.
- # contains tuples (metrics for output, names of metrics).
- nested_metrics = training_utils.collect_metrics(metrics, self.output_names)
- nested_weighted_metrics = training_utils.collect_metrics(weighted_metrics,
- self.output_names)
- self.metrics_updates = []
- self.stateful_metric_names = []
- self.stateful_metric_functions = []
- with K.name_scope('metrics'):
- for i in range(len(self.outputs)):
- if i in skip_target_indices:
- continue
-
- y_true = self.targets[i]
- y_pred = self.outputs[i]
- weights = self.sample_weights[i]
- output_metrics = nested_metrics[i]
- output_weighted_metrics = nested_weighted_metrics[i]
- output_shape = self.outputs[i].get_shape().as_list()
- loss_fn = self.loss_functions[i]
-
- def handle_metrics(metrics, output_shape, loss_fn, weights=None):
- """Invokes metric functions for the output."""
-
- for metric in metrics:
- metric_fn = training_utils.get_metric_function(
- metric, output_shape=output_shape, loss_fn=loss_fn)
- metric_name = training_utils.get_metric_name(
- metric, weighted=weights is not None)
-
- with K.name_scope(metric_name):
- weighted_metric_fn = training_utils.weighted_masked_objective(
- metric_fn)
- metric_result = weighted_metric_fn(
- y_true, y_pred, weights=weights, mask=masks[i]) # pylint: disable=undefined-loop-variable
-
- metric_name = training_utils.add_metric_name(self, metric_name, i) # pylint: disable=undefined-loop-variable
- self.metrics_tensors.append(metric_result)
-
- # Keep track of state updates created by
- # stateful metrics (i.e. metrics layers).
- if isinstance(metric_fn, base_layer.Layer) and metric_fn.stateful:
- self.stateful_metric_names.append(metric_name)
- self.stateful_metric_functions.append(metric_fn)
- self.metrics_updates += metric_fn.updates
-
- handle_metrics(output_metrics, output_shape, loss_fn)
- handle_metrics(
- output_weighted_metrics, output_shape, loss_fn, weights=weights)
+ # Invoke metric functions for all the outputs.
+ self._handle_metrics(
+ self.outputs,
+ masks=masks,
+ targets=self.targets,
+ skip_target_indices=skip_target_indices,
+ sample_weights=self.sample_weights)
# Prepare gradient updates and state updates.
self.total_loss = total_loss
@@ -717,8 +845,8 @@ class Model(Network):
x_values, y_values = distributed_training_utils.\
validate_distributed_dataset_inputs(self._distribution_strategy, x, y)
- _, _, sample_weights = self._standardize_weights(x_values[0],
- y_values[0],
+ _, _, sample_weights = self._standardize_weights(x_values,
+ y_values,
sample_weight,
class_weight,
batch_size)
@@ -856,7 +984,7 @@ class Model(Network):
all_inputs = []
is_build_called = False
is_compile_called = False
- if not self.built:
+ if not self.inputs:
# We need to use `x` to set the model inputs.
# We type-check that `x` and `y` are either single arrays
# or lists of arrays.
@@ -1067,22 +1195,13 @@ class Model(Network):
'in their call() signatures do not yet support shape inference. File '
'a feature request if this limitation bothers you.')
if self.__class__.__name__ == 'Sequential':
- # Note: we can't test whether the model is `Sequential` via `isinstance`
- # since `Sequential` depends on `Model`.
- if isinstance(inputs, list):
- assert len(inputs) == 1
- inputs = inputs[0]
-
if tensor_util.is_tensor(inputs):
- if context.executing_eagerly():
- input_shape = (None,) + tuple(inputs.get_shape().as_list()[1:])
- self.build(input_shape=input_shape)
- else:
- self.symbolic_set_inputs(inputs)
+ input_shape = (None,) + tuple(inputs.get_shape().as_list()[1:])
+ self.build(input_shape=input_shape)
else:
input_shape = (None,) + inputs.shape[1:]
self.build(input_shape=input_shape)
- elif context.executing_eagerly():
+ if context.executing_eagerly():
self._eager_set_inputs(inputs)
else:
self._symbolic_set_inputs(inputs, training=training)
@@ -1273,7 +1392,7 @@ class Model(Network):
0 = silent, 1 = progress bar, 2 = one line per epoch.
callbacks: List of `keras.callbacks.Callback` instances.
List of callbacks to apply during training.
- See [callbacks](/callbacks).
+ See [callbacks](/api_docs/python/tf/keras/callbacks).
validation_split: Float between 0 and 1.
Fraction of the training data to be used as validation data.
The model will set apart this fraction of the training data,
@@ -1891,6 +2010,10 @@ class Model(Network):
Raises:
ValueError: In case the generator yields data in an invalid format.
"""
+ if self._distribution_strategy:
+ raise NotImplementedError('`fit_generator` is not supported for '
+ 'models compiled with DistributionStrategy.')
+
if not self.built and not self._is_graph_network:
raise NotImplementedError(
'`fit_generator` is not yet enabled for unbuilt Model subclasses')
@@ -1958,6 +2081,10 @@ class Model(Network):
Raises:
ValueError: In case the generator yields data in an invalid format.
"""
+ if self._distribution_strategy:
+ raise NotImplementedError('`evaluate_generator` is not supported for '
+ 'models compiled with DistributionStrategy.')
+
if not self.built and not self._is_graph_network:
raise NotImplementedError(
'`evaluate_generator` is not yet enabled for '
@@ -2012,6 +2139,10 @@ class Model(Network):
Raises:
ValueError: In case the generator yields data in an invalid format.
"""
+ if self._distribution_strategy:
+ raise NotImplementedError('`predict_generator` is not supported for '
+ 'models compiled with DistributionStrategy.')
+
if not self.built and not self._is_graph_network:
raise NotImplementedError(
'`predict_generator` is not yet enabled for unbuilt Model subclasses')
@@ -2025,6 +2156,21 @@ class Model(Network):
use_multiprocessing=use_multiprocessing,
verbose=verbose)
+ def _get_callback_model(self):
+ """Returns the Callback Model for this Model."""
+
+ if hasattr(self, '_replicated_model') and self._replicated_model:
+ # When using training_distributed, we set the callback model
+ # to an instance of the `DistributedModel` that we create in
+ # the `compile` call. The `DistributedModel` is initialized
+ # with the first replicated model. We need to set the callback
+ # model to a DistributedModel to allow us to override saving
+ # and loading weights when we checkpoint the model during training.
+ return self._replicated_model
+ if hasattr(self, 'callback_model') and self.callback_model:
+ return self.callback_model
+ return self
+
class DistributedCallbackModel(Model):
"""Model that is used for callbacks with DistributionStrategy."""
@@ -2065,4 +2211,3 @@ class DistributedCallbackModel(Model):
logging.warning('You are accessing attribute ' + item + 'of the'
'DistributedCallbackModel that may not have been set'
'correctly.')
-