aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/model_fn_test.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-21 19:53:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-21 20:00:41 -0700
commit47c0bda0e7f736a9328aaf76aba7c8006e24556f (patch)
treead2a6ab71adddc0d07c7f306c270122937b6a5b0 /tensorflow/python/estimator/model_fn_test.py
parent1ab795b54274a26a92690f36eff65674fb500f91 (diff)
Move from deprecated self.test_session() to self.cached_session().
self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about: * the fact that the session may be reused. * the session is not closed even when doing a "with self.test_session()" statement. PiperOrigin-RevId: 209703607
Diffstat (limited to 'tensorflow/python/estimator/model_fn_test.py')
-rw-r--r--tensorflow/python/estimator/model_fn_test.py104
1 files changed, 52 insertions, 52 deletions
diff --git a/tensorflow/python/estimator/model_fn_test.py b/tensorflow/python/estimator/model_fn_test.py
index 08e41fd414..b6f1b16a22 100644
--- a/tensorflow/python/estimator/model_fn_test.py
+++ b/tensorflow/python/estimator/model_fn_test.py
@@ -48,7 +48,7 @@ class EstimatorSpecTrainTest(test.TestCase):
def testRequiredArgumentsSet(self):
"""Tests that no errors are raised when all required arguments are set."""
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.TRAIN,
loss=constant_op.constant(1.),
@@ -56,7 +56,7 @@ class EstimatorSpecTrainTest(test.TestCase):
def testAllArgumentsSet(self):
"""Tests that no errors are raised when all arguments are set."""
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant(1.)
predictions = {'loss': loss}
classes = constant_op.constant('hello')
@@ -77,7 +77,7 @@ class EstimatorSpecTrainTest(test.TestCase):
def testLossNumber(self):
"""Tests that error is raised when loss is a number (not Tensor)."""
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(TypeError, 'loss must be Tensor'):
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.TRAIN,
@@ -86,20 +86,20 @@ class EstimatorSpecTrainTest(test.TestCase):
def testLoss1DTensor(self):
"""Tests that no errors are raised when loss is 1D tensor."""
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.TRAIN,
loss=constant_op.constant([1.]),
train_op=control_flow_ops.no_op())
def testLossMissing(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(ValueError, 'Missing loss'):
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.TRAIN, train_op=control_flow_ops.no_op())
def testLossNotScalar(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(ValueError, 'Loss must be scalar'):
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.TRAIN,
@@ -107,7 +107,7 @@ class EstimatorSpecTrainTest(test.TestCase):
train_op=control_flow_ops.no_op())
def testLossSparseTensor(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = sparse_tensor.SparseTensor(
indices=[[0]],
values=[0.],
@@ -121,7 +121,7 @@ class EstimatorSpecTrainTest(test.TestCase):
def testLossFromDifferentGraph(self):
with ops.Graph().as_default():
loss = constant_op.constant(1.)
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(
ValueError, 'must be from the default graph'):
model_fn.EstimatorSpec(
@@ -130,13 +130,13 @@ class EstimatorSpecTrainTest(test.TestCase):
train_op=control_flow_ops.no_op())
def testTrainOpMissing(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(ValueError, 'Missing train_op'):
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.TRAIN, loss=constant_op.constant(1.))
def testTrainOpNotOperationAndTensor(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(TypeError,
'train_op must be Operation or Tensor'):
model_fn.EstimatorSpec(
@@ -147,7 +147,7 @@ class EstimatorSpecTrainTest(test.TestCase):
def testTrainOpFromDifferentGraph(self):
with ops.Graph().as_default():
train_op = control_flow_ops.no_op()
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(
ValueError, 'must be from the default graph'):
model_fn.EstimatorSpec(
@@ -156,7 +156,7 @@ class EstimatorSpecTrainTest(test.TestCase):
train_op=train_op)
def testTrainingChiefHookInvalid(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(
TypeError, 'All hooks must be SessionRunHook instances'):
model_fn.EstimatorSpec(
@@ -166,7 +166,7 @@ class EstimatorSpecTrainTest(test.TestCase):
training_chief_hooks=[_InvalidHook()])
def testTrainingHookInvalid(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(
TypeError, 'All hooks must be SessionRunHook instances'):
model_fn.EstimatorSpec(
@@ -176,7 +176,7 @@ class EstimatorSpecTrainTest(test.TestCase):
training_hooks=[_InvalidHook()])
def testScaffoldInvalid(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(
TypeError, r'scaffold must be tf\.train\.Scaffold'):
model_fn.EstimatorSpec(
@@ -186,7 +186,7 @@ class EstimatorSpecTrainTest(test.TestCase):
scaffold=_InvalidScaffold())
def testReturnDefaultScaffold(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
estimator_spec = model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.TRAIN,
loss=constant_op.constant(1.),
@@ -199,7 +199,7 @@ class EstimatorSpecEvalTest(test.TestCase):
def testRequiredArgumentsSet(self):
"""Tests that no errors are raised when all required arguments are set."""
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant(1.)
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.EVAL,
@@ -208,7 +208,7 @@ class EstimatorSpecEvalTest(test.TestCase):
def testAllArgumentsSet(self):
"""Tests that no errors are raised when all arguments are set."""
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant(1.)
predictions = {'loss': loss}
classes = constant_op.constant('hello')
@@ -227,7 +227,7 @@ class EstimatorSpecEvalTest(test.TestCase):
evaluation_hooks=[_FakeHook()])
def testEvaluationHookInvalid(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(
TypeError, 'All hooks must be SessionRunHook instances'):
model_fn.EstimatorSpec(
@@ -237,7 +237,7 @@ class EstimatorSpecEvalTest(test.TestCase):
def testTupleMetric(self):
"""Tests that no errors are raised when a metric is tuple-valued."""
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant(1.)
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.EVAL,
@@ -248,7 +248,7 @@ class EstimatorSpecEvalTest(test.TestCase):
def testLoss1DTensor(self):
"""Tests that no errors are raised when loss is 1D tensor."""
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant([1.])
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.EVAL,
@@ -257,7 +257,7 @@ class EstimatorSpecEvalTest(test.TestCase):
def testLossNumber(self):
"""Tests that error is raised when loss is a number (not Tensor)."""
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(TypeError, 'loss must be Tensor'):
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.EVAL,
@@ -265,14 +265,14 @@ class EstimatorSpecEvalTest(test.TestCase):
loss=1.)
def testLossMissing(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(ValueError, 'Missing loss'):
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.EVAL,
predictions={'loss': constant_op.constant(1.)})
def testLossNotScalar(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant([1., 2.])
with self.assertRaisesRegexp(ValueError, 'Loss must be scalar'):
model_fn.EstimatorSpec(
@@ -281,7 +281,7 @@ class EstimatorSpecEvalTest(test.TestCase):
loss=loss)
def testLossSparseTensor(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = sparse_tensor.SparseTensor(
indices=[[0]],
values=[0.],
@@ -296,7 +296,7 @@ class EstimatorSpecEvalTest(test.TestCase):
def testLossFromDifferentGraph(self):
with ops.Graph().as_default():
loss = constant_op.constant(1.)
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(
ValueError, 'must be from the default graph'):
model_fn.EstimatorSpec(
@@ -305,7 +305,7 @@ class EstimatorSpecEvalTest(test.TestCase):
loss=loss)
def testReplaceRaisesConstructorChecks(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant(1.)
spec = model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.EVAL, predictions={'loss': loss}, loss=loss)
@@ -313,7 +313,7 @@ class EstimatorSpecEvalTest(test.TestCase):
spec._replace(loss=constant_op.constant([1., 2.]))
def testReplaceDoesReplace(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant(1.)
spec = model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.EVAL, predictions={'loss': loss}, loss=loss)
@@ -321,7 +321,7 @@ class EstimatorSpecEvalTest(test.TestCase):
self.assertEqual(['m'], list(new_spec.predictions.keys()))
def testReplaceNotAllowModeChange(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant(1.)
spec = model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.EVAL, predictions={'loss': loss}, loss=loss)
@@ -331,13 +331,13 @@ class EstimatorSpecEvalTest(test.TestCase):
spec._replace(mode=model_fn.ModeKeys.TRAIN)
def testPredictionsMissingIsOkay(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.EVAL, loss=constant_op.constant(1.))
def testPredictionsTensor(self):
"""Tests that no error is raised when predictions is Tensor (not dict)."""
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant(1.)
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.EVAL,
@@ -345,7 +345,7 @@ class EstimatorSpecEvalTest(test.TestCase):
loss=loss)
def testPredictionsNumber(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(
TypeError, r'predictions\[number\] must be Tensor'):
model_fn.EstimatorSpec(
@@ -354,7 +354,7 @@ class EstimatorSpecEvalTest(test.TestCase):
loss=constant_op.constant(1.))
def testPredictionsSparseTensor(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
predictions = {
'sparse': sparse_tensor.SparseTensor(
indices=[[0]],
@@ -370,7 +370,7 @@ class EstimatorSpecEvalTest(test.TestCase):
def testPredictionsFromDifferentGraph(self):
with ops.Graph().as_default():
predictions = {'loss': constant_op.constant(1.)}
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(
ValueError, 'must be from the default graph'):
model_fn.EstimatorSpec(
@@ -379,7 +379,7 @@ class EstimatorSpecEvalTest(test.TestCase):
loss=constant_op.constant(1.))
def testEvalMetricOpsNoDict(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant(1.)
with self.assertRaisesRegexp(
TypeError, 'eval_metric_ops must be a dict'):
@@ -390,7 +390,7 @@ class EstimatorSpecEvalTest(test.TestCase):
eval_metric_ops=loss)
def testEvalMetricOpsNoTuple(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant(1.)
with self.assertRaisesRegexp(
TypeError,
@@ -403,7 +403,7 @@ class EstimatorSpecEvalTest(test.TestCase):
eval_metric_ops={'loss': loss})
def testEvalMetricOpsNoTensorOrOperation(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant(1.)
with self.assertRaisesRegexp(TypeError, 'must be Operation or Tensor'):
model_fn.EstimatorSpec(
@@ -413,7 +413,7 @@ class EstimatorSpecEvalTest(test.TestCase):
eval_metric_ops={'loss': ('NonTensor', loss)})
def testEvalMetricNestedNoTensorOrOperation(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant(1.)
with self.assertRaisesRegexp(TypeError, 'must be Operation or Tensor'):
model_fn.EstimatorSpec(
@@ -427,7 +427,7 @@ class EstimatorSpecEvalTest(test.TestCase):
with ops.Graph().as_default():
eval_metric_ops = {
'loss': (control_flow_ops.no_op(), constant_op.constant(1.))}
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant(1.)
with self.assertRaisesRegexp(
ValueError, 'must be from the default graph'):
@@ -443,14 +443,14 @@ class EstimatorSpecInferTest(test.TestCase):
def testRequiredArgumentsSet(self):
"""Tests that no errors are raised when all required arguments are set."""
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.PREDICT,
predictions={'loss': constant_op.constant(1.)})
def testAllArgumentsSet(self):
"""Tests that no errors are raised when all arguments are set."""
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant(1.)
predictions = {'loss': loss}
classes = constant_op.constant('hello')
@@ -470,7 +470,7 @@ class EstimatorSpecInferTest(test.TestCase):
prediction_hooks=[_FakeHook()])
def testPredictionHookInvalid(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(
TypeError, 'All hooks must be SessionRunHook instances'):
model_fn.EstimatorSpec(
@@ -479,25 +479,25 @@ class EstimatorSpecInferTest(test.TestCase):
prediction_hooks=[_InvalidHook()])
def testPredictionsMissing(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(ValueError, 'Missing predictions'):
model_fn.EstimatorSpec(mode=model_fn.ModeKeys.PREDICT)
def testPredictionsTensor(self):
"""Tests that no error is raised when predictions is Tensor (not dict)."""
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.PREDICT, predictions=constant_op.constant(1.))
def testPredictionsNumber(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(
TypeError, r'predictions\[number\] must be Tensor'):
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.PREDICT, predictions={'number': 1.})
def testPredictionsSparseTensor(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
predictions = {
'sparse': sparse_tensor.SparseTensor(
indices=[[0]],
@@ -509,7 +509,7 @@ class EstimatorSpecInferTest(test.TestCase):
mode=model_fn.ModeKeys.PREDICT, predictions=predictions)
def testExportOutputsNoDict(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
predictions = {'loss': constant_op.constant(1.)}
classes = constant_op.constant('hello')
with self.assertRaisesRegexp(
@@ -520,7 +520,7 @@ class EstimatorSpecInferTest(test.TestCase):
export_outputs=export_output.ClassificationOutput(classes=classes))
def testExportOutputsValueNotExportOutput(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
predictions = {'loss': constant_op.constant(1.)}
with self.assertRaisesRegexp(
TypeError,
@@ -533,7 +533,7 @@ class EstimatorSpecInferTest(test.TestCase):
export_outputs={'head_name': predictions})
def testExportOutputsSingleheadMissingDefault(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
predictions = {'loss': constant_op.constant(1.)}
output_1 = constant_op.constant([1.])
regression_output = export_output.RegressionOutput(value=output_1)
@@ -552,7 +552,7 @@ class EstimatorSpecInferTest(test.TestCase):
self.assertEqual(expected_export_outputs, estimator_spec.export_outputs)
def testExportOutputsMultiheadWithDefault(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
predictions = {'loss': constant_op.constant(1.)}
output_1 = constant_op.constant([1.])
output_2 = constant_op.constant(['2'])
@@ -571,7 +571,7 @@ class EstimatorSpecInferTest(test.TestCase):
self.assertEqual(export_outputs, estimator_spec.export_outputs)
def testExportOutputsMultiheadMissingDefault(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
predictions = {'loss': constant_op.constant(1.)}
output_1 = constant_op.constant([1.])
output_2 = constant_op.constant(['2'])
@@ -594,13 +594,13 @@ class EstimatorSpecInferTest(test.TestCase):
def testDefaultExportOutputCreated(self):
"""Ensure that a default PredictOutput is created for export."""
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
predictions = constant_op.constant(1.)
self._assertDefaultExportOutputForPredictions(predictions)
def testDefaultExportOutputCreatedDict(self):
"""Ensure that a default PredictOutput is created for export for dicts."""
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
predictions = {'loss': constant_op.constant(1.),
'score': constant_op.constant(10.)}
self._assertDefaultExportOutputForPredictions(predictions)