diff options
author | Mustafa Ispir <ispir@google.com> | 2016-09-16 12:41:33 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-09-16 13:47:30 -0700 |
commit | c1cdc5c41f880a1d8fb3f24892246730f90f4051 (patch) | |
tree | b46e934b3d15a4b63eb0162e9287befb874fb897 | |
parent | 236a1c7f7d577d9758d06a6c382035065075578d (diff) |
Added contract tests to wide-n-deep estimators.
Change: 133422226
6 files changed, 74 insertions, 2 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py index 768e8045e6..c3685cc6f6 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py @@ -25,6 +25,7 @@ import numpy as np import tensorflow as tf from tensorflow.contrib.learn.python.learn.estimators import _sklearn +from tensorflow.contrib.learn.python.learn.estimators import estimator_test_utils def _get_quantile_based_buckets(feature_values, num_buckets): @@ -59,6 +60,10 @@ def _iris_input_logistic_fn(): class DNNLinearCombinedClassifierTest(tf.test.TestCase): + def testEstimatorContract(self): + estimator_test_utils.assert_estimator_contract( + self, tf.contrib.learn.DNNLinearCombinedClassifier) + def testLogisticRegression_MatrixData(self): """Tests binary classification using matrix data as input.""" iris = _prepare_iris_data_for_logistic_regression() @@ -576,6 +581,10 @@ class DNNLinearCombinedClassifierTest(tf.test.TestCase): class DNNLinearCombinedRegressorTest(tf.test.TestCase): + def testEstimatorContract(self): + estimator_test_utils.assert_estimator_contract( + self, tf.contrib.learn.DNNLinearCombinedRegressor) + def _input_fn_train(self): # Create 4 rows of (y = x) target = tf.constant([[100.], [3.], [2.], [2.]]) diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py index 617194c48a..64d2fe6d70 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py @@ -26,6 +26,7 @@ import numpy as np import tensorflow as tf from tensorflow.contrib.learn.python.learn.estimators import _sklearn +from tensorflow.contrib.learn.python.learn.estimators import estimator_test_utils # pylint: disable=g-import-not-at-top try: @@ -60,6 +61,10 @@ def _iris_input_multiclass_fn(): class DNNClassifierTest(tf.test.TestCase): + def testEstimatorContract(self): + estimator_test_utils.assert_estimator_contract( + self, tf.contrib.learn.DNNClassifier) + def testLogisticRegression_MatrixData(self): """Tests binary classification using matrix data as input.""" cont_features = [ @@ -508,6 +513,10 @@ class DNNClassifierTest(tf.test.TestCase): class DNNRegressorTest(tf.test.TestCase): + def testEstimatorContract(self): + estimator_test_utils.assert_estimator_contract( + self, tf.contrib.learn.DNNRegressor) + def testRegression_MatrixData(self): """Tests regression using matrix data as input.""" cont_features = [ diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index ac793706e2..5ed621794f 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -446,8 +446,6 @@ class BaseEstimator( Returns: Numpy array - value of the tensor. """ - if name.endswith(':0'): - name = name[:-2] return checkpoints.load_variable(self.model_dir, name) def get_variable_names(self): diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test_utils.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test_utils.py new file mode 100644 index 0000000000..3ba74e1c91 --- /dev/null +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test_utils.py @@ -0,0 +1,44 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utils for Estimator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import inspect + + +def assert_estimator_contract(tester, estimator_class): + """Asserts whether given estimator satisfies the expected contract. + + This doesn't check every details of contract. This test is used for that a + function is not forgotten to implement in a precanned Estimator. + + Args: + tester: A tf.test.TestCase. + estimator_class: 'type' object of pre-canned estimator. + """ + attributes = inspect.getmembers(estimator_class) + attribute_names = [a[0] for a in attributes] + + tester.assertTrue('config' in attribute_names) + tester.assertTrue('evaluate' in attribute_names) + tester.assertTrue('export' in attribute_names) + tester.assertTrue('fit' in attribute_names) + tester.assertTrue('get_variable_names' in attribute_names) + tester.assertTrue('get_variable_value' in attribute_names) + tester.assertTrue('model_dir' in attribute_names) + tester.assertTrue('predict' in attribute_names) diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear.py b/tensorflow/contrib/learn/python/learn/estimators/linear.py index a61d87e1c0..8106777616 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/linear.py +++ b/tensorflow/contrib/learn/python/learn/estimators/linear.py @@ -548,6 +548,9 @@ class LinearClassifier(evaluable.Evaluable, trainable.Trainable): def get_variable_names(self): return [name for name, _ in checkpoints.list_variables(self._model_dir)] + def get_variable_value(self, name): + return checkpoints.load_variable(self.model_dir, name) + def export(self, export_dir, input_fn=None, diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py index 6d91c4c4e6..692c03c562 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py @@ -26,6 +26,7 @@ import numpy as np import tensorflow as tf from tensorflow.contrib.learn.python.learn.estimators import _sklearn +from tensorflow.contrib.learn.python.learn.estimators import estimator_test_utils from tensorflow.contrib.learn.python.learn.metric_spec import MetricSpec @@ -38,6 +39,10 @@ def _iris_input_fn(): class LinearClassifierTest(tf.test.TestCase): + def testEstimatorContract(self): + estimator_test_utils.assert_estimator_contract( + self, tf.contrib.learn.LinearClassifier) + def testTrain(self): """Tests that loss goes down with training.""" @@ -670,6 +675,10 @@ class LinearClassifierTest(tf.test.TestCase): class LinearRegressorTest(tf.test.TestCase): + def testEstimatorContract(self): + estimator_test_utils.assert_estimator_contract( + self, tf.contrib.learn.LinearRegressor) + def testRegression(self): """Tests that loss goes down with training.""" |