diff options
Diffstat (limited to 'tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py')
-rw-r--r-- | tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py index d9f03c3840..94ea7bc2eb 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py @@ -47,7 +47,7 @@ def get_empty_tensors(gradient_shape, hessian_shape): class EqualitySplitHandlerTest(test_util.TensorFlowTestCase): def testGenerateFeatureSplitCandidates(self): - with self.test_session() as sess: + with self.cached_session() as sess: # The data looks like the following: # Example | Gradients | Partition | Feature ID | # i0 | (0.2, 0.12) | 0 | 1,2 | @@ -281,7 +281,7 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase): gains[0], 0.00001) def testGenerateFeatureSplitCandidatesSumReduction(self): - with self.test_session() as sess: + with self.cached_session() as sess: # The data looks like the following: # Example | Gradients | Partition | Feature ID | # i0 | (0.2, 0.12) | 0 | 1,2 | @@ -404,7 +404,7 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase): self.assertEqual(1, split_node.feature_id) def testGenerateFeatureSplitCandidatesMulticlass(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Batch size is 4, 2 gradients per each instance. gradients = array_ops.constant( [[0.2, 0.1], [-0.5, 0.2], [1.2, 3.4], [4.0, -3.5]], shape=[4, 2]) @@ -482,7 +482,7 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase): self.assertEqual(1, split_node.feature_id) def testEmpty(self): - with self.test_session() as sess: + with self.cached_session() as sess: gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0]) hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13]) partition_ids = [0, 0, 0, 1] @@ -530,7 +530,7 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase): self.assertEqual(len(splits), 0) def testInactive(self): - with self.test_session() as sess: + with self.cached_session() as sess: gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0]) hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13]) partition_ids = [0, 0, 0, 1] |