aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensor_forest/python
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensor_forest/python')
-rw-r--r--tensorflow/contrib/tensor_forest/python/__init__.py1
-rw-r--r--tensorflow/contrib/tensor_forest/python/constants.py26
-rw-r--r--tensorflow/contrib/tensor_forest/python/kernel_tests/best_splits_op_test.py18
-rw-r--r--tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py101
-rw-r--r--tensorflow/contrib/tensor_forest/python/kernel_tests/finished_nodes_op_test.py56
-rw-r--r--tensorflow/contrib/tensor_forest/python/kernel_tests/sample_inputs_op_test.py34
-rw-r--r--tensorflow/contrib/tensor_forest/python/kernel_tests/tree_predictions_op_test.py70
-rw-r--r--tensorflow/contrib/tensor_forest/python/kernel_tests/update_fertile_slots_op_test.py29
-rw-r--r--tensorflow/contrib/tensor_forest/python/ops/inference_ops.py24
-rw-r--r--tensorflow/contrib/tensor_forest/python/ops/training_ops.py28
-rw-r--r--tensorflow/contrib/tensor_forest/python/tensor_forest.py510
-rw-r--r--tensorflow/contrib/tensor_forest/python/tensor_forest_test.py41
12 files changed, 639 insertions, 299 deletions
diff --git a/tensorflow/contrib/tensor_forest/python/__init__.py b/tensorflow/contrib/tensor_forest/python/__init__.py
index 0f692bbe97..a9dd599c97 100644
--- a/tensorflow/contrib/tensor_forest/python/__init__.py
+++ b/tensorflow/contrib/tensor_forest/python/__init__.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.contrib.tensor_forest.python import constants
from tensorflow.contrib.tensor_forest.python import tensor_forest
from tensorflow.contrib.tensor_forest.python.ops import inference_ops
from tensorflow.contrib.tensor_forest.python.ops import training_ops
diff --git a/tensorflow/contrib/tensor_forest/python/constants.py b/tensorflow/contrib/tensor_forest/python/constants.py
new file mode 100644
index 0000000000..029c782461
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/python/constants.py
@@ -0,0 +1,26 @@
+# pylint: disable=g-bad-file-header
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Constants used by tensorforest. Some of these map to values in C++ ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# If tree[i][0] equals this value, then i is a leaf node.
+LEAF_NODE = -1
+
+# Data column types for indicating categorical or other non-float values.
+DATA_FLOAT = 0
+DATA_CATEGORICAL = 1
diff --git a/tensorflow/contrib/tensor_forest/python/kernel_tests/best_splits_op_test.py b/tensorflow/contrib/tensor_forest/python/kernel_tests/best_splits_op_test.py
index c5b5981adb..3641ab0ee0 100644
--- a/tensorflow/contrib/tensor_forest/python/kernel_tests/best_splits_op_test.py
+++ b/tensorflow/contrib/tensor_forest/python/kernel_tests/best_splits_op_test.py
@@ -30,14 +30,16 @@ class BestSplitsClassificationTests(test_util.TensorFlowTestCase):
def setUp(self):
self.finished = [3, 5]
self.node_map = [-1, -1, -1, 0, -1, 3, -1, -1, -1]
- self.candidate_counts = [[[50., 60., 40., 3.], [70., 30., 70., 30.]],
- [[0., 0., 0., 0.], [0., 0., 0., 0.]],
- [[0., 0., 0., 0.], [0., 0., 0., 0.]],
- [[10., 10., 10., 10.], [10., 5., 5., 10.]]]
- self.total_counts = [[100., 100., 100., 100.],
- [0., 0., 0., 0.],
- [0., 0., 0., 0.],
- [100., 100., 100., 100.]]
+ self.candidate_counts = [[[153., 50., 60., 40., 3.],
+ [200., 70., 30., 70., 30.]],
+ [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]],
+ [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]],
+ [[40., 10., 10., 10., 10.],
+ [30., 10., 5., 5., 10.]]]
+ self.total_counts = [[400., 100., 100., 100., 100.],
+ [0., 0., 0., 0., 0.],
+ [0., 0., 0., 0., 0.],
+ [400., 100., 100., 100., 100.]]
self.squares = []
self.ops = training_ops.Load()
diff --git a/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py b/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py
index eb61573f24..a50eb22795 100644
--- a/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py
+++ b/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import tensorflow as tf
+from tensorflow.contrib.tensor_forest.python import constants
from tensorflow.contrib.tensor_forest.python.ops import training_ops
from tensorflow.python.framework import test_util
@@ -37,16 +38,20 @@ class CountExtremelyRandomStatsClassificationTest(test_util.TensorFlowTestCase):
self.split_features = [[1], [-1]]
self.split_thresholds = [[1.], [0.]]
self.ops = training_ops.Load()
+ self.epochs = [0, 1, 1]
+ self.current_epoch = [1]
+ self.data_spec = [constants.DATA_FLOAT] * 2
def testSimple(self):
with self.test_session():
(pcw_node_sums, _, pcw_splits_indices, pcw_splits_sums, _,
pcw_totals_indices, pcw_totals_sums, _, leaves) = (
self.ops.count_extremely_random_stats(
- self.input_data, self.input_labels, self.tree,
- self.tree_thresholds, self.node_map,
- self.split_features, self.split_thresholds, num_classes=5,
- regression=False))
+ self.input_data, [], [], [], self.data_spec, self.input_labels,
+ self.tree, self.tree_thresholds, self.node_map,
+ self.split_features, self.split_thresholds, self.epochs,
+ self.current_epoch,
+ num_classes=5, regression=False))
self.assertAllEqual(
[[4., 1., 1., 1., 1.], [2., 1., 1., 0., 0.], [2., 0., 0., 1., 1.]],
@@ -57,15 +62,68 @@ class CountExtremelyRandomStatsClassificationTest(test_util.TensorFlowTestCase):
self.assertAllEqual([1., 2., 1.], pcw_totals_sums.eval())
self.assertAllEqual([1, 1, 2, 2], leaves.eval())
+ def testSparseInput(self):
+ sparse_shape = [4, 10]
+ sparse_indices = [[0, 0], [0, 4], [0, 9],
+ [1, 0], [1, 7],
+ [2, 0],
+ [3, 1], [3, 4]]
+ sparse_values = [3.0, -1.0, 0.5,
+ 1.5, 6.0,
+ -2.0,
+ -0.5, 2.0]
+ with self.test_session():
+ (pcw_node_sums, _, pcw_splits_indices, pcw_splits_sums, _,
+ pcw_totals_indices, pcw_totals_sums, _, leaves) = (
+ self.ops.count_extremely_random_stats(
+ [], sparse_indices, sparse_values, sparse_shape, self.data_spec,
+ self.input_labels, self.tree,
+ self.tree_thresholds, self.node_map,
+ self.split_features, self.split_thresholds, self.epochs,
+ self.current_epoch,
+ num_classes=5, regression=False))
+
+ self.assertAllEqual(
+ [[4., 1., 1., 1., 1.],
+ [2., 0., 0., 1., 1.],
+ [2., 1., 1., 0., 0.]],
+ pcw_node_sums.eval())
+ self.assertAllEqual([[0, 0, 4], [0, 0, 0], [0, 0, 3]],
+ pcw_splits_indices.eval())
+ self.assertAllEqual([1., 2., 1.], pcw_splits_sums.eval())
+ self.assertAllEqual([[0, 4], [0, 0], [0, 3]], pcw_totals_indices.eval())
+ self.assertAllEqual([1., 2., 1.], pcw_totals_sums.eval())
+ self.assertAllEqual([2, 2, 1, 1], leaves.eval())
+
+ def testFutureEpoch(self):
+ current_epoch = [3]
+ with self.test_session():
+ (pcw_node_sums, _, _, pcw_splits_sums, _,
+ _, pcw_totals_sums, _, leaves) = (
+ self.ops.count_extremely_random_stats(
+ self.input_data, [], [], [], self.data_spec, self.input_labels,
+ self.tree, self.tree_thresholds, self.node_map,
+ self.split_features, self.split_thresholds, self.epochs,
+ current_epoch, num_classes=5, regression=False))
+
+ self.assertAllEqual(
+ [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]],
+ pcw_node_sums.eval())
+ self.assertAllEqual([], pcw_splits_sums.eval())
+ self.assertAllEqual([], pcw_totals_sums.eval())
+ self.assertAllEqual([1, 1, 2, 2], leaves.eval())
+
def testThreaded(self):
with self.test_session(
config=tf.ConfigProto(intra_op_parallelism_threads=2)):
(pcw_node_sums, _, pcw_splits_indices, pcw_splits_sums, _,
pcw_totals_indices, pcw_totals_sums, _, leaves) = (
self.ops.count_extremely_random_stats(
- self.input_data, self.input_labels, self.tree,
- self.tree_thresholds, self.node_map, self.split_features,
- self.split_thresholds, num_classes=5, regression=False))
+ self.input_data, [], [], [], self.data_spec, self.input_labels,
+ self.tree, self.tree_thresholds, self.node_map,
+ self.split_features,
+ self.split_thresholds, self.epochs, self.current_epoch,
+ num_classes=5, regression=False))
self.assertAllEqual([[4., 1., 1., 1., 1.], [2., 1., 1., 0., 0.],
[2., 0., 0., 1., 1.]],
@@ -81,10 +139,10 @@ class CountExtremelyRandomStatsClassificationTest(test_util.TensorFlowTestCase):
(pcw_node_sums, _, pcw_splits_indices, pcw_splits_sums, _,
pcw_totals_indices, pcw_totals_sums, _, leaves) = (
self.ops.count_extremely_random_stats(
- self.input_data, self.input_labels, self.tree,
- self.tree_thresholds, [-1] * 3,
- self.split_features, self.split_thresholds, num_classes=5,
- regression=False))
+ self.input_data, [], [], [], self.data_spec, self.input_labels,
+ self.tree, self.tree_thresholds, [-1] * 3,
+ self.split_features, self.split_thresholds, self.epochs,
+ self.current_epoch, num_classes=5, regression=False))
self.assertAllEqual([[4., 1., 1., 1., 1.], [2., 1., 1., 0., 0.],
[2., 0., 0., 1., 1.]],
@@ -101,13 +159,13 @@ class CountExtremelyRandomStatsClassificationTest(test_util.TensorFlowTestCase):
with self.test_session():
with self.assertRaisesOpError(
'Number of nodes should be the same in '
- 'tree, tree_thresholds, and node_to_accumulator'):
+ 'tree, tree_thresholds, node_to_accumulator, and birth_epoch.'):
pcw_node, _, _, _, _, _, _, _, _ = (
self.ops.count_extremely_random_stats(
- self.input_data, self.input_labels, self.tree,
- self.tree_thresholds, self.node_map,
- self.split_features, self.split_thresholds, num_classes=5,
- regression=False))
+ self.input_data, [], [], [], self.data_spec, self.input_labels,
+ self.tree, self.tree_thresholds, self.node_map,
+ self.split_features, self.split_thresholds, self.epochs,
+ self.current_epoch, num_classes=5, regression=False))
self.assertAllEqual([], pcw_node.eval())
@@ -124,6 +182,9 @@ class CountExtremelyRandomStatsRegressionTest(test_util.TensorFlowTestCase):
self.split_features = [[1], [-1]]
self.split_thresholds = [[1.], [0.]]
self.ops = training_ops.Load()
+ self.epochs = [0, 1, 1]
+ self.current_epoch = [1]
+ self.data_spec = [constants.DATA_FLOAT] * 2
def testSimple(self):
with self.test_session():
@@ -131,10 +192,10 @@ class CountExtremelyRandomStatsRegressionTest(test_util.TensorFlowTestCase):
pcw_splits_squares, pcw_totals_indices,
pcw_totals_sums, pcw_totals_squares, leaves) = (
self.ops.count_extremely_random_stats(
- self.input_data, self.input_labels, self.tree,
- self.tree_thresholds, self.node_map,
- self.split_features, self.split_thresholds, num_classes=2,
- regression=True))
+ self.input_data, [], [], [], self.data_spec, self.input_labels,
+ self.tree, self.tree_thresholds, self.node_map,
+ self.split_features, self.split_thresholds, self.epochs,
+ self.current_epoch, num_classes=2, regression=True))
self.assertAllEqual(
[[4., 14.], [2., 9.], [2., 5.]], pcw_node_sums.eval())
diff --git a/tensorflow/contrib/tensor_forest/python/kernel_tests/finished_nodes_op_test.py b/tensorflow/contrib/tensor_forest/python/kernel_tests/finished_nodes_op_test.py
index 24fbe2c11d..222ef2b2eb 100644
--- a/tensorflow/contrib/tensor_forest/python/kernel_tests/finished_nodes_op_test.py
+++ b/tensorflow/contrib/tensor_forest/python/kernel_tests/finished_nodes_op_test.py
@@ -30,35 +30,71 @@ class FinishedNodesTest(test_util.TensorFlowTestCase):
def setUp(self):
self.leaves = [1, 3, 4]
self.node_map = [-1, -1, -1, 0, 1, -1]
- self.pcw_total_splits = [[6, 3, 3], [11, 4, 7], [0, 0, 0], [0, 0, 0],
+ self.split_sums = [
+ # Accumulator 1
+ [[3, 0, 3], [2, 1, 1], [3, 1, 2]],
+ # Accumulator 2
+ [[6, 3, 3], [6, 2, 4], [5, 0, 5]],
+ # Accumulator 3
+ [[0, 0, 0], [0, 0, 0], [0, 0, 0]],
+ # Accumulator 4
+ [[0, 0, 0], [0, 0, 0], [0, 0, 0]],
+ # Accumulator 5
+ [[0, 0, 0], [0, 0, 0], [0, 0, 0]]
+ ]
+ self.split_squares = []
+ self.accumulator_sums = [[6, 3, 3], [11, 4, 7], [0, 0, 0], [0, 0, 0],
[0, 0, 0]]
+ self.accumulator_squares = []
self.ops = training_ops.Load()
+ self.birth_epochs = [0, 0, 0, 1, 1, 1]
+ self.current_epoch = [1]
def testSimple(self):
with self.test_session():
- finished = self.ops.finished_nodes(self.leaves, self.node_map,
- self.pcw_total_splits,
- num_split_after_samples=10)
+ finished, stale = self.ops.finished_nodes(
+ self.leaves, self.node_map, self.split_sums,
+ self.split_squares, self.accumulator_sums, self.accumulator_squares,
+ self.birth_epochs, self.current_epoch,
+ regression=False, num_split_after_samples=10, min_split_samples=10)
self.assertAllEqual([4], finished.eval())
+ self.assertAllEqual([], stale.eval())
def testNoAccumulators(self):
with self.test_session():
- finished = self.ops.finished_nodes(self.leaves, [-1] * 6,
- self.pcw_total_splits,
- num_split_after_samples=10)
+ finished, stale = self.ops.finished_nodes(
+ self.leaves, [-1] * 6, self.split_sums,
+ self.split_squares, self.accumulator_sums, self.accumulator_squares,
+ self.birth_epochs, self.current_epoch,
+ regression=False, num_split_after_samples=10, min_split_samples=10)
self.assertAllEqual([], finished.eval())
+ self.assertAllEqual([], stale.eval())
def testBadInput(self):
with self.test_session():
with self.assertRaisesOpError(
'leaf_tensor should be one-dimensional'):
- finished = self.ops.finished_nodes([self.leaves], self.node_map,
- self.pcw_total_splits,
- num_split_after_samples=10)
+ finished, stale = self.ops.finished_nodes(
+ [self.leaves], self.node_map, self.split_sums,
+ self.split_squares, self.accumulator_sums, self.accumulator_squares,
+ self.birth_epochs, self.current_epoch,
+ regression=False, num_split_after_samples=10, min_split_samples=10)
self.assertAllEqual([], finished.eval())
+ self.assertAllEqual([], stale.eval())
+
+ def testEarlyDominates(self):
+ with self.test_session():
+ finished, stale = self.ops.finished_nodes(
+ self.leaves, self.node_map, self.split_sums,
+ self.split_squares, self.accumulator_sums, self.accumulator_squares,
+ self.birth_epochs, self.current_epoch,
+ regression=False, num_split_after_samples=10, min_split_samples=5)
+
+ self.assertAllEqual([4], finished.eval())
+ self.assertAllEqual([], stale.eval())
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/contrib/tensor_forest/python/kernel_tests/sample_inputs_op_test.py b/tensorflow/contrib/tensor_forest/python/kernel_tests/sample_inputs_op_test.py
index 0bbd94a2a4..9830651a5d 100644
--- a/tensorflow/contrib/tensor_forest/python/kernel_tests/sample_inputs_op_test.py
+++ b/tensorflow/contrib/tensor_forest/python/kernel_tests/sample_inputs_op_test.py
@@ -41,7 +41,8 @@ class SampleInputsTest(test_util.TensorFlowTestCase):
tf.initialize_all_variables().run()
indices, feature_updates, threshold_updates = (
self.ops.sample_inputs(
- self.input_data, self.node_map, self.leaves, self.split_features,
+ self.input_data, [], [], [],
+ self.node_map, self.leaves, self.split_features,
self.split_thresholds, split_initializations_per_input=1,
split_sampling_random_seed=3))
self.assertAllEqual([1, 0], indices.eval())
@@ -50,12 +51,38 @@ class SampleInputsTest(test_util.TensorFlowTestCase):
self.assertAllEqual([[5., -2., 50.], [-1., -10., 0.]],
threshold_updates.eval())
+ def testSparse(self):
+ sparse_shape = [4, 10]
+ sparse_indices = [[0, 0], [0, 4], [0, 9],
+ [1, 0], [1, 7],
+ [2, 0],
+ [3, 1], [3, 4]]
+ sparse_values = [3.0, -1.0, 0.5,
+ 1.5, 6.0,
+ -2.0,
+ -0.5, 2.0]
+
+ with self.test_session():
+ tf.initialize_all_variables().run()
+ indices, feature_updates, threshold_updates = (
+ self.ops.sample_inputs(
+ [], sparse_indices, sparse_values, sparse_shape,
+ self.node_map, self.leaves, self.split_features,
+ self.split_thresholds, split_initializations_per_input=1,
+ split_sampling_random_seed=3))
+ self.assertAllEqual([1, 0], indices.eval())
+ self.assertAllEqual([[1, 0, 0], [4, 7, -1]],
+ feature_updates.eval())
+ self.assertAllEqual([[5., -2., -2.], [-1., 6., 0.]],
+ threshold_updates.eval())
+
def testNoAccumulators(self):
with self.test_session():
tf.initialize_all_variables().run()
indices, feature_updates, threshold_updates = (
self.ops.sample_inputs(
- self.input_data, [-1] * 3, self.leaves, self.split_features,
+ self.input_data, [], [], [],
+ [-1] * 3, self.leaves, self.split_features,
self.split_thresholds, split_initializations_per_input=1,
split_sampling_random_seed=3))
self.assertAllEqual([], indices.eval())
@@ -69,7 +96,8 @@ class SampleInputsTest(test_util.TensorFlowTestCase):
with self.assertRaisesOpError(
'split_features and split_thresholds should be the same shape.'):
indices, _, _ = self.ops.sample_inputs(
- self.input_data, self.node_map, self.leaves, self.split_features,
+ self.input_data, [], [], [],
+ self.node_map, self.leaves, self.split_features,
self.split_thresholds, split_initializations_per_input=1,
split_sampling_random_seed=3)
self.assertAllEqual([], indices.eval())
diff --git a/tensorflow/contrib/tensor_forest/python/kernel_tests/tree_predictions_op_test.py b/tensorflow/contrib/tensor_forest/python/kernel_tests/tree_predictions_op_test.py
index e61085657a..aaead5610f 100644
--- a/tensorflow/contrib/tensor_forest/python/kernel_tests/tree_predictions_op_test.py
+++ b/tensorflow/contrib/tensor_forest/python/kernel_tests/tree_predictions_op_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import tensorflow # pylint: disable=unused-import
+from tensorflow.contrib.tensor_forest.python import constants
from tensorflow.contrib.tensor_forest.python.ops import inference_ops
from tensorflow.python.framework import test_util
@@ -29,6 +30,7 @@ class TreePredictionsTest(test_util.TensorFlowTestCase):
def setUp(self):
self.ops = inference_ops.Load()
+ self.data_spec = [constants.DATA_FLOAT] * 2
def testSimple(self):
input_data = [[-1., 0.], [-1., 2.], # node 1
@@ -41,13 +43,65 @@ class TreePredictionsTest(test_util.TensorFlowTestCase):
with self.test_session():
predictions = self.ops.tree_predictions(
- input_data, tree, tree_thresholds, node_pcw,
- valid_leaf_threshold=1)
+ input_data, [], [], [], self.data_spec, tree, tree_thresholds,
+ node_pcw, valid_leaf_threshold=1)
self.assertAllClose([[0.1, 0.1, 0.8], [0.1, 0.1, 0.8],
[0.5, 0.25, 0.25], [0.5, 0.25, 0.25]],
predictions.eval())
+ def testSparseInput(self):
+ sparse_shape = [3, 10]
+ sparse_indices = [[0, 0], [0, 4], [0, 9],
+ [1, 0], [1, 7],
+ [2, 0]]
+ sparse_values = [3.0, -1.0, 0.5,
+ 1.5, 6.0,
+ -2.0]
+ sparse_data_spec = [constants.DATA_FLOAT] * 10
+
+ tree = [[1, 0], [-1, 0], [-1, 0]]
+ tree_thresholds = [0., 0., 0.]
+ node_pcw = [[1.0, 0.3, 0.4, 0.3], [1.0, 0.1, 0.1, 0.8],
+ [1.0, 0.5, 0.25, 0.25]]
+
+ with self.test_session():
+ predictions = self.ops.tree_predictions(
+ [], sparse_indices, sparse_values, sparse_shape, sparse_data_spec,
+ tree, tree_thresholds, node_pcw,
+ valid_leaf_threshold=1)
+
+ self.assertAllClose([[0.5, 0.25, 0.25],
+ [0.5, 0.25, 0.25],
+ [0.1, 0.1, 0.8]],
+ predictions.eval())
+
+ def testSparseInputDefaultIsZero(self):
+ sparse_shape = [3, 10]
+ sparse_indices = [[0, 0], [0, 4], [0, 9],
+ [1, 0], [1, 7],
+ [2, 0]]
+ sparse_values = [3.0, -1.0, 0.5,
+ 1.5, 6.0,
+ -2.0]
+ sparse_data_spec = [constants.DATA_FLOAT] * 10
+
+ tree = [[1, 7], [-1, 0], [-1, 0]]
+ tree_thresholds = [3.0, 0., 0.]
+ node_pcw = [[1.0, 0.3, 0.4, 0.3], [1.0, 0.1, 0.1, 0.8],
+ [1.0, 0.5, 0.25, 0.25]]
+
+ with self.test_session():
+ predictions = self.ops.tree_predictions(
+ [], sparse_indices, sparse_values, sparse_shape, sparse_data_spec,
+ tree, tree_thresholds, node_pcw,
+ valid_leaf_threshold=1)
+
+ self.assertAllClose([[0.1, 0.1, 0.8],
+ [0.5, 0.25, 0.25],
+ [0.1, 0.1, 0.8]],
+ predictions.eval())
+
def testBackoffToParent(self):
input_data = [[-1., 0.], [-1., 2.], # node 1
[1., 0.], [1., -2.]] # node 2
@@ -59,8 +113,8 @@ class TreePredictionsTest(test_util.TensorFlowTestCase):
with self.test_session():
predictions = self.ops.tree_predictions(
- input_data, tree, tree_thresholds, node_pcw,
- valid_leaf_threshold=10)
+ input_data, [], [], [], self.data_spec, tree, tree_thresholds,
+ node_pcw, valid_leaf_threshold=10)
# Node 2 has enough data, but Node 1 needs to combine with the parent
# counts.
@@ -78,8 +132,8 @@ class TreePredictionsTest(test_util.TensorFlowTestCase):
with self.test_session():
predictions = self.ops.tree_predictions(
- input_data, tree, tree_thresholds, node_pcw,
- valid_leaf_threshold=10)
+ input_data, [], [], [], self.data_spec, tree, tree_thresholds,
+ node_pcw, valid_leaf_threshold=10)
self.assertEquals((0, 3), predictions.eval().shape)
@@ -97,8 +151,8 @@ class TreePredictionsTest(test_util.TensorFlowTestCase):
'Number of nodes should be the same in tree, tree_thresholds '
'and node_pcw.'):
predictions = self.ops.tree_predictions(
- input_data, tree, tree_thresholds, node_pcw,
- valid_leaf_threshold=10)
+ input_data, [], [], [], self.data_spec, tree, tree_thresholds,
+ node_pcw, valid_leaf_threshold=10)
self.assertEquals((0, 3), predictions.eval().shape)
diff --git a/tensorflow/contrib/tensor_forest/python/kernel_tests/update_fertile_slots_op_test.py b/tensorflow/contrib/tensor_forest/python/kernel_tests/update_fertile_slots_op_test.py
index f370903b3c..c9af01c50b 100644
--- a/tensorflow/contrib/tensor_forest/python/kernel_tests/update_fertile_slots_op_test.py
+++ b/tensorflow/contrib/tensor_forest/python/kernel_tests/update_fertile_slots_op_test.py
@@ -40,48 +40,43 @@ class UpdateFertileSlotsTest(test_util.TensorFlowTestCase):
self.node_map = [-1, -1, 0, -1, -1, -1, -1]
self.total_counts = [[80., 40., 40.]]
self.ops = training_ops.Load()
+ self.stale_leaves = []
def testSimple(self):
with self.test_session():
- (node_map_updates, accumulators_cleared, accumulators_allocated,
- new_nfl, new_nfl_scores) = self.ops.update_fertile_slots(
+ (node_map_updates, accumulators_cleared,
+ accumulators_allocated) = self.ops.update_fertile_slots(
self.finished, self.non_fertile_leaves, self.non_fertile_leaf_scores,
self.end_of_tree, self.depths,
- self.total_counts, self.node_map, max_depth=4)
+ self.total_counts, self.node_map, self.stale_leaves, max_depth=4)
self.assertAllEqual([[2, 4], [-1, 0]], node_map_updates.eval())
self.assertAllEqual([], accumulators_cleared.eval())
self.assertAllEqual([0], accumulators_allocated.eval())
- self.assertAllEqual([3, 5, 6], new_nfl.eval())
- self.assertAllEqual([10., 1., 1.], new_nfl_scores.eval())
def testReachedMaxDepth(self):
with self.test_session():
- (node_map_updates, accumulators_cleared, accumulators_allocated,
- new_nfl, new_nfl_scores) = self.ops.update_fertile_slots(
+ (node_map_updates, accumulators_cleared,
+ accumulators_allocated) = self.ops.update_fertile_slots(
self.finished, self.non_fertile_leaves, self.non_fertile_leaf_scores,
self.end_of_tree, self.depths,
- self.total_counts, self.node_map, max_depth=3)
+ self.total_counts, self.node_map, self.stale_leaves, max_depth=3)
self.assertAllEqual([[2], [-1]], node_map_updates.eval())
self.assertAllEqual([0], accumulators_cleared.eval())
self.assertAllEqual([], accumulators_allocated.eval())
- self.assertAllEqual([-1], new_nfl.eval())
- self.assertAllEqual([0.0], new_nfl_scores.eval())
def testNoFinished(self):
with self.test_session():
- (node_map_updates, accumulators_cleared, accumulators_allocated,
- new_nfl, new_nfl_scores) = self.ops.update_fertile_slots(
+ (node_map_updates, accumulators_cleared,
+ accumulators_allocated) = self.ops.update_fertile_slots(
[], self.non_fertile_leaves, self.non_fertile_leaf_scores,
self.end_of_tree, self.depths,
- self.total_counts, self.node_map, max_depth=4)
+ self.total_counts, self.node_map, self.stale_leaves, max_depth=4)
self.assertAllEqual((2, 0), node_map_updates.eval().shape)
self.assertAllEqual([], accumulators_cleared.eval())
self.assertAllEqual([], accumulators_allocated.eval())
- self.assertAllEqual([4, 3], new_nfl.eval())
- self.assertAllEqual([15., 10.], new_nfl_scores.eval())
def testBadInput(self):
del self.non_fertile_leaf_scores[-1]
@@ -89,10 +84,10 @@ class UpdateFertileSlotsTest(test_util.TensorFlowTestCase):
with self.assertRaisesOpError(
'Number of non fertile leaves should be the same in '
'non_fertile_leaves and non_fertile_leaf_scores.'):
- (node_map_updates, _, _, _, _) = self.ops.update_fertile_slots(
+ (node_map_updates, _, _) = self.ops.update_fertile_slots(
self.finished, self.non_fertile_leaves,
self.non_fertile_leaf_scores, self.end_of_tree, self.depths,
- self.total_counts, self.node_map, max_depth=4)
+ self.total_counts, self.node_map, self.stale_leaves, max_depth=4)
self.assertAllEqual((2, 0), node_map_updates.eval().shape)
diff --git a/tensorflow/contrib/tensor_forest/python/ops/inference_ops.py b/tensorflow/contrib/tensor_forest/python/ops/inference_ops.py
index 6f4e6fff40..88f8112ed4 100644
--- a/tensorflow/contrib/tensor_forest/python/ops/inference_ops.py
+++ b/tensorflow/contrib/tensor_forest/python/ops/inference_ops.py
@@ -1,3 +1,4 @@
+# pylint: disable=g-bad-file-header
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -17,13 +18,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import os
import threading
-import tensorflow as tf
-
+from tensorflow.python.framework import load_library
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.platform import resource_loader
+from tensorflow.python.platform import tf_logging as logging
+
INFERENCE_OPS_FILE = '_inference_ops.so'
@@ -38,7 +40,11 @@ ops.NoGradient('TreePredictions')
def TreePredictions(op):
"""Shape function for TreePredictions Op."""
num_points = op.inputs[0].get_shape()[0].value
- num_classes = op.inputs[3].get_shape()[1].value
+ sparse_shape = op.inputs[3].get_shape()
+ if sparse_shape.ndims == 2:
+ num_points = sparse_shape[0].value
+ num_classes = op.inputs[7].get_shape()[1].value
+
# The output of TreePredictions is
# [node_pcw(evaluate_tree(x), c) for c in classes for x in input_data].
return [tensor_shape.TensorShape([num_points, num_classes - 1])]
@@ -49,16 +55,14 @@ def TreePredictions(op):
# there's not yet any guarantee that the shared object exists.
# In which case, "import tensorflow" will always crash, even for users that
# never use contrib.
-def Load(library_base_dir=''):
+def Load():
"""Load the inference ops library and return the loaded module."""
with _ops_lock:
global _inference_ops
if not _inference_ops:
- data_files_path = os.path.join(library_base_dir,
- tf.resource_loader.get_data_files_path())
- tf.logging.info('data path: %s', data_files_path)
- _inference_ops = tf.load_op_library(os.path.join(
- data_files_path, INFERENCE_OPS_FILE))
+ ops_path = resource_loader.get_path_to_datafile(INFERENCE_OPS_FILE)
+ logging.info('data path: %s', ops_path)
+ _inference_ops = load_library.load_op_library(ops_path)
assert _inference_ops, 'Could not load inference_ops.so'
return _inference_ops
diff --git a/tensorflow/contrib/tensor_forest/python/ops/training_ops.py b/tensorflow/contrib/tensor_forest/python/ops/training_ops.py
index 7a108baf42..d25d5ce50b 100644
--- a/tensorflow/contrib/tensor_forest/python/ops/training_ops.py
+++ b/tensorflow/contrib/tensor_forest/python/ops/training_ops.py
@@ -1,3 +1,4 @@
+# pylint: disable=g-bad-file-header
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -17,13 +18,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import os
import threading
-import tensorflow as tf
-
+from tensorflow.python.framework import load_library
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.platform import resource_loader
+from tensorflow.python.platform import tf_logging as logging
TRAINING_OPS_FILE = '_training_ops.so'
@@ -45,7 +46,10 @@ def _CountExtremelyRandomStatsShape(op):
"""Shape function for CountExtremelyRandomStats Op."""
regression = op.get_attr('regression')
num_points = op.inputs[0].get_shape()[0].value
- num_nodes = op.inputs[2].get_shape()[0].value
+ sparse_shape = op.inputs[3].get_shape()
+ if sparse_shape.ndims == 2:
+ num_points = sparse_shape[0].value
+ num_nodes = op.inputs[6].get_shape()[0].value
num_classes = op.get_attr('num_classes')
# The output of TraverseTree is [leaf_node_index(x) for x in input_data].
return [tensor_shape.TensorShape([num_nodes, num_classes]), # node sums
@@ -66,7 +70,7 @@ def _CountExtremelyRandomStatsShape(op):
@ops.RegisterShape('SampleInputs')
def _SampleInputsShape(op):
"""Shape function for SampleInputs Op."""
- num_splits = op.inputs[3].get_shape()[1].value
+ num_splits = op.inputs[6].get_shape()[1].value
return [[None], [None, num_splits], [None, num_splits]]
@@ -85,7 +89,7 @@ def _GrowTreeShape(unused_op):
@ops.RegisterShape('FinishedNodes')
def _FinishedNodesShape(unused_op):
"""Shape function for FinishedNodes Op."""
- return [[None]]
+ return [[None], [None]]
@ops.RegisterShape('ScatterAddNdim')
@@ -97,7 +101,7 @@ def _ScatterAddNdimShape(unused_op):
@ops.RegisterShape('UpdateFertileSlots')
def _UpdateFertileSlotsShape(unused_op):
"""Shape function for UpdateFertileSlots Op."""
- return [[None, 2], [None], [None], [None], [None]]
+ return [[None, 2], [None], [None]]
# Workaround for the fact that importing tensorflow imports contrib
@@ -105,16 +109,14 @@ def _UpdateFertileSlotsShape(unused_op):
# there's not yet any guarantee that the shared object exists.
# In which case, "import tensorflow" will always crash, even for users that
# never use contrib.
-def Load(library_base_dir=''):
+def Load():
"""Load training ops library and return the loaded module."""
with _ops_lock:
global _training_ops
if not _training_ops:
- data_files_path = os.path.join(library_base_dir,
- tf.resource_loader.get_data_files_path())
- tf.logging.info('data path: %s', data_files_path)
- _training_ops = tf.load_op_library(os.path.join(
- data_files_path, TRAINING_OPS_FILE))
+ ops_path = resource_loader.get_path_to_datafile(TRAINING_OPS_FILE)
+ logging.info('data path: %s', ops_path)
+ _training_ops = load_library.load_op_library(ops_path)
assert _training_ops, 'Could not load _training_ops.so'
return _training_ops
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest.py b/tensorflow/contrib/tensor_forest/python/tensor_forest.py
index f48efaa5db..791954c51f 100644
--- a/tensorflow/contrib/tensor_forest/python/tensor_forest.py
+++ b/tensorflow/contrib/tensor_forest/python/tensor_forest.py
@@ -1,3 +1,4 @@
+# pylint: disable=g-bad-file-header
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -20,14 +21,22 @@ from __future__ import print_function
import math
import random
-import tensorflow as tf
-
+from tensorflow.contrib.tensor_forest.python import constants
from tensorflow.contrib.tensor_forest.python.ops import inference_ops
from tensorflow.contrib.tensor_forest.python.ops import training_ops
-
-# If tree[i][0] equals this value, then i is a leaf node.
-LEAF_NODE = -1
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables as tf_variables
+from tensorflow.python.platform import tf_logging as logging
# A convenience class for holding random forest hyperparameters.
@@ -49,6 +58,7 @@ class ForestHParams(object):
max_depth=0, num_splits_to_consider=0,
feature_bagging_fraction=1.0,
max_fertile_nodes=0, split_after_samples=250,
+ min_split_samples=5,
valid_leaf_threshold=1, **kwargs):
self.num_trees = num_trees
self.max_nodes = max_nodes
@@ -58,6 +68,7 @@ class ForestHParams(object):
self.num_splits_to_consider = num_splits_to_consider
self.max_fertile_nodes = max_fertile_nodes
self.split_after_samples = split_after_samples
+ self.min_split_samples = min_split_samples
self.valid_leaf_threshold = valid_leaf_threshold
for name, value in kwargs.items():
@@ -72,11 +83,6 @@ class ForestHParams(object):
_ = getattr(self, 'num_classes')
_ = getattr(self, 'num_features')
- self.training_library_base_dir = getattr(
- self, 'training_library_base_dir', '')
- self.inference_library_base_dir = getattr(
- self, 'inference_library_base_dir', '')
-
self.bagged_num_features = int(self.feature_bagging_fraction *
self.num_features)
@@ -147,92 +153,86 @@ class TreeTrainingVariables(object):
"""
def __init__(self, params, tree_num, training):
- self.tree = tf.get_variable(
- name=self.get_tree_name('tree', tree_num), dtype=tf.int32,
- initializer=tf.constant(
- [[-1, -1]] + [[-2, -1]] * (params.max_nodes - 1)))
- self.tree_thresholds = tf.get_variable(
+ self.tree = variable_scope.get_variable(
+ name=self.get_tree_name('tree', tree_num), dtype=dtypes.int32,
+ shape=[params.max_nodes, 2],
+ initializer=init_ops.constant_initializer(-2))
+ self.tree_thresholds = variable_scope.get_variable(
name=self.get_tree_name('tree_thresholds', tree_num),
shape=[params.max_nodes],
- initializer=tf.constant_initializer(-1.0))
- self.tree_depths = tf.get_variable(
+ initializer=init_ops.constant_initializer(-1.0))
+ self.tree_depths = variable_scope.get_variable(
name=self.get_tree_name('tree_depths', tree_num),
shape=[params.max_nodes],
- dtype=tf.int32,
- initializer=tf.constant_initializer(1))
- self.end_of_tree = tf.get_variable(
+ dtype=dtypes.int32,
+ initializer=init_ops.constant_initializer(1))
+ self.end_of_tree = variable_scope.get_variable(
name=self.get_tree_name('end_of_tree', tree_num),
- dtype=tf.int32,
- initializer=tf.constant([1]))
+ dtype=dtypes.int32,
+ initializer=constant_op.constant([1]))
+ self.start_epoch = tf_variables.Variable(
+ [0] * (params.max_nodes), name='start_epoch')
if training:
- self.non_fertile_leaves = tf.get_variable(
- name=self.get_tree_name('non_fertile_leaves', tree_num),
- dtype=tf.int32,
- initializer=tf.constant([0]))
- self.non_fertile_leaf_scores = tf.get_variable(
- name=self.get_tree_name('non_fertile_leaf_scores', tree_num),
- initializer=tf.constant([1.0]))
-
- self.node_to_accumulator_map = tf.get_variable(
+ self.node_to_accumulator_map = variable_scope.get_variable(
name=self.get_tree_name('node_to_accumulator_map', tree_num),
shape=[params.max_nodes],
- dtype=tf.int32,
- initializer=tf.constant_initializer(-1))
+ dtype=dtypes.int32,
+ initializer=init_ops.constant_initializer(-1))
- self.candidate_split_features = tf.get_variable(
+ self.candidate_split_features = variable_scope.get_variable(
name=self.get_tree_name('candidate_split_features', tree_num),
shape=[params.max_fertile_nodes, params.num_splits_to_consider],
- dtype=tf.int32,
- initializer=tf.constant_initializer(-1))
- self.candidate_split_thresholds = tf.get_variable(
+ dtype=dtypes.int32,
+ initializer=init_ops.constant_initializer(-1))
+ self.candidate_split_thresholds = variable_scope.get_variable(
name=self.get_tree_name('candidate_split_thresholds', tree_num),
shape=[params.max_fertile_nodes, params.num_splits_to_consider],
- initializer=tf.constant_initializer(0.0))
+ initializer=init_ops.constant_initializer(0.0))
# Statistics shared by classification and regression.
- self.node_sums = tf.get_variable(
+ self.node_sums = variable_scope.get_variable(
name=self.get_tree_name('node_sums', tree_num),
shape=[params.max_nodes, params.num_output_columns],
- initializer=tf.constant_initializer(0.0))
+ initializer=init_ops.constant_initializer(0.0))
if training:
- self.candidate_split_sums = tf.get_variable(
+ self.candidate_split_sums = variable_scope.get_variable(
name=self.get_tree_name('candidate_split_sums', tree_num),
shape=[params.max_fertile_nodes, params.num_splits_to_consider,
params.num_output_columns],
- initializer=tf.constant_initializer(0.0))
- self.accumulator_sums = tf.get_variable(
+ initializer=init_ops.constant_initializer(0.0))
+ self.accumulator_sums = variable_scope.get_variable(
name=self.get_tree_name('accumulator_sums', tree_num),
shape=[params.max_fertile_nodes, params.num_output_columns],
- initializer=tf.constant_initializer(-1.0))
+ initializer=init_ops.constant_initializer(-1.0))
# Regression also tracks second order stats.
if params.regression:
- self.node_squares = tf.get_variable(
+ self.node_squares = variable_scope.get_variable(
name=self.get_tree_name('node_squares', tree_num),
shape=[params.max_nodes, params.num_output_columns],
- initializer=tf.constant_initializer(0.0))
+ initializer=init_ops.constant_initializer(0.0))
- self.candidate_split_squares = tf.get_variable(
+ self.candidate_split_squares = variable_scope.get_variable(
name=self.get_tree_name('candidate_split_squares', tree_num),
shape=[params.max_fertile_nodes, params.num_splits_to_consider,
params.num_output_columns],
- initializer=tf.constant_initializer(0.0))
+ initializer=init_ops.constant_initializer(0.0))
- self.accumulator_squares = tf.get_variable(
+ self.accumulator_squares = variable_scope.get_variable(
name=self.get_tree_name('accumulator_squares', tree_num),
shape=[params.max_fertile_nodes, params.num_output_columns],
- initializer=tf.constant_initializer(-1.0))
+ initializer=init_ops.constant_initializer(-1.0))
else:
- self.node_squares = tf.constant(
+ self.node_squares = constant_op.constant(
0.0, name=self.get_tree_name('node_squares', tree_num))
- self.candidate_split_squares = tf.constant(
+ self.candidate_split_squares = constant_op.constant(
0.0, name=self.get_tree_name('candidate_split_squares', tree_num))
- self.accumulator_squares = tf.constant(
+ self.accumulator_squares = constant_op.constant(
0.0, name=self.get_tree_name('accumulator_squares', tree_num))
def get_tree_name(self, name, num):
@@ -273,11 +273,11 @@ class ForestTrainingVariables(object):
"""
def __init__(self, params, device_assigner, training=True,
- tree_variable_class=TreeTrainingVariables):
+ tree_variables_class=TreeTrainingVariables):
self.variables = []
for i in range(params.num_trees):
- with tf.device(device_assigner.get_device(i)):
- self.variables.append(tree_variable_class(params, i, training))
+ with ops.device(device_assigner.get_device(i)):
+ self.variables.append(tree_variables_class(params, i, training))
def __setitem__(self, t, val):
self.variables[t] = val
@@ -299,7 +299,7 @@ class RandomForestDeviceAssigner(object):
def get_device(self, unused_tree_num):
if not self.cached:
- dummy = tf.constant(0)
+ dummy = constant_op.constant(0)
self.cached = dummy.device
return self.cached
@@ -308,43 +308,51 @@ class RandomForestDeviceAssigner(object):
class RandomForestGraphs(object):
"""Builds TF graphs for random forest training and inference."""
- def __init__(self, params, device_assigner=None, variables=None,
- tree_graphs=None,
+ def __init__(self, params, device_assigner=None,
+ variables=None, tree_variables_class=TreeTrainingVariables,
+ tree_graphs=None, training=True,
t_ops=training_ops,
i_ops=inference_ops):
self.params = params
self.device_assigner = device_assigner or RandomForestDeviceAssigner()
- tf.logging.info('Constructing forest with params = ')
- tf.logging.info(self.params.__dict__)
+ logging.info('Constructing forest with params = ')
+ logging.info(self.params.__dict__)
self.variables = variables or ForestTrainingVariables(
- self.params, device_assigner=self.device_assigner)
+ self.params, device_assigner=self.device_assigner, training=training,
+ tree_variables_class=tree_variables_class)
tree_graph_class = tree_graphs or RandomTreeGraphs
self.trees = [
tree_graph_class(
self.variables[i], self.params,
- t_ops.Load(self.params.training_library_base_dir),
- i_ops.Load(self.params.inference_library_base_dir), i)
+ t_ops.Load(), i_ops.Load(), i)
for i in range(self.params.num_trees)]
def _bag_features(self, tree_num, input_data):
- split_data = tf.split(1, self.params.num_features, input_data)
- return tf.concat(1, [split_data[ind]
- for ind in self.params.bagged_features[tree_num]])
+ split_data = array_ops.split(1, self.params.num_features, input_data)
+ return array_ops.concat(
+ 1, [split_data[ind] for ind in self.params.bagged_features[tree_num]])
- def training_graph(self, input_data, input_labels):
+ def training_graph(self, input_data, input_labels, data_spec=None,
+ epoch=None, **tree_kwargs):
"""Constructs a TF graph for training a random forest.
Args:
- input_data: A tensor or placeholder for input data.
+ input_data: A tensor or SparseTensor or placeholder for input data.
input_labels: A tensor or placeholder for labels associated with
input_data.
+ data_spec: A list of tf.dtype values specifying the original types of
+ each column.
+ epoch: A tensor or placeholder for the epoch the training data comes from.
+ **tree_kwargs: Keyword arguments passed to each tree's training_graph.
Returns:
The last op in the random forest training graph.
"""
+ data_spec = ([constants.DATA_FLOAT] * self.params.num_features
+ if data_spec is None else data_spec)
tree_graphs = []
for i in range(self.params.num_trees):
- with tf.device(self.device_assigner.get_device(i)):
+ with ops.device(self.device_assigner.get_device(i)):
seed = self.params.base_random_seed
if seed != 0:
seed += i
@@ -354,40 +362,54 @@ class RandomForestGraphs(object):
if self.params.bagging_fraction < 1.0:
# TODO(thomaswc): This does sampling without replacment. Consider
# also allowing sampling with replacement as an option.
- batch_size = tf.slice(tf.shape(input_data), [0], [1])
- r = tf.random_uniform(batch_size, seed=seed)
- mask = tf.less(r, tf.ones_like(r) * self.params.bagging_fraction)
- gather_indices = tf.squeeze(tf.where(mask), squeeze_dims=[1])
+ batch_size = array_ops.slice(array_ops.shape(input_data), [0], [1])
+ r = random_ops.random_uniform(batch_size, seed=seed)
+ mask = math_ops.less(
+ r, array_ops.ones_like(r) * self.params.bagging_fraction)
+ gather_indices = array_ops.squeeze(
+ array_ops.where(mask), squeeze_dims=[1])
# TODO(thomaswc): Calculate out-of-bag data and labels, and store
# them for use in calculating statistics later.
- tree_data = tf.gather(input_data, gather_indices)
- tree_labels = tf.gather(input_labels, gather_indices)
+ tree_data = array_ops.gather(input_data, gather_indices)
+ tree_labels = array_ops.gather(input_labels, gather_indices)
if self.params.bagged_features:
tree_data = self._bag_features(i, tree_data)
- tree_graphs.append(
- self.trees[i].training_graph(tree_data, tree_labels, seed))
- return tf.group(*tree_graphs)
+ initialization = self.trees[i].tree_initialization()
+
+ with ops.control_dependencies([initialization]):
+ tree_graphs.append(
+ self.trees[i].training_graph(
+ tree_data, tree_labels, seed, data_spec=data_spec,
+ epoch=([0] if epoch is None else epoch),
+ **tree_kwargs))
- def inference_graph(self, input_data):
+ return control_flow_ops.group(*tree_graphs)
+
+ def inference_graph(self, input_data, data_spec=None):
"""Constructs a TF graph for evaluating a random forest.
Args:
- input_data: A tensor or placeholder for input data.
+ input_data: A tensor or SparseTensor or placeholder for input data.
+ data_spec: A list of tf.dtype values specifying the original types of
+ each column.
Returns:
The last op in the random forest inference graph.
"""
+ data_spec = ([constants.DATA_FLOAT] * self.params.num_features
+ if data_spec is None else data_spec)
probabilities = []
for i in range(self.params.num_trees):
- with tf.device(self.device_assigner.get_device(i)):
+ with ops.device(self.device_assigner.get_device(i)):
tree_data = input_data
if self.params.bagged_features:
tree_data = self._bag_features(i, input_data)
- probabilities.append(self.trees[i].inference_graph(tree_data))
- with tf.device(self.device_assigner.get_device(0)):
- all_predict = tf.pack(probabilities)
- return tf.reduce_sum(all_predict, 0) / self.params.num_trees
+ probabilities.append(self.trees[i].inference_graph(tree_data,
+ data_spec))
+ with ops.device(self.device_assigner.get_device(0)):
+ all_predict = array_ops.pack(probabilities)
+ return math_ops.reduce_sum(all_predict, 0) / self.params.num_trees
def average_size(self):
"""Constructs a TF graph for evaluating the average size of a forest.
@@ -397,9 +419,16 @@ class RandomForestGraphs(object):
"""
sizes = []
for i in range(self.params.num_trees):
- with tf.device(self.device_assigner.get_device(i)):
+ with ops.device(self.device_assigner.get_device(i)):
sizes.append(self.trees[i].size())
- return tf.reduce_mean(tf.pack(sizes))
+ return math_ops.reduce_mean(array_ops.pack(sizes))
+
+ def training_loss(self):
+ return math_ops.neg(self.average_size())
+
+ # pylint: disable=unused-argument
+ def validation_loss(self, features, labels):
+ return math_ops.neg(self.average_size())
def average_impurity(self):
"""Constructs a TF graph for evaluating the leaf impurity of a forest.
@@ -409,14 +438,14 @@ class RandomForestGraphs(object):
"""
impurities = []
for i in range(self.params.num_trees):
- with tf.device(self.device_assigner.get_device(i)):
+ with ops.device(self.device_assigner.get_device(i)):
impurities.append(self.trees[i].average_impurity())
- return tf.reduce_mean(tf.pack(impurities))
+ return math_ops.reduce_mean(array_ops.pack(impurities))
def get_stats(self, session):
tree_stats = []
for i in range(self.params.num_trees):
- with tf.device(self.device_assigner.get_device(i)):
+ with ops.device(self.device_assigner.get_device(i)):
tree_stats.append(self.trees[i].get_stats(session))
return ForestStats(tree_stats, self.params)
@@ -431,6 +460,18 @@ class RandomTreeGraphs(object):
self.params = params
self.tree_num = tree_num
+ def tree_initialization(self):
+ def _init_tree():
+ return state_ops.scatter_update(self.variables.tree, [0], [[-1, -1]]).op
+
+ def _nothing():
+ return control_flow_ops.no_op()
+
+ return control_flow_ops.cond(
+ math_ops.equal(array_ops.squeeze(array_ops.slice(
+ self.variables.tree, [0, 0], [1, 1])), -2),
+ _init_tree, _nothing)
+
def _gini(self, class_counts):
"""Calculate the Gini impurity.
@@ -444,9 +485,9 @@ class RandomTreeGraphs(object):
Returns:
A 1-D tensor of the Gini impurities for each row in the input.
"""
- smoothed = 1.0 + tf.slice(class_counts, [0, 1], [-1, -1])
- sums = tf.reduce_sum(smoothed, 1)
- sum_squares = tf.reduce_sum(tf.square(smoothed), 1)
+ smoothed = 1.0 + array_ops.slice(class_counts, [0, 1], [-1, -1])
+ sums = math_ops.reduce_sum(smoothed, 1)
+ sum_squares = math_ops.reduce_sum(math_ops.square(smoothed), 1)
return 1.0 - sum_squares / (sums * sums)
@@ -463,9 +504,9 @@ class RandomTreeGraphs(object):
Returns:
A 1-D tensor of the Gini impurities for each row in the input.
"""
- smoothed = 1.0 + tf.slice(class_counts, [0, 1], [-1, -1])
- sums = tf.reduce_sum(smoothed, 1)
- sum_squares = tf.reduce_sum(tf.square(smoothed), 1)
+ smoothed = 1.0 + array_ops.slice(class_counts, [0, 1], [-1, -1])
+ sums = math_ops.reduce_sum(smoothed, 1)
+ sum_squares = math_ops.reduce_sum(math_ops.square(smoothed), 1)
return sums - sum_squares / sums
@@ -483,40 +524,58 @@ class RandomTreeGraphs(object):
Returns:
A 1-D tensor of the variances for each row in the input.
"""
- total_count = tf.slice(sums, [0, 0], [-1, 1])
+ total_count = array_ops.slice(sums, [0, 0], [-1, 1])
e_x = sums / total_count
e_x2 = squares / total_count
- return tf.reduce_sum(e_x2 - tf.square(e_x), 1)
+ return math_ops.reduce_sum(e_x2 - math_ops.square(e_x), 1)
+
+ def training_graph(self, input_data, input_labels, random_seed,
+ data_spec, epoch=None):
- def training_graph(self, input_data, input_labels, random_seed):
"""Constructs a TF graph for training a random tree.
Args:
- input_data: A tensor or placeholder for input data.
+ input_data: A tensor or SparseTensor or placeholder for input data.
input_labels: A tensor or placeholder for labels associated with
input_data.
random_seed: The random number generator seed to use for this tree. 0
means use the current time as the seed.
+ data_spec: A list of tf.dtype values specifying the original types of
+ each column.
+ epoch: A tensor or placeholder for the epoch the training data comes from.
Returns:
The last op in the random tree training graph.
"""
+ epoch = [0] if epoch is None else epoch
+
+ sparse_indices = []
+ sparse_values = []
+ sparse_shape = []
+ if isinstance(input_data, ops.SparseTensor):
+ sparse_indices = input_data.indices
+ sparse_values = input_data.values
+ sparse_shape = input_data.shape
+ input_data = []
+
# Count extremely random stats.
(node_sums, node_squares, splits_indices, splits_sums,
splits_squares, totals_indices, totals_sums,
totals_squares, input_leaves) = (
self.training_ops.count_extremely_random_stats(
- input_data, input_labels, self.variables.tree,
+ input_data, sparse_indices, sparse_values, sparse_shape,
+ data_spec, input_labels, self.variables.tree,
self.variables.tree_thresholds,
self.variables.node_to_accumulator_map,
self.variables.candidate_split_features,
self.variables.candidate_split_thresholds,
+ self.variables.start_epoch, epoch,
num_classes=self.params.num_output_columns,
regression=self.params.regression))
node_update_ops = []
node_update_ops.append(
- tf.assign_add(self.variables.node_sums, node_sums))
+ state_ops.assign_add(self.variables.node_sums, node_sums))
splits_update_ops = []
splits_update_ops.append(self.training_ops.scatter_add_ndim(
@@ -527,8 +586,8 @@ class RandomTreeGraphs(object):
totals_sums))
if self.params.regression:
- node_update_ops.append(tf.assign_add(self.variables.node_squares,
- node_squares))
+ node_update_ops.append(state_ops.assign_add(self.variables.node_squares,
+ node_squares))
splits_update_ops.append(self.training_ops.scatter_add_ndim(
self.variables.candidate_split_squares,
splits_indices, splits_squares))
@@ -539,63 +598,56 @@ class RandomTreeGraphs(object):
# Sample inputs.
update_indices, feature_updates, threshold_updates = (
self.training_ops.sample_inputs(
- input_data, self.variables.node_to_accumulator_map,
+ input_data, sparse_indices, sparse_values, sparse_shape,
+ self.variables.node_to_accumulator_map,
input_leaves, self.variables.candidate_split_features,
self.variables.candidate_split_thresholds,
split_initializations_per_input=(
self.params.split_initializations_per_input),
split_sampling_random_seed=random_seed))
- update_features_op = tf.scatter_update(
+ update_features_op = state_ops.scatter_update(
self.variables.candidate_split_features, update_indices,
feature_updates)
- update_thresholds_op = tf.scatter_update(
+ update_thresholds_op = state_ops.scatter_update(
self.variables.candidate_split_thresholds, update_indices,
threshold_updates)
# Calculate finished nodes.
- with tf.control_dependencies(splits_update_ops):
- children = tf.squeeze(tf.slice(self.variables.tree, [0, 0], [-1, 1]),
- squeeze_dims=[1])
- is_leaf = tf.equal(LEAF_NODE, children)
- leaves = tf.to_int32(tf.squeeze(tf.where(is_leaf), squeeze_dims=[1]))
- finished = self.training_ops.finished_nodes(
+ with ops.control_dependencies(splits_update_ops):
+ children = array_ops.squeeze(array_ops.slice(
+ self.variables.tree, [0, 0], [-1, 1]), squeeze_dims=[1])
+ is_leaf = math_ops.equal(constants.LEAF_NODE, children)
+ leaves = math_ops.to_int32(array_ops.squeeze(array_ops.where(is_leaf),
+ squeeze_dims=[1]))
+ finished, stale = self.training_ops.finished_nodes(
leaves, self.variables.node_to_accumulator_map,
+ self.variables.candidate_split_sums,
+ self.variables.candidate_split_squares,
self.variables.accumulator_sums,
- num_split_after_samples=self.params.split_after_samples)
+ self.variables.accumulator_squares,
+ self.variables.start_epoch, epoch,
+ num_split_after_samples=self.params.split_after_samples,
+ min_split_samples=self.params.min_split_samples)
# Update leaf scores.
- # TODO(gilberth): Optimize this. It currently calculates counts for
- # every non-fertile leaf.
- with tf.control_dependencies(node_update_ops):
- def dont_update_leaf_scores():
- return self.variables.non_fertile_leaf_scores
-
- def update_leaf_scores_regression():
- sums = tf.gather(self.variables.node_sums,
- self.variables.non_fertile_leaves)
- squares = tf.gather(self.variables.node_squares,
- self.variables.non_fertile_leaves)
- new_scores = self._variance(sums, squares)
- return tf.assign(self.variables.non_fertile_leaf_scores, new_scores)
-
- def update_leaf_scores_classification():
- counts = tf.gather(self.variables.node_sums,
- self.variables.non_fertile_leaves)
- new_scores = self._weighted_gini(counts)
- return tf.assign(self.variables.non_fertile_leaf_scores, new_scores)
-
- # Because we can't have tf.self.variables of size 0, we have to put in a
- # garbage value of -1 in there. Here we check for that so we don't
- # try to index into node_per_class_weights in a tf.gather with a negative
- # number.
- update_nonfertile_leaves_scores_op = tf.cond(
- tf.less(self.variables.non_fertile_leaves[0], 0),
- dont_update_leaf_scores,
- update_leaf_scores_regression if self.params.regression else
- update_leaf_scores_classification)
+ non_fertile_leaves = array_ops.boolean_mask(
+ leaves, math_ops.less(array_ops.gather(
+ self.variables.node_to_accumulator_map, leaves), 0))
+
+ # TODO(gilberth): It should be possible to limit the number of non
+ # fertile leaves we calculate scores for, especially since we can only take
+ # at most array_ops.shape(finished)[0] of them.
+ with ops.control_dependencies(node_update_ops):
+ sums = array_ops.gather(self.variables.node_sums, non_fertile_leaves)
+ if self.params.regression:
+ squares = array_ops.gather(self.variables.node_squares,
+ non_fertile_leaves)
+ non_fertile_leaf_scores = self._variance(sums, squares)
+ else:
+ non_fertile_leaf_scores = self._weighted_gini(sums)
# Calculate best splits.
- with tf.control_dependencies(splits_update_ops):
+ with ops.control_dependencies(splits_update_ops):
split_indices = self.training_ops.best_splits(
finished, self.variables.node_to_accumulator_map,
self.variables.candidate_split_sums,
@@ -605,7 +657,7 @@ class RandomTreeGraphs(object):
regression=self.params.regression)
# Grow tree.
- with tf.control_dependencies([update_features_op, update_thresholds_op]):
+ with ops.control_dependencies([update_features_op, update_thresholds_op]):
(tree_update_indices, tree_children_updates,
tree_threshold_updates, tree_depth_updates, new_eot) = (
self.training_ops.grow_tree(
@@ -613,110 +665,138 @@ class RandomTreeGraphs(object):
self.variables.node_to_accumulator_map, finished, split_indices,
self.variables.candidate_split_features,
self.variables.candidate_split_thresholds))
- tree_update_op = tf.scatter_update(
+ tree_update_op = state_ops.scatter_update(
self.variables.tree, tree_update_indices, tree_children_updates)
- threhsolds_update_op = tf.scatter_update(
+ thresholds_update_op = state_ops.scatter_update(
self.variables.tree_thresholds, tree_update_indices,
tree_threshold_updates)
- depth_update_op = tf.scatter_update(
+ depth_update_op = state_ops.scatter_update(
self.variables.tree_depths, tree_update_indices, tree_depth_updates)
+ # TODO(thomaswc): Only update the epoch on the new leaves.
+ new_epoch_updates = epoch * array_ops.ones_like(tree_depth_updates)
+ epoch_update_op = state_ops.scatter_update(
+ self.variables.start_epoch, tree_update_indices,
+ new_epoch_updates)
# Update fertile slots.
- with tf.control_dependencies([update_nonfertile_leaves_scores_op,
- depth_update_op]):
- (node_map_updates, accumulators_cleared, accumulators_allocated,
- new_nonfertile_leaves, new_nonfertile_leaves_scores) = (
- self.training_ops.update_fertile_slots(
- finished, self.variables.non_fertile_leaves,
- self.variables.non_fertile_leaf_scores,
- self.variables.end_of_tree, self.variables.tree_depths,
- self.variables.accumulator_sums,
- self.variables.node_to_accumulator_map,
- max_depth=self.params.max_depth,
- regression=self.params.regression))
+ with ops.control_dependencies([depth_update_op]):
+ (node_map_updates, accumulators_cleared, accumulators_allocated) = (
+ self.training_ops.update_fertile_slots(
+ finished, non_fertile_leaves,
+ non_fertile_leaf_scores,
+ self.variables.end_of_tree, self.variables.tree_depths,
+ self.variables.accumulator_sums,
+ self.variables.node_to_accumulator_map,
+ stale,
+ max_depth=self.params.max_depth,
+ regression=self.params.regression))
# Ensure end_of_tree doesn't get updated until UpdateFertileSlots has
# used it to calculate new leaves.
- gated_new_eot, = tf.tuple([new_eot], control_inputs=[new_nonfertile_leaves])
- eot_update_op = tf.assign(self.variables.end_of_tree, gated_new_eot)
+ gated_new_eot, = control_flow_ops.tuple([new_eot],
+ control_inputs=[node_map_updates])
+ eot_update_op = state_ops.assign(self.variables.end_of_tree, gated_new_eot)
updates = []
updates.append(eot_update_op)
updates.append(tree_update_op)
- updates.append(threhsolds_update_op)
- updates.append(tf.assign(
- self.variables.non_fertile_leaves, new_nonfertile_leaves,
- validate_shape=False))
- updates.append(tf.assign(
- self.variables.non_fertile_leaf_scores,
- new_nonfertile_leaves_scores, validate_shape=False))
-
- updates.append(tf.scatter_update(
+ updates.append(thresholds_update_op)
+ updates.append(epoch_update_op)
+
+ updates.append(state_ops.scatter_update(
self.variables.node_to_accumulator_map,
- tf.squeeze(tf.slice(node_map_updates, [0, 0], [1, -1]),
- squeeze_dims=[0]),
- tf.squeeze(tf.slice(node_map_updates, [1, 0], [1, -1]),
- squeeze_dims=[0])))
+ array_ops.squeeze(array_ops.slice(node_map_updates, [0, 0], [1, -1]),
+ squeeze_dims=[0]),
+ array_ops.squeeze(array_ops.slice(node_map_updates, [1, 0], [1, -1]),
+ squeeze_dims=[0])))
- cleared_and_allocated_accumulators = tf.concat(
+ cleared_and_allocated_accumulators = array_ops.concat(
0, [accumulators_cleared, accumulators_allocated])
# Calculate values to put into scatter update for candidate counts.
# Candidate split counts are always reset back to 0 for both cleared
# and allocated accumulators. This means some accumulators might be doubly
# reset to 0 if the were released and not allocated, then later allocated.
- split_values = tf.tile(
- tf.expand_dims(tf.expand_dims(
- tf.zeros_like(cleared_and_allocated_accumulators, dtype=tf.float32),
- 1), 2),
+ split_values = array_ops.tile(
+ array_ops.expand_dims(array_ops.expand_dims(
+ array_ops.zeros_like(cleared_and_allocated_accumulators,
+ dtype=dtypes.float32), 1), 2),
[1, self.params.num_splits_to_consider, self.params.num_output_columns])
- updates.append(tf.scatter_update(
+ updates.append(state_ops.scatter_update(
self.variables.candidate_split_sums,
cleared_and_allocated_accumulators, split_values))
if self.params.regression:
- updates.append(tf.scatter_update(
+ updates.append(state_ops.scatter_update(
self.variables.candidate_split_squares,
cleared_and_allocated_accumulators, split_values))
# Calculate values to put into scatter update for total counts.
- total_cleared = tf.tile(
- tf.expand_dims(
- tf.neg(tf.ones_like(accumulators_cleared, dtype=tf.float32)), 1),
+ total_cleared = array_ops.tile(
+ array_ops.expand_dims(
+ math_ops.neg(array_ops.ones_like(accumulators_cleared,
+ dtype=dtypes.float32)), 1),
[1, self.params.num_output_columns])
- total_reset = tf.tile(
- tf.expand_dims(
- tf.zeros_like(accumulators_allocated, dtype=tf.float32), 1),
+ total_reset = array_ops.tile(
+ array_ops.expand_dims(
+ array_ops.zeros_like(accumulators_allocated,
+ dtype=dtypes.float32), 1),
[1, self.params.num_output_columns])
- accumulator_updates = tf.concat(0, [total_cleared, total_reset])
- updates.append(tf.scatter_update(
+ accumulator_updates = array_ops.concat(0, [total_cleared, total_reset])
+ updates.append(state_ops.scatter_update(
self.variables.accumulator_sums,
cleared_and_allocated_accumulators, accumulator_updates))
if self.params.regression:
- updates.append(tf.scatter_update(
+ updates.append(state_ops.scatter_update(
self.variables.accumulator_squares,
cleared_and_allocated_accumulators, accumulator_updates))
# Calculate values to put into scatter update for candidate splits.
- split_features_updates = tf.tile(
- tf.expand_dims(
- tf.neg(tf.ones_like(cleared_and_allocated_accumulators)), 1),
+ split_features_updates = array_ops.tile(
+ array_ops.expand_dims(
+ math_ops.neg(array_ops.ones_like(
+ cleared_and_allocated_accumulators)), 1),
[1, self.params.num_splits_to_consider])
- updates.append(tf.scatter_update(
+ updates.append(state_ops.scatter_update(
self.variables.candidate_split_features,
cleared_and_allocated_accumulators, split_features_updates))
- return tf.group(*updates)
+ updates += self.finish_iteration()
+
+ return control_flow_ops.group(*updates)
+
+ def finish_iteration(self):
+ """Perform any operations that should be done at the end of an iteration.
+
+ This is mostly useful for subclasses that need to reset variables after
+ an iteration, such as ones that are used to finish nodes.
+
+ Returns:
+ A list of operations.
+ """
+ return []
- def inference_graph(self, input_data):
+ def inference_graph(self, input_data, data_spec):
"""Constructs a TF graph for evaluating a random tree.
Args:
- input_data: A tensor or placeholder for input data.
+ input_data: A tensor or SparseTensor or placeholder for input data.
+ data_spec: A list of tf.dtype values specifying the original types of
+ each column.
Returns:
The last op in the random tree inference graph.
"""
+ sparse_indices = []
+ sparse_values = []
+ sparse_shape = []
+ if isinstance(input_data, ops.SparseTensor):
+ sparse_indices = input_data.indices
+ sparse_values = input_data.values
+ sparse_shape = input_data.shape
+ input_data = []
return self.inference_ops.tree_predictions(
- input_data, self.variables.tree, self.variables.tree_thresholds,
+ input_data, sparse_indices, sparse_values, sparse_shape, data_spec,
+ self.variables.tree,
+ self.variables.tree_thresholds,
self.variables.node_sums,
valid_leaf_threshold=self.params.valid_leaf_threshold)
@@ -729,13 +809,22 @@ class RandomTreeGraphs(object):
Returns:
The last op in the graph.
"""
- children = tf.squeeze(tf.slice(self.variables.tree, [0, 0], [-1, 1]),
- squeeze_dims=[1])
- is_leaf = tf.equal(LEAF_NODE, children)
- leaves = tf.to_int32(tf.squeeze(tf.where(is_leaf), squeeze_dims=[1]))
- counts = tf.gather(self.variables.node_sums, leaves)
- impurity = self._weighted_gini(counts)
- return tf.reduce_sum(impurity) / tf.reduce_sum(counts + 1.0)
+ children = array_ops.squeeze(array_ops.slice(
+ self.variables.tree, [0, 0], [-1, 1]), squeeze_dims=[1])
+ is_leaf = math_ops.equal(constants.LEAF_NODE, children)
+ leaves = math_ops.to_int32(array_ops.squeeze(array_ops.where(is_leaf),
+ squeeze_dims=[1]))
+ counts = array_ops.gather(self.variables.node_sums, leaves)
+ gini = self._weighted_gini(counts)
+ # Guard against step 1, when there often are no leaves yet.
+ def impurity():
+ return gini
+ # Since average impurity can be used for loss, when there's no data just
+ # return a big number so that loss always decreases.
+ def big():
+ return array_ops.ones_like(gini, dtype=dtypes.float32) * 10000000.
+ return control_flow_ops.cond(math_ops.greater(
+ array_ops.shape(leaves)[0], 0), impurity, big)
def size(self):
"""Constructs a TF graph for evaluating the current number of nodes.
@@ -747,7 +836,8 @@ class RandomTreeGraphs(object):
def get_stats(self, session):
num_nodes = self.variables.end_of_tree.eval(session=session) - 1
- num_leaves = tf.where(
- tf.equal(tf.squeeze(tf.slice(self.variables.tree, [0, 0], [-1, 1])),
- LEAF_NODE)).eval(session=session).shape[0]
+ num_leaves = array_ops.where(
+ math_ops.equal(array_ops.squeeze(array_ops.slice(
+ self.variables.tree, [0, 0], [-1, 1])), constants.LEAF_NODE)
+ ).eval(session=session).shape[0]
return TreeStats(num_nodes, num_leaves)
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
index c3e1c8520d..4e4cfcd1e8 100644
--- a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
+++ b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
@@ -105,6 +105,47 @@ class TensorForestTest(test_util.TensorFlowTestCase):
graph = graph_builder.average_impurity()
self.assertTrue(isinstance(graph, tf.Tensor))
+ def testTrainingConstructionClassificationSparse(self):
+ input_data = tf.SparseTensor(
+ indices=[[0, 0], [0, 3],
+ [1, 0], [1, 7],
+ [2, 1],
+ [3, 9]],
+ values=[-1.0, 0.0,
+ -1., 2.,
+ 1.,
+ -2.0],
+ shape=[4, 10])
+ input_labels = [0, 1, 2, 3]
+
+ params = tensor_forest.ForestHParams(
+ num_classes=4, num_features=10, num_trees=10, max_nodes=1000,
+ split_after_samples=25).fill()
+
+ graph_builder = tensor_forest.RandomForestGraphs(params)
+ graph = graph_builder.training_graph(input_data, input_labels)
+ self.assertTrue(isinstance(graph, tf.Operation))
+
+ def testInferenceConstructionSparse(self):
+ input_data = tf.SparseTensor(
+ indices=[[0, 0], [0, 3],
+ [1, 0], [1, 7],
+ [2, 1],
+ [3, 9]],
+ values=[-1.0, 0.0,
+ -1., 2.,
+ 1.,
+ -2.0],
+ shape=[4, 10])
+
+ params = tensor_forest.ForestHParams(
+ num_classes=4, num_features=10, num_trees=10, max_nodes=1000,
+ split_after_samples=25).fill()
+
+ graph_builder = tensor_forest.RandomForestGraphs(params)
+ graph = graph_builder.inference_graph(input_data)
+ self.assertTrue(isinstance(graph, tf.Tensor))
+
if __name__ == '__main__':
googletest.main()