aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/model_fn_test.py
diff options
context:
space:
mode:
authorGravatar Mustafa Ispir <ispir@google.com>2017-09-29 12:21:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-29 12:33:00 -0700
commit0fb83965a209eb03c1c090e3e540fd7c2c7d1025 (patch)
treec0a7de3131acbcfa1a2dd34531cea34d5f059cb5 /tensorflow/python/estimator/model_fn_test.py
parent76db7553ab2998116a62d6c242aa39373a362993 (diff)
Users can call EstimatorSpec._replace since it's a namedtuple. Calling _replace does not run validations. Here we provide a new 'replace' which does the validations.
PiperOrigin-RevId: 170516477
Diffstat (limited to 'tensorflow/python/estimator/model_fn_test.py')
-rw-r--r--tensorflow/python/estimator/model_fn_test.py26
1 files changed, 26 insertions, 0 deletions
diff --git a/tensorflow/python/estimator/model_fn_test.py b/tensorflow/python/estimator/model_fn_test.py
index c41df41353..d67c4b7161 100644
--- a/tensorflow/python/estimator/model_fn_test.py
+++ b/tensorflow/python/estimator/model_fn_test.py
@@ -303,6 +303,32 @@ class EstimatorSpecEvalTest(test.TestCase):
predictions={'prediction': constant_op.constant(1.)},
loss=loss)
+ def testReplaceRaisesConstructorChecks(self):
+ with ops.Graph().as_default(), self.test_session():
+ loss = constant_op.constant(1.)
+ spec = model_fn.EstimatorSpec(
+ mode=model_fn.ModeKeys.EVAL, predictions={'loss': loss}, loss=loss)
+ with self.assertRaisesRegexp(ValueError, 'Loss must be scalar'):
+ spec._replace(loss=constant_op.constant([1., 2.]))
+
+ def testReplaceDoesReplace(self):
+ with ops.Graph().as_default(), self.test_session():
+ loss = constant_op.constant(1.)
+ spec = model_fn.EstimatorSpec(
+ mode=model_fn.ModeKeys.EVAL, predictions={'loss': loss}, loss=loss)
+ new_spec = spec._replace(predictions={'m': loss})
+ self.assertEqual(['m'], list(new_spec.predictions.keys()))
+
+ def testReplaceNotAllowModeChange(self):
+ with ops.Graph().as_default(), self.test_session():
+ loss = constant_op.constant(1.)
+ spec = model_fn.EstimatorSpec(
+ mode=model_fn.ModeKeys.EVAL, predictions={'loss': loss}, loss=loss)
+ spec._replace(mode=model_fn.ModeKeys.EVAL)
+ with self.assertRaisesRegexp(ValueError,
+ 'mode of EstimatorSpec cannot be changed'):
+ spec._replace(mode=model_fn.ModeKeys.TRAIN)
+
def testPredictionsMissingIsOkay(self):
with ops.Graph().as_default(), self.test_session():
model_fn.EstimatorSpec(