diff options
Diffstat (limited to 'tensorflow/python/estimator/model_fn_test.py')
-rw-r--r-- | tensorflow/python/estimator/model_fn_test.py | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/tensorflow/python/estimator/model_fn_test.py b/tensorflow/python/estimator/model_fn_test.py index c41df41353..d67c4b7161 100644 --- a/tensorflow/python/estimator/model_fn_test.py +++ b/tensorflow/python/estimator/model_fn_test.py @@ -303,6 +303,32 @@ class EstimatorSpecEvalTest(test.TestCase): predictions={'prediction': constant_op.constant(1.)}, loss=loss) + def testReplaceRaisesConstructorChecks(self): + with ops.Graph().as_default(), self.test_session(): + loss = constant_op.constant(1.) + spec = model_fn.EstimatorSpec( + mode=model_fn.ModeKeys.EVAL, predictions={'loss': loss}, loss=loss) + with self.assertRaisesRegexp(ValueError, 'Loss must be scalar'): + spec._replace(loss=constant_op.constant([1., 2.])) + + def testReplaceDoesReplace(self): + with ops.Graph().as_default(), self.test_session(): + loss = constant_op.constant(1.) + spec = model_fn.EstimatorSpec( + mode=model_fn.ModeKeys.EVAL, predictions={'loss': loss}, loss=loss) + new_spec = spec._replace(predictions={'m': loss}) + self.assertEqual(['m'], list(new_spec.predictions.keys())) + + def testReplaceNotAllowModeChange(self): + with ops.Graph().as_default(), self.test_session(): + loss = constant_op.constant(1.) + spec = model_fn.EstimatorSpec( + mode=model_fn.ModeKeys.EVAL, predictions={'loss': loss}, loss=loss) + spec._replace(mode=model_fn.ModeKeys.EVAL) + with self.assertRaisesRegexp(ValueError, + 'mode of EstimatorSpec cannot be changed'): + spec._replace(mode=model_fn.ModeKeys.TRAIN) + def testPredictionsMissingIsOkay(self): with ops.Graph().as_default(), self.test_session(): model_fn.EstimatorSpec( |