aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Mustafa Ispir <ispir@google.com>2017-01-03 17:37:56 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-03 17:47:05 -0800
commit6703501903e1920b55fb76a2ee85398a6e296bf9 (patch)
tree7f126f51c77e7e8e31a7b0e75e61b2675c275300 /tensorflow
parent7c63520043f3589c7c7f9df84562b158ed6408ff (diff)
Added Experiment integration tests with custom Estimator, linear/dnn/combined Estimators.
Change: 143505386
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py25
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dnn_test.py26
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator_test.py8
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/linear_test.py26
4 files changed, 85 insertions, 0 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py
index 404d2eb2a8..259c894d26 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py
@@ -32,6 +32,7 @@ import numpy as np
from tensorflow.contrib.framework.python.ops import variables
from tensorflow.contrib.layers.python.layers import feature_column
+from tensorflow.contrib.learn.python.learn import experiment
from tensorflow.contrib.learn.python.learn.datasets import base
from tensorflow.contrib.learn.python.learn.estimators import _sklearn
from tensorflow.contrib.learn.python.learn.estimators import dnn_linear_combined
@@ -149,6 +150,18 @@ class DNNLinearCombinedClassifierTest(test.TestCase):
estimator_test_utils.assert_estimator_contract(
self, dnn_linear_combined.DNNLinearCombinedClassifier)
+ def testExperimentIntegration(self):
+ cont_features = [feature_column.real_valued_column('feature', dimension=4)]
+
+ exp = experiment.Experiment(
+ estimator=dnn_linear_combined.DNNLinearCombinedClassifier(
+ linear_feature_columns=cont_features,
+ dnn_feature_columns=cont_features,
+ dnn_hidden_units=[3, 3]),
+ train_input_fn=test_data.iris_input_logistic_fn,
+ eval_input_fn=test_data.iris_input_logistic_fn)
+ exp.test()
+
def testNoFeatureColumns(self):
with self.assertRaisesRegexp(
ValueError,
@@ -789,6 +802,18 @@ class DNNLinearCombinedClassifierTest(test.TestCase):
class DNNLinearCombinedRegressorTest(test.TestCase):
+ def testExperimentIntegration(self):
+ cont_features = [feature_column.real_valued_column('feature', dimension=4)]
+
+ exp = experiment.Experiment(
+ estimator=dnn_linear_combined.DNNLinearCombinedRegressor(
+ linear_feature_columns=cont_features,
+ dnn_feature_columns=cont_features,
+ dnn_hidden_units=[3, 3]),
+ train_input_fn=test_data.iris_input_logistic_fn,
+ eval_input_fn=test_data.iris_input_logistic_fn)
+ exp.test()
+
def testEstimatorContract(self):
estimator_test_utils.assert_estimator_contract(
self, dnn_linear_combined.DNNLinearCombinedRegressor)
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py
index 6059f9f1d2..a48ba282c2 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py
@@ -31,6 +31,7 @@ if hasattr(sys, 'getdlopenflags') and hasattr(sys, 'setdlopenflags'):
import numpy as np
from tensorflow.contrib.layers.python.layers import feature_column
+from tensorflow.contrib.learn.python.learn import experiment
from tensorflow.contrib.learn.python.learn.datasets import base
from tensorflow.contrib.learn.python.learn.estimators import _sklearn
from tensorflow.contrib.learn.python.learn.estimators import dnn
@@ -132,6 +133,19 @@ class EmbeddingMultiplierTest(test.TestCase):
class DNNClassifierTest(test.TestCase):
+ def testExperimentIntegration(self):
+ exp = experiment.Experiment(
+ estimator=dnn.DNNClassifier(
+ n_classes=3,
+ feature_columns=[
+ feature_column.real_valued_column(
+ 'feature', dimension=4)
+ ],
+ hidden_units=[3, 3]),
+ train_input_fn=test_data.iris_input_multiclass_fn,
+ eval_input_fn=test_data.iris_input_multiclass_fn)
+ exp.test()
+
def _assertInRange(self, expected_min, expected_max, actual):
self.assertLessEqual(expected_min, actual)
self.assertGreaterEqual(expected_max, actual)
@@ -772,6 +786,18 @@ class DNNClassifierTest(test.TestCase):
class DNNRegressorTest(test.TestCase):
+ def testExperimentIntegration(self):
+ exp = experiment.Experiment(
+ estimator=dnn.DNNRegressor(
+ feature_columns=[
+ feature_column.real_valued_column(
+ 'feature', dimension=4)
+ ],
+ hidden_units=[3, 3]),
+ train_input_fn=test_data.iris_input_logistic_fn,
+ eval_input_fn=test_data.iris_input_logistic_fn)
+ exp.test()
+
def testEstimatorContract(self):
estimator_test_utils.assert_estimator_contract(self, dnn.DNNRegressor)
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
index c9155c73ac..eb682d0930 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
@@ -38,6 +38,7 @@ from tensorflow.contrib import learn
from tensorflow.contrib.framework.python.ops import variables
from tensorflow.contrib.layers.python.layers import feature_column as feature_column_lib
from tensorflow.contrib.layers.python.layers import optimizers
+from tensorflow.contrib.learn.python.learn import experiment
from tensorflow.contrib.learn.python.learn import metric_spec
from tensorflow.contrib.learn.python.learn import models
from tensorflow.contrib.learn.python.learn import monitors as monitors_lib
@@ -258,6 +259,13 @@ class CheckCallsMonitor(monitors_lib.BaseMonitor):
class EstimatorTest(test.TestCase):
+ def testExperimentIntegration(self):
+ exp = experiment.Experiment(
+ estimator=estimator.Estimator(model_fn=linear_model_fn),
+ train_input_fn=boston_input_fn,
+ eval_input_fn=boston_input_fn)
+ exp.test()
+
def testModelFnArgs(self):
expected_param = {'some_param': 'some_value'}
expected_config = run_config.RunConfig()
diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py
index 040b7d9a07..a9794679ef 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py
@@ -31,6 +31,7 @@ if hasattr(sys, 'getdlopenflags') and hasattr(sys, 'setdlopenflags'):
import numpy as np
from tensorflow.contrib.layers.python.layers import feature_column as feature_column_lib
+from tensorflow.contrib.learn.python.learn import experiment
from tensorflow.contrib.learn.python.learn.datasets import base
from tensorflow.contrib.learn.python.learn.estimators import _sklearn
from tensorflow.contrib.learn.python.learn.estimators import estimator
@@ -62,6 +63,19 @@ def _prepare_iris_data_for_logistic_regression():
class LinearClassifierTest(test.TestCase):
+ def testExperimentIntegration(self):
+ cont_features = [
+ feature_column_lib.real_valued_column(
+ 'feature', dimension=4)
+ ]
+
+ exp = experiment.Experiment(
+ estimator=linear.LinearClassifier(
+ n_classes=3, feature_columns=cont_features),
+ train_input_fn=test_data.iris_input_multiclass_fn,
+ eval_input_fn=test_data.iris_input_multiclass_fn)
+ exp.test()
+
def testEstimatorContract(self):
estimator_test_utils.assert_estimator_contract(self,
linear.LinearClassifier)
@@ -878,6 +892,18 @@ class LinearClassifierTest(test.TestCase):
class LinearRegressorTest(test.TestCase):
+ def testExperimentIntegration(self):
+ cont_features = [
+ feature_column_lib.real_valued_column(
+ 'feature', dimension=4)
+ ]
+
+ exp = experiment.Experiment(
+ estimator=linear.LinearRegressor(feature_columns=cont_features),
+ train_input_fn=test_data.iris_input_logistic_fn,
+ eval_input_fn=test_data.iris_input_logistic_fn)
+ exp.test()
+
def testEstimatorContract(self):
estimator_test_utils.assert_estimator_contract(self, linear.LinearRegressor)