aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Jianwei Xie <xiejw@google.com>2017-01-25 14:09:48 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-25 14:25:42 -0800
commit59757b7afd5dd08d5651ca966f03511bb2aad7bd (patch)
tree5ffa4f80e83345d09855a6afa1cab487d5b1a4c7
parent665d811195a3d8ab84f4ca598bbd717ea51431cf (diff)
Remove continuous_eval_predicate_fn from Experiment class.
Change: 145594656
-rw-r--r--tensorflow/contrib/learn/python/learn/experiment.py58
-rw-r--r--tensorflow/contrib/learn/python/learn/experiment_test.py6
2 files changed, 33 insertions, 31 deletions
diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py
index 70b12631d4..8669616d87 100644
--- a/tensorflow/contrib/learn/python/learn/experiment.py
+++ b/tensorflow/contrib/learn/python/learn/experiment.py
@@ -74,8 +74,7 @@ class Experiment(object):
continuous_eval_throttle_secs=60,
min_eval_frequency=1,
delay_workers_by_global_step=False,
- export_strategies=None,
- continuous_eval_predicate_fn=None):
+ export_strategies=None):
"""Constructor for `Experiment`.
Creates an Experiment instance. None of the functions passed to this
@@ -111,12 +110,6 @@ class Experiment(object):
delay_workers_by_global_step: if `True` delays training workers
based on global step instead of time.
export_strategies: A list of `ExportStrategy`s, or a single one, or None.
- continuous_eval_predicate_fn: A predicate function determining whether to
- continue eval after each iteration. `predicate_fn` takes the evaluation
- results as arguments. At the beginning of evaluation, the passed eval
- results will be None so it's expected that the predicate function
- handles that gracefully. When `predicate_fn` is not specified,
- continuous eval will run in an infinite loop.
Raises:
ValueError: if `estimator` does not implement `Evaluable` and `Trainable`,
@@ -142,8 +135,6 @@ class Experiment(object):
self._train_monitors = train_monitors[:] if train_monitors else []
self._eval_hooks = eval_hooks[:] if eval_hooks else []
self._set_export_strategies(export_strategies)
- # Mutable fields, using the setters.
- self.continuous_eval_predicate_fn = continuous_eval_predicate_fn
@property
def estimator(self):
@@ -161,17 +152,6 @@ class Experiment(object):
def eval_steps(self):
return self._eval_steps
- @property
- def continuous_eval_predicate_fn(self):
- return self._continuous_eval_predicate_fn
-
- @continuous_eval_predicate_fn.setter
- def continuous_eval_predicate_fn(self, value):
- if value is not None and not callable(value):
- raise ValueError(
- "`continuous_eval_predicate_fn` must be a callable, or None.")
- self._continuous_eval_predicate_fn = value
-
def _set_export_strategies(self, value):
if value is None:
self._export_strategies = []
@@ -288,12 +268,15 @@ class Experiment(object):
self._min_eval_frequency = self._local_eval_frequency
return self.train_and_evaluate()
+ # TODO(xiejw): Allow continuous_eval_predicate_fn to be passed via constructor
+ # once stopping all jobs is implemented.
def _continuous_eval(self,
input_fn,
name,
delay_secs,
throttle_delay_secs,
- evaluate_checkpoint_only_once=True):
+ evaluate_checkpoint_only_once=True,
+ continuous_eval_predicate_fn=None):
"""Run continuous eval.
Runs infinite eval on the evaluation data set. This function starts
@@ -311,7 +294,22 @@ class Experiment(object):
self._continuous_eval_throttle_secs.
evaluate_checkpoint_only_once: Whether to skip evaluation of checkpoints
that have already been evaluated. Default is `True`.
+ continuous_eval_predicate_fn: A predicate function determining whether to
+ continue eval after each iteration. `predicate_fn` takes the evaluation
+ results as arguments. At the beginning of evaluation, the passed eval
+ results will be None so it's expected that the predicate function
+ handles that gracefully. When `predicate_fn` is not specified,
+ continuous eval will run in an infinite loop.
+
+ Raises:
+ ValueError: if `continuous_eval_predicate_fn` is neither None nor
+ callable.
"""
+ if (continuous_eval_predicate_fn is not None and
+ not callable(continuous_eval_predicate_fn)):
+ raise ValueError(
+ "`continuous_eval_predicate_fn` must be a callable, or None.")
+
if delay_secs is None:
delay_secs = self._eval_delay_secs
if throttle_delay_secs is None:
@@ -324,8 +322,8 @@ class Experiment(object):
previous_path = None
eval_result = None
last_warning_time = 0
- while (not self.continuous_eval_predicate_fn or
- self.continuous_eval_predicate_fn(eval_result)):
+ while (not continuous_eval_predicate_fn or
+ continuous_eval_predicate_fn(eval_result)):
start = time.time()
error_msg = None
@@ -370,22 +368,26 @@ class Experiment(object):
def continuous_eval(self,
delay_secs=None,
throttle_delay_secs=None,
- evaluate_checkpoint_only_once=True):
+ evaluate_checkpoint_only_once=True,
+ continuous_eval_predicate_fn=None):
self._continuous_eval(
self._eval_input_fn,
name="continuous",
delay_secs=delay_secs,
throttle_delay_secs=throttle_delay_secs,
- evaluate_checkpoint_only_once=evaluate_checkpoint_only_once)
+ evaluate_checkpoint_only_once=evaluate_checkpoint_only_once,
+ continuous_eval_predicate_fn=continuous_eval_predicate_fn)
def continuous_eval_on_train_data(self,
delay_secs=None,
- throttle_delay_secs=None):
+ throttle_delay_secs=None,
+ continuous_eval_predicate_fn=None):
self._continuous_eval(
self._train_input_fn,
name="continuous_on_train_data",
delay_secs=delay_secs,
- throttle_delay_secs=throttle_delay_secs)
+ throttle_delay_secs=throttle_delay_secs,
+ continuous_eval_predicate_fn=continuous_eval_predicate_fn)
def train_and_evaluate(self):
"""Interleaves training and evaluation.
diff --git a/tensorflow/contrib/learn/python/learn/experiment_test.py b/tensorflow/contrib/learn/python/learn/experiment_test.py
index cf937443f0..398fc6e176 100644
--- a/tensorflow/contrib/learn/python/learn/experiment_test.py
+++ b/tensorflow/contrib/learn/python/learn/experiment_test.py
@@ -342,9 +342,9 @@ class ExperimentTest(test.TestCase):
eval_metrics='eval_metrics',
eval_hooks=[noop_hook],
eval_delay_secs=0,
- continuous_eval_throttle_secs=0,
- continuous_eval_predicate_fn=_predicate_fn)
- ex.continuous_eval(evaluate_checkpoint_only_once=False)
+ continuous_eval_throttle_secs=0)
+ ex.continuous_eval(evaluate_checkpoint_only_once=False,
+ continuous_eval_predicate_fn=_predicate_fn)
self.assertEqual(0, est.fit_count)
self.assertEqual(3, est.eval_count)
self.assertEqual([noop_hook], est.eval_hooks)