diff options
author | Karmel Allison <karmel@google.com> | 2018-05-10 09:47:37 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-10 09:50:44 -0700 |
commit | 9c18251256a88e23c47f60f3597f9c764000fba4 (patch) | |
tree | 6d1bc0a0f1b450460389d3182e97f98ec6641ce8 /tensorflow/python/estimator/estimator_test.py | |
parent | e696dc1bd07f62c6621a7224e15c8d3fbc160054 (diff) |
For Estimators, SavedModels for multiple modes should be exported into the same
file.
PiperOrigin-RevId: 196128943
Diffstat (limited to 'tensorflow/python/estimator/estimator_test.py')
-rw-r--r-- | tensorflow/python/estimator/estimator_test.py | 170 |
1 files changed, 140 insertions, 30 deletions
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index 02088e5134..c9c6bdfeb5 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -2013,12 +2013,9 @@ class EstimatorExportTest(test.TestCase): input_receiver_fn_map = { model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() } - export_dirs, tmpdir = self._test_export_all_saved_models( + export_dir, tmpdir = self._test_export_all_saved_models( input_receiver_fn_map) - self.assertEqual(len(export_dirs), 1) - # Restore, to validate that the export was well-formed. - export_dir = export_dirs[model_fn_lib.ModeKeys.PREDICT] with ops.Graph().as_default() as graph: with session.Session(graph=graph) as sess: loader.load(sess, [tag_constants.SERVING], export_dir) @@ -2035,12 +2032,9 @@ class EstimatorExportTest(test.TestCase): input_receiver_fn_map = { model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), } - export_dirs, tmpdir = self._test_export_all_saved_models( + export_dir, tmpdir = self._test_export_all_saved_models( input_receiver_fn_map) - self.assertEqual(len(export_dirs), 1) - # Restore, to validate that the export was well-formed. - export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN] with ops.Graph().as_default() as graph: with session.Session(graph=graph) as sess: loader.load(sess, [tag_constants.TRAINING], export_dir) @@ -2058,12 +2052,9 @@ class EstimatorExportTest(test.TestCase): input_receiver_fn_map = { model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn() } - export_dirs, tmpdir = self._test_export_all_saved_models( + export_dir, tmpdir = self._test_export_all_saved_models( input_receiver_fn_map) - self.assertEqual(len(export_dirs), 1) - # Restore, to validate that the export was well-formed. - export_dir = export_dirs[model_fn_lib.ModeKeys.EVAL] with ops.Graph().as_default() as graph: with session.Session(graph=graph) as sess: loader.load(sess, [tag_constants.EVAL], export_dir) @@ -2082,12 +2073,9 @@ class EstimatorExportTest(test.TestCase): model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn() } - export_dirs, tmpdir = self._test_export_all_saved_models( + export_dir, tmpdir = self._test_export_all_saved_models( input_receiver_fn_map) - self.assertEqual(len(export_dirs), 2) - # Restore, to validate that the export was well-formed. - export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN] with ops.Graph().as_default() as graph: with session.Session(graph=graph) as sess: loader.load(sess, [tag_constants.TRAINING], export_dir) @@ -2096,7 +2084,7 @@ class EstimatorExportTest(test.TestCase): self.assertFalse('eval_multiplied' in graph_ops) self.assertTrue('feature_x' in graph_ops) self.assertTrue('weight' in graph_ops) - export_dir = export_dirs[model_fn_lib.ModeKeys.EVAL] + with ops.Graph().as_default() as graph: with session.Session(graph=graph) as sess: loader.load(sess, [tag_constants.EVAL], export_dir) @@ -2117,12 +2105,11 @@ class EstimatorExportTest(test.TestCase): model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn(), model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() } - export_dirs, tmpdir = self._test_export_all_saved_models( + export_dir, tmpdir = self._test_export_all_saved_models( input_receiver_fn_map) # Restore, to validate that the export was well-formed. - for mode, tag_set in model_fn_lib.EXPORT_TAG_MAP.items(): - export_dir = export_dirs[mode] + for tag_set in model_fn_lib.EXPORT_TAG_MAP.values(): with ops.Graph().as_default() as graph: with session.Session(graph=graph) as sess: loader.load(sess, tag_set, export_dir) @@ -2139,10 +2126,9 @@ class EstimatorExportTest(test.TestCase): model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() } - export_dirs, tmpdir = self._test_export_all_saved_models( + export_dir, tmpdir = self._test_export_all_saved_models( input_receiver_fn_map) - export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN] with ops.Graph().as_default() as graph: with session.Session(graph=graph) as sess: loader.load(sess, [tag_constants.TRAINING], export_dir) @@ -2150,7 +2136,6 @@ class EstimatorExportTest(test.TestCase): self.assertTrue('later_var' in graph_ops) self.assertTrue('weight' in graph_ops) - export_dir = export_dirs[model_fn_lib.ModeKeys.PREDICT] with ops.Graph().as_default() as graph: with session.Session(graph=graph) as sess: loader.load(sess, [tag_constants.SERVING], export_dir) @@ -2166,10 +2151,9 @@ class EstimatorExportTest(test.TestCase): model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() } - export_dirs, tmpdir = self._test_export_all_saved_models( + export_dir, tmpdir = self._test_export_all_saved_models( input_receiver_fn_map) - export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN] with ops.Graph().as_default() as graph: with session.Session(graph=graph) as sess: loader.load(sess, [tag_constants.TRAINING], export_dir) @@ -2179,7 +2163,6 @@ class EstimatorExportTest(test.TestCase): collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) self.assertEqual(3, collection_vars[-1].eval()) - export_dir = export_dirs[model_fn_lib.ModeKeys.PREDICT] with ops.Graph().as_default() as graph: with session.Session(graph=graph) as sess: loader.load(sess, [tag_constants.SERVING], export_dir) @@ -2207,16 +2190,15 @@ class EstimatorExportTest(test.TestCase): # Perform the export. export_dir_base = os.path.join( compat.as_bytes(tmpdir), compat.as_bytes('export')) - export_dirs = est._export_all_saved_models( + export_dir = est._export_all_saved_models( export_dir_base, input_receiver_fn_map) # Check that all the files are in the right places. self.assertTrue(gfile.Exists(export_dir_base)) - for _, export_dir in export_dirs.items(): - self._validate_exported_files(export_dir) + self._validate_exported_files(export_dir) - return export_dirs, tmpdir + return export_dir, tmpdir def _validate_exported_files(self, export_dir): self.assertTrue(gfile.Exists(export_dir)) @@ -2233,6 +2215,42 @@ class EstimatorExportTest(test.TestCase): compat.as_bytes(export_dir), compat.as_bytes('variables/variables.data-00000-of-00001')))) + def test_export_all_saved_models_var_not_found(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + + def _model_fn_with_predict_only_vars(features, labels, mode): + _, _ = features, labels + if mode == model_fn_lib.ModeKeys.PREDICT: + variables.Variable(1., name='only_in_predict') + else: + variables.Variable(1., name='otherwise') + + prediction = constant_op.constant(1.) + return model_fn_lib.EstimatorSpec( + mode, + predictions=prediction, + loss=constant_op.constant(1.), + train_op=state_ops.assign_add(training.get_global_step(), 1), + export_outputs={ + 'test': export_output.PredictOutput({'prediction': prediction}) + }) + + tmpdir = tempfile.mkdtemp() + est = estimator.Estimator(model_fn=_model_fn_with_predict_only_vars) + est.train(input_fn=_x_y_input_fn, steps=1) + + # Perform the export. + export_dir_base = os.path.join( + compat.as_bytes(tmpdir), compat.as_bytes('export')) + + err_regex = r'Could not load all requested variables[\w\W]*infer' + with self.assertRaisesRegexp(ValueError, err_regex): + est._export_all_saved_models(export_dir_base, input_receiver_fn_map) + def test_export_savedmodel_with_saveables_proto_roundtrip(self): tmpdir = tempfile.mkdtemp() est = estimator.Estimator( @@ -2464,6 +2482,43 @@ class EstimatorExportTest(test.TestCase): self.assertTrue(self.mock_saver.restore.called) + def test_scaffold_is_used_for_saver_multiple_modes(self): + tmpdir = tempfile.mkdtemp() + + def _model_fn_scaffold(features, labels, mode): + _, _ = features, labels + variables.Variable(1., name='weight') + real_saver = saver.Saver() + self.mock_saver = test.mock.Mock( + wraps=real_saver, saver_def=real_saver.saver_def) + scores = constant_op.constant([3.]) + if mode == model_fn_lib.ModeKeys.PREDICT: + scaffold = training.Scaffold(saver=self.mock_saver) + else: + scaffold = training.Scaffold() + return model_fn_lib.EstimatorSpec( + mode=mode, + predictions=constant_op.constant([[1.]]), + loss=constant_op.constant(0.), + train_op=state_ops.assign_add(training.get_global_step(), 1), + scaffold=scaffold, + export_outputs={'test': export_output.ClassificationOutput(scores)}) + + est = estimator.Estimator(model_fn=_model_fn_scaffold) + est.train(dummy_input_fn, steps=1) + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + + # Perform the export. + export_dir_base = os.path.join( + compat.as_bytes(tmpdir), compat.as_bytes('export')) + est._export_all_saved_models(export_dir_base, input_receiver_fn_map) + + self.assertTrue(self.mock_saver.restore.called) + def test_scaffold_is_used_for_local_init(self): tmpdir = tempfile.mkdtemp() @@ -2509,6 +2564,61 @@ class EstimatorExportTest(test.TestCase): my_int_value = sess.run(my_int) self.assertEqual(12345, my_int_value) + def test_scaffold_is_used_for_local_init_multiple_modes(self): + tmpdir = tempfile.mkdtemp() + + def _model_fn_scaffold(features, labels, mode): + _, _ = features, labels + my_int = variables.Variable(1, name='my_int', + collections=[ops.GraphKeys.LOCAL_VARIABLES]) + scores = constant_op.constant([3.]) + with ops.control_dependencies([ + variables.local_variables_initializer(), + lookup_ops.tables_initializer() + ]): + assign_op = state_ops.assign(my_int, 12345) + + custom_local_init_op = None + if mode == model_fn_lib.ModeKeys.PREDICT: + # local_initSop must be an Operation, not a Tensor. + custom_local_init_op = control_flow_ops.group(assign_op) + + return model_fn_lib.EstimatorSpec( + mode=mode, + predictions=constant_op.constant([[1.]]), + loss=constant_op.constant(0.), + train_op=state_ops.assign_add(training.get_global_step(), 1), + scaffold=training.Scaffold(local_init_op=custom_local_init_op), + export_outputs={'test': export_output.ClassificationOutput(scores)}) + + est = estimator.Estimator(model_fn=_model_fn_scaffold) + est.train(dummy_input_fn, steps=1) + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + + # Perform the export. + export_dir_base = os.path.join( + compat.as_bytes(tmpdir), compat.as_bytes('export')) + export_dir = est._export_all_saved_models( + export_dir_base, input_receiver_fn_map) + + # Restore, to validate that the custom local_init_op runs. + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.SERVING], export_dir) + my_int = graph.get_tensor_by_name('my_int:0') + my_int_value = sess.run(my_int) + self.assertEqual(12345, my_int_value) + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + my_int = graph.get_tensor_by_name('my_int:0') + my_int_value = sess.run(my_int) + self.assertEqual(1, my_int_value) + def test_features_labels_mode(self): given_features = {'test-features': constant_op.constant([[1], [1]])} |