aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/estimator
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-11 11:11:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-11 11:16:23 -0700
commite103c7b4fb2ac6ebf9472b3c2b01c35222872ae0 (patch)
tree079b6e9e51c575ce5a03f01c3a19546ff2410688 /tensorflow/contrib/estimator
parent158cd6220231fcf758a45c2dcd40d93cd0aec9e0 (diff)
Add average prediction, average label to regression head.
PiperOrigin-RevId: 204154837
Diffstat (limited to 'tensorflow/contrib/estimator')
-rw-r--r--tensorflow/contrib/estimator/python/estimator/baseline_test.py8
1 files changed, 7 insertions, 1 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/baseline_test.py b/tensorflow/contrib/estimator/python/estimator/baseline_test.py
index d0e3e670f7..505c94e971 100644
--- a/tensorflow/contrib/estimator/python/estimator/baseline_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/baseline_test.py
@@ -113,6 +113,8 @@ class BaselineEstimatorEvaluationTest(test.TestCase):
self.assertDictEqual({
metric_keys.MetricKeys.LOSS: 18.,
metric_keys.MetricKeys.LOSS_MEAN: 9.,
+ metric_keys.MetricKeys.PREDICTION_MEAN: 13.,
+ metric_keys.MetricKeys.LABEL_MEAN: 10.,
ops.GraphKeys.GLOBAL_STEP: 100
}, eval_metrics)
@@ -141,6 +143,8 @@ class BaselineEstimatorEvaluationTest(test.TestCase):
self.assertDictEqual({
metric_keys.MetricKeys.LOSS: 27.,
metric_keys.MetricKeys.LOSS_MEAN: 9.,
+ metric_keys.MetricKeys.PREDICTION_MEAN: 13.,
+ metric_keys.MetricKeys.LABEL_MEAN: 10.,
ops.GraphKeys.GLOBAL_STEP: 100
}, eval_metrics)
@@ -166,7 +170,9 @@ class BaselineEstimatorEvaluationTest(test.TestCase):
self.assertItemsEqual(
(metric_keys.MetricKeys.LOSS, metric_keys.MetricKeys.LOSS_MEAN,
- ops.GraphKeys.GLOBAL_STEP), eval_metrics.keys())
+ metric_keys.MetricKeys.PREDICTION_MEAN,
+ metric_keys.MetricKeys.LABEL_MEAN, ops.GraphKeys.GLOBAL_STEP),
+ eval_metrics.keys())
# Logit is bias which is [46, 58]
self.assertAlmostEqual(0, eval_metrics[metric_keys.MetricKeys.LOSS])