diff options
author | 2017-06-09 12:15:45 -0700 | |
---|---|---|
committer | 2017-06-09 12:21:57 -0700 | |
commit | 50292fcc2b63093f81e150981cd0a6b7729540e0 (patch) | |
tree | a63602c4b9a1529c65cb5120d5f9c04e49cd103b | |
parent | 1274753032c863c87aec0c0770b80c8d526495e8 (diff) |
Add train-op test for dnn linear combined estimator.
PiperOrigin-RevId: 158546896
-rw-r--r-- | tensorflow/python/estimator/canned/dnn_linear_combined_test.py | 87 |
1 files changed, 74 insertions, 13 deletions
diff --git a/tensorflow/python/estimator/canned/dnn_linear_combined_test.py b/tensorflow/python/estimator/canned/dnn_linear_combined_test.py index f93e422b70..15264fa924 100644 --- a/tensorflow/python/estimator/canned/dnn_linear_combined_test.py +++ b/tensorflow/python/estimator/canned/dnn_linear_combined_test.py @@ -38,9 +38,13 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import nn from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import variables as variables_lib from tensorflow.python.platform import gfile from tensorflow.python.platform import test +from tensorflow.python.training import checkpoint_utils +from tensorflow.python.training import gradient_descent from tensorflow.python.training import input as input_lib +from tensorflow.python.training import optimizer as optimizer_lib try: # pylint: disable=g-import-not-at-top @@ -59,19 +63,18 @@ class DNNOnlyModelFnTest(dnn_testing_utils.BaseDNNModelFnTest, test.TestCase): test.TestCase.__init__(self, methodName) dnn_testing_utils.BaseDNNModelFnTest.__init__(self, self._dnn_only_model_fn) - def _dnn_only_model_fn( - self, - features, - labels, - mode, - head, - hidden_units, - feature_columns, - optimizer='Adagrad', - activation_fn=nn.relu, - dropout=None, # pylint: disable=redefined-outer-name - input_layer_partitioner=None, - config=None): + def _dnn_only_model_fn(self, + features, + labels, + mode, + head, + hidden_units, + feature_columns, + optimizer='Adagrad', + activation_fn=nn.relu, + dropout=None, + input_layer_partitioner=None, + config=None): return dnn_linear_combined._dnn_linear_combined_model_fn( features=features, labels=labels, @@ -535,5 +538,63 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase): batch_size=batch_size) +class DNNLinearCombinedTrainOpTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + shutil.rmtree(self._model_dir) + + def _mock_optimizer(self, real_optimizer, var_name_prefix): + """Verifies global_step is None and var_names start with given prefix.""" + + def _minimize(loss, global_step=None, var_list=None): + self.assertIsNone(global_step) + trainable_vars = var_list or ops.get_collection( + ops.GraphKeys.TRAINABLE_VARIABLES) + var_names = [var.name for var in trainable_vars] + self.assertTrue( + all([name.startswith(var_name_prefix) for name in var_names])) + # var is used to check this op called by training. + var = variables_lib.Variable(0., name=(var_name_prefix + '_called')) + with ops.control_dependencies([var.assign(100.)]): + return real_optimizer.minimize(loss, global_step, var_list) + + optimizer_mock = test.mock.NonCallableMagicMock( + spec=optimizer_lib.Optimizer, wraps=real_optimizer) + optimizer_mock.minimize = test.mock.MagicMock(wraps=_minimize) + + return optimizer_mock + + def test_train_op_calls_both_dnn_and_linear(self): + opt = gradient_descent.GradientDescentOptimizer(1.) + x_column = feature_column.numeric_column('x') + input_fn = numpy_io.numpy_input_fn( + x={'x': np.array([[0.], [1.]])}, + y=np.array([[0.], [1.]]), + batch_size=1, + shuffle=False) + est = dnn_linear_combined.DNNLinearCombinedClassifier( + linear_feature_columns=[x_column], + # verifies linear_optimizer is used only for linear part. + linear_optimizer=self._mock_optimizer(opt, 'linear'), + dnn_hidden_units=(2, 2), + dnn_feature_columns=[x_column], + # verifies dnn_optimizer is used only for linear part. + dnn_optimizer=self._mock_optimizer(opt, 'dnn'), + model_dir=self._model_dir) + est.train(input_fn, steps=1) + # verifies train_op fires linear minimize op + self.assertEqual(100., + checkpoint_utils.load_variable( + self._model_dir, 'binary_logistic_head/linear_called')) + # verifies train_op fires dnn minimize op + self.assertEqual(100., + checkpoint_utils.load_variable( + self._model_dir, 'binary_logistic_head/dnn_called')) + + if __name__ == '__main__': test.main() |