diff options
author | Jianwei Xie <xiejw@google.com> | 2017-03-27 12:56:36 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-27 14:11:26 -0700 |
commit | 5c65a93856028ea9717c769bbc96338e1ea05c76 (patch) | |
tree | 490882fa09857f1505886a191e1733b518c8d642 | |
parent | b32a1b7294b9d8b4fe07d3b4ecde1a83a68d2ace (diff) |
added more checks for experiment.
Change: 151366403
-rw-r--r-- | tensorflow/contrib/learn/python/learn/experiment.py | 4 | ||||
-rw-r--r-- | tensorflow/contrib/learn/python/learn/experiment_test.py | 20 |
2 files changed, 24 insertions, 0 deletions
diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py index a147915430..6067513293 100644 --- a/tensorflow/contrib/learn/python/learn/experiment.py +++ b/tensorflow/contrib/learn/python/learn/experiment.py @@ -524,6 +524,10 @@ class Experiment(object): raise ValueError( "`continuous_eval_predicate_fn` must be a callable, or None.") + if not isinstance(train_steps_per_iteration, int): + raise ValueError( + "`train_steps_per_iteration` must be an integer.") + eval_result = None # TODO(b/33295821): improve the way to determine the diff --git a/tensorflow/contrib/learn/python/learn/experiment_test.py b/tensorflow/contrib/learn/python/learn/experiment_test.py index c9d981cb0f..abd1e3e66f 100644 --- a/tensorflow/contrib/learn/python/learn/experiment_test.py +++ b/tensorflow/contrib/learn/python/learn/experiment_test.py @@ -592,6 +592,26 @@ class ExperimentTest(test.TestCase): self.assertEqual(0, est.eval_count) self.assertEqual(1, est.export_count) + def test_continuous_train_and_eval_with_invalid_predicate_fn(self): + for est in self._estimators_for_tests(): + ex = experiment.Experiment( + est, + train_input_fn='train_input', + eval_input_fn='eval_input') + with self.assertRaisesRegexp( + ValueError, '`continuous_eval_predicate_fn` must be a callable'): + ex.continuous_train_and_eval(continuous_eval_predicate_fn='fn') + + def test_continuous_train_and_eval_with_invalid_train_steps_iterations(self): + for est in self._estimators_for_tests(): + ex = experiment.Experiment( + est, + train_input_fn='train_input', + eval_input_fn='eval_input') + with self.assertRaisesRegexp( + ValueError, '`train_steps_per_iteration` must be an integer.'): + ex.continuous_train_and_eval(train_steps_per_iteration='123') + @test.mock.patch.object(server_lib, 'Server') def test_run_std_server(self, mock_server): # Arrange. |