aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensor_forest/python/kernel_tests/update_fertile_slots_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensor_forest/python/kernel_tests/update_fertile_slots_op_test.py')
-rw-r--r--tensorflow/contrib/tensor_forest/python/kernel_tests/update_fertile_slots_op_test.py29
1 files changed, 12 insertions, 17 deletions
diff --git a/tensorflow/contrib/tensor_forest/python/kernel_tests/update_fertile_slots_op_test.py b/tensorflow/contrib/tensor_forest/python/kernel_tests/update_fertile_slots_op_test.py
index f370903b3c..c9af01c50b 100644
--- a/tensorflow/contrib/tensor_forest/python/kernel_tests/update_fertile_slots_op_test.py
+++ b/tensorflow/contrib/tensor_forest/python/kernel_tests/update_fertile_slots_op_test.py
@@ -40,48 +40,43 @@ class UpdateFertileSlotsTest(test_util.TensorFlowTestCase):
self.node_map = [-1, -1, 0, -1, -1, -1, -1]
self.total_counts = [[80., 40., 40.]]
self.ops = training_ops.Load()
+ self.stale_leaves = []
def testSimple(self):
with self.test_session():
- (node_map_updates, accumulators_cleared, accumulators_allocated,
- new_nfl, new_nfl_scores) = self.ops.update_fertile_slots(
+ (node_map_updates, accumulators_cleared,
+ accumulators_allocated) = self.ops.update_fertile_slots(
self.finished, self.non_fertile_leaves, self.non_fertile_leaf_scores,
self.end_of_tree, self.depths,
- self.total_counts, self.node_map, max_depth=4)
+ self.total_counts, self.node_map, self.stale_leaves, max_depth=4)
self.assertAllEqual([[2, 4], [-1, 0]], node_map_updates.eval())
self.assertAllEqual([], accumulators_cleared.eval())
self.assertAllEqual([0], accumulators_allocated.eval())
- self.assertAllEqual([3, 5, 6], new_nfl.eval())
- self.assertAllEqual([10., 1., 1.], new_nfl_scores.eval())
def testReachedMaxDepth(self):
with self.test_session():
- (node_map_updates, accumulators_cleared, accumulators_allocated,
- new_nfl, new_nfl_scores) = self.ops.update_fertile_slots(
+ (node_map_updates, accumulators_cleared,
+ accumulators_allocated) = self.ops.update_fertile_slots(
self.finished, self.non_fertile_leaves, self.non_fertile_leaf_scores,
self.end_of_tree, self.depths,
- self.total_counts, self.node_map, max_depth=3)
+ self.total_counts, self.node_map, self.stale_leaves, max_depth=3)
self.assertAllEqual([[2], [-1]], node_map_updates.eval())
self.assertAllEqual([0], accumulators_cleared.eval())
self.assertAllEqual([], accumulators_allocated.eval())
- self.assertAllEqual([-1], new_nfl.eval())
- self.assertAllEqual([0.0], new_nfl_scores.eval())
def testNoFinished(self):
with self.test_session():
- (node_map_updates, accumulators_cleared, accumulators_allocated,
- new_nfl, new_nfl_scores) = self.ops.update_fertile_slots(
+ (node_map_updates, accumulators_cleared,
+ accumulators_allocated) = self.ops.update_fertile_slots(
[], self.non_fertile_leaves, self.non_fertile_leaf_scores,
self.end_of_tree, self.depths,
- self.total_counts, self.node_map, max_depth=4)
+ self.total_counts, self.node_map, self.stale_leaves, max_depth=4)
self.assertAllEqual((2, 0), node_map_updates.eval().shape)
self.assertAllEqual([], accumulators_cleared.eval())
self.assertAllEqual([], accumulators_allocated.eval())
- self.assertAllEqual([4, 3], new_nfl.eval())
- self.assertAllEqual([15., 10.], new_nfl_scores.eval())
def testBadInput(self):
del self.non_fertile_leaf_scores[-1]
@@ -89,10 +84,10 @@ class UpdateFertileSlotsTest(test_util.TensorFlowTestCase):
with self.assertRaisesOpError(
'Number of non fertile leaves should be the same in '
'non_fertile_leaves and non_fertile_leaf_scores.'):
- (node_map_updates, _, _, _, _) = self.ops.update_fertile_slots(
+ (node_map_updates, _, _) = self.ops.update_fertile_slots(
self.finished, self.non_fertile_leaves,
self.non_fertile_leaf_scores, self.end_of_tree, self.depths,
- self.total_counts, self.node_map, max_depth=4)
+ self.total_counts, self.node_map, self.stale_leaves, max_depth=4)
self.assertAllEqual((2, 0), node_map_updates.eval().shape)