aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py')
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py
index d9f03c3840..94ea7bc2eb 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py
@@ -47,7 +47,7 @@ def get_empty_tensors(gradient_shape, hessian_shape):
class EqualitySplitHandlerTest(test_util.TensorFlowTestCase):
def testGenerateFeatureSplitCandidates(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# The data looks like the following:
# Example | Gradients | Partition | Feature ID |
# i0 | (0.2, 0.12) | 0 | 1,2 |
@@ -281,7 +281,7 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase):
gains[0], 0.00001)
def testGenerateFeatureSplitCandidatesSumReduction(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# The data looks like the following:
# Example | Gradients | Partition | Feature ID |
# i0 | (0.2, 0.12) | 0 | 1,2 |
@@ -404,7 +404,7 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase):
self.assertEqual(1, split_node.feature_id)
def testGenerateFeatureSplitCandidatesMulticlass(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Batch size is 4, 2 gradients per each instance.
gradients = array_ops.constant(
[[0.2, 0.1], [-0.5, 0.2], [1.2, 3.4], [4.0, -3.5]], shape=[4, 2])
@@ -482,7 +482,7 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase):
self.assertEqual(1, split_node.feature_id)
def testEmpty(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0])
hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13])
partition_ids = [0, 0, 0, 1]
@@ -530,7 +530,7 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase):
self.assertEqual(len(splits), 0)
def testInactive(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0])
hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13])
partition_ids = [0, 0, 0, 1]