diff options
author | Mustafa Ispir <ispir@google.com> | 2017-09-29 12:21:37 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-09-29 12:33:00 -0700 |
commit | 0fb83965a209eb03c1c090e3e540fd7c2c7d1025 (patch) | |
tree | c0a7de3131acbcfa1a2dd34531cea34d5f059cb5 /tensorflow/python/estimator/model_fn_test.py | |
parent | 76db7553ab2998116a62d6c242aa39373a362993 (diff) |
Users can call EstimatorSpec._replace since it's a namedtuple. Calling _replace does not run validations. Here we provide a new 'replace' which does the validations.
PiperOrigin-RevId: 170516477
Diffstat (limited to 'tensorflow/python/estimator/model_fn_test.py')
-rw-r--r-- | tensorflow/python/estimator/model_fn_test.py | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/tensorflow/python/estimator/model_fn_test.py b/tensorflow/python/estimator/model_fn_test.py index c41df41353..d67c4b7161 100644 --- a/tensorflow/python/estimator/model_fn_test.py +++ b/tensorflow/python/estimator/model_fn_test.py @@ -303,6 +303,32 @@ class EstimatorSpecEvalTest(test.TestCase): predictions={'prediction': constant_op.constant(1.)}, loss=loss) + def testReplaceRaisesConstructorChecks(self): + with ops.Graph().as_default(), self.test_session(): + loss = constant_op.constant(1.) + spec = model_fn.EstimatorSpec( + mode=model_fn.ModeKeys.EVAL, predictions={'loss': loss}, loss=loss) + with self.assertRaisesRegexp(ValueError, 'Loss must be scalar'): + spec._replace(loss=constant_op.constant([1., 2.])) + + def testReplaceDoesReplace(self): + with ops.Graph().as_default(), self.test_session(): + loss = constant_op.constant(1.) + spec = model_fn.EstimatorSpec( + mode=model_fn.ModeKeys.EVAL, predictions={'loss': loss}, loss=loss) + new_spec = spec._replace(predictions={'m': loss}) + self.assertEqual(['m'], list(new_spec.predictions.keys())) + + def testReplaceNotAllowModeChange(self): + with ops.Graph().as_default(), self.test_session(): + loss = constant_op.constant(1.) + spec = model_fn.EstimatorSpec( + mode=model_fn.ModeKeys.EVAL, predictions={'loss': loss}, loss=loss) + spec._replace(mode=model_fn.ModeKeys.EVAL) + with self.assertRaisesRegexp(ValueError, + 'mode of EstimatorSpec cannot be changed'): + spec._replace(mode=model_fn.ModeKeys.TRAIN) + def testPredictionsMissingIsOkay(self): with ops.Graph().as_default(), self.test_session(): model_fn.EstimatorSpec( |