aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py')
-rw-r--r--tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py20
1 files changed, 10 insertions, 10 deletions
diff --git a/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py
index d55240297a..ea022820e4 100644
--- a/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py
+++ b/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py
@@ -32,7 +32,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testGrowWithEmptyEnsemble(self):
"""Test growing an empty ensemble."""
- with self.test_session() as session:
+ with self.cached_session() as session:
# Create empty ensemble.
tree_ensemble = boosted_trees_ops.TreeEnsemble('ensemble')
tree_ensemble_handle = tree_ensemble.resource_handle
@@ -141,7 +141,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testBiasCenteringOnEmptyEnsemble(self):
"""Test growing with bias centering on an empty ensemble."""
- with self.test_session() as session:
+ with self.cached_session() as session:
# Create empty ensemble.
tree_ensemble = boosted_trees_ops.TreeEnsemble('ensemble')
tree_ensemble_handle = tree_ensemble.resource_handle
@@ -184,7 +184,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testGrowExistingEnsembleTreeNotFinalized(self):
"""Test growing an existing ensemble with the last tree not finalized."""
- with self.test_session() as session:
+ with self.cached_session() as session:
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
text_format.Merge("""
trees {
@@ -368,7 +368,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testGrowExistingEnsembleTreeFinalized(self):
"""Test growing an existing ensemble with the last tree finalized."""
- with self.test_session() as session:
+ with self.cached_session() as session:
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
text_format.Merge("""
trees {
@@ -517,7 +517,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testPrePruning(self):
"""Test growing an existing ensemble with pre-pruning."""
- with self.test_session() as session:
+ with self.cached_session() as session:
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
text_format.Merge("""
trees {
@@ -673,7 +673,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testMetadataWhenCantSplitDueToEmptySplits(self):
"""Test that the metadata is updated even though we can't split."""
- with self.test_session() as session:
+ with self.cached_session() as session:
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
text_format.Merge(
"""
@@ -784,7 +784,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testMetadataWhenCantSplitDuePrePruning(self):
"""Test metadata is updated correctly when no split due to prepruning."""
- with self.test_session() as session:
+ with self.cached_session() as session:
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
text_format.Merge(
"""
@@ -919,7 +919,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testPostPruningOfSomeNodes(self):
"""Test growing an ensemble with post-pruning."""
- with self.test_session() as session:
+ with self.cached_session() as session:
# Create empty ensemble.
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
tree_ensemble = boosted_trees_ops.TreeEnsemble(
@@ -1253,7 +1253,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testPostPruningOfAllNodes(self):
"""Test growing an ensemble with post-pruning, with all nodes are pruned."""
- with self.test_session() as session:
+ with self.cached_session() as session:
# Create empty ensemble.
# Create empty ensemble.
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
@@ -1436,7 +1436,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testPostPruningChangesNothing(self):
"""Test growing an ensemble with post-pruning with all gains >0."""
- with self.test_session() as session:
+ with self.cached_session() as session:
# Create empty ensemble.
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
tree_ensemble = boosted_trees_ops.TreeEnsemble(