aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Michael Case <mikecase@google.com>2018-06-14 06:11:42 -0700
committerGravatar GitHub <noreply@github.com>2018-06-14 06:11:42 -0700
commitd9b82cdb028c61bceb2c26958eb3172e865d42a6 (patch)
tree5c110489fdc760df22d35644a9afd29cccf97cea
parenta519d1a8d23810233f3ae2368b9e63b6d33af12c (diff)
parentc9dca4aae99ab21a9317c093d42b18d94dba23e7 (diff)
Merge pull request #19976 from guillaumekln/cherry-pick-best-exporter-fix
BestExporter cherry-pick request for r1.9: Only calls compare function if values were read from event file
-rw-r--r--tensorflow/python/estimator/exporter.py7
-rw-r--r--tensorflow/python/estimator/exporter_test.py34
2 files changed, 38 insertions, 3 deletions
diff --git a/tensorflow/python/estimator/exporter.py b/tensorflow/python/estimator/exporter.py
index a7212bb83e..766ea23f2a 100644
--- a/tensorflow/python/estimator/exporter.py
+++ b/tensorflow/python/estimator/exporter.py
@@ -360,9 +360,10 @@ class BestExporter(Exporter):
for value in event.summary.value:
if value.HasField('simple_value'):
event_eval_result[value.tag] = value.simple_value
- if best_eval_result is None or self._compare_fn(
- best_eval_result, event_eval_result):
- best_eval_result = event_eval_result
+ if event_eval_result:
+ if best_eval_result is None or self._compare_fn(
+ best_eval_result, event_eval_result):
+ best_eval_result = event_eval_result
return best_eval_result
diff --git a/tensorflow/python/estimator/exporter_test.py b/tensorflow/python/estimator/exporter_test.py
index 4cb4bffc8d..c4b006955c 100644
--- a/tensorflow/python/estimator/exporter_test.py
+++ b/tensorflow/python/estimator/exporter_test.py
@@ -148,6 +148,40 @@ class BestExporterTest(test.TestCase):
"checkpoint_path", {"loss": 20}, False)
self.assertEqual(None, export_result)
+ def test_best_exporter_with_empty_event(self):
+
+ def _serving_input_receiver_fn():
+ pass
+
+ export_dir_base = tempfile.mkdtemp()
+ gfile.MkDir(export_dir_base)
+ gfile.MkDir(export_dir_base + "/export")
+ gfile.MkDir(export_dir_base + "/eval")
+
+ eval_dir_base = os.path.join(export_dir_base, "eval_continuous")
+ estimator_lib._write_dict_to_summary(eval_dir_base, {}, 1)
+ estimator_lib._write_dict_to_summary(eval_dir_base, {"loss": 60}, 2)
+
+ exporter = exporter_lib.BestExporter(
+ name="best_exporter",
+ serving_input_receiver_fn=_serving_input_receiver_fn,
+ event_file_pattern="eval_continuous/*.tfevents.*",
+ assets_extra={"from/path": "to/path"},
+ as_text=False,
+ exports_to_keep=1)
+
+ estimator = test.mock.Mock(spec=estimator_lib.Estimator)
+ estimator.model_dir = export_dir_base
+ estimator.export_savedmodel.return_value = "export_result_path"
+
+ export_result = exporter.export(estimator, export_dir_base,
+ "checkpoint_path", {"loss": 100}, False)
+ self.assertEqual(None, export_result)
+
+ export_result = exporter.export(estimator, export_dir_base,
+ "checkpoint_path", {"loss": 10}, False)
+ self.assertEqual("export_result_path", export_result)
+
def test_garbage_collect_exports(self):
export_dir_base = tempfile.mkdtemp()
gfile.MkDir(export_dir_base)