From 90e66f2aa1015496c8f8e9be573ef83f542a2ad0 Mon Sep 17 00:00:00 2001 From: Goutham Bhat Date: Mon, 9 Jul 2018 11:39:24 -0700 Subject: Early-stopping functionality for use with tf.estimator API. PiperOrigin-RevId: 203801553 --- tensorflow/contrib/estimator/BUILD | 29 ++ tensorflow/contrib/estimator/__init__.py | 7 + .../estimator/python/estimator/early_stopping.py | 468 +++++++++++++++++++++ .../python/estimator/early_stopping_test.py | 233 ++++++++++ 4 files changed, 737 insertions(+) create mode 100644 tensorflow/contrib/estimator/python/estimator/early_stopping.py create mode 100644 tensorflow/contrib/estimator/python/estimator/early_stopping_test.py (limited to 'tensorflow/contrib/estimator') diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index 30d297a5fb..11d40f5982 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -18,6 +18,7 @@ py_library( ":boosted_trees", ":dnn", ":dnn_linear_combined", + ":early_stopping", ":export", ":extenders", ":head", @@ -590,3 +591,31 @@ py_test( "@six_archive//:six", ], ) + +py_library( + name = "early_stopping", + srcs = ["python/estimator/early_stopping.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:platform", + "//tensorflow/python:state_ops", + "//tensorflow/python:summary", + "//tensorflow/python:training", + "//tensorflow/python/estimator", + ], +) + +py_test( + name = "early_stopping_test", + srcs = ["python/estimator/early_stopping_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":early_stopping", + "//tensorflow/python:client_testlib", + "//tensorflow/python/estimator", + "@absl_py//absl/testing:parameterized", + ], +) diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py index 788ac5ca70..09fcfd66a1 100644 --- a/tensorflow/contrib/estimator/__init__.py +++ b/tensorflow/contrib/estimator/__init__.py @@ -23,6 +23,7 @@ from tensorflow.contrib.estimator.python.estimator.baseline import * from tensorflow.contrib.estimator.python.estimator.boosted_trees import * from tensorflow.contrib.estimator.python.estimator.dnn import * from tensorflow.contrib.estimator.python.estimator.dnn_linear_combined import * +from tensorflow.contrib.estimator.python.estimator.early_stopping import * from tensorflow.contrib.estimator.python.estimator.export import * from tensorflow.contrib.estimator.python.estimator.extenders import * from tensorflow.contrib.estimator.python.estimator.head import * @@ -63,6 +64,12 @@ _allowed_symbols = [ 'RNNEstimator', 'export_saved_model_for_mode', 'export_all_saved_models', + 'make_early_stopping_hook', + 'read_eval_metrics', + 'stop_if_lower_hook', + 'stop_if_higher_hook', + 'stop_if_no_increase_hook', + 'stop_if_no_decrease_hook', ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/estimator/python/estimator/early_stopping.py b/tensorflow/contrib/estimator/python/estimator/early_stopping.py new file mode 100644 index 0000000000..af4855e91e --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/early_stopping.py @@ -0,0 +1,468 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for early stopping.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import operator +import os + +from tensorflow.python.estimator import estimator as estimator_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import gfile +from tensorflow.python.platform import tf_logging +from tensorflow.python.summary import summary_iterator +from tensorflow.python.training import basic_session_run_hooks +from tensorflow.python.training import session_run_hook +from tensorflow.python.training import training_util + +_EVENT_FILE_GLOB_PATTERN = 'events.out.tfevents.*' + + +def make_early_stopping_hook(estimator, + should_stop_fn, + run_every_secs=60, + run_every_steps=None): + """Creates early-stopping hook. + + Returns a `SessionRunHook` that stops training when `should_stop_fn` returns + `True`. + + Usage example: + + ```python + estimator = ... + hook = early_stopping.make_early_stopping_hook( + estimator, should_stop_fn=make_stop_fn(...)) + train_spec = tf.estimator.TrainSpec(..., hooks=[hook]) + tf.estimator.train_and_evaluate(estimator, train_spec, ...) + ``` + + Args: + estimator: A `tf.estimator.Estimator` instance. + should_stop_fn: `callable`, function that takes no arguments and returns a + `bool`. If the function returns `True`, stopping will be initiated by the + chief. + run_every_secs: If specified, calls `should_stop_fn` at an interval of + `run_every_secs` seconds. Defaults to 60 seconds. Either this or + `run_every_steps` must be set. + run_every_steps: If specified, calls `should_stop_fn` every + `run_every_steps` steps. Either this or `run_every_secs` must be set. + + Returns: + A `SessionRunHook` that periodically executes `should_stop_fn` and initiates + early stopping if the function returns `True`. + + Raises: + TypeError: If `estimator` is not of type `tf.estimator.Estimator`. + ValueError: If both `run_every_secs` and `run_every_steps` are set. + """ + if not isinstance(estimator, estimator_lib.Estimator): + raise TypeError('`estimator` must have type `tf.estimator.Estimator`. ' + 'Got: {}'.format(type(estimator))) + + if run_every_secs is not None and run_every_steps is not None: + raise ValueError('Only one of `run_every_secs` and `run_every_steps` must ' + 'be set.') + + if estimator.config.is_chief: + return _StopOnPredicateHook(should_stop_fn, run_every_secs, run_every_steps) + else: + return _CheckForStoppingHook() + + +def stop_if_higher_hook(estimator, + metric_name, + threshold, + eval_dir=None, + min_steps=0, + run_every_secs=60, + run_every_steps=None): + """Creates hook to stop if the given metric is higher than the threshold. + + Usage example: + + ```python + estimator = ... + # Hook to stop training if accuracy becomes higher than 0.9. + hook = early_stopping.stop_if_higher_hook(estimator, "accuracy", 0.9) + train_spec = tf.estimator.TrainSpec(..., hooks=[hook]) + tf.estimator.train_and_evaluate(estimator, train_spec, ...) + ``` + + Args: + estimator: A `tf.estimator.Estimator` instance. + metric_name: `str`, metric to track. "loss", "accuracy", etc. + threshold: Numeric threshold for the given metric. + eval_dir: If set, directory containing summary files with eval metrics. By + default, `estimator.eval_dir()` will be used. + min_steps: `int`, stop is never requested if global step is less than this + value. Defaults to 0. + run_every_secs: If specified, calls `should_stop_fn` at an interval of + `run_every_secs` seconds. Defaults to 60 seconds. Either this or + `run_every_steps` must be set. + run_every_steps: If specified, calls `should_stop_fn` every + `run_every_steps` steps. Either this or `run_every_secs` must be set. + + Returns: + An early-stopping hook of type `SessionRunHook` that periodically checks + if the given metric is higher than specified threshold and initiates + early stopping if true. + """ + return _stop_if_threshold_crossed_hook( + estimator=estimator, + metric_name=metric_name, + threshold=threshold, + higher_is_better=True, + eval_dir=eval_dir, + min_steps=min_steps, + run_every_secs=run_every_secs, + run_every_steps=run_every_steps) + + +def stop_if_lower_hook(estimator, + metric_name, + threshold, + eval_dir=None, + min_steps=0, + run_every_secs=60, + run_every_steps=None): + """Creates hook to stop if the given metric is lower than the threshold. + + Usage example: + + ```python + estimator = ... + # Hook to stop training if loss becomes lower than 100. + hook = early_stopping.stop_if_lower_hook(estimator, "loss", 100) + train_spec = tf.estimator.TrainSpec(..., hooks=[hook]) + tf.estimator.train_and_evaluate(estimator, train_spec, ...) + ``` + + Args: + estimator: A `tf.estimator.Estimator` instance. + metric_name: `str`, metric to track. "loss", "accuracy", etc. + threshold: Numeric threshold for the given metric. + eval_dir: If set, directory containing summary files with eval metrics. By + default, `estimator.eval_dir()` will be used. + min_steps: `int`, stop is never requested if global step is less than this + value. Defaults to 0. + run_every_secs: If specified, calls `should_stop_fn` at an interval of + `run_every_secs` seconds. Defaults to 60 seconds. Either this or + `run_every_steps` must be set. + run_every_steps: If specified, calls `should_stop_fn` every + `run_every_steps` steps. Either this or `run_every_secs` must be set. + + Returns: + An early-stopping hook of type `SessionRunHook` that periodically checks + if the given metric is lower than specified threshold and initiates + early stopping if true. + """ + return _stop_if_threshold_crossed_hook( + estimator=estimator, + metric_name=metric_name, + threshold=threshold, + higher_is_better=False, + eval_dir=eval_dir, + min_steps=min_steps, + run_every_secs=run_every_secs, + run_every_steps=run_every_steps) + + +def stop_if_no_increase_hook(estimator, + metric_name, + max_steps_without_increase, + eval_dir=None, + min_steps=0, + run_every_secs=60, + run_every_steps=None): + """Creates hook to stop if metric does not increase within given max steps. + + Usage example: + + ```python + estimator = ... + # Hook to stop training if accuracy does not increase in over 100000 steps. + hook = early_stopping.stop_if_no_increase_hook(estimator, "accuracy", 100000) + train_spec = tf.estimator.TrainSpec(..., hooks=[hook]) + tf.estimator.train_and_evaluate(estimator, train_spec, ...) + ``` + + Args: + estimator: A `tf.estimator.Estimator` instance. + metric_name: `str`, metric to track. "loss", "accuracy", etc. + max_steps_without_increase: `int`, maximum number of training steps with no + increase in the given metric. + eval_dir: If set, directory containing summary files with eval metrics. By + default, `estimator.eval_dir()` will be used. + min_steps: `int`, stop is never requested if global step is less than this + value. Defaults to 0. + run_every_secs: If specified, calls `should_stop_fn` at an interval of + `run_every_secs` seconds. Defaults to 60 seconds. Either this or + `run_every_steps` must be set. + run_every_steps: If specified, calls `should_stop_fn` every + `run_every_steps` steps. Either this or `run_every_secs` must be set. + + Returns: + An early-stopping hook of type `SessionRunHook` that periodically checks + if the given metric shows no increase over given maximum number of + training steps, and initiates early stopping if true. + """ + return _stop_if_no_metric_improvement_hook( + estimator=estimator, + metric_name=metric_name, + max_steps_without_improvement=max_steps_without_increase, + higher_is_better=True, + eval_dir=eval_dir, + min_steps=min_steps, + run_every_secs=run_every_secs, + run_every_steps=run_every_steps) + + +def stop_if_no_decrease_hook(estimator, + metric_name, + max_steps_without_decrease, + eval_dir=None, + min_steps=0, + run_every_secs=60, + run_every_steps=None): + """Creates hook to stop if metric does not decrease within given max steps. + + Usage example: + + ```python + estimator = ... + # Hook to stop training if loss does not decrease in over 100000 steps. + hook = early_stopping.stop_if_no_decrease_hook(estimator, "loss", 100000) + train_spec = tf.estimator.TrainSpec(..., hooks=[hook]) + tf.estimator.train_and_evaluate(estimator, train_spec, ...) + ``` + + Args: + estimator: A `tf.estimator.Estimator` instance. + metric_name: `str`, metric to track. "loss", "accuracy", etc. + max_steps_without_decrease: `int`, maximum number of training steps with no + decrease in the given metric. + eval_dir: If set, directory containing summary files with eval metrics. By + default, `estimator.eval_dir()` will be used. + min_steps: `int`, stop is never requested if global step is less than this + value. Defaults to 0. + run_every_secs: If specified, calls `should_stop_fn` at an interval of + `run_every_secs` seconds. Defaults to 60 seconds. Either this or + `run_every_steps` must be set. + run_every_steps: If specified, calls `should_stop_fn` every + `run_every_steps` steps. Either this or `run_every_secs` must be set. + + Returns: + An early-stopping hook of type `SessionRunHook` that periodically checks + if the given metric shows no decrease over given maximum number of + training steps, and initiates early stopping if true. + """ + return _stop_if_no_metric_improvement_hook( + estimator=estimator, + metric_name=metric_name, + max_steps_without_improvement=max_steps_without_decrease, + higher_is_better=False, + eval_dir=eval_dir, + min_steps=min_steps, + run_every_secs=run_every_secs, + run_every_steps=run_every_steps) + + +def read_eval_metrics(eval_dir): + """Helper to read eval metrics from eval summary files. + + Args: + eval_dir: Directory containing summary files with eval metrics. + + Returns: + A `dict` with global steps mapping to `dict` of metric names and values. + """ + eval_metrics_dict = {} + for event in _summaries(eval_dir): + if not event.HasField('summary'): + continue + metrics = {} + for value in event.summary.value: + if value.HasField('simple_value'): + metrics[value.tag] = value.simple_value + if metrics: + eval_metrics_dict[event.step] = metrics + return eval_metrics_dict + + +def _stop_if_threshold_crossed_hook(estimator, metric_name, threshold, + higher_is_better, eval_dir, min_steps, + run_every_secs, run_every_steps): + """Creates early-stopping hook to stop training if threshold is crossed.""" + + if eval_dir is None: + eval_dir = estimator.eval_dir() + + is_lhs_better = operator.gt if higher_is_better else operator.lt + greater_or_lesser = 'greater than' if higher_is_better else 'less than' + + def stop_if_threshold_crossed_fn(): + """Returns `True` if the given metric crosses specified threshold.""" + + eval_results = read_eval_metrics(eval_dir) + + for step, metrics in eval_results.items(): + if step < min_steps: + continue + val = metrics[metric_name] + if is_lhs_better(val, threshold): + tf_logging.info( + 'At step %s, metric "%s" has value %s which is %s the configured ' + 'threshold (%s) for early stopping.', step, metric_name, val, + greater_or_lesser, threshold) + return True + return False + + return make_early_stopping_hook( + estimator=estimator, + should_stop_fn=stop_if_threshold_crossed_fn, + run_every_secs=run_every_secs, + run_every_steps=run_every_steps) + + +def _stop_if_no_metric_improvement_hook( + estimator, metric_name, max_steps_without_improvement, higher_is_better, + eval_dir, min_steps, run_every_secs, run_every_steps): + """Returns hook to stop training if given metric shows no improvement.""" + + if eval_dir is None: + eval_dir = estimator.eval_dir() + + is_lhs_better = operator.gt if higher_is_better else operator.lt + increase_or_decrease = 'increase' if higher_is_better else 'decrease' + + def stop_if_no_metric_improvement_fn(): + """Returns `True` if metric does not improve within max steps.""" + + eval_results = read_eval_metrics(eval_dir) + + best_val = None + best_val_step = None + for step, metrics in eval_results.items(): + if step < min_steps: + continue + val = metrics[metric_name] + if best_val is None or is_lhs_better(val, best_val): + best_val = val + best_val_step = step + if step - best_val_step >= max_steps_without_improvement: + tf_logging.info( + 'No %s in metric "%s" for %s steps, which is greater than or equal ' + 'to max steps (%s) configured for early stopping.', + increase_or_decrease, metric_name, step - best_val_step, + max_steps_without_improvement) + return True + return False + + return make_early_stopping_hook( + estimator=estimator, + should_stop_fn=stop_if_no_metric_improvement_fn, + run_every_secs=run_every_secs, + run_every_steps=run_every_steps) + + +def _summaries(eval_dir): + """Yields `tensorflow.Event` protos from event files in the eval dir. + + Args: + eval_dir: Directory containing summary files with eval metrics. + + Yields: + `tensorflow.Event` object read from the event files. + """ + for event_file in gfile.Glob( + os.path.join(eval_dir, _EVENT_FILE_GLOB_PATTERN)): + for event in summary_iterator.summary_iterator(event_file): + yield event + + +def _get_or_create_stop_var(): + with variable_scope.variable_scope( + name_or_scope='signal_early_stopping', + values=[], + reuse=variable_scope.AUTO_REUSE): + return variable_scope.get_variable( + name='STOP', + shape=[], + dtype=dtypes.bool, + initializer=init_ops.constant_initializer(False), + collections=[ops.GraphKeys.GLOBAL_VARIABLES], + trainable=False) + + +class _StopOnPredicateHook(session_run_hook.SessionRunHook): + """Hook that requests stop when `should_stop_fn` returns `True`.""" + + def __init__(self, should_stop_fn, run_every_secs=60, run_every_steps=None): + if not callable(should_stop_fn): + raise TypeError('`should_stop_fn` must be callable.') + + self._should_stop_fn = should_stop_fn + self._timer = basic_session_run_hooks.SecondOrStepTimer( + every_secs=run_every_secs, every_steps=run_every_steps) + self._global_step_tensor = None + self._stop_var = None + self._stop_op = None + + def begin(self): + self._global_step_tensor = training_util.get_global_step() + self._stop_var = _get_or_create_stop_var() + self._stop_op = state_ops.assign(self._stop_var, True) + + def before_run(self, run_context): + del run_context + return session_run_hook.SessionRunArgs(self._global_step_tensor) + + def after_run(self, run_context, run_values): + global_step = run_values.results + if self._timer.should_trigger_for_step(global_step): + self._timer.update_last_triggered_step(global_step) + if self._should_stop_fn(): + tf_logging.info('Requesting early stopping at global step %d', + global_step) + run_context.session.run(self._stop_op) + run_context.request_stop() + + +class _CheckForStoppingHook(session_run_hook.SessionRunHook): + """Hook that requests stop if stop is requested by `_StopOnPredicateHook`.""" + + def __init__(self): + self._stop_var = None + + def begin(self): + self._stop_var = _get_or_create_stop_var() + + def before_run(self, run_context): + del run_context + return session_run_hook.SessionRunArgs(self._stop_var) + + def after_run(self, run_context, run_values): + should_early_stop = run_values.results + if should_early_stop: + tf_logging.info('Early stopping requested, suspending run.') + run_context.request_stop() diff --git a/tensorflow/contrib/estimator/python/estimator/early_stopping_test.py b/tensorflow/contrib/estimator/python/estimator/early_stopping_test.py new file mode 100644 index 0000000000..b5eee818fa --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/early_stopping_test.py @@ -0,0 +1,233 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for early_stopping.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import tempfile + +from absl.testing import parameterized +from tensorflow.contrib.estimator.python.estimator import early_stopping +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator import run_config +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.platform import test +from tensorflow.python.training import monitored_session +from tensorflow.python.training import training_util + + +class _FakeRunConfig(run_config.RunConfig): + + def __init__(self, is_chief): + super(_FakeRunConfig, self).__init__() + self._is_chief = is_chief + + @property + def is_chief(self): + return self._is_chief + + +def _dummy_model_fn(features, labels, params): + _, _, _ = features, labels, params + + +class _FakeEstimator(estimator.Estimator): + """Fake estimator for testing.""" + + def __init__(self, config): + super(_FakeEstimator, self).__init__( + model_fn=_dummy_model_fn, config=config) + + +def _write_events(eval_dir, params): + """Test helper to write events to summary files.""" + for steps, loss, accuracy in params: + estimator._write_dict_to_summary(eval_dir, { + 'loss': loss, + 'accuracy': accuracy, + }, steps) + + +class ReadEvalMetricsTest(test.TestCase): + + def test_read_eval_metrics(self): + eval_dir = tempfile.mkdtemp() + _write_events( + eval_dir, + [ + # steps, loss, accuracy + (1000, 1, 2), + (2000, 3, 4), + (3000, 5, 6), + ]) + self.assertEqual({ + 1000: { + 'loss': 1, + 'accuracy': 2 + }, + 2000: { + 'loss': 3, + 'accuracy': 4 + }, + 3000: { + 'loss': 5, + 'accuracy': 6 + }, + }, early_stopping.read_eval_metrics(eval_dir)) + + +class EarlyStoppingHooksTest(test.TestCase, parameterized.TestCase): + + def setUp(self): + config = _FakeRunConfig(is_chief=True) + self._estimator = _FakeEstimator(config=config) + eval_dir = self._estimator.eval_dir() + os.makedirs(eval_dir) + _write_events( + eval_dir, + [ + # steps, loss, accuracy + (1000, 0.8, 0.5), + (2000, 0.7, 0.6), + (3000, 0.4, 0.7), + (3500, 0.41, 0.68), + ]) + + def run_session(self, hooks, should_stop): + hooks = hooks if isinstance(hooks, list) else [hooks] + with ops.Graph().as_default(): + training_util.create_global_step() + no_op = control_flow_ops.no_op() + with monitored_session.SingularMonitoredSession(hooks=hooks) as mon_sess: + mon_sess.run(no_op) + self.assertEqual(mon_sess.should_stop(), should_stop) + + @parameterized.parameters((0.8, 0, False), (0.6, 4000, False), (0.6, 0, True)) + def test_stop_if_higher_hook(self, threshold, min_steps, should_stop): + self.run_session( + early_stopping.stop_if_higher_hook( + self._estimator, + metric_name='accuracy', + threshold=threshold, + min_steps=min_steps), should_stop) + + @parameterized.parameters((0.3, 0, False), (0.5, 4000, False), (0.5, 0, True)) + def test_stop_if_lower_hook(self, threshold, min_steps, should_stop): + self.run_session( + early_stopping.stop_if_lower_hook( + self._estimator, + metric_name='loss', + threshold=threshold, + min_steps=min_steps), should_stop) + + @parameterized.parameters((1500, 0, False), (500, 4000, False), + (500, 0, True)) + def test_stop_if_no_increase_hook(self, max_steps, min_steps, should_stop): + self.run_session( + early_stopping.stop_if_no_increase_hook( + self._estimator, + metric_name='accuracy', + max_steps_without_increase=max_steps, + min_steps=min_steps), should_stop) + + @parameterized.parameters((1500, 0, False), (500, 4000, False), + (500, 0, True)) + def test_stop_if_no_decrease_hook(self, max_steps, min_steps, should_stop): + self.run_session( + early_stopping.stop_if_no_decrease_hook( + self._estimator, + metric_name='loss', + max_steps_without_decrease=max_steps, + min_steps=min_steps), should_stop) + + @parameterized.parameters((1500, 0.3, False), (1500, 0.5, True), + (500, 0.3, True)) + def test_multiple_hooks(self, max_steps, loss_threshold, should_stop): + self.run_session([ + early_stopping.stop_if_no_decrease_hook( + self._estimator, + metric_name='loss', + max_steps_without_decrease=max_steps), + early_stopping.stop_if_lower_hook( + self._estimator, metric_name='loss', threshold=loss_threshold) + ], should_stop) + + @parameterized.parameters(False, True) + def test_make_early_stopping_hook(self, should_stop): + self.run_session([ + early_stopping.make_early_stopping_hook( + self._estimator, should_stop_fn=lambda: should_stop) + ], should_stop) + + def test_make_early_stopping_hook_typeerror(self): + with self.assertRaises(TypeError): + early_stopping.make_early_stopping_hook( + estimator=object(), should_stop_fn=lambda: True) + + def test_make_early_stopping_hook_valueerror(self): + with self.assertRaises(ValueError): + early_stopping.make_early_stopping_hook( + self._estimator, + should_stop_fn=lambda: True, + run_every_secs=60, + run_every_steps=100) + + +class StopOnPredicateHookTest(test.TestCase): + + def test_stop(self): + hook = early_stopping._StopOnPredicateHook( + should_stop_fn=lambda: False, run_every_secs=0) + with ops.Graph().as_default(): + training_util.create_global_step() + no_op = control_flow_ops.no_op() + with monitored_session.SingularMonitoredSession(hooks=[hook]) as mon_sess: + mon_sess.run(no_op) + self.assertFalse(mon_sess.should_stop()) + self.assertFalse(mon_sess.raw_session().run(hook._stop_var)) + + hook = early_stopping._StopOnPredicateHook( + should_stop_fn=lambda: True, run_every_secs=0) + with ops.Graph().as_default(): + training_util.create_global_step() + no_op = control_flow_ops.no_op() + with monitored_session.SingularMonitoredSession(hooks=[hook]) as mon_sess: + mon_sess.run(no_op) + self.assertTrue(mon_sess.should_stop()) + self.assertTrue(mon_sess.raw_session().run(hook._stop_var)) + + +class CheckForStoppingHookTest(test.TestCase): + + def test_stop(self): + hook = early_stopping._CheckForStoppingHook() + with ops.Graph().as_default(): + no_op = control_flow_ops.no_op() + assign_op = state_ops.assign(early_stopping._get_or_create_stop_var(), + True) + with monitored_session.SingularMonitoredSession(hooks=[hook]) as mon_sess: + mon_sess.run(no_op) + self.assertFalse(mon_sess.should_stop()) + mon_sess.run(assign_op) + self.assertTrue(mon_sess.should_stop()) + + +if __name__ == '__main__': + test.main() -- cgit v1.2.3