aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/estimator
diff options
context:
space:
mode:
authorGravatar Goutham Bhat <goutham@google.com>2018-07-09 11:39:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-09 11:49:44 -0700
commit90e66f2aa1015496c8f8e9be573ef83f542a2ad0 (patch)
tree987098a458bbbd5239322c2033ed6804c6b813da /tensorflow/contrib/estimator
parentbcf7e315b4031b3c355af12ca2a4961bcd25c248 (diff)
Early-stopping functionality for use with tf.estimator API.
PiperOrigin-RevId: 203801553
Diffstat (limited to 'tensorflow/contrib/estimator')
-rw-r--r--tensorflow/contrib/estimator/BUILD29
-rw-r--r--tensorflow/contrib/estimator/__init__.py7
-rw-r--r--tensorflow/contrib/estimator/python/estimator/early_stopping.py468
-rw-r--r--tensorflow/contrib/estimator/python/estimator/early_stopping_test.py233
4 files changed, 737 insertions, 0 deletions
diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD
index 30d297a5fb..11d40f5982 100644
--- a/tensorflow/contrib/estimator/BUILD
+++ b/tensorflow/contrib/estimator/BUILD
@@ -18,6 +18,7 @@ py_library(
":boosted_trees",
":dnn",
":dnn_linear_combined",
+ ":early_stopping",
":export",
":extenders",
":head",
@@ -590,3 +591,31 @@ py_test(
"@six_archive//:six",
],
)
+
+py_library(
+ name = "early_stopping",
+ srcs = ["python/estimator/early_stopping.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:init_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:summary",
+ "//tensorflow/python:training",
+ "//tensorflow/python/estimator",
+ ],
+)
+
+py_test(
+ name = "early_stopping_test",
+ srcs = ["python/estimator/early_stopping_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":early_stopping",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/estimator",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py
index 788ac5ca70..09fcfd66a1 100644
--- a/tensorflow/contrib/estimator/__init__.py
+++ b/tensorflow/contrib/estimator/__init__.py
@@ -23,6 +23,7 @@ from tensorflow.contrib.estimator.python.estimator.baseline import *
from tensorflow.contrib.estimator.python.estimator.boosted_trees import *
from tensorflow.contrib.estimator.python.estimator.dnn import *
from tensorflow.contrib.estimator.python.estimator.dnn_linear_combined import *
+from tensorflow.contrib.estimator.python.estimator.early_stopping import *
from tensorflow.contrib.estimator.python.estimator.export import *
from tensorflow.contrib.estimator.python.estimator.extenders import *
from tensorflow.contrib.estimator.python.estimator.head import *
@@ -63,6 +64,12 @@ _allowed_symbols = [
'RNNEstimator',
'export_saved_model_for_mode',
'export_all_saved_models',
+ 'make_early_stopping_hook',
+ 'read_eval_metrics',
+ 'stop_if_lower_hook',
+ 'stop_if_higher_hook',
+ 'stop_if_no_increase_hook',
+ 'stop_if_no_decrease_hook',
]
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/estimator/python/estimator/early_stopping.py b/tensorflow/contrib/estimator/python/estimator/early_stopping.py
new file mode 100644
index 0000000000..af4855e91e
--- /dev/null
+++ b/tensorflow/contrib/estimator/python/estimator/early_stopping.py
@@ -0,0 +1,468 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities for early stopping."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import operator
+import os
+
+from tensorflow.python.estimator import estimator as estimator_lib
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import tf_logging
+from tensorflow.python.summary import summary_iterator
+from tensorflow.python.training import basic_session_run_hooks
+from tensorflow.python.training import session_run_hook
+from tensorflow.python.training import training_util
+
+_EVENT_FILE_GLOB_PATTERN = 'events.out.tfevents.*'
+
+
+def make_early_stopping_hook(estimator,
+ should_stop_fn,
+ run_every_secs=60,
+ run_every_steps=None):
+ """Creates early-stopping hook.
+
+ Returns a `SessionRunHook` that stops training when `should_stop_fn` returns
+ `True`.
+
+ Usage example:
+
+ ```python
+ estimator = ...
+ hook = early_stopping.make_early_stopping_hook(
+ estimator, should_stop_fn=make_stop_fn(...))
+ train_spec = tf.estimator.TrainSpec(..., hooks=[hook])
+ tf.estimator.train_and_evaluate(estimator, train_spec, ...)
+ ```
+
+ Args:
+ estimator: A `tf.estimator.Estimator` instance.
+ should_stop_fn: `callable`, function that takes no arguments and returns a
+ `bool`. If the function returns `True`, stopping will be initiated by the
+ chief.
+ run_every_secs: If specified, calls `should_stop_fn` at an interval of
+ `run_every_secs` seconds. Defaults to 60 seconds. Either this or
+ `run_every_steps` must be set.
+ run_every_steps: If specified, calls `should_stop_fn` every
+ `run_every_steps` steps. Either this or `run_every_secs` must be set.
+
+ Returns:
+ A `SessionRunHook` that periodically executes `should_stop_fn` and initiates
+ early stopping if the function returns `True`.
+
+ Raises:
+ TypeError: If `estimator` is not of type `tf.estimator.Estimator`.
+ ValueError: If both `run_every_secs` and `run_every_steps` are set.
+ """
+ if not isinstance(estimator, estimator_lib.Estimator):
+ raise TypeError('`estimator` must have type `tf.estimator.Estimator`. '
+ 'Got: {}'.format(type(estimator)))
+
+ if run_every_secs is not None and run_every_steps is not None:
+ raise ValueError('Only one of `run_every_secs` and `run_every_steps` must '
+ 'be set.')
+
+ if estimator.config.is_chief:
+ return _StopOnPredicateHook(should_stop_fn, run_every_secs, run_every_steps)
+ else:
+ return _CheckForStoppingHook()
+
+
+def stop_if_higher_hook(estimator,
+ metric_name,
+ threshold,
+ eval_dir=None,
+ min_steps=0,
+ run_every_secs=60,
+ run_every_steps=None):
+ """Creates hook to stop if the given metric is higher than the threshold.
+
+ Usage example:
+
+ ```python
+ estimator = ...
+ # Hook to stop training if accuracy becomes higher than 0.9.
+ hook = early_stopping.stop_if_higher_hook(estimator, "accuracy", 0.9)
+ train_spec = tf.estimator.TrainSpec(..., hooks=[hook])
+ tf.estimator.train_and_evaluate(estimator, train_spec, ...)
+ ```
+
+ Args:
+ estimator: A `tf.estimator.Estimator` instance.
+ metric_name: `str`, metric to track. "loss", "accuracy", etc.
+ threshold: Numeric threshold for the given metric.
+ eval_dir: If set, directory containing summary files with eval metrics. By
+ default, `estimator.eval_dir()` will be used.
+ min_steps: `int`, stop is never requested if global step is less than this
+ value. Defaults to 0.
+ run_every_secs: If specified, calls `should_stop_fn` at an interval of
+ `run_every_secs` seconds. Defaults to 60 seconds. Either this or
+ `run_every_steps` must be set.
+ run_every_steps: If specified, calls `should_stop_fn` every
+ `run_every_steps` steps. Either this or `run_every_secs` must be set.
+
+ Returns:
+ An early-stopping hook of type `SessionRunHook` that periodically checks
+ if the given metric is higher than specified threshold and initiates
+ early stopping if true.
+ """
+ return _stop_if_threshold_crossed_hook(
+ estimator=estimator,
+ metric_name=metric_name,
+ threshold=threshold,
+ higher_is_better=True,
+ eval_dir=eval_dir,
+ min_steps=min_steps,
+ run_every_secs=run_every_secs,
+ run_every_steps=run_every_steps)
+
+
+def stop_if_lower_hook(estimator,
+ metric_name,
+ threshold,
+ eval_dir=None,
+ min_steps=0,
+ run_every_secs=60,
+ run_every_steps=None):
+ """Creates hook to stop if the given metric is lower than the threshold.
+
+ Usage example:
+
+ ```python
+ estimator = ...
+ # Hook to stop training if loss becomes lower than 100.
+ hook = early_stopping.stop_if_lower_hook(estimator, "loss", 100)
+ train_spec = tf.estimator.TrainSpec(..., hooks=[hook])
+ tf.estimator.train_and_evaluate(estimator, train_spec, ...)
+ ```
+
+ Args:
+ estimator: A `tf.estimator.Estimator` instance.
+ metric_name: `str`, metric to track. "loss", "accuracy", etc.
+ threshold: Numeric threshold for the given metric.
+ eval_dir: If set, directory containing summary files with eval metrics. By
+ default, `estimator.eval_dir()` will be used.
+ min_steps: `int`, stop is never requested if global step is less than this
+ value. Defaults to 0.
+ run_every_secs: If specified, calls `should_stop_fn` at an interval of
+ `run_every_secs` seconds. Defaults to 60 seconds. Either this or
+ `run_every_steps` must be set.
+ run_every_steps: If specified, calls `should_stop_fn` every
+ `run_every_steps` steps. Either this or `run_every_secs` must be set.
+
+ Returns:
+ An early-stopping hook of type `SessionRunHook` that periodically checks
+ if the given metric is lower than specified threshold and initiates
+ early stopping if true.
+ """
+ return _stop_if_threshold_crossed_hook(
+ estimator=estimator,
+ metric_name=metric_name,
+ threshold=threshold,
+ higher_is_better=False,
+ eval_dir=eval_dir,
+ min_steps=min_steps,
+ run_every_secs=run_every_secs,
+ run_every_steps=run_every_steps)
+
+
+def stop_if_no_increase_hook(estimator,
+ metric_name,
+ max_steps_without_increase,
+ eval_dir=None,
+ min_steps=0,
+ run_every_secs=60,
+ run_every_steps=None):
+ """Creates hook to stop if metric does not increase within given max steps.
+
+ Usage example:
+
+ ```python
+ estimator = ...
+ # Hook to stop training if accuracy does not increase in over 100000 steps.
+ hook = early_stopping.stop_if_no_increase_hook(estimator, "accuracy", 100000)
+ train_spec = tf.estimator.TrainSpec(..., hooks=[hook])
+ tf.estimator.train_and_evaluate(estimator, train_spec, ...)
+ ```
+
+ Args:
+ estimator: A `tf.estimator.Estimator` instance.
+ metric_name: `str`, metric to track. "loss", "accuracy", etc.
+ max_steps_without_increase: `int`, maximum number of training steps with no
+ increase in the given metric.
+ eval_dir: If set, directory containing summary files with eval metrics. By
+ default, `estimator.eval_dir()` will be used.
+ min_steps: `int`, stop is never requested if global step is less than this
+ value. Defaults to 0.
+ run_every_secs: If specified, calls `should_stop_fn` at an interval of
+ `run_every_secs` seconds. Defaults to 60 seconds. Either this or
+ `run_every_steps` must be set.
+ run_every_steps: If specified, calls `should_stop_fn` every
+ `run_every_steps` steps. Either this or `run_every_secs` must be set.
+
+ Returns:
+ An early-stopping hook of type `SessionRunHook` that periodically checks
+ if the given metric shows no increase over given maximum number of
+ training steps, and initiates early stopping if true.
+ """
+ return _stop_if_no_metric_improvement_hook(
+ estimator=estimator,
+ metric_name=metric_name,
+ max_steps_without_improvement=max_steps_without_increase,
+ higher_is_better=True,
+ eval_dir=eval_dir,
+ min_steps=min_steps,
+ run_every_secs=run_every_secs,
+ run_every_steps=run_every_steps)
+
+
+def stop_if_no_decrease_hook(estimator,
+ metric_name,
+ max_steps_without_decrease,
+ eval_dir=None,
+ min_steps=0,
+ run_every_secs=60,
+ run_every_steps=None):
+ """Creates hook to stop if metric does not decrease within given max steps.
+
+ Usage example:
+
+ ```python
+ estimator = ...
+ # Hook to stop training if loss does not decrease in over 100000 steps.
+ hook = early_stopping.stop_if_no_decrease_hook(estimator, "loss", 100000)
+ train_spec = tf.estimator.TrainSpec(..., hooks=[hook])
+ tf.estimator.train_and_evaluate(estimator, train_spec, ...)
+ ```
+
+ Args:
+ estimator: A `tf.estimator.Estimator` instance.
+ metric_name: `str`, metric to track. "loss", "accuracy", etc.
+ max_steps_without_decrease: `int`, maximum number of training steps with no
+ decrease in the given metric.
+ eval_dir: If set, directory containing summary files with eval metrics. By
+ default, `estimator.eval_dir()` will be used.
+ min_steps: `int`, stop is never requested if global step is less than this
+ value. Defaults to 0.
+ run_every_secs: If specified, calls `should_stop_fn` at an interval of
+ `run_every_secs` seconds. Defaults to 60 seconds. Either this or
+ `run_every_steps` must be set.
+ run_every_steps: If specified, calls `should_stop_fn` every
+ `run_every_steps` steps. Either this or `run_every_secs` must be set.
+
+ Returns:
+ An early-stopping hook of type `SessionRunHook` that periodically checks
+ if the given metric shows no decrease over given maximum number of
+ training steps, and initiates early stopping if true.
+ """
+ return _stop_if_no_metric_improvement_hook(
+ estimator=estimator,
+ metric_name=metric_name,
+ max_steps_without_improvement=max_steps_without_decrease,
+ higher_is_better=False,
+ eval_dir=eval_dir,
+ min_steps=min_steps,
+ run_every_secs=run_every_secs,
+ run_every_steps=run_every_steps)
+
+
+def read_eval_metrics(eval_dir):
+ """Helper to read eval metrics from eval summary files.
+
+ Args:
+ eval_dir: Directory containing summary files with eval metrics.
+
+ Returns:
+ A `dict` with global steps mapping to `dict` of metric names and values.
+ """
+ eval_metrics_dict = {}
+ for event in _summaries(eval_dir):
+ if not event.HasField('summary'):
+ continue
+ metrics = {}
+ for value in event.summary.value:
+ if value.HasField('simple_value'):
+ metrics[value.tag] = value.simple_value
+ if metrics:
+ eval_metrics_dict[event.step] = metrics
+ return eval_metrics_dict
+
+
+def _stop_if_threshold_crossed_hook(estimator, metric_name, threshold,
+ higher_is_better, eval_dir, min_steps,
+ run_every_secs, run_every_steps):
+ """Creates early-stopping hook to stop training if threshold is crossed."""
+
+ if eval_dir is None:
+ eval_dir = estimator.eval_dir()
+
+ is_lhs_better = operator.gt if higher_is_better else operator.lt
+ greater_or_lesser = 'greater than' if higher_is_better else 'less than'
+
+ def stop_if_threshold_crossed_fn():
+ """Returns `True` if the given metric crosses specified threshold."""
+
+ eval_results = read_eval_metrics(eval_dir)
+
+ for step, metrics in eval_results.items():
+ if step < min_steps:
+ continue
+ val = metrics[metric_name]
+ if is_lhs_better(val, threshold):
+ tf_logging.info(
+ 'At step %s, metric "%s" has value %s which is %s the configured '
+ 'threshold (%s) for early stopping.', step, metric_name, val,
+ greater_or_lesser, threshold)
+ return True
+ return False
+
+ return make_early_stopping_hook(
+ estimator=estimator,
+ should_stop_fn=stop_if_threshold_crossed_fn,
+ run_every_secs=run_every_secs,
+ run_every_steps=run_every_steps)
+
+
+def _stop_if_no_metric_improvement_hook(
+ estimator, metric_name, max_steps_without_improvement, higher_is_better,
+ eval_dir, min_steps, run_every_secs, run_every_steps):
+ """Returns hook to stop training if given metric shows no improvement."""
+
+ if eval_dir is None:
+ eval_dir = estimator.eval_dir()
+
+ is_lhs_better = operator.gt if higher_is_better else operator.lt
+ increase_or_decrease = 'increase' if higher_is_better else 'decrease'
+
+ def stop_if_no_metric_improvement_fn():
+ """Returns `True` if metric does not improve within max steps."""
+
+ eval_results = read_eval_metrics(eval_dir)
+
+ best_val = None
+ best_val_step = None
+ for step, metrics in eval_results.items():
+ if step < min_steps:
+ continue
+ val = metrics[metric_name]
+ if best_val is None or is_lhs_better(val, best_val):
+ best_val = val
+ best_val_step = step
+ if step - best_val_step >= max_steps_without_improvement:
+ tf_logging.info(
+ 'No %s in metric "%s" for %s steps, which is greater than or equal '
+ 'to max steps (%s) configured for early stopping.',
+ increase_or_decrease, metric_name, step - best_val_step,
+ max_steps_without_improvement)
+ return True
+ return False
+
+ return make_early_stopping_hook(
+ estimator=estimator,
+ should_stop_fn=stop_if_no_metric_improvement_fn,
+ run_every_secs=run_every_secs,
+ run_every_steps=run_every_steps)
+
+
+def _summaries(eval_dir):
+ """Yields `tensorflow.Event` protos from event files in the eval dir.
+
+ Args:
+ eval_dir: Directory containing summary files with eval metrics.
+
+ Yields:
+ `tensorflow.Event` object read from the event files.
+ """
+ for event_file in gfile.Glob(
+ os.path.join(eval_dir, _EVENT_FILE_GLOB_PATTERN)):
+ for event in summary_iterator.summary_iterator(event_file):
+ yield event
+
+
+def _get_or_create_stop_var():
+ with variable_scope.variable_scope(
+ name_or_scope='signal_early_stopping',
+ values=[],
+ reuse=variable_scope.AUTO_REUSE):
+ return variable_scope.get_variable(
+ name='STOP',
+ shape=[],
+ dtype=dtypes.bool,
+ initializer=init_ops.constant_initializer(False),
+ collections=[ops.GraphKeys.GLOBAL_VARIABLES],
+ trainable=False)
+
+
+class _StopOnPredicateHook(session_run_hook.SessionRunHook):
+ """Hook that requests stop when `should_stop_fn` returns `True`."""
+
+ def __init__(self, should_stop_fn, run_every_secs=60, run_every_steps=None):
+ if not callable(should_stop_fn):
+ raise TypeError('`should_stop_fn` must be callable.')
+
+ self._should_stop_fn = should_stop_fn
+ self._timer = basic_session_run_hooks.SecondOrStepTimer(
+ every_secs=run_every_secs, every_steps=run_every_steps)
+ self._global_step_tensor = None
+ self._stop_var = None
+ self._stop_op = None
+
+ def begin(self):
+ self._global_step_tensor = training_util.get_global_step()
+ self._stop_var = _get_or_create_stop_var()
+ self._stop_op = state_ops.assign(self._stop_var, True)
+
+ def before_run(self, run_context):
+ del run_context
+ return session_run_hook.SessionRunArgs(self._global_step_tensor)
+
+ def after_run(self, run_context, run_values):
+ global_step = run_values.results
+ if self._timer.should_trigger_for_step(global_step):
+ self._timer.update_last_triggered_step(global_step)
+ if self._should_stop_fn():
+ tf_logging.info('Requesting early stopping at global step %d',
+ global_step)
+ run_context.session.run(self._stop_op)
+ run_context.request_stop()
+
+
+class _CheckForStoppingHook(session_run_hook.SessionRunHook):
+ """Hook that requests stop if stop is requested by `_StopOnPredicateHook`."""
+
+ def __init__(self):
+ self._stop_var = None
+
+ def begin(self):
+ self._stop_var = _get_or_create_stop_var()
+
+ def before_run(self, run_context):
+ del run_context
+ return session_run_hook.SessionRunArgs(self._stop_var)
+
+ def after_run(self, run_context, run_values):
+ should_early_stop = run_values.results
+ if should_early_stop:
+ tf_logging.info('Early stopping requested, suspending run.')
+ run_context.request_stop()
diff --git a/tensorflow/contrib/estimator/python/estimator/early_stopping_test.py b/tensorflow/contrib/estimator/python/estimator/early_stopping_test.py
new file mode 100644
index 0000000000..b5eee818fa
--- /dev/null
+++ b/tensorflow/contrib/estimator/python/estimator/early_stopping_test.py
@@ -0,0 +1,233 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for early_stopping."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import tempfile
+
+from absl.testing import parameterized
+from tensorflow.contrib.estimator.python.estimator import early_stopping
+from tensorflow.python.estimator import estimator
+from tensorflow.python.estimator import run_config
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.platform import test
+from tensorflow.python.training import monitored_session
+from tensorflow.python.training import training_util
+
+
+class _FakeRunConfig(run_config.RunConfig):
+
+ def __init__(self, is_chief):
+ super(_FakeRunConfig, self).__init__()
+ self._is_chief = is_chief
+
+ @property
+ def is_chief(self):
+ return self._is_chief
+
+
+def _dummy_model_fn(features, labels, params):
+ _, _, _ = features, labels, params
+
+
+class _FakeEstimator(estimator.Estimator):
+ """Fake estimator for testing."""
+
+ def __init__(self, config):
+ super(_FakeEstimator, self).__init__(
+ model_fn=_dummy_model_fn, config=config)
+
+
+def _write_events(eval_dir, params):
+ """Test helper to write events to summary files."""
+ for steps, loss, accuracy in params:
+ estimator._write_dict_to_summary(eval_dir, {
+ 'loss': loss,
+ 'accuracy': accuracy,
+ }, steps)
+
+
+class ReadEvalMetricsTest(test.TestCase):
+
+ def test_read_eval_metrics(self):
+ eval_dir = tempfile.mkdtemp()
+ _write_events(
+ eval_dir,
+ [
+ # steps, loss, accuracy
+ (1000, 1, 2),
+ (2000, 3, 4),
+ (3000, 5, 6),
+ ])
+ self.assertEqual({
+ 1000: {
+ 'loss': 1,
+ 'accuracy': 2
+ },
+ 2000: {
+ 'loss': 3,
+ 'accuracy': 4
+ },
+ 3000: {
+ 'loss': 5,
+ 'accuracy': 6
+ },
+ }, early_stopping.read_eval_metrics(eval_dir))
+
+
+class EarlyStoppingHooksTest(test.TestCase, parameterized.TestCase):
+
+ def setUp(self):
+ config = _FakeRunConfig(is_chief=True)
+ self._estimator = _FakeEstimator(config=config)
+ eval_dir = self._estimator.eval_dir()
+ os.makedirs(eval_dir)
+ _write_events(
+ eval_dir,
+ [
+ # steps, loss, accuracy
+ (1000, 0.8, 0.5),
+ (2000, 0.7, 0.6),
+ (3000, 0.4, 0.7),
+ (3500, 0.41, 0.68),
+ ])
+
+ def run_session(self, hooks, should_stop):
+ hooks = hooks if isinstance(hooks, list) else [hooks]
+ with ops.Graph().as_default():
+ training_util.create_global_step()
+ no_op = control_flow_ops.no_op()
+ with monitored_session.SingularMonitoredSession(hooks=hooks) as mon_sess:
+ mon_sess.run(no_op)
+ self.assertEqual(mon_sess.should_stop(), should_stop)
+
+ @parameterized.parameters((0.8, 0, False), (0.6, 4000, False), (0.6, 0, True))
+ def test_stop_if_higher_hook(self, threshold, min_steps, should_stop):
+ self.run_session(
+ early_stopping.stop_if_higher_hook(
+ self._estimator,
+ metric_name='accuracy',
+ threshold=threshold,
+ min_steps=min_steps), should_stop)
+
+ @parameterized.parameters((0.3, 0, False), (0.5, 4000, False), (0.5, 0, True))
+ def test_stop_if_lower_hook(self, threshold, min_steps, should_stop):
+ self.run_session(
+ early_stopping.stop_if_lower_hook(
+ self._estimator,
+ metric_name='loss',
+ threshold=threshold,
+ min_steps=min_steps), should_stop)
+
+ @parameterized.parameters((1500, 0, False), (500, 4000, False),
+ (500, 0, True))
+ def test_stop_if_no_increase_hook(self, max_steps, min_steps, should_stop):
+ self.run_session(
+ early_stopping.stop_if_no_increase_hook(
+ self._estimator,
+ metric_name='accuracy',
+ max_steps_without_increase=max_steps,
+ min_steps=min_steps), should_stop)
+
+ @parameterized.parameters((1500, 0, False), (500, 4000, False),
+ (500, 0, True))
+ def test_stop_if_no_decrease_hook(self, max_steps, min_steps, should_stop):
+ self.run_session(
+ early_stopping.stop_if_no_decrease_hook(
+ self._estimator,
+ metric_name='loss',
+ max_steps_without_decrease=max_steps,
+ min_steps=min_steps), should_stop)
+
+ @parameterized.parameters((1500, 0.3, False), (1500, 0.5, True),
+ (500, 0.3, True))
+ def test_multiple_hooks(self, max_steps, loss_threshold, should_stop):
+ self.run_session([
+ early_stopping.stop_if_no_decrease_hook(
+ self._estimator,
+ metric_name='loss',
+ max_steps_without_decrease=max_steps),
+ early_stopping.stop_if_lower_hook(
+ self._estimator, metric_name='loss', threshold=loss_threshold)
+ ], should_stop)
+
+ @parameterized.parameters(False, True)
+ def test_make_early_stopping_hook(self, should_stop):
+ self.run_session([
+ early_stopping.make_early_stopping_hook(
+ self._estimator, should_stop_fn=lambda: should_stop)
+ ], should_stop)
+
+ def test_make_early_stopping_hook_typeerror(self):
+ with self.assertRaises(TypeError):
+ early_stopping.make_early_stopping_hook(
+ estimator=object(), should_stop_fn=lambda: True)
+
+ def test_make_early_stopping_hook_valueerror(self):
+ with self.assertRaises(ValueError):
+ early_stopping.make_early_stopping_hook(
+ self._estimator,
+ should_stop_fn=lambda: True,
+ run_every_secs=60,
+ run_every_steps=100)
+
+
+class StopOnPredicateHookTest(test.TestCase):
+
+ def test_stop(self):
+ hook = early_stopping._StopOnPredicateHook(
+ should_stop_fn=lambda: False, run_every_secs=0)
+ with ops.Graph().as_default():
+ training_util.create_global_step()
+ no_op = control_flow_ops.no_op()
+ with monitored_session.SingularMonitoredSession(hooks=[hook]) as mon_sess:
+ mon_sess.run(no_op)
+ self.assertFalse(mon_sess.should_stop())
+ self.assertFalse(mon_sess.raw_session().run(hook._stop_var))
+
+ hook = early_stopping._StopOnPredicateHook(
+ should_stop_fn=lambda: True, run_every_secs=0)
+ with ops.Graph().as_default():
+ training_util.create_global_step()
+ no_op = control_flow_ops.no_op()
+ with monitored_session.SingularMonitoredSession(hooks=[hook]) as mon_sess:
+ mon_sess.run(no_op)
+ self.assertTrue(mon_sess.should_stop())
+ self.assertTrue(mon_sess.raw_session().run(hook._stop_var))
+
+
+class CheckForStoppingHookTest(test.TestCase):
+
+ def test_stop(self):
+ hook = early_stopping._CheckForStoppingHook()
+ with ops.Graph().as_default():
+ no_op = control_flow_ops.no_op()
+ assign_op = state_ops.assign(early_stopping._get_or_create_stop_var(),
+ True)
+ with monitored_session.SingularMonitoredSession(hooks=[hook]) as mon_sess:
+ mon_sess.run(no_op)
+ self.assertFalse(mon_sess.should_stop())
+ mon_sess.run(assign_op)
+ self.assertTrue(mon_sess.should_stop())
+
+
+if __name__ == '__main__':
+ test.main()