aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/canned/head.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/estimator/canned/head.py')
-rw-r--r--tensorflow/python/estimator/canned/head.py22
1 files changed, 14 insertions, 8 deletions
diff --git a/tensorflow/python/estimator/canned/head.py b/tensorflow/python/estimator/canned/head.py
index b74ef1015c..da9a64c2bc 100644
--- a/tensorflow/python/estimator/canned/head.py
+++ b/tensorflow/python/estimator/canned/head.py
@@ -1398,15 +1398,21 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
weights=weights,
processed_labels=labels)
- def _eval_metric_ops(self, weights, unreduced_loss, regularization_loss):
+ def _eval_metric_ops(self, predicted_value, labels, weights, unreduced_loss,
+ regularization_loss):
"""Returns the Eval metric ops."""
keys = metric_keys.MetricKeys
# Estimator already adds a metric for loss.
eval_metric_ops = {
_summary_key(self._name, keys.LOSS_MEAN):
- metrics_lib.mean(
- values=unreduced_loss,
- weights=weights)
+ metrics_lib.mean(values=unreduced_loss, weights=weights),
+ _summary_key(self._name, keys.PREDICTION_MEAN):
+ _predictions_mean(
+ predictions=predicted_value,
+ weights=weights,
+ name=keys.PREDICTION_MEAN),
+ _summary_key(self._name, keys.LABEL_MEAN):
+ metrics_lib.mean(values=labels, weights=weights)
}
if regularization_loss is not None:
regularization_loss_key = _summary_key(
@@ -1489,13 +1495,13 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
predictions=predictions,
loss=regularized_training_loss,
eval_metrics=_create_eval_metrics_tuple(
- self._eval_metric_ops,
- {
+ self._eval_metric_ops, {
+ 'predicted_value': predicted_value,
+ 'labels': labels,
'weights': weights,
'unreduced_loss': unreduced_loss,
'regularization_loss': regularization_loss,
- }
- ))
+ }))
# Train.
if optimizer is not None: