diff options
Diffstat (limited to 'tensorflow/contrib/learn/python/learn/estimators/estimators_test.py')
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/estimators_test.py | 7 |
1 files changed, 7 insertions, 0 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimators_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimators_test.py index 00032d9f91..7372bb7a1a 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimators_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimators_test.py @@ -30,6 +30,7 @@ import numpy as np from tensorflow.contrib.learn.python import learn from tensorflow.contrib.learn.python.learn import datasets +from tensorflow.contrib.learn.python.learn import metric_spec from tensorflow.contrib.learn.python.learn.estimators import estimator as estimator_lib from tensorflow.contrib.learn.python.learn.estimators._sklearn import accuracy_score from tensorflow.contrib.learn.python.learn.estimators._sklearn import train_test_split @@ -74,6 +75,12 @@ class FeatureEngineeringFunctionTest(test.TestCase): prediction = next(estimator.predict(input_fn=input_fn, as_iterable=True)) # predictions = transformed_x (9) self.assertEqual(9., prediction) + metrics = estimator.evaluate( + input_fn=input_fn, steps=1, + metrics={"label": + metric_spec.MetricSpec(lambda predictions, labels: labels)}) + # labels = transformed_y (99) + self.assertEqual(99., metrics["label"]) def testNoneFeatureEngineeringFn(self): |