aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/estimator_test.py
diff options
context:
space:
mode:
authorGravatar Karmel Allison <karmel@google.com>2018-05-10 09:47:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-10 09:50:44 -0700
commit9c18251256a88e23c47f60f3597f9c764000fba4 (patch)
tree6d1bc0a0f1b450460389d3182e97f98ec6641ce8 /tensorflow/python/estimator/estimator_test.py
parente696dc1bd07f62c6621a7224e15c8d3fbc160054 (diff)
For Estimators, SavedModels for multiple modes should be exported into the same
file. PiperOrigin-RevId: 196128943
Diffstat (limited to 'tensorflow/python/estimator/estimator_test.py')
-rw-r--r--tensorflow/python/estimator/estimator_test.py170
1 files changed, 140 insertions, 30 deletions
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index 02088e5134..c9c6bdfeb5 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -2013,12 +2013,9 @@ class EstimatorExportTest(test.TestCase):
input_receiver_fn_map = {
model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn()
}
- export_dirs, tmpdir = self._test_export_all_saved_models(
+ export_dir, tmpdir = self._test_export_all_saved_models(
input_receiver_fn_map)
- self.assertEqual(len(export_dirs), 1)
- # Restore, to validate that the export was well-formed.
- export_dir = export_dirs[model_fn_lib.ModeKeys.PREDICT]
with ops.Graph().as_default() as graph:
with session.Session(graph=graph) as sess:
loader.load(sess, [tag_constants.SERVING], export_dir)
@@ -2035,12 +2032,9 @@ class EstimatorExportTest(test.TestCase):
input_receiver_fn_map = {
model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(),
}
- export_dirs, tmpdir = self._test_export_all_saved_models(
+ export_dir, tmpdir = self._test_export_all_saved_models(
input_receiver_fn_map)
- self.assertEqual(len(export_dirs), 1)
- # Restore, to validate that the export was well-formed.
- export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN]
with ops.Graph().as_default() as graph:
with session.Session(graph=graph) as sess:
loader.load(sess, [tag_constants.TRAINING], export_dir)
@@ -2058,12 +2052,9 @@ class EstimatorExportTest(test.TestCase):
input_receiver_fn_map = {
model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn()
}
- export_dirs, tmpdir = self._test_export_all_saved_models(
+ export_dir, tmpdir = self._test_export_all_saved_models(
input_receiver_fn_map)
- self.assertEqual(len(export_dirs), 1)
- # Restore, to validate that the export was well-formed.
- export_dir = export_dirs[model_fn_lib.ModeKeys.EVAL]
with ops.Graph().as_default() as graph:
with session.Session(graph=graph) as sess:
loader.load(sess, [tag_constants.EVAL], export_dir)
@@ -2082,12 +2073,9 @@ class EstimatorExportTest(test.TestCase):
model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(),
model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn()
}
- export_dirs, tmpdir = self._test_export_all_saved_models(
+ export_dir, tmpdir = self._test_export_all_saved_models(
input_receiver_fn_map)
- self.assertEqual(len(export_dirs), 2)
- # Restore, to validate that the export was well-formed.
- export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN]
with ops.Graph().as_default() as graph:
with session.Session(graph=graph) as sess:
loader.load(sess, [tag_constants.TRAINING], export_dir)
@@ -2096,7 +2084,7 @@ class EstimatorExportTest(test.TestCase):
self.assertFalse('eval_multiplied' in graph_ops)
self.assertTrue('feature_x' in graph_ops)
self.assertTrue('weight' in graph_ops)
- export_dir = export_dirs[model_fn_lib.ModeKeys.EVAL]
+
with ops.Graph().as_default() as graph:
with session.Session(graph=graph) as sess:
loader.load(sess, [tag_constants.EVAL], export_dir)
@@ -2117,12 +2105,11 @@ class EstimatorExportTest(test.TestCase):
model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn(),
model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn()
}
- export_dirs, tmpdir = self._test_export_all_saved_models(
+ export_dir, tmpdir = self._test_export_all_saved_models(
input_receiver_fn_map)
# Restore, to validate that the export was well-formed.
- for mode, tag_set in model_fn_lib.EXPORT_TAG_MAP.items():
- export_dir = export_dirs[mode]
+ for tag_set in model_fn_lib.EXPORT_TAG_MAP.values():
with ops.Graph().as_default() as graph:
with session.Session(graph=graph) as sess:
loader.load(sess, tag_set, export_dir)
@@ -2139,10 +2126,9 @@ class EstimatorExportTest(test.TestCase):
model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(),
model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn()
}
- export_dirs, tmpdir = self._test_export_all_saved_models(
+ export_dir, tmpdir = self._test_export_all_saved_models(
input_receiver_fn_map)
- export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN]
with ops.Graph().as_default() as graph:
with session.Session(graph=graph) as sess:
loader.load(sess, [tag_constants.TRAINING], export_dir)
@@ -2150,7 +2136,6 @@ class EstimatorExportTest(test.TestCase):
self.assertTrue('later_var' in graph_ops)
self.assertTrue('weight' in graph_ops)
- export_dir = export_dirs[model_fn_lib.ModeKeys.PREDICT]
with ops.Graph().as_default() as graph:
with session.Session(graph=graph) as sess:
loader.load(sess, [tag_constants.SERVING], export_dir)
@@ -2166,10 +2151,9 @@ class EstimatorExportTest(test.TestCase):
model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(),
model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn()
}
- export_dirs, tmpdir = self._test_export_all_saved_models(
+ export_dir, tmpdir = self._test_export_all_saved_models(
input_receiver_fn_map)
- export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN]
with ops.Graph().as_default() as graph:
with session.Session(graph=graph) as sess:
loader.load(sess, [tag_constants.TRAINING], export_dir)
@@ -2179,7 +2163,6 @@ class EstimatorExportTest(test.TestCase):
collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
self.assertEqual(3, collection_vars[-1].eval())
- export_dir = export_dirs[model_fn_lib.ModeKeys.PREDICT]
with ops.Graph().as_default() as graph:
with session.Session(graph=graph) as sess:
loader.load(sess, [tag_constants.SERVING], export_dir)
@@ -2207,16 +2190,15 @@ class EstimatorExportTest(test.TestCase):
# Perform the export.
export_dir_base = os.path.join(
compat.as_bytes(tmpdir), compat.as_bytes('export'))
- export_dirs = est._export_all_saved_models(
+ export_dir = est._export_all_saved_models(
export_dir_base, input_receiver_fn_map)
# Check that all the files are in the right places.
self.assertTrue(gfile.Exists(export_dir_base))
- for _, export_dir in export_dirs.items():
- self._validate_exported_files(export_dir)
+ self._validate_exported_files(export_dir)
- return export_dirs, tmpdir
+ return export_dir, tmpdir
def _validate_exported_files(self, export_dir):
self.assertTrue(gfile.Exists(export_dir))
@@ -2233,6 +2215,42 @@ class EstimatorExportTest(test.TestCase):
compat.as_bytes(export_dir),
compat.as_bytes('variables/variables.data-00000-of-00001'))))
+ def test_export_all_saved_models_var_not_found(self):
+ input_receiver_fn_map = {
+ model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(),
+ model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn(),
+ model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn()
+ }
+
+ def _model_fn_with_predict_only_vars(features, labels, mode):
+ _, _ = features, labels
+ if mode == model_fn_lib.ModeKeys.PREDICT:
+ variables.Variable(1., name='only_in_predict')
+ else:
+ variables.Variable(1., name='otherwise')
+
+ prediction = constant_op.constant(1.)
+ return model_fn_lib.EstimatorSpec(
+ mode,
+ predictions=prediction,
+ loss=constant_op.constant(1.),
+ train_op=state_ops.assign_add(training.get_global_step(), 1),
+ export_outputs={
+ 'test': export_output.PredictOutput({'prediction': prediction})
+ })
+
+ tmpdir = tempfile.mkdtemp()
+ est = estimator.Estimator(model_fn=_model_fn_with_predict_only_vars)
+ est.train(input_fn=_x_y_input_fn, steps=1)
+
+ # Perform the export.
+ export_dir_base = os.path.join(
+ compat.as_bytes(tmpdir), compat.as_bytes('export'))
+
+ err_regex = r'Could not load all requested variables[\w\W]*infer'
+ with self.assertRaisesRegexp(ValueError, err_regex):
+ est._export_all_saved_models(export_dir_base, input_receiver_fn_map)
+
def test_export_savedmodel_with_saveables_proto_roundtrip(self):
tmpdir = tempfile.mkdtemp()
est = estimator.Estimator(
@@ -2464,6 +2482,43 @@ class EstimatorExportTest(test.TestCase):
self.assertTrue(self.mock_saver.restore.called)
+ def test_scaffold_is_used_for_saver_multiple_modes(self):
+ tmpdir = tempfile.mkdtemp()
+
+ def _model_fn_scaffold(features, labels, mode):
+ _, _ = features, labels
+ variables.Variable(1., name='weight')
+ real_saver = saver.Saver()
+ self.mock_saver = test.mock.Mock(
+ wraps=real_saver, saver_def=real_saver.saver_def)
+ scores = constant_op.constant([3.])
+ if mode == model_fn_lib.ModeKeys.PREDICT:
+ scaffold = training.Scaffold(saver=self.mock_saver)
+ else:
+ scaffold = training.Scaffold()
+ return model_fn_lib.EstimatorSpec(
+ mode=mode,
+ predictions=constant_op.constant([[1.]]),
+ loss=constant_op.constant(0.),
+ train_op=state_ops.assign_add(training.get_global_step(), 1),
+ scaffold=scaffold,
+ export_outputs={'test': export_output.ClassificationOutput(scores)})
+
+ est = estimator.Estimator(model_fn=_model_fn_scaffold)
+ est.train(dummy_input_fn, steps=1)
+ input_receiver_fn_map = {
+ model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(),
+ model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn(),
+ model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn()
+ }
+
+ # Perform the export.
+ export_dir_base = os.path.join(
+ compat.as_bytes(tmpdir), compat.as_bytes('export'))
+ est._export_all_saved_models(export_dir_base, input_receiver_fn_map)
+
+ self.assertTrue(self.mock_saver.restore.called)
+
def test_scaffold_is_used_for_local_init(self):
tmpdir = tempfile.mkdtemp()
@@ -2509,6 +2564,61 @@ class EstimatorExportTest(test.TestCase):
my_int_value = sess.run(my_int)
self.assertEqual(12345, my_int_value)
+ def test_scaffold_is_used_for_local_init_multiple_modes(self):
+ tmpdir = tempfile.mkdtemp()
+
+ def _model_fn_scaffold(features, labels, mode):
+ _, _ = features, labels
+ my_int = variables.Variable(1, name='my_int',
+ collections=[ops.GraphKeys.LOCAL_VARIABLES])
+ scores = constant_op.constant([3.])
+ with ops.control_dependencies([
+ variables.local_variables_initializer(),
+ lookup_ops.tables_initializer()
+ ]):
+ assign_op = state_ops.assign(my_int, 12345)
+
+ custom_local_init_op = None
+ if mode == model_fn_lib.ModeKeys.PREDICT:
+ # local_initSop must be an Operation, not a Tensor.
+ custom_local_init_op = control_flow_ops.group(assign_op)
+
+ return model_fn_lib.EstimatorSpec(
+ mode=mode,
+ predictions=constant_op.constant([[1.]]),
+ loss=constant_op.constant(0.),
+ train_op=state_ops.assign_add(training.get_global_step(), 1),
+ scaffold=training.Scaffold(local_init_op=custom_local_init_op),
+ export_outputs={'test': export_output.ClassificationOutput(scores)})
+
+ est = estimator.Estimator(model_fn=_model_fn_scaffold)
+ est.train(dummy_input_fn, steps=1)
+ input_receiver_fn_map = {
+ model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(),
+ model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn(),
+ model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn()
+ }
+
+ # Perform the export.
+ export_dir_base = os.path.join(
+ compat.as_bytes(tmpdir), compat.as_bytes('export'))
+ export_dir = est._export_all_saved_models(
+ export_dir_base, input_receiver_fn_map)
+
+ # Restore, to validate that the custom local_init_op runs.
+ with ops.Graph().as_default() as graph:
+ with session.Session(graph=graph) as sess:
+ loader.load(sess, [tag_constants.SERVING], export_dir)
+ my_int = graph.get_tensor_by_name('my_int:0')
+ my_int_value = sess.run(my_int)
+ self.assertEqual(12345, my_int_value)
+ with ops.Graph().as_default() as graph:
+ with session.Session(graph=graph) as sess:
+ loader.load(sess, [tag_constants.TRAINING], export_dir)
+ my_int = graph.get_tensor_by_name('my_int:0')
+ my_int_value = sess.run(my_int)
+ self.assertEqual(1, my_int_value)
+
def test_features_labels_mode(self):
given_features = {'test-features': constant_op.constant([[1], [1]])}