diff options
-rw-r--r-- | tensorflow/contrib/learn/python/learn/experiment.py | 38 | ||||
-rw-r--r-- | tensorflow/contrib/learn/python/learn/experiment_test.py | 80 |
2 files changed, 97 insertions, 21 deletions
diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py index 6067513293..a8f8d995fe 100644 --- a/tensorflow/contrib/learn/python/learn/experiment.py +++ b/tensorflow/contrib/learn/python/learn/experiment.py @@ -78,7 +78,8 @@ class Experiment(object): continuous_eval_throttle_secs=60, min_eval_frequency=1, delay_workers_by_global_step=False, - export_strategies=None): + export_strategies=None, + train_steps_per_iteration=None): """Constructor for `Experiment`. Creates an Experiment instance. None of the functions passed to this @@ -117,6 +118,11 @@ class Experiment(object): delay_workers_by_global_step: if `True` delays training workers based on global step instead of time. export_strategies: A list of `ExportStrategy`s, or a single one, or None. + train_steps_per_iteration: (applies only to continuous_train_and_eval). + Perform this many (integer) number of train steps for each + training-evaluation iteration. With a small value, the model will be + evaluated more frequently with more checkpoints saved. If `None`, will + use a default value (which is smaller than `train_steps` if provided). Raises: ValueError: if `estimator` does not implement Estimator interface, @@ -155,6 +161,12 @@ class Experiment(object): self._eval_hooks = eval_hooks[:] if eval_hooks else [] self._set_export_strategies(export_strategies) + self._train_steps_per_iteration = train_steps_per_iteration + if (self._train_steps_per_iteration is not None and + not isinstance(self._train_steps_per_iteration, int)): + raise ValueError( + "`train_steps_per_iteration` must be an integer.") + @property def estimator(self): return self._estimator @@ -478,12 +490,11 @@ class Experiment(object): @experimental def continuous_train_and_eval(self, - train_steps_per_iteration=1000, continuous_eval_predicate_fn=None): """Interleaves training and evaluation. - The frequency of evaluation is controlled by the - `train_steps_per_iteration`. The model will be first trained for + The frequency of evaluation is controlled by the `train_steps_per_iteration` + (via constructor). The model will be first trained for `train_steps_per_iteration`, and then be evaluated in turns. This differs from `train_and_evaluate` as follows: @@ -499,10 +510,6 @@ class Experiment(object): is generated at the end of each small trainning iteration. Args: - train_steps_per_iteration: The (integer) number of train steps for - each training-evaluation iteration. With a small - `train_steps_per_iteration`, the model will be evaluated more frequently - with more checkpoints saved. continuous_eval_predicate_fn: A predicate function determining whether to continue after each iteration. `predicate_fn` takes the evaluation results as its arguments. At the beginning of evaluation, the passed @@ -524,16 +531,15 @@ class Experiment(object): raise ValueError( "`continuous_eval_predicate_fn` must be a callable, or None.") - if not isinstance(train_steps_per_iteration, int): - raise ValueError( - "`train_steps_per_iteration` must be an integer.") - eval_result = None - # TODO(b/33295821): improve the way to determine the - # train_steps_per_iteration. - if self._train_steps and train_steps_per_iteration > self._train_steps: - train_steps_per_iteration = self._train_steps + # Set the default value for train_steps_per_iteration, which will be + # overriden by other settings. + train_steps_per_iteration = 1000 + if self._train_steps_per_iteration is not None: + train_steps_per_iteration = self._train_steps_per_iteration + elif self._train_steps is not None: + train_steps_per_iteration = int(self._train_steps / 10) while (not continuous_eval_predicate_fn or continuous_eval_predicate_fn(eval_result)): diff --git a/tensorflow/contrib/learn/python/learn/experiment_test.py b/tensorflow/contrib/learn/python/learn/experiment_test.py index abd1e3e66f..00ed062b0a 100644 --- a/tensorflow/contrib/learn/python/learn/experiment_test.py +++ b/tensorflow/contrib/learn/python/learn/experiment_test.py @@ -592,6 +592,76 @@ class ExperimentTest(test.TestCase): self.assertEqual(0, est.eval_count) self.assertEqual(1, est.export_count) + def test_continuous_train_and_eval_with_adapted_steps_per_iteration(self): + mock_estimator = test.mock.Mock(core_estimator.Estimator) + type(mock_estimator).model_dir = test.mock.PropertyMock( + return_value='test_dir') + + total_steps = 100000000000000 + ex = experiment.Experiment( + mock_estimator, + train_input_fn='train_input', + eval_input_fn='eval_input', + train_steps=total_steps) + + def predicate_fn(eval_result): + # Allows the first invoke only. + return eval_result is None + + ex.continuous_train_and_eval(continuous_eval_predicate_fn=predicate_fn) + mock_estimator.train.assert_called_once_with( + input_fn='train_input', + steps=int(total_steps/10), + max_steps=test.mock.ANY, + hooks=test.mock.ANY) + + def test_continuous_train_and_eval_with_steps_per_iteration_from_user(self): + mock_estimator = test.mock.Mock(core_estimator.Estimator) + type(mock_estimator).model_dir = test.mock.PropertyMock( + return_value='test_dir') + + total_steps = 100000000000000 + ex = experiment.Experiment( + mock_estimator, + train_input_fn='train_input', + eval_input_fn='eval_input', + train_steps_per_iteration=1234, + train_steps=total_steps) + + def predicate_fn(eval_result): + # Allows the first invoke only. + return eval_result is None + + ex.continuous_train_and_eval(continuous_eval_predicate_fn=predicate_fn) + mock_estimator.train.assert_called_once_with( + input_fn='train_input', + steps=1234, + max_steps=test.mock.ANY, + hooks=test.mock.ANY) + + def test_continuous_train_and_eval_with_default_steps_per_iteration(self): + mock_estimator = test.mock.Mock(core_estimator.Estimator) + type(mock_estimator).model_dir = test.mock.PropertyMock( + return_value='test_dir') + + ex = experiment.Experiment( + mock_estimator, + train_input_fn='train_input', + eval_input_fn='eval_input', + train_steps_per_iteration=None, + train_steps=None) + + def predicate_fn(eval_result): + # Allows the first invoke only. + return eval_result is None + + ex.continuous_train_and_eval(continuous_eval_predicate_fn=predicate_fn) + mock_estimator.train.assert_called_once_with( + input_fn='train_input', + steps=1000, + max_steps=test.mock.ANY, + hooks=test.mock.ANY) + def test_continuous_train_and_eval_with_invalid_predicate_fn(self): for est in self._estimators_for_tests(): ex = experiment.Experiment( @@ -604,13 +674,13 @@ class ExperimentTest(test.TestCase): def test_continuous_train_and_eval_with_invalid_train_steps_iterations(self): for est in self._estimators_for_tests(): - ex = experiment.Experiment( - est, - train_input_fn='train_input', - eval_input_fn='eval_input') with self.assertRaisesRegexp( ValueError, '`train_steps_per_iteration` must be an integer.'): - ex.continuous_train_and_eval(train_steps_per_iteration='123') + experiment.Experiment( + est, + train_input_fn='train_input', + eval_input_fn='eval_input', + train_steps_per_iteration='123') @test.mock.patch.object(server_lib, 'Server') def test_run_std_server(self, mock_server): |