aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Jianwei Xie <xiejw@google.com>2017-03-27 12:56:36 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-27 14:11:26 -0700
commit5c65a93856028ea9717c769bbc96338e1ea05c76 (patch)
tree490882fa09857f1505886a191e1733b518c8d642
parentb32a1b7294b9d8b4fe07d3b4ecde1a83a68d2ace (diff)
added more checks for experiment.
Change: 151366403
-rw-r--r--tensorflow/contrib/learn/python/learn/experiment.py4
-rw-r--r--tensorflow/contrib/learn/python/learn/experiment_test.py20
2 files changed, 24 insertions, 0 deletions
diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py
index a147915430..6067513293 100644
--- a/tensorflow/contrib/learn/python/learn/experiment.py
+++ b/tensorflow/contrib/learn/python/learn/experiment.py
@@ -524,6 +524,10 @@ class Experiment(object):
raise ValueError(
"`continuous_eval_predicate_fn` must be a callable, or None.")
+ if not isinstance(train_steps_per_iteration, int):
+ raise ValueError(
+ "`train_steps_per_iteration` must be an integer.")
+
eval_result = None
# TODO(b/33295821): improve the way to determine the
diff --git a/tensorflow/contrib/learn/python/learn/experiment_test.py b/tensorflow/contrib/learn/python/learn/experiment_test.py
index c9d981cb0f..abd1e3e66f 100644
--- a/tensorflow/contrib/learn/python/learn/experiment_test.py
+++ b/tensorflow/contrib/learn/python/learn/experiment_test.py
@@ -592,6 +592,26 @@ class ExperimentTest(test.TestCase):
self.assertEqual(0, est.eval_count)
self.assertEqual(1, est.export_count)
+ def test_continuous_train_and_eval_with_invalid_predicate_fn(self):
+ for est in self._estimators_for_tests():
+ ex = experiment.Experiment(
+ est,
+ train_input_fn='train_input',
+ eval_input_fn='eval_input')
+ with self.assertRaisesRegexp(
+ ValueError, '`continuous_eval_predicate_fn` must be a callable'):
+ ex.continuous_train_and_eval(continuous_eval_predicate_fn='fn')
+
+ def test_continuous_train_and_eval_with_invalid_train_steps_iterations(self):
+ for est in self._estimators_for_tests():
+ ex = experiment.Experiment(
+ est,
+ train_input_fn='train_input',
+ eval_input_fn='eval_input')
+ with self.assertRaisesRegexp(
+ ValueError, '`train_steps_per_iteration` must be an integer.'):
+ ex.continuous_train_and_eval(train_steps_per_iteration='123')
+
@test.mock.patch.object(server_lib, 'Server')
def test_run_std_server(self, mock_server):
# Arrange.