diff options
Diffstat (limited to 'tensorflow/contrib/estimator/python/estimator/baseline_test.py')
-rw-r--r-- | tensorflow/contrib/estimator/python/estimator/baseline_test.py | 8 |
1 files changed, 7 insertions, 1 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/baseline_test.py b/tensorflow/contrib/estimator/python/estimator/baseline_test.py index d0e3e670f7..505c94e971 100644 --- a/tensorflow/contrib/estimator/python/estimator/baseline_test.py +++ b/tensorflow/contrib/estimator/python/estimator/baseline_test.py @@ -113,6 +113,8 @@ class BaselineEstimatorEvaluationTest(test.TestCase): self.assertDictEqual({ metric_keys.MetricKeys.LOSS: 18., metric_keys.MetricKeys.LOSS_MEAN: 9., + metric_keys.MetricKeys.PREDICTION_MEAN: 13., + metric_keys.MetricKeys.LABEL_MEAN: 10., ops.GraphKeys.GLOBAL_STEP: 100 }, eval_metrics) @@ -141,6 +143,8 @@ class BaselineEstimatorEvaluationTest(test.TestCase): self.assertDictEqual({ metric_keys.MetricKeys.LOSS: 27., metric_keys.MetricKeys.LOSS_MEAN: 9., + metric_keys.MetricKeys.PREDICTION_MEAN: 13., + metric_keys.MetricKeys.LABEL_MEAN: 10., ops.GraphKeys.GLOBAL_STEP: 100 }, eval_metrics) @@ -166,7 +170,9 @@ class BaselineEstimatorEvaluationTest(test.TestCase): self.assertItemsEqual( (metric_keys.MetricKeys.LOSS, metric_keys.MetricKeys.LOSS_MEAN, - ops.GraphKeys.GLOBAL_STEP), eval_metrics.keys()) + metric_keys.MetricKeys.PREDICTION_MEAN, + metric_keys.MetricKeys.LABEL_MEAN, ops.GraphKeys.GLOBAL_STEP), + eval_metrics.keys()) # Logit is bias which is [46, 58] self.assertAlmostEqual(0, eval_metrics[metric_keys.MetricKeys.LOSS]) |