diff options
Diffstat (limited to 'tensorflow/contrib/learn/python/learn/estimators/estimator_test.py')
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/estimator_test.py | 86 |
1 files changed, 84 insertions, 2 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py index 5ebc299b57..3405005327 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py @@ -91,7 +91,18 @@ def boston_eval_fn(): 0) +def extract(data, key): + if isinstance(data, dict): + assert key in data + return data[key] + else: + return data + + def linear_model_params_fn(features, labels, mode, params): + features = extract(features, 'input') + labels = extract(labels, 'labels') + assert mode in ( tf.contrib.learn.ModeKeys.TRAIN, tf.contrib.learn.ModeKeys.EVAL, @@ -106,6 +117,8 @@ def linear_model_params_fn(features, labels, mode, params): def linear_model_fn(features, labels, mode): + features = extract(features, 'input') + labels = extract(labels, 'labels') assert mode in ( tf.contrib.learn.ModeKeys.TRAIN, tf.contrib.learn.ModeKeys.EVAL, @@ -140,8 +153,8 @@ def linear_model_fn_with_model_fn_ops(features, labels, mode): def logistic_model_no_mode_fn(features, labels): - if isinstance(labels, dict): - labels = labels['labels'] + features = extract(features, 'input') + labels = extract(labels, 'labels') labels = tf.one_hot(labels, 3, 1, 0) prediction, loss = ( tf.contrib.learn.models.logistic_regression_zero_init(features, labels) @@ -346,6 +359,34 @@ class EstimatorTest(tf.test.TestCase): with self.assertRaises(tf.contrib.learn.NotFittedError): est.predict(x=boston.data) + def testContinueTrainingDictionaryInput(self): + boston = tf.contrib.learn.datasets.load_boston() + output_dir = tempfile.mkdtemp() + est = tf.contrib.learn.Estimator(model_fn=linear_model_fn, + model_dir=output_dir) + boston_input = {'input': boston.data} + float64_target = {'labels': boston.target.astype(np.float64)} + est.fit(x=boston_input, y=float64_target, steps=50) + scores = est.evaluate( + x=boston_input, + y=float64_target, + metrics={'MSE': tf.contrib.metrics.streaming_mean_squared_error}) + del est + # Create another estimator object with the same output dir. + est2 = tf.contrib.learn.Estimator(model_fn=linear_model_fn, + model_dir=output_dir) + + # Check we can evaluate and predict. + scores2 = est2.evaluate( + x=boston_input, + y=float64_target, + metrics={'MSE': tf.contrib.metrics.streaming_mean_squared_error}) + self.assertAllClose(scores2['MSE'], + scores['MSE']) + predictions = np.array(list(est2.predict(x=boston_input))) + other_score = _sklearn.mean_squared_error(predictions, float64_target['labels']) + self.assertAllClose(other_score, scores['MSE']) + def testContinueTraining(self): boston = tf.contrib.learn.datasets.load_boston() output_dir = tempfile.mkdtemp() @@ -405,6 +446,22 @@ class EstimatorTest(tf.test.TestCase): self.assertTrue('global_step' in scores) self.assertEqual(100, scores['global_step']) + def testBostonAllDictionaryInput(self): + boston = tf.contrib.learn.datasets.load_boston() + est = tf.contrib.learn.Estimator(model_fn=linear_model_fn) + boston_input = {'input': boston.data} + float64_target = {'labels': boston.target.astype(np.float64)} + est.fit(x=boston_input, y=float64_target, steps=100) + scores = est.evaluate( + x=boston_input, + y=float64_target, + metrics={'MSE': tf.contrib.metrics.streaming_mean_squared_error}) + predictions = np.array(list(est.predict(x=boston_input))) + other_score = _sklearn.mean_squared_error(predictions, boston.target) + self.assertAllClose(other_score, scores['MSE']) + self.assertTrue('global_step' in scores) + self.assertEqual(scores['global_step'], 100) + def testIrisAll(self): iris = tf.contrib.learn.datasets.load_iris() est = tf.contrib.learn.SKCompat( @@ -428,6 +485,31 @@ class EstimatorTest(tf.test.TestCase): self.assertTrue('global_step' in scores) self.assertEqual(100, scores['global_step']) + def testIrisAllDictionaryInput(self): + iris = tf.contrib.learn.datasets.load_iris() + est = tf.contrib.learn.Estimator(model_fn=logistic_model_no_mode_fn) + iris_data = {'input': iris.data} + iris_target = {'labels': iris.target} + est.fit(iris_data, iris_target, steps=100) + scores = est.evaluate( + x=iris_data, + y=iris_target, + metrics={('accuracy', 'class'): tf.contrib.metrics.streaming_accuracy}) + predictions = list(est.predict(x=iris_data)) + predictions_class = list(est.predict(x=iris_data, outputs=['class'])) + self.assertEqual(len(predictions), iris.target.shape[0]) + classes_batch = np.array([p['class'] for p in predictions]) + self.assertAllClose( + classes_batch, + np.array([p['class'] for p in predictions_class])) + self.assertAllClose( + classes_batch, + np.argmax(np.array([p['prob'] for p in predictions]), axis=1)) + other_score = _sklearn.accuracy_score(iris.target, classes_batch) + self.assertAllClose(other_score, scores['accuracy']) + self.assertTrue('global_step' in scores) + self.assertEqual(scores['global_step'], 100) + def testIrisInputFn(self): iris = tf.contrib.learn.datasets.load_iris() est = tf.contrib.learn.Estimator(model_fn=logistic_model_no_mode_fn) |