diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-21 19:53:43 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-21 20:00:41 -0700 |
commit | 47c0bda0e7f736a9328aaf76aba7c8006e24556f (patch) | |
tree | ad2a6ab71adddc0d07c7f306c270122937b6a5b0 /tensorflow/python/estimator/model_fn_test.py | |
parent | 1ab795b54274a26a92690f36eff65674fb500f91 (diff) |
Move from deprecated self.test_session() to self.cached_session().
self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about:
* the fact that the session may be reused.
* the session is not closed even when doing a "with self.test_session()" statement.
PiperOrigin-RevId: 209703607
Diffstat (limited to 'tensorflow/python/estimator/model_fn_test.py')
-rw-r--r-- | tensorflow/python/estimator/model_fn_test.py | 104 |
1 files changed, 52 insertions, 52 deletions
diff --git a/tensorflow/python/estimator/model_fn_test.py b/tensorflow/python/estimator/model_fn_test.py index 08e41fd414..b6f1b16a22 100644 --- a/tensorflow/python/estimator/model_fn_test.py +++ b/tensorflow/python/estimator/model_fn_test.py @@ -48,7 +48,7 @@ class EstimatorSpecTrainTest(test.TestCase): def testRequiredArgumentsSet(self): """Tests that no errors are raised when all required arguments are set.""" - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): model_fn.EstimatorSpec( mode=model_fn.ModeKeys.TRAIN, loss=constant_op.constant(1.), @@ -56,7 +56,7 @@ class EstimatorSpecTrainTest(test.TestCase): def testAllArgumentsSet(self): """Tests that no errors are raised when all arguments are set.""" - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): loss = constant_op.constant(1.) predictions = {'loss': loss} classes = constant_op.constant('hello') @@ -77,7 +77,7 @@ class EstimatorSpecTrainTest(test.TestCase): def testLossNumber(self): """Tests that error is raised when loss is a number (not Tensor).""" - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): with self.assertRaisesRegexp(TypeError, 'loss must be Tensor'): model_fn.EstimatorSpec( mode=model_fn.ModeKeys.TRAIN, @@ -86,20 +86,20 @@ class EstimatorSpecTrainTest(test.TestCase): def testLoss1DTensor(self): """Tests that no errors are raised when loss is 1D tensor.""" - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): model_fn.EstimatorSpec( mode=model_fn.ModeKeys.TRAIN, loss=constant_op.constant([1.]), train_op=control_flow_ops.no_op()) def testLossMissing(self): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): with self.assertRaisesRegexp(ValueError, 'Missing loss'): model_fn.EstimatorSpec( mode=model_fn.ModeKeys.TRAIN, train_op=control_flow_ops.no_op()) def testLossNotScalar(self): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): with self.assertRaisesRegexp(ValueError, 'Loss must be scalar'): model_fn.EstimatorSpec( mode=model_fn.ModeKeys.TRAIN, @@ -107,7 +107,7 @@ class EstimatorSpecTrainTest(test.TestCase): train_op=control_flow_ops.no_op()) def testLossSparseTensor(self): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): loss = sparse_tensor.SparseTensor( indices=[[0]], values=[0.], @@ -121,7 +121,7 @@ class EstimatorSpecTrainTest(test.TestCase): def testLossFromDifferentGraph(self): with ops.Graph().as_default(): loss = constant_op.constant(1.) - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): with self.assertRaisesRegexp( ValueError, 'must be from the default graph'): model_fn.EstimatorSpec( @@ -130,13 +130,13 @@ class EstimatorSpecTrainTest(test.TestCase): train_op=control_flow_ops.no_op()) def testTrainOpMissing(self): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): with self.assertRaisesRegexp(ValueError, 'Missing train_op'): model_fn.EstimatorSpec( mode=model_fn.ModeKeys.TRAIN, loss=constant_op.constant(1.)) def testTrainOpNotOperationAndTensor(self): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): with self.assertRaisesRegexp(TypeError, 'train_op must be Operation or Tensor'): model_fn.EstimatorSpec( @@ -147,7 +147,7 @@ class EstimatorSpecTrainTest(test.TestCase): def testTrainOpFromDifferentGraph(self): with ops.Graph().as_default(): train_op = control_flow_ops.no_op() - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): with self.assertRaisesRegexp( ValueError, 'must be from the default graph'): model_fn.EstimatorSpec( @@ -156,7 +156,7 @@ class EstimatorSpecTrainTest(test.TestCase): train_op=train_op) def testTrainingChiefHookInvalid(self): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): with self.assertRaisesRegexp( TypeError, 'All hooks must be SessionRunHook instances'): model_fn.EstimatorSpec( @@ -166,7 +166,7 @@ class EstimatorSpecTrainTest(test.TestCase): training_chief_hooks=[_InvalidHook()]) def testTrainingHookInvalid(self): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): with self.assertRaisesRegexp( TypeError, 'All hooks must be SessionRunHook instances'): model_fn.EstimatorSpec( @@ -176,7 +176,7 @@ class EstimatorSpecTrainTest(test.TestCase): training_hooks=[_InvalidHook()]) def testScaffoldInvalid(self): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): with self.assertRaisesRegexp( TypeError, r'scaffold must be tf\.train\.Scaffold'): model_fn.EstimatorSpec( @@ -186,7 +186,7 @@ class EstimatorSpecTrainTest(test.TestCase): scaffold=_InvalidScaffold()) def testReturnDefaultScaffold(self): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): estimator_spec = model_fn.EstimatorSpec( mode=model_fn.ModeKeys.TRAIN, loss=constant_op.constant(1.), @@ -199,7 +199,7 @@ class EstimatorSpecEvalTest(test.TestCase): def testRequiredArgumentsSet(self): """Tests that no errors are raised when all required arguments are set.""" - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): loss = constant_op.constant(1.) model_fn.EstimatorSpec( mode=model_fn.ModeKeys.EVAL, @@ -208,7 +208,7 @@ class EstimatorSpecEvalTest(test.TestCase): def testAllArgumentsSet(self): """Tests that no errors are raised when all arguments are set.""" - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): loss = constant_op.constant(1.) predictions = {'loss': loss} classes = constant_op.constant('hello') @@ -227,7 +227,7 @@ class EstimatorSpecEvalTest(test.TestCase): evaluation_hooks=[_FakeHook()]) def testEvaluationHookInvalid(self): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): with self.assertRaisesRegexp( TypeError, 'All hooks must be SessionRunHook instances'): model_fn.EstimatorSpec( @@ -237,7 +237,7 @@ class EstimatorSpecEvalTest(test.TestCase): def testTupleMetric(self): """Tests that no errors are raised when a metric is tuple-valued.""" - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): loss = constant_op.constant(1.) model_fn.EstimatorSpec( mode=model_fn.ModeKeys.EVAL, @@ -248,7 +248,7 @@ class EstimatorSpecEvalTest(test.TestCase): def testLoss1DTensor(self): """Tests that no errors are raised when loss is 1D tensor.""" - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): loss = constant_op.constant([1.]) model_fn.EstimatorSpec( mode=model_fn.ModeKeys.EVAL, @@ -257,7 +257,7 @@ class EstimatorSpecEvalTest(test.TestCase): def testLossNumber(self): """Tests that error is raised when loss is a number (not Tensor).""" - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): with self.assertRaisesRegexp(TypeError, 'loss must be Tensor'): model_fn.EstimatorSpec( mode=model_fn.ModeKeys.EVAL, @@ -265,14 +265,14 @@ class EstimatorSpecEvalTest(test.TestCase): loss=1.) def testLossMissing(self): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): with self.assertRaisesRegexp(ValueError, 'Missing loss'): model_fn.EstimatorSpec( mode=model_fn.ModeKeys.EVAL, predictions={'loss': constant_op.constant(1.)}) def testLossNotScalar(self): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): loss = constant_op.constant([1., 2.]) with self.assertRaisesRegexp(ValueError, 'Loss must be scalar'): model_fn.EstimatorSpec( @@ -281,7 +281,7 @@ class EstimatorSpecEvalTest(test.TestCase): loss=loss) def testLossSparseTensor(self): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): loss = sparse_tensor.SparseTensor( indices=[[0]], values=[0.], @@ -296,7 +296,7 @@ class EstimatorSpecEvalTest(test.TestCase): def testLossFromDifferentGraph(self): with ops.Graph().as_default(): loss = constant_op.constant(1.) - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): with self.assertRaisesRegexp( ValueError, 'must be from the default graph'): model_fn.EstimatorSpec( @@ -305,7 +305,7 @@ class EstimatorSpecEvalTest(test.TestCase): loss=loss) def testReplaceRaisesConstructorChecks(self): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): loss = constant_op.constant(1.) spec = model_fn.EstimatorSpec( mode=model_fn.ModeKeys.EVAL, predictions={'loss': loss}, loss=loss) @@ -313,7 +313,7 @@ class EstimatorSpecEvalTest(test.TestCase): spec._replace(loss=constant_op.constant([1., 2.])) def testReplaceDoesReplace(self): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): loss = constant_op.constant(1.) spec = model_fn.EstimatorSpec( mode=model_fn.ModeKeys.EVAL, predictions={'loss': loss}, loss=loss) @@ -321,7 +321,7 @@ class EstimatorSpecEvalTest(test.TestCase): self.assertEqual(['m'], list(new_spec.predictions.keys())) def testReplaceNotAllowModeChange(self): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): loss = constant_op.constant(1.) spec = model_fn.EstimatorSpec( mode=model_fn.ModeKeys.EVAL, predictions={'loss': loss}, loss=loss) @@ -331,13 +331,13 @@ class EstimatorSpecEvalTest(test.TestCase): spec._replace(mode=model_fn.ModeKeys.TRAIN) def testPredictionsMissingIsOkay(self): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): model_fn.EstimatorSpec( mode=model_fn.ModeKeys.EVAL, loss=constant_op.constant(1.)) def testPredictionsTensor(self): """Tests that no error is raised when predictions is Tensor (not dict).""" - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): loss = constant_op.constant(1.) model_fn.EstimatorSpec( mode=model_fn.ModeKeys.EVAL, @@ -345,7 +345,7 @@ class EstimatorSpecEvalTest(test.TestCase): loss=loss) def testPredictionsNumber(self): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): with self.assertRaisesRegexp( TypeError, r'predictions\[number\] must be Tensor'): model_fn.EstimatorSpec( @@ -354,7 +354,7 @@ class EstimatorSpecEvalTest(test.TestCase): loss=constant_op.constant(1.)) def testPredictionsSparseTensor(self): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): predictions = { 'sparse': sparse_tensor.SparseTensor( indices=[[0]], @@ -370,7 +370,7 @@ class EstimatorSpecEvalTest(test.TestCase): def testPredictionsFromDifferentGraph(self): with ops.Graph().as_default(): predictions = {'loss': constant_op.constant(1.)} - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): with self.assertRaisesRegexp( ValueError, 'must be from the default graph'): model_fn.EstimatorSpec( @@ -379,7 +379,7 @@ class EstimatorSpecEvalTest(test.TestCase): loss=constant_op.constant(1.)) def testEvalMetricOpsNoDict(self): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): loss = constant_op.constant(1.) with self.assertRaisesRegexp( TypeError, 'eval_metric_ops must be a dict'): @@ -390,7 +390,7 @@ class EstimatorSpecEvalTest(test.TestCase): eval_metric_ops=loss) def testEvalMetricOpsNoTuple(self): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): loss = constant_op.constant(1.) with self.assertRaisesRegexp( TypeError, @@ -403,7 +403,7 @@ class EstimatorSpecEvalTest(test.TestCase): eval_metric_ops={'loss': loss}) def testEvalMetricOpsNoTensorOrOperation(self): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): loss = constant_op.constant(1.) with self.assertRaisesRegexp(TypeError, 'must be Operation or Tensor'): model_fn.EstimatorSpec( @@ -413,7 +413,7 @@ class EstimatorSpecEvalTest(test.TestCase): eval_metric_ops={'loss': ('NonTensor', loss)}) def testEvalMetricNestedNoTensorOrOperation(self): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): loss = constant_op.constant(1.) with self.assertRaisesRegexp(TypeError, 'must be Operation or Tensor'): model_fn.EstimatorSpec( @@ -427,7 +427,7 @@ class EstimatorSpecEvalTest(test.TestCase): with ops.Graph().as_default(): eval_metric_ops = { 'loss': (control_flow_ops.no_op(), constant_op.constant(1.))} - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): loss = constant_op.constant(1.) with self.assertRaisesRegexp( ValueError, 'must be from the default graph'): @@ -443,14 +443,14 @@ class EstimatorSpecInferTest(test.TestCase): def testRequiredArgumentsSet(self): """Tests that no errors are raised when all required arguments are set.""" - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): model_fn.EstimatorSpec( mode=model_fn.ModeKeys.PREDICT, predictions={'loss': constant_op.constant(1.)}) def testAllArgumentsSet(self): """Tests that no errors are raised when all arguments are set.""" - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): loss = constant_op.constant(1.) predictions = {'loss': loss} classes = constant_op.constant('hello') @@ -470,7 +470,7 @@ class EstimatorSpecInferTest(test.TestCase): prediction_hooks=[_FakeHook()]) def testPredictionHookInvalid(self): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): with self.assertRaisesRegexp( TypeError, 'All hooks must be SessionRunHook instances'): model_fn.EstimatorSpec( @@ -479,25 +479,25 @@ class EstimatorSpecInferTest(test.TestCase): prediction_hooks=[_InvalidHook()]) def testPredictionsMissing(self): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): with self.assertRaisesRegexp(ValueError, 'Missing predictions'): model_fn.EstimatorSpec(mode=model_fn.ModeKeys.PREDICT) def testPredictionsTensor(self): """Tests that no error is raised when predictions is Tensor (not dict).""" - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): model_fn.EstimatorSpec( mode=model_fn.ModeKeys.PREDICT, predictions=constant_op.constant(1.)) def testPredictionsNumber(self): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): with self.assertRaisesRegexp( TypeError, r'predictions\[number\] must be Tensor'): model_fn.EstimatorSpec( mode=model_fn.ModeKeys.PREDICT, predictions={'number': 1.}) def testPredictionsSparseTensor(self): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): predictions = { 'sparse': sparse_tensor.SparseTensor( indices=[[0]], @@ -509,7 +509,7 @@ class EstimatorSpecInferTest(test.TestCase): mode=model_fn.ModeKeys.PREDICT, predictions=predictions) def testExportOutputsNoDict(self): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): predictions = {'loss': constant_op.constant(1.)} classes = constant_op.constant('hello') with self.assertRaisesRegexp( @@ -520,7 +520,7 @@ class EstimatorSpecInferTest(test.TestCase): export_outputs=export_output.ClassificationOutput(classes=classes)) def testExportOutputsValueNotExportOutput(self): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): predictions = {'loss': constant_op.constant(1.)} with self.assertRaisesRegexp( TypeError, @@ -533,7 +533,7 @@ class EstimatorSpecInferTest(test.TestCase): export_outputs={'head_name': predictions}) def testExportOutputsSingleheadMissingDefault(self): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): predictions = {'loss': constant_op.constant(1.)} output_1 = constant_op.constant([1.]) regression_output = export_output.RegressionOutput(value=output_1) @@ -552,7 +552,7 @@ class EstimatorSpecInferTest(test.TestCase): self.assertEqual(expected_export_outputs, estimator_spec.export_outputs) def testExportOutputsMultiheadWithDefault(self): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): predictions = {'loss': constant_op.constant(1.)} output_1 = constant_op.constant([1.]) output_2 = constant_op.constant(['2']) @@ -571,7 +571,7 @@ class EstimatorSpecInferTest(test.TestCase): self.assertEqual(export_outputs, estimator_spec.export_outputs) def testExportOutputsMultiheadMissingDefault(self): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): predictions = {'loss': constant_op.constant(1.)} output_1 = constant_op.constant([1.]) output_2 = constant_op.constant(['2']) @@ -594,13 +594,13 @@ class EstimatorSpecInferTest(test.TestCase): def testDefaultExportOutputCreated(self): """Ensure that a default PredictOutput is created for export.""" - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): predictions = constant_op.constant(1.) self._assertDefaultExportOutputForPredictions(predictions) def testDefaultExportOutputCreatedDict(self): """Ensure that a default PredictOutput is created for export for dicts.""" - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): predictions = {'loss': constant_op.constant(1.), 'score': constant_op.constant(10.)} self._assertDefaultExportOutputForPredictions(predictions) |