diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-07-11 11:11:03 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-11 11:16:23 -0700 |
commit | e103c7b4fb2ac6ebf9472b3c2b01c35222872ae0 (patch) | |
tree | 079b6e9e51c575ce5a03f01c3a19546ff2410688 /tensorflow/contrib/estimator | |
parent | 158cd6220231fcf758a45c2dcd40d93cd0aec9e0 (diff) |
Add average prediction, average label to regression head.
PiperOrigin-RevId: 204154837
Diffstat (limited to 'tensorflow/contrib/estimator')
-rw-r--r-- | tensorflow/contrib/estimator/python/estimator/baseline_test.py | 8 |
1 files changed, 7 insertions, 1 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/baseline_test.py b/tensorflow/contrib/estimator/python/estimator/baseline_test.py index d0e3e670f7..505c94e971 100644 --- a/tensorflow/contrib/estimator/python/estimator/baseline_test.py +++ b/tensorflow/contrib/estimator/python/estimator/baseline_test.py @@ -113,6 +113,8 @@ class BaselineEstimatorEvaluationTest(test.TestCase): self.assertDictEqual({ metric_keys.MetricKeys.LOSS: 18., metric_keys.MetricKeys.LOSS_MEAN: 9., + metric_keys.MetricKeys.PREDICTION_MEAN: 13., + metric_keys.MetricKeys.LABEL_MEAN: 10., ops.GraphKeys.GLOBAL_STEP: 100 }, eval_metrics) @@ -141,6 +143,8 @@ class BaselineEstimatorEvaluationTest(test.TestCase): self.assertDictEqual({ metric_keys.MetricKeys.LOSS: 27., metric_keys.MetricKeys.LOSS_MEAN: 9., + metric_keys.MetricKeys.PREDICTION_MEAN: 13., + metric_keys.MetricKeys.LABEL_MEAN: 10., ops.GraphKeys.GLOBAL_STEP: 100 }, eval_metrics) @@ -166,7 +170,9 @@ class BaselineEstimatorEvaluationTest(test.TestCase): self.assertItemsEqual( (metric_keys.MetricKeys.LOSS, metric_keys.MetricKeys.LOSS_MEAN, - ops.GraphKeys.GLOBAL_STEP), eval_metrics.keys()) + metric_keys.MetricKeys.PREDICTION_MEAN, + metric_keys.MetricKeys.LABEL_MEAN, ops.GraphKeys.GLOBAL_STEP), + eval_metrics.keys()) # Logit is bias which is [46, 58] self.assertAlmostEqual(0, eval_metrics[metric_keys.MetricKeys.LOSS]) |