diff options
author | Jianwei Xie <xiejw@google.com> | 2017-01-25 14:09:48 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-01-25 14:25:42 -0800 |
commit | 59757b7afd5dd08d5651ca966f03511bb2aad7bd (patch) | |
tree | 5ffa4f80e83345d09855a6afa1cab487d5b1a4c7 | |
parent | 665d811195a3d8ab84f4ca598bbd717ea51431cf (diff) |
Remove continuous_eval_predicate_fn from Experiment class.
Change: 145594656
-rw-r--r-- | tensorflow/contrib/learn/python/learn/experiment.py | 58 | ||||
-rw-r--r-- | tensorflow/contrib/learn/python/learn/experiment_test.py | 6 |
2 files changed, 33 insertions, 31 deletions
diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py index 70b12631d4..8669616d87 100644 --- a/tensorflow/contrib/learn/python/learn/experiment.py +++ b/tensorflow/contrib/learn/python/learn/experiment.py @@ -74,8 +74,7 @@ class Experiment(object): continuous_eval_throttle_secs=60, min_eval_frequency=1, delay_workers_by_global_step=False, - export_strategies=None, - continuous_eval_predicate_fn=None): + export_strategies=None): """Constructor for `Experiment`. Creates an Experiment instance. None of the functions passed to this @@ -111,12 +110,6 @@ class Experiment(object): delay_workers_by_global_step: if `True` delays training workers based on global step instead of time. export_strategies: A list of `ExportStrategy`s, or a single one, or None. - continuous_eval_predicate_fn: A predicate function determining whether to - continue eval after each iteration. `predicate_fn` takes the evaluation - results as arguments. At the beginning of evaluation, the passed eval - results will be None so it's expected that the predicate function - handles that gracefully. When `predicate_fn` is not specified, - continuous eval will run in an infinite loop. Raises: ValueError: if `estimator` does not implement `Evaluable` and `Trainable`, @@ -142,8 +135,6 @@ class Experiment(object): self._train_monitors = train_monitors[:] if train_monitors else [] self._eval_hooks = eval_hooks[:] if eval_hooks else [] self._set_export_strategies(export_strategies) - # Mutable fields, using the setters. - self.continuous_eval_predicate_fn = continuous_eval_predicate_fn @property def estimator(self): @@ -161,17 +152,6 @@ class Experiment(object): def eval_steps(self): return self._eval_steps - @property - def continuous_eval_predicate_fn(self): - return self._continuous_eval_predicate_fn - - @continuous_eval_predicate_fn.setter - def continuous_eval_predicate_fn(self, value): - if value is not None and not callable(value): - raise ValueError( - "`continuous_eval_predicate_fn` must be a callable, or None.") - self._continuous_eval_predicate_fn = value - def _set_export_strategies(self, value): if value is None: self._export_strategies = [] @@ -288,12 +268,15 @@ class Experiment(object): self._min_eval_frequency = self._local_eval_frequency return self.train_and_evaluate() + # TODO(xiejw): Allow continuous_eval_predicate_fn to be passed via constructor + # once stopping all jobs is implemented. def _continuous_eval(self, input_fn, name, delay_secs, throttle_delay_secs, - evaluate_checkpoint_only_once=True): + evaluate_checkpoint_only_once=True, + continuous_eval_predicate_fn=None): """Run continuous eval. Runs infinite eval on the evaluation data set. This function starts @@ -311,7 +294,22 @@ class Experiment(object): self._continuous_eval_throttle_secs. evaluate_checkpoint_only_once: Whether to skip evaluation of checkpoints that have already been evaluated. Default is `True`. + continuous_eval_predicate_fn: A predicate function determining whether to + continue eval after each iteration. `predicate_fn` takes the evaluation + results as arguments. At the beginning of evaluation, the passed eval + results will be None so it's expected that the predicate function + handles that gracefully. When `predicate_fn` is not specified, + continuous eval will run in an infinite loop. + + Raises: + ValueError: if `continuous_eval_predicate_fn` is neither None nor + callable. """ + if (continuous_eval_predicate_fn is not None and + not callable(continuous_eval_predicate_fn)): + raise ValueError( + "`continuous_eval_predicate_fn` must be a callable, or None.") + if delay_secs is None: delay_secs = self._eval_delay_secs if throttle_delay_secs is None: @@ -324,8 +322,8 @@ class Experiment(object): previous_path = None eval_result = None last_warning_time = 0 - while (not self.continuous_eval_predicate_fn or - self.continuous_eval_predicate_fn(eval_result)): + while (not continuous_eval_predicate_fn or + continuous_eval_predicate_fn(eval_result)): start = time.time() error_msg = None @@ -370,22 +368,26 @@ class Experiment(object): def continuous_eval(self, delay_secs=None, throttle_delay_secs=None, - evaluate_checkpoint_only_once=True): + evaluate_checkpoint_only_once=True, + continuous_eval_predicate_fn=None): self._continuous_eval( self._eval_input_fn, name="continuous", delay_secs=delay_secs, throttle_delay_secs=throttle_delay_secs, - evaluate_checkpoint_only_once=evaluate_checkpoint_only_once) + evaluate_checkpoint_only_once=evaluate_checkpoint_only_once, + continuous_eval_predicate_fn=continuous_eval_predicate_fn) def continuous_eval_on_train_data(self, delay_secs=None, - throttle_delay_secs=None): + throttle_delay_secs=None, + continuous_eval_predicate_fn=None): self._continuous_eval( self._train_input_fn, name="continuous_on_train_data", delay_secs=delay_secs, - throttle_delay_secs=throttle_delay_secs) + throttle_delay_secs=throttle_delay_secs, + continuous_eval_predicate_fn=continuous_eval_predicate_fn) def train_and_evaluate(self): """Interleaves training and evaluation. diff --git a/tensorflow/contrib/learn/python/learn/experiment_test.py b/tensorflow/contrib/learn/python/learn/experiment_test.py index cf937443f0..398fc6e176 100644 --- a/tensorflow/contrib/learn/python/learn/experiment_test.py +++ b/tensorflow/contrib/learn/python/learn/experiment_test.py @@ -342,9 +342,9 @@ class ExperimentTest(test.TestCase): eval_metrics='eval_metrics', eval_hooks=[noop_hook], eval_delay_secs=0, - continuous_eval_throttle_secs=0, - continuous_eval_predicate_fn=_predicate_fn) - ex.continuous_eval(evaluate_checkpoint_only_once=False) + continuous_eval_throttle_secs=0) + ex.continuous_eval(evaluate_checkpoint_only_once=False, + continuous_eval_predicate_fn=_predicate_fn) self.assertEqual(0, est.fit_count) self.assertEqual(3, est.eval_count) self.assertEqual([noop_hook], est.eval_hooks) |