aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/estimator_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/estimator/estimator_test.py')
-rw-r--r--tensorflow/python/estimator/estimator_test.py68
1 files changed, 68 insertions, 0 deletions
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index 733c7fb95d..8bc410ba0b 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -28,6 +28,7 @@ import six
from google.protobuf import text_format
+from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import estimator
@@ -38,6 +39,7 @@ from tensorflow.python.estimator.export import export_output
from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import test_util
@@ -202,6 +204,10 @@ class EstimatorConstructorTest(test.TestCase):
est = estimator.Estimator(model_fn=model_fn)
self.assertTrue(isinstance(est.config, run_config.RunConfig))
+ self.assertTrue(est._session_config.allow_soft_placement)
+ rewrite_options = est._session_config.graph_options.rewrite_options
+ self.assertEqual(rewrite_options.meta_optimizer_iterations,
+ rewriter_config_pb2.RewriterConfig.ONE)
def test_default_model_dir(self):
@@ -1296,6 +1302,31 @@ class EstimatorEvaluateTest(test.TestCase):
dummy_input_fn, steps=1, checkpoint_path=est1.latest_checkpoint())
self.assertEqual(5, scores['global_step'])
+ def test_wrong_shape_throws_reasonable_error(self):
+ """Make sure we are helpful when model_fns change. See b/110263146."""
+ def _get_model_fn(val=1):
+ def _model_fn(features, labels, mode):
+ del features, labels # unused
+ variables.Variable(val, name='weight')
+ return model_fn_lib.EstimatorSpec(
+ mode=mode,
+ predictions=constant_op.constant([[1.]]),
+ loss=constant_op.constant(0.),
+ train_op=state_ops.assign_add(training.get_global_step(), 1))
+ return _model_fn
+
+ model_fn_1 = _get_model_fn()
+ model_fn_2 = _get_model_fn(val=[1])
+
+ est1 = estimator.Estimator(model_fn=model_fn_1)
+ est1.train(dummy_input_fn, steps=5)
+ est2 = estimator.Estimator(
+ model_fn=model_fn_2, model_dir=est1.model_dir)
+
+ expected_msg = 'Restoring from checkpoint failed.*a mismatch between'
+ with self.assertRaisesRegexp(errors.InvalidArgumentError, expected_msg):
+ est2.train(dummy_input_fn, steps=1,)
+
def test_scaffold_is_used(self):
def _model_fn_scaffold(features, labels, mode):
@@ -2278,6 +2309,43 @@ class EstimatorExportTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, err_regex):
est._export_all_saved_models(export_dir_base, input_receiver_fn_map)
+ def test_export_all_saved_models_metric_operation(self):
+ """Ensures metrics ops.Operations can be expoerted (b/109740581)."""
+
+ def _model_fn(features, labels, mode):
+ del features, labels # Unused
+ metrics = {'metrics': (constant_op.constant([0]),
+ control_flow_ops.no_op())}
+ return model_fn_lib.EstimatorSpec(
+ mode,
+ predictions=constant_op.constant(10.),
+ loss=constant_op.constant(1.),
+ train_op=state_ops.assign_add(training.get_global_step(), 1),
+ eval_metric_ops=metrics)
+
+ tmpdir = tempfile.mkdtemp()
+ est = estimator.Estimator(model_fn=_model_fn)
+ est.train(input_fn=dummy_input_fn, steps=1)
+
+ # Perform the export.
+ export_dir_base = os.path.join(
+ compat.as_bytes(tmpdir), compat.as_bytes('metric_operation_export'))
+
+ input_receiver_fn_map = {
+ model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn()}
+
+ export_dir = est._export_all_saved_models(
+ export_dir_base, input_receiver_fn_map)
+
+ # Restore, to validate that the export was well-formed.
+ with ops.Graph().as_default() as graph:
+ with session.Session(graph=graph) as sess:
+ meta_graph = loader.load(sess, [tag_constants.EVAL], export_dir)
+ sig_outputs = meta_graph.signature_def[
+ model_fn_lib.ModeKeys.EVAL].outputs
+ self.assertEqual(
+ sig_outputs['metrics/update_op'].name, 'metric_op_wrapper:0')
+
def test_export_savedmodel_with_saveables_proto_roundtrip(self):
tmpdir = tempfile.mkdtemp()
est = estimator.Estimator(