aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Mustafa Ispir <ispir@google.com>2016-09-16 12:41:33 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-16 13:47:30 -0700
commitc1cdc5c41f880a1d8fb3f24892246730f90f4051 (patch)
treeb46e934b3d15a4b63eb0162e9287befb874fb897
parent236a1c7f7d577d9758d06a6c382035065075578d (diff)
Added contract tests to wide-n-deep estimators.
Change: 133422226
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py9
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dnn_test.py9
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator.py2
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator_test_utils.py44
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/linear.py3
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/linear_test.py9
6 files changed, 74 insertions, 2 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py
index 768e8045e6..c3685cc6f6 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py
@@ -25,6 +25,7 @@ import numpy as np
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.estimators import _sklearn
+from tensorflow.contrib.learn.python.learn.estimators import estimator_test_utils
def _get_quantile_based_buckets(feature_values, num_buckets):
@@ -59,6 +60,10 @@ def _iris_input_logistic_fn():
class DNNLinearCombinedClassifierTest(tf.test.TestCase):
+ def testEstimatorContract(self):
+ estimator_test_utils.assert_estimator_contract(
+ self, tf.contrib.learn.DNNLinearCombinedClassifier)
+
def testLogisticRegression_MatrixData(self):
"""Tests binary classification using matrix data as input."""
iris = _prepare_iris_data_for_logistic_regression()
@@ -576,6 +581,10 @@ class DNNLinearCombinedClassifierTest(tf.test.TestCase):
class DNNLinearCombinedRegressorTest(tf.test.TestCase):
+ def testEstimatorContract(self):
+ estimator_test_utils.assert_estimator_contract(
+ self, tf.contrib.learn.DNNLinearCombinedRegressor)
+
def _input_fn_train(self):
# Create 4 rows of (y = x)
target = tf.constant([[100.], [3.], [2.], [2.]])
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py
index 617194c48a..64d2fe6d70 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py
@@ -26,6 +26,7 @@ import numpy as np
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.estimators import _sklearn
+from tensorflow.contrib.learn.python.learn.estimators import estimator_test_utils
# pylint: disable=g-import-not-at-top
try:
@@ -60,6 +61,10 @@ def _iris_input_multiclass_fn():
class DNNClassifierTest(tf.test.TestCase):
+ def testEstimatorContract(self):
+ estimator_test_utils.assert_estimator_contract(
+ self, tf.contrib.learn.DNNClassifier)
+
def testLogisticRegression_MatrixData(self):
"""Tests binary classification using matrix data as input."""
cont_features = [
@@ -508,6 +513,10 @@ class DNNClassifierTest(tf.test.TestCase):
class DNNRegressorTest(tf.test.TestCase):
+ def testEstimatorContract(self):
+ estimator_test_utils.assert_estimator_contract(
+ self, tf.contrib.learn.DNNRegressor)
+
def testRegression_MatrixData(self):
"""Tests regression using matrix data as input."""
cont_features = [
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
index ac793706e2..5ed621794f 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
@@ -446,8 +446,6 @@ class BaseEstimator(
Returns:
Numpy array - value of the tensor.
"""
- if name.endswith(':0'):
- name = name[:-2]
return checkpoints.load_variable(self.model_dir, name)
def get_variable_names(self):
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test_utils.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test_utils.py
new file mode 100644
index 0000000000..3ba74e1c91
--- /dev/null
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test_utils.py
@@ -0,0 +1,44 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utils for Estimator."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import inspect
+
+
+def assert_estimator_contract(tester, estimator_class):
+ """Asserts whether given estimator satisfies the expected contract.
+
+ This doesn't check every details of contract. This test is used for that a
+ function is not forgotten to implement in a precanned Estimator.
+
+ Args:
+ tester: A tf.test.TestCase.
+ estimator_class: 'type' object of pre-canned estimator.
+ """
+ attributes = inspect.getmembers(estimator_class)
+ attribute_names = [a[0] for a in attributes]
+
+ tester.assertTrue('config' in attribute_names)
+ tester.assertTrue('evaluate' in attribute_names)
+ tester.assertTrue('export' in attribute_names)
+ tester.assertTrue('fit' in attribute_names)
+ tester.assertTrue('get_variable_names' in attribute_names)
+ tester.assertTrue('get_variable_value' in attribute_names)
+ tester.assertTrue('model_dir' in attribute_names)
+ tester.assertTrue('predict' in attribute_names)
diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear.py b/tensorflow/contrib/learn/python/learn/estimators/linear.py
index a61d87e1c0..8106777616 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/linear.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/linear.py
@@ -548,6 +548,9 @@ class LinearClassifier(evaluable.Evaluable, trainable.Trainable):
def get_variable_names(self):
return [name for name, _ in checkpoints.list_variables(self._model_dir)]
+ def get_variable_value(self, name):
+ return checkpoints.load_variable(self.model_dir, name)
+
def export(self,
export_dir,
input_fn=None,
diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py
index 6d91c4c4e6..692c03c562 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py
@@ -26,6 +26,7 @@ import numpy as np
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.estimators import _sklearn
+from tensorflow.contrib.learn.python.learn.estimators import estimator_test_utils
from tensorflow.contrib.learn.python.learn.metric_spec import MetricSpec
@@ -38,6 +39,10 @@ def _iris_input_fn():
class LinearClassifierTest(tf.test.TestCase):
+ def testEstimatorContract(self):
+ estimator_test_utils.assert_estimator_contract(
+ self, tf.contrib.learn.LinearClassifier)
+
def testTrain(self):
"""Tests that loss goes down with training."""
@@ -670,6 +675,10 @@ class LinearClassifierTest(tf.test.TestCase):
class LinearRegressorTest(tf.test.TestCase):
+ def testEstimatorContract(self):
+ estimator_test_utils.assert_estimator_contract(
+ self, tf.contrib.learn.LinearRegressor)
+
def testRegression(self):
"""Tests that loss goes down with training."""