aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/learn/python/learn/estimators/estimators_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/learn/python/learn/estimators/estimators_test.py')
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimators_test.py7
1 files changed, 7 insertions, 0 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimators_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimators_test.py
index 00032d9f91..7372bb7a1a 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimators_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimators_test.py
@@ -30,6 +30,7 @@ import numpy as np
from tensorflow.contrib.learn.python import learn
from tensorflow.contrib.learn.python.learn import datasets
+from tensorflow.contrib.learn.python.learn import metric_spec
from tensorflow.contrib.learn.python.learn.estimators import estimator as estimator_lib
from tensorflow.contrib.learn.python.learn.estimators._sklearn import accuracy_score
from tensorflow.contrib.learn.python.learn.estimators._sklearn import train_test_split
@@ -74,6 +75,12 @@ class FeatureEngineeringFunctionTest(test.TestCase):
prediction = next(estimator.predict(input_fn=input_fn, as_iterable=True))
# predictions = transformed_x (9)
self.assertEqual(9., prediction)
+ metrics = estimator.evaluate(
+ input_fn=input_fn, steps=1,
+ metrics={"label":
+ metric_spec.MetricSpec(lambda predictions, labels: labels)})
+ # labels = transformed_y (99)
+ self.assertEqual(99., metrics["label"])
def testNoneFeatureEngineeringFn(self):