aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/estimator_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/estimator/estimator_test.py')
-rw-r--r--tensorflow/python/estimator/estimator_test.py42
1 files changed, 42 insertions, 0 deletions
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index 2a0e4e7617..8bc410ba0b 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -28,6 +28,7 @@ import six
from google.protobuf import text_format
+from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import estimator
@@ -203,6 +204,10 @@ class EstimatorConstructorTest(test.TestCase):
est = estimator.Estimator(model_fn=model_fn)
self.assertTrue(isinstance(est.config, run_config.RunConfig))
+ self.assertTrue(est._session_config.allow_soft_placement)
+ rewrite_options = est._session_config.graph_options.rewrite_options
+ self.assertEqual(rewrite_options.meta_optimizer_iterations,
+ rewriter_config_pb2.RewriterConfig.ONE)
def test_default_model_dir(self):
@@ -2304,6 +2309,43 @@ class EstimatorExportTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, err_regex):
est._export_all_saved_models(export_dir_base, input_receiver_fn_map)
+ def test_export_all_saved_models_metric_operation(self):
+ """Ensures metrics ops.Operations can be expoerted (b/109740581)."""
+
+ def _model_fn(features, labels, mode):
+ del features, labels # Unused
+ metrics = {'metrics': (constant_op.constant([0]),
+ control_flow_ops.no_op())}
+ return model_fn_lib.EstimatorSpec(
+ mode,
+ predictions=constant_op.constant(10.),
+ loss=constant_op.constant(1.),
+ train_op=state_ops.assign_add(training.get_global_step(), 1),
+ eval_metric_ops=metrics)
+
+ tmpdir = tempfile.mkdtemp()
+ est = estimator.Estimator(model_fn=_model_fn)
+ est.train(input_fn=dummy_input_fn, steps=1)
+
+ # Perform the export.
+ export_dir_base = os.path.join(
+ compat.as_bytes(tmpdir), compat.as_bytes('metric_operation_export'))
+
+ input_receiver_fn_map = {
+ model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn()}
+
+ export_dir = est._export_all_saved_models(
+ export_dir_base, input_receiver_fn_map)
+
+ # Restore, to validate that the export was well-formed.
+ with ops.Graph().as_default() as graph:
+ with session.Session(graph=graph) as sess:
+ meta_graph = loader.load(sess, [tag_constants.EVAL], export_dir)
+ sig_outputs = meta_graph.signature_def[
+ model_fn_lib.ModeKeys.EVAL].outputs
+ self.assertEqual(
+ sig_outputs['metrics/update_op'].name, 'metric_op_wrapper:0')
+
def test_export_savedmodel_with_saveables_proto_roundtrip(self):
tmpdir = tempfile.mkdtemp()
est = estimator.Estimator(