aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Anjali Sridhar <anjalisridhar@google.com>2018-04-19 10:26:54 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-19 10:29:31 -0700
commitba3bc495bbf1140e9375e1ec03c3ff788b8ebc6e (patch)
tree5654aca90213dfb01c24e0e1971c3e38ab329b92
parent5fbd21e3bbd4f89dd2c6eed8a63b66ee2eff40a0 (diff)
Add metric names to model.metrics_names in compile for keras models run in eager execution. This prevents us from dropping metrics when we run model.evaluate.
PiperOrigin-RevId: 193536341
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/training.py29
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/training_eager.py39
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/training_eager_test.py12
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/training_test.py26
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/training_utils.py62
5 files changed, 109 insertions, 59 deletions
diff --git a/tensorflow/python/keras/_impl/keras/engine/training.py b/tensorflow/python/keras/_impl/keras/engine/training.py
index 7c46743814..012d9ceea4 100644
--- a/tensorflow/python/keras/_impl/keras/engine/training.py
+++ b/tensorflow/python/keras/_impl/keras/engine/training.py
@@ -276,6 +276,8 @@ class Model(Network):
self.metrics_names.append(self.output_names[i] + '_loss')
self.nested_metrics = training_utils.collect_metrics(metrics,
self.output_names)
+ with K.name_scope('metrics'):
+ training_utils.populate_metric_names(self)
self._feed_sample_weight_modes = []
for i in range(len(self.outputs)):
self._feed_sample_weight_modes.append(None)
@@ -462,7 +464,6 @@ class Model(Network):
output_weighted_metrics = nested_weighted_metrics[i]
def handle_metrics(metrics, weights=None):
- metric_name_prefix = 'weighted_' if weights is not None else ''
for metric in metrics:
if metric in ('accuracy', 'acc', 'crossentropy', 'ce'):
@@ -489,39 +490,19 @@ class Model(Network):
metric_fn = metrics_module.categorical_accuracy
elif metric in ('crossentropy', 'ce'):
metric_fn = metrics_module.categorical_crossentropy
- if metric in ('accuracy', 'acc'):
- suffix = 'acc'
- elif metric in ('crossentropy', 'ce'):
- suffix = 'ce'
weighted_metric_fn = training_utils.weighted_masked_objective(
metric_fn)
- metric_name = metric_name_prefix + suffix
else:
metric_fn = metrics_module.get(metric)
weighted_metric_fn = training_utils.weighted_masked_objective(
metric_fn)
- # Get metric name as string
- if hasattr(metric_fn, 'name'):
- metric_name = metric_fn.name
- else:
- metric_name = metric_fn.__name__
- metric_name = metric_name_prefix + metric_name
-
+ metric_name = training_utils.get_base_metric_name(
+ metric, weighted=weights is not None)
with K.name_scope(metric_name):
metric_result = weighted_metric_fn(
y_true, y_pred, weights=weights, mask=masks[i])
- # Append to self.metrics_names, self.metric_tensors,
- # self.stateful_metric_names
- if len(self.output_names) > 1:
- metric_name = '%s_%s' % (self.output_names[i], metric_name)
- # Dedupe name
- j = 1
- base_metric_name = metric_name
- while metric_name in self.metrics_names:
- metric_name = '%s_%d' % (base_metric_name, j)
- j += 1
- self.metrics_names.append(metric_name)
+ training_utils.add_metric_name(self, metric_name, i)
self.metrics_tensors.append(metric_result)
# Keep track of state updates created by
diff --git a/tensorflow/python/keras/_impl/keras/engine/training_eager.py b/tensorflow/python/keras/_impl/keras/engine/training_eager.py
index 695669d9ee..ad239d6151 100644
--- a/tensorflow/python/keras/_impl/keras/engine/training_eager.py
+++ b/tensorflow/python/keras/_impl/keras/engine/training_eager.py
@@ -100,7 +100,7 @@ def _eager_metrics_fn(model, outputs, targets):
metric_names.append(metric_name)
metric_results.append(backend.mean(metric_result))
- return metric_names, metric_results
+ return metric_results
def _model_loss(model, inputs, targets, sample_weights=None, training=False):
@@ -151,7 +151,12 @@ def _model_loss(model, inputs, targets, sample_weights=None, training=False):
with backend.name_scope(model.output_names[i] + '_loss'):
output_loss = weighted_masked_fn(
targets[i], outs[i], weights, mask=mask)
- loss_metrics.append(backend.mean(output_loss))
+ # If the number of outputs is 1 then we don't append the loss metric
+ # associated with each model output. When there are multiple outputs
+ # associated with a model, each output's loss is calculated and returned
+ # as part of the loss_metrics.
+ if len(model.outputs) > 1:
+ loss_metrics.append(backend.mean(output_loss))
loss_weight = model.loss_weights_list[i]
if total_loss is None:
@@ -274,7 +279,7 @@ def train_on_batch(model, inputs, targets, sample_weights=None):
model, inputs, targets, sample_weights=sample_weights, training=True)
if not isinstance(outs, list):
outs = [outs]
- _, metrics_results = _eager_metrics_fn(
+ metrics_results = _eager_metrics_fn(
model, outs, targets)
if not isinstance(loss, list):
loss = [loss]
@@ -304,7 +309,7 @@ def test_on_batch(model, inputs, targets, sample_weights=None):
model, inputs, targets, sample_weights=sample_weights, training=False)
if not isinstance(outs, list):
outs = [outs]
- _, metrics_results = _eager_metrics_fn(
+ metrics_results = _eager_metrics_fn(
model, outs, targets)
if not isinstance(loss, list):
loss = [loss]
@@ -498,34 +503,12 @@ def fit_loop(
for l, o in zip(out_labels, outs):
batch_logs[l] = o
# Required for Eager mode
- metrics_names, metrics_results = _eager_metrics_fn(
- model, outs, targets_batch)
+ metrics_results = _eager_metrics_fn(model, outs, targets_batch)
batch_logs['loss'] = tensor_util.constant_value(backend.mean(loss))
- # TODO(anjalisridhar): Move this to compile to avoid duplicate code.
- # In graph mode we set the metric names in compile. However in
- # Eager mode we calculate the metrics for each batch in fit_loop.
- # We could calculate the metric names and functions in compile.
- # This would avoid setting the callback parameters separately.
- # We need to do this for the first iteration alone
- for m in metrics_names:
- if m not in callback_metrics:
- callback_metrics.append(m)
-
- callbacks.set_params({
- 'batch_size': batch_size,
- 'epochs': epochs,
- 'steps': steps_per_epoch,
- 'samples': num_train_samples,
- 'verbose': verbose,
- 'do_validation': do_validation,
- 'metrics': callback_metrics or [],
- })
-
for k, v in zip(model.metrics_names,
[backend.mean(loss)] + loss_metrics + metrics_results):
batch_logs[k] = tensor_util.constant_value(v)
-
callbacks.on_batch_end(batch_index, batch_logs)
if callback_model.stop_training:
break
@@ -611,7 +594,7 @@ def test_loop(model, inputs, targets,
targets_batch,
sample_weights=sample_weights_batch,
training=False)
- _, metrics_results = _eager_metrics_fn(model, loss_outs, targets_batch)
+ metrics_results = _eager_metrics_fn(model, loss_outs, targets_batch)
batch_outs = []
for _, v in zip(model.metrics_names,
[backend.mean(loss)] + loss_metrics + metrics_results):
diff --git a/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py b/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py
index ed0f91ee1e..deaf1d1306 100644
--- a/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py
+++ b/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py
@@ -212,7 +212,7 @@ class TrainingTest(test.TestCase):
optimizer = RMSPropOptimizer(learning_rate=0.001)
loss = 'mse'
loss_weights = [1., 0.5]
- metrics = ['mae']
+ metrics = ['acc', 'mae']
model.compile(
optimizer,
loss,
@@ -231,20 +231,20 @@ class TrainingTest(test.TestCase):
[input_a_np, input_b_np], [output_d_np, output_e_np],
batch_size=5,
verbose=0)
- self.assertEqual(len(out), 5)
+ self.assertEqual(len(out), 7)
out = model.evaluate(
[input_a_np, input_b_np], [output_d_np, output_e_np],
batch_size=5,
verbose=1)
- self.assertEqual(len(out), 5)
+ self.assertEqual(len(out), 7)
out = model.evaluate(
[input_a_np, input_b_np], [output_d_np, output_e_np],
batch_size=5,
verbose=2)
- self.assertEqual(len(out), 5)
+ self.assertEqual(len(out), 7)
out = model.test_on_batch([input_a_np, input_b_np],
[output_d_np, output_e_np])
- self.assertEqual(len(out), 5)
+ self.assertEqual(len(out), 7)
# Test evaluate with dictionary inputs
model.evaluate(
@@ -625,7 +625,6 @@ class LossWeightingTest(test.TestCase):
bad_w_np = np.random.random((10, 2, 2))
model.fit(x_np, [y_np, y_np], epochs=1, sample_weight={'1': bad_w_np})
-
class CorrectnessTest(test.TestCase):
@tf_test_util.run_in_graph_and_eager_modes()
@@ -649,7 +648,6 @@ class CorrectnessTest(test.TestCase):
self.assertEqual(
np.around(history.history['loss'][-1], decimals=4), 0.6173)
-
if __name__ == '__main__':
ops.enable_eager_execution()
test.main()
diff --git a/tensorflow/python/keras/_impl/keras/engine/training_test.py b/tensorflow/python/keras/_impl/keras/engine/training_test.py
index 6699fd5212..d9281436de 100644
--- a/tensorflow/python/keras/_impl/keras/engine/training_test.py
+++ b/tensorflow/python/keras/_impl/keras/engine/training_test.py
@@ -24,12 +24,15 @@ import unittest
import numpy as np
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.keras._impl import keras
from tensorflow.python.keras._impl.keras import testing_utils
from tensorflow.python.keras._impl.keras.engine.training_utils import weighted_masked_objective
from tensorflow.python.keras._impl.keras.utils.generic_utils import slice_arrays
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
+from tensorflow.python.training.rmsprop import RMSPropOptimizer
+
try:
import scipy.sparse as scipy_sparse # pylint: disable=g-import-not-at-top
@@ -1684,6 +1687,29 @@ class TestTrainingWithDataTensors(test.TestCase):
model.train_on_batch([input_a_np, input_b_np],
[output_a_np, output_b_np])
+ @tf_test_util.run_in_graph_and_eager_modes()
+ def test_metric_names_are_identical_in_graph_and_eager(self):
+ a = keras.layers.Input(shape=(3,), name='input_a')
+ b = keras.layers.Input(shape=(3,), name='input_b')
+
+ dense = keras.layers.Dense(4, name='dense')
+ c = dense(a)
+ d = dense(b)
+ e = keras.layers.Dropout(0.5, name='dropout')(c)
+
+ model = keras.models.Model([a, b], [d, e])
+
+ optimizer = RMSPropOptimizer(learning_rate=0.001)
+ loss = 'mse'
+ loss_weights = [1., 0.5]
+ metrics = ['mae', 'acc']
+ model.compile(optimizer, loss, metrics=metrics, loss_weights=loss_weights)
+ reference_metric_names = ['loss', 'dense_loss', 'dropout_loss',
+ 'dense_mean_absolute_error',
+ 'dense_acc',
+ 'dropout_mean_absolute_error',
+ 'dropout_acc']
+ self.assertEqual(reference_metric_names, model.metrics_names)
if __name__ == '__main__':
# Bazel sets these environment variables to very long paths.
diff --git a/tensorflow/python/keras/_impl/keras/engine/training_utils.py b/tensorflow/python/keras/_impl/keras/engine/training_utils.py
index 48afe48e6c..662938f421 100644
--- a/tensorflow/python/keras/_impl/keras/engine/training_utils.py
+++ b/tensorflow/python/keras/_impl/keras/engine/training_utils.py
@@ -26,6 +26,7 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras import losses
+from tensorflow.python.keras._impl.keras import metrics as metrics_module
from tensorflow.python.ops import math_ops
@@ -552,3 +553,64 @@ def standardize_weights(y,
def has_symbolic_tensors(ls):
return (any(tensor_util.is_tensor(v) for v in ls)
and not context.executing_eagerly())
+
+
+def populate_metric_names(model):
+ for i in range(len(model.outputs)):
+ metrics = model.nested_metrics[i]
+ for metric in metrics:
+ base_metric_name = get_base_metric_name(metric)
+ add_metric_name(model, base_metric_name, i)
+
+
+def get_base_metric_name(metric, weighted=False):
+ """Returns the metric name given the metric function.
+
+ Arguments:
+ metric: Metric function name or reference.
+ weighted: Boolean indicating if the metric for which we are adding
+ names is weighted.
+
+ Returns:
+ a metric name.
+ """
+ metric_name_prefix = 'weighted_' if weighted else ''
+ if metric in ('accuracy', 'acc', 'crossentropy', 'ce'):
+ if metric in ('accuracy', 'acc'):
+ suffix = 'acc'
+ elif metric in ('crossentropy', 'ce'):
+ suffix = 'ce'
+ metric_name = metric_name_prefix + suffix
+ else:
+ metric_fn = metrics_module.get(metric)
+ # Get metric name as string
+ if hasattr(metric_fn, 'name'):
+ metric_name = metric_fn.name
+ else:
+ metric_name = metric_fn.__name__
+ metric_name = metric_name_prefix + metric_name
+
+ return metric_name
+
+
+def add_metric_name(model, metric_name, index):
+ """Makes the metric name unique and adds it to the model's metric name list.
+
+ If there are multiple outputs for which the metrics are calculated, the
+ metric names have to be made unique by appending an integer.
+
+ Arguments:
+ model: Model to which we are adding metric names.
+ metric_name: Metric name that corresponds to the metric specified by the
+ user. For example: 'acc'
+ index: The index of the model output for which the metric name is being
+ added.
+ """
+ if len(model.output_names) > 1:
+ metric_name = '%s_%s' % (model.output_names[index], metric_name)
+ j = 1
+ base_metric_name = metric_name
+ while metric_name in model.metrics_names:
+ metric_name = '%s_%d' % (base_metric_name, j)
+ j += 1
+ model.metrics_names.append(metric_name)