aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/model_fn_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/estimator/model_fn_test.py')
-rw-r--r--tensorflow/python/estimator/model_fn_test.py26
1 files changed, 26 insertions, 0 deletions
diff --git a/tensorflow/python/estimator/model_fn_test.py b/tensorflow/python/estimator/model_fn_test.py
index c41df41353..d67c4b7161 100644
--- a/tensorflow/python/estimator/model_fn_test.py
+++ b/tensorflow/python/estimator/model_fn_test.py
@@ -303,6 +303,32 @@ class EstimatorSpecEvalTest(test.TestCase):
predictions={'prediction': constant_op.constant(1.)},
loss=loss)
+ def testReplaceRaisesConstructorChecks(self):
+ with ops.Graph().as_default(), self.test_session():
+ loss = constant_op.constant(1.)
+ spec = model_fn.EstimatorSpec(
+ mode=model_fn.ModeKeys.EVAL, predictions={'loss': loss}, loss=loss)
+ with self.assertRaisesRegexp(ValueError, 'Loss must be scalar'):
+ spec._replace(loss=constant_op.constant([1., 2.]))
+
+ def testReplaceDoesReplace(self):
+ with ops.Graph().as_default(), self.test_session():
+ loss = constant_op.constant(1.)
+ spec = model_fn.EstimatorSpec(
+ mode=model_fn.ModeKeys.EVAL, predictions={'loss': loss}, loss=loss)
+ new_spec = spec._replace(predictions={'m': loss})
+ self.assertEqual(['m'], list(new_spec.predictions.keys()))
+
+ def testReplaceNotAllowModeChange(self):
+ with ops.Graph().as_default(), self.test_session():
+ loss = constant_op.constant(1.)
+ spec = model_fn.EstimatorSpec(
+ mode=model_fn.ModeKeys.EVAL, predictions={'loss': loss}, loss=loss)
+ spec._replace(mode=model_fn.ModeKeys.EVAL)
+ with self.assertRaisesRegexp(ValueError,
+ 'mode of EstimatorSpec cannot be changed'):
+ spec._replace(mode=model_fn.ModeKeys.TRAIN)
+
def testPredictionsMissingIsOkay(self):
with ops.Graph().as_default(), self.test_session():
model_fn.EstimatorSpec(