aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/learn/python/learn/estimators/estimator_test.py')
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator_test.py86
1 files changed, 84 insertions, 2 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
index 5ebc299b57..3405005327 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
@@ -91,7 +91,18 @@ def boston_eval_fn():
0)
+def extract(data, key):
+ if isinstance(data, dict):
+ assert key in data
+ return data[key]
+ else:
+ return data
+
+
def linear_model_params_fn(features, labels, mode, params):
+ features = extract(features, 'input')
+ labels = extract(labels, 'labels')
+
assert mode in (
tf.contrib.learn.ModeKeys.TRAIN,
tf.contrib.learn.ModeKeys.EVAL,
@@ -106,6 +117,8 @@ def linear_model_params_fn(features, labels, mode, params):
def linear_model_fn(features, labels, mode):
+ features = extract(features, 'input')
+ labels = extract(labels, 'labels')
assert mode in (
tf.contrib.learn.ModeKeys.TRAIN,
tf.contrib.learn.ModeKeys.EVAL,
@@ -140,8 +153,8 @@ def linear_model_fn_with_model_fn_ops(features, labels, mode):
def logistic_model_no_mode_fn(features, labels):
- if isinstance(labels, dict):
- labels = labels['labels']
+ features = extract(features, 'input')
+ labels = extract(labels, 'labels')
labels = tf.one_hot(labels, 3, 1, 0)
prediction, loss = (
tf.contrib.learn.models.logistic_regression_zero_init(features, labels)
@@ -346,6 +359,34 @@ class EstimatorTest(tf.test.TestCase):
with self.assertRaises(tf.contrib.learn.NotFittedError):
est.predict(x=boston.data)
+ def testContinueTrainingDictionaryInput(self):
+ boston = tf.contrib.learn.datasets.load_boston()
+ output_dir = tempfile.mkdtemp()
+ est = tf.contrib.learn.Estimator(model_fn=linear_model_fn,
+ model_dir=output_dir)
+ boston_input = {'input': boston.data}
+ float64_target = {'labels': boston.target.astype(np.float64)}
+ est.fit(x=boston_input, y=float64_target, steps=50)
+ scores = est.evaluate(
+ x=boston_input,
+ y=float64_target,
+ metrics={'MSE': tf.contrib.metrics.streaming_mean_squared_error})
+ del est
+ # Create another estimator object with the same output dir.
+ est2 = tf.contrib.learn.Estimator(model_fn=linear_model_fn,
+ model_dir=output_dir)
+
+ # Check we can evaluate and predict.
+ scores2 = est2.evaluate(
+ x=boston_input,
+ y=float64_target,
+ metrics={'MSE': tf.contrib.metrics.streaming_mean_squared_error})
+ self.assertAllClose(scores2['MSE'],
+ scores['MSE'])
+ predictions = np.array(list(est2.predict(x=boston_input)))
+ other_score = _sklearn.mean_squared_error(predictions, float64_target['labels'])
+ self.assertAllClose(other_score, scores['MSE'])
+
def testContinueTraining(self):
boston = tf.contrib.learn.datasets.load_boston()
output_dir = tempfile.mkdtemp()
@@ -405,6 +446,22 @@ class EstimatorTest(tf.test.TestCase):
self.assertTrue('global_step' in scores)
self.assertEqual(100, scores['global_step'])
+ def testBostonAllDictionaryInput(self):
+ boston = tf.contrib.learn.datasets.load_boston()
+ est = tf.contrib.learn.Estimator(model_fn=linear_model_fn)
+ boston_input = {'input': boston.data}
+ float64_target = {'labels': boston.target.astype(np.float64)}
+ est.fit(x=boston_input, y=float64_target, steps=100)
+ scores = est.evaluate(
+ x=boston_input,
+ y=float64_target,
+ metrics={'MSE': tf.contrib.metrics.streaming_mean_squared_error})
+ predictions = np.array(list(est.predict(x=boston_input)))
+ other_score = _sklearn.mean_squared_error(predictions, boston.target)
+ self.assertAllClose(other_score, scores['MSE'])
+ self.assertTrue('global_step' in scores)
+ self.assertEqual(scores['global_step'], 100)
+
def testIrisAll(self):
iris = tf.contrib.learn.datasets.load_iris()
est = tf.contrib.learn.SKCompat(
@@ -428,6 +485,31 @@ class EstimatorTest(tf.test.TestCase):
self.assertTrue('global_step' in scores)
self.assertEqual(100, scores['global_step'])
+ def testIrisAllDictionaryInput(self):
+ iris = tf.contrib.learn.datasets.load_iris()
+ est = tf.contrib.learn.Estimator(model_fn=logistic_model_no_mode_fn)
+ iris_data = {'input': iris.data}
+ iris_target = {'labels': iris.target}
+ est.fit(iris_data, iris_target, steps=100)
+ scores = est.evaluate(
+ x=iris_data,
+ y=iris_target,
+ metrics={('accuracy', 'class'): tf.contrib.metrics.streaming_accuracy})
+ predictions = list(est.predict(x=iris_data))
+ predictions_class = list(est.predict(x=iris_data, outputs=['class']))
+ self.assertEqual(len(predictions), iris.target.shape[0])
+ classes_batch = np.array([p['class'] for p in predictions])
+ self.assertAllClose(
+ classes_batch,
+ np.array([p['class'] for p in predictions_class]))
+ self.assertAllClose(
+ classes_batch,
+ np.argmax(np.array([p['prob'] for p in predictions]), axis=1))
+ other_score = _sklearn.accuracy_score(iris.target, classes_batch)
+ self.assertAllClose(other_score, scores['accuracy'])
+ self.assertTrue('global_step' in scores)
+ self.assertEqual(scores['global_step'], 100)
+
def testIrisInputFn(self):
iris = tf.contrib.learn.datasets.load_iris()
est = tf.contrib.learn.Estimator(model_fn=logistic_model_no_mode_fn)