aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/learn/python/learn/experiment.py38
-rw-r--r--tensorflow/contrib/learn/python/learn/experiment_test.py80
2 files changed, 97 insertions, 21 deletions
diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py
index 6067513293..a8f8d995fe 100644
--- a/tensorflow/contrib/learn/python/learn/experiment.py
+++ b/tensorflow/contrib/learn/python/learn/experiment.py
@@ -78,7 +78,8 @@ class Experiment(object):
continuous_eval_throttle_secs=60,
min_eval_frequency=1,
delay_workers_by_global_step=False,
- export_strategies=None):
+ export_strategies=None,
+ train_steps_per_iteration=None):
"""Constructor for `Experiment`.
Creates an Experiment instance. None of the functions passed to this
@@ -117,6 +118,11 @@ class Experiment(object):
delay_workers_by_global_step: if `True` delays training workers
based on global step instead of time.
export_strategies: A list of `ExportStrategy`s, or a single one, or None.
+ train_steps_per_iteration: (applies only to continuous_train_and_eval).
+ Perform this many (integer) number of train steps for each
+ training-evaluation iteration. With a small value, the model will be
+ evaluated more frequently with more checkpoints saved. If `None`, will
+ use a default value (which is smaller than `train_steps` if provided).
Raises:
ValueError: if `estimator` does not implement Estimator interface,
@@ -155,6 +161,12 @@ class Experiment(object):
self._eval_hooks = eval_hooks[:] if eval_hooks else []
self._set_export_strategies(export_strategies)
+ self._train_steps_per_iteration = train_steps_per_iteration
+ if (self._train_steps_per_iteration is not None and
+ not isinstance(self._train_steps_per_iteration, int)):
+ raise ValueError(
+ "`train_steps_per_iteration` must be an integer.")
+
@property
def estimator(self):
return self._estimator
@@ -478,12 +490,11 @@ class Experiment(object):
@experimental
def continuous_train_and_eval(self,
- train_steps_per_iteration=1000,
continuous_eval_predicate_fn=None):
"""Interleaves training and evaluation.
- The frequency of evaluation is controlled by the
- `train_steps_per_iteration`. The model will be first trained for
+ The frequency of evaluation is controlled by the `train_steps_per_iteration`
+ (via constructor). The model will be first trained for
`train_steps_per_iteration`, and then be evaluated in turns.
This differs from `train_and_evaluate` as follows:
@@ -499,10 +510,6 @@ class Experiment(object):
is generated at the end of each small trainning iteration.
Args:
- train_steps_per_iteration: The (integer) number of train steps for
- each training-evaluation iteration. With a small
- `train_steps_per_iteration`, the model will be evaluated more frequently
- with more checkpoints saved.
continuous_eval_predicate_fn: A predicate function determining whether to
continue after each iteration. `predicate_fn` takes the evaluation
results as its arguments. At the beginning of evaluation, the passed
@@ -524,16 +531,15 @@ class Experiment(object):
raise ValueError(
"`continuous_eval_predicate_fn` must be a callable, or None.")
- if not isinstance(train_steps_per_iteration, int):
- raise ValueError(
- "`train_steps_per_iteration` must be an integer.")
-
eval_result = None
- # TODO(b/33295821): improve the way to determine the
- # train_steps_per_iteration.
- if self._train_steps and train_steps_per_iteration > self._train_steps:
- train_steps_per_iteration = self._train_steps
+ # Set the default value for train_steps_per_iteration, which will be
+ # overriden by other settings.
+ train_steps_per_iteration = 1000
+ if self._train_steps_per_iteration is not None:
+ train_steps_per_iteration = self._train_steps_per_iteration
+ elif self._train_steps is not None:
+ train_steps_per_iteration = int(self._train_steps / 10)
while (not continuous_eval_predicate_fn or
continuous_eval_predicate_fn(eval_result)):
diff --git a/tensorflow/contrib/learn/python/learn/experiment_test.py b/tensorflow/contrib/learn/python/learn/experiment_test.py
index abd1e3e66f..00ed062b0a 100644
--- a/tensorflow/contrib/learn/python/learn/experiment_test.py
+++ b/tensorflow/contrib/learn/python/learn/experiment_test.py
@@ -592,6 +592,76 @@ class ExperimentTest(test.TestCase):
self.assertEqual(0, est.eval_count)
self.assertEqual(1, est.export_count)
+ def test_continuous_train_and_eval_with_adapted_steps_per_iteration(self):
+ mock_estimator = test.mock.Mock(core_estimator.Estimator)
+ type(mock_estimator).model_dir = test.mock.PropertyMock(
+ return_value='test_dir')
+
+ total_steps = 100000000000000
+ ex = experiment.Experiment(
+ mock_estimator,
+ train_input_fn='train_input',
+ eval_input_fn='eval_input',
+ train_steps=total_steps)
+
+ def predicate_fn(eval_result):
+ # Allows the first invoke only.
+ return eval_result is None
+
+ ex.continuous_train_and_eval(continuous_eval_predicate_fn=predicate_fn)
+ mock_estimator.train.assert_called_once_with(
+ input_fn='train_input',
+ steps=int(total_steps/10),
+ max_steps=test.mock.ANY,
+ hooks=test.mock.ANY)
+
+ def test_continuous_train_and_eval_with_steps_per_iteration_from_user(self):
+ mock_estimator = test.mock.Mock(core_estimator.Estimator)
+ type(mock_estimator).model_dir = test.mock.PropertyMock(
+ return_value='test_dir')
+
+ total_steps = 100000000000000
+ ex = experiment.Experiment(
+ mock_estimator,
+ train_input_fn='train_input',
+ eval_input_fn='eval_input',
+ train_steps_per_iteration=1234,
+ train_steps=total_steps)
+
+ def predicate_fn(eval_result):
+ # Allows the first invoke only.
+ return eval_result is None
+
+ ex.continuous_train_and_eval(continuous_eval_predicate_fn=predicate_fn)
+ mock_estimator.train.assert_called_once_with(
+ input_fn='train_input',
+ steps=1234,
+ max_steps=test.mock.ANY,
+ hooks=test.mock.ANY)
+
+ def test_continuous_train_and_eval_with_default_steps_per_iteration(self):
+ mock_estimator = test.mock.Mock(core_estimator.Estimator)
+ type(mock_estimator).model_dir = test.mock.PropertyMock(
+ return_value='test_dir')
+
+ ex = experiment.Experiment(
+ mock_estimator,
+ train_input_fn='train_input',
+ eval_input_fn='eval_input',
+ train_steps_per_iteration=None,
+ train_steps=None)
+
+ def predicate_fn(eval_result):
+ # Allows the first invoke only.
+ return eval_result is None
+
+ ex.continuous_train_and_eval(continuous_eval_predicate_fn=predicate_fn)
+ mock_estimator.train.assert_called_once_with(
+ input_fn='train_input',
+ steps=1000,
+ max_steps=test.mock.ANY,
+ hooks=test.mock.ANY)
+
def test_continuous_train_and_eval_with_invalid_predicate_fn(self):
for est in self._estimators_for_tests():
ex = experiment.Experiment(
@@ -604,13 +674,13 @@ class ExperimentTest(test.TestCase):
def test_continuous_train_and_eval_with_invalid_train_steps_iterations(self):
for est in self._estimators_for_tests():
- ex = experiment.Experiment(
- est,
- train_input_fn='train_input',
- eval_input_fn='eval_input')
with self.assertRaisesRegexp(
ValueError, '`train_steps_per_iteration` must be an integer.'):
- ex.continuous_train_and_eval(train_steps_per_iteration='123')
+ experiment.Experiment(
+ est,
+ train_input_fn='train_input',
+ eval_input_fn='eval_input',
+ train_steps_per_iteration='123')
@test.mock.patch.object(server_lib, 'Server')
def test_run_std_server(self, mock_server):