aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensor_forest/python/kernel_tests/best_splits_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensor_forest/python/kernel_tests/best_splits_op_test.py')
-rw-r--r--tensorflow/contrib/tensor_forest/python/kernel_tests/best_splits_op_test.py18
1 files changed, 10 insertions, 8 deletions
diff --git a/tensorflow/contrib/tensor_forest/python/kernel_tests/best_splits_op_test.py b/tensorflow/contrib/tensor_forest/python/kernel_tests/best_splits_op_test.py
index c5b5981adb..3641ab0ee0 100644
--- a/tensorflow/contrib/tensor_forest/python/kernel_tests/best_splits_op_test.py
+++ b/tensorflow/contrib/tensor_forest/python/kernel_tests/best_splits_op_test.py
@@ -30,14 +30,16 @@ class BestSplitsClassificationTests(test_util.TensorFlowTestCase):
def setUp(self):
self.finished = [3, 5]
self.node_map = [-1, -1, -1, 0, -1, 3, -1, -1, -1]
- self.candidate_counts = [[[50., 60., 40., 3.], [70., 30., 70., 30.]],
- [[0., 0., 0., 0.], [0., 0., 0., 0.]],
- [[0., 0., 0., 0.], [0., 0., 0., 0.]],
- [[10., 10., 10., 10.], [10., 5., 5., 10.]]]
- self.total_counts = [[100., 100., 100., 100.],
- [0., 0., 0., 0.],
- [0., 0., 0., 0.],
- [100., 100., 100., 100.]]
+ self.candidate_counts = [[[153., 50., 60., 40., 3.],
+ [200., 70., 30., 70., 30.]],
+ [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]],
+ [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]],
+ [[40., 10., 10., 10., 10.],
+ [30., 10., 5., 5., 10.]]]
+ self.total_counts = [[400., 100., 100., 100., 100.],
+ [0., 0., 0., 0., 0.],
+ [0., 0., 0., 0., 0.],
+ [400., 100., 100., 100., 100.]]
self.squares = []
self.ops = training_ops.Load()