aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Mustafa Ispir <ispir@google.com>2017-06-09 12:15:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-09 12:21:57 -0700
commit50292fcc2b63093f81e150981cd0a6b7729540e0 (patch)
treea63602c4b9a1529c65cb5120d5f9c04e49cd103b
parent1274753032c863c87aec0c0770b80c8d526495e8 (diff)
Add train-op test for dnn linear combined estimator.
PiperOrigin-RevId: 158546896
-rw-r--r--tensorflow/python/estimator/canned/dnn_linear_combined_test.py87
1 files changed, 74 insertions, 13 deletions
diff --git a/tensorflow/python/estimator/canned/dnn_linear_combined_test.py b/tensorflow/python/estimator/canned/dnn_linear_combined_test.py
index f93e422b70..15264fa924 100644
--- a/tensorflow/python/estimator/canned/dnn_linear_combined_test.py
+++ b/tensorflow/python/estimator/canned/dnn_linear_combined_test.py
@@ -38,9 +38,13 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import parsing_ops
+from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
+from tensorflow.python.training import checkpoint_utils
+from tensorflow.python.training import gradient_descent
from tensorflow.python.training import input as input_lib
+from tensorflow.python.training import optimizer as optimizer_lib
try:
# pylint: disable=g-import-not-at-top
@@ -59,19 +63,18 @@ class DNNOnlyModelFnTest(dnn_testing_utils.BaseDNNModelFnTest, test.TestCase):
test.TestCase.__init__(self, methodName)
dnn_testing_utils.BaseDNNModelFnTest.__init__(self, self._dnn_only_model_fn)
- def _dnn_only_model_fn(
- self,
- features,
- labels,
- mode,
- head,
- hidden_units,
- feature_columns,
- optimizer='Adagrad',
- activation_fn=nn.relu,
- dropout=None, # pylint: disable=redefined-outer-name
- input_layer_partitioner=None,
- config=None):
+ def _dnn_only_model_fn(self,
+ features,
+ labels,
+ mode,
+ head,
+ hidden_units,
+ feature_columns,
+ optimizer='Adagrad',
+ activation_fn=nn.relu,
+ dropout=None,
+ input_layer_partitioner=None,
+ config=None):
return dnn_linear_combined._dnn_linear_combined_model_fn(
features=features,
labels=labels,
@@ -535,5 +538,63 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase):
batch_size=batch_size)
+class DNNLinearCombinedTrainOpTest(test.TestCase):
+
+ def setUp(self):
+ self._model_dir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ if self._model_dir:
+ shutil.rmtree(self._model_dir)
+
+ def _mock_optimizer(self, real_optimizer, var_name_prefix):
+ """Verifies global_step is None and var_names start with given prefix."""
+
+ def _minimize(loss, global_step=None, var_list=None):
+ self.assertIsNone(global_step)
+ trainable_vars = var_list or ops.get_collection(
+ ops.GraphKeys.TRAINABLE_VARIABLES)
+ var_names = [var.name for var in trainable_vars]
+ self.assertTrue(
+ all([name.startswith(var_name_prefix) for name in var_names]))
+ # var is used to check this op called by training.
+ var = variables_lib.Variable(0., name=(var_name_prefix + '_called'))
+ with ops.control_dependencies([var.assign(100.)]):
+ return real_optimizer.minimize(loss, global_step, var_list)
+
+ optimizer_mock = test.mock.NonCallableMagicMock(
+ spec=optimizer_lib.Optimizer, wraps=real_optimizer)
+ optimizer_mock.minimize = test.mock.MagicMock(wraps=_minimize)
+
+ return optimizer_mock
+
+ def test_train_op_calls_both_dnn_and_linear(self):
+ opt = gradient_descent.GradientDescentOptimizer(1.)
+ x_column = feature_column.numeric_column('x')
+ input_fn = numpy_io.numpy_input_fn(
+ x={'x': np.array([[0.], [1.]])},
+ y=np.array([[0.], [1.]]),
+ batch_size=1,
+ shuffle=False)
+ est = dnn_linear_combined.DNNLinearCombinedClassifier(
+ linear_feature_columns=[x_column],
+ # verifies linear_optimizer is used only for linear part.
+ linear_optimizer=self._mock_optimizer(opt, 'linear'),
+ dnn_hidden_units=(2, 2),
+ dnn_feature_columns=[x_column],
+ # verifies dnn_optimizer is used only for linear part.
+ dnn_optimizer=self._mock_optimizer(opt, 'dnn'),
+ model_dir=self._model_dir)
+ est.train(input_fn, steps=1)
+ # verifies train_op fires linear minimize op
+ self.assertEqual(100.,
+ checkpoint_utils.load_variable(
+ self._model_dir, 'binary_logistic_head/linear_called'))
+ # verifies train_op fires dnn minimize op
+ self.assertEqual(100.,
+ checkpoint_utils.load_variable(
+ self._model_dir, 'binary_logistic_head/dnn_called'))
+
+
if __name__ == '__main__':
test.main()