aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensor_forest/python/tensor_forest.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-07-13 18:12:25 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-13 19:18:10 -0700
commit4ddfb7812e7a1b22b16baa1ab6f1c319e1289bc5 (patch)
treed0043d5b601c7c95f3dd623e55b043c3f6f0ad35 /tensorflow/contrib/tensor_forest/python/tensor_forest.py
parent6c7681fbbcc3c244f3e406abc4ea1287fd717752 (diff)
Give names to training and inference ops in tensor_forest, which helps with integration with frameworks that identify them by name.
Change: 127387066
Diffstat (limited to 'tensorflow/contrib/tensor_forest/python/tensor_forest.py')
-rw-r--r--tensorflow/contrib/tensor_forest/python/tensor_forest.py6
1 files changed, 4 insertions, 2 deletions
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest.py b/tensorflow/contrib/tensor_forest/python/tensor_forest.py
index e3a4be1e9a..03ad655a05 100644
--- a/tensorflow/contrib/tensor_forest/python/tensor_forest.py
+++ b/tensorflow/contrib/tensor_forest/python/tensor_forest.py
@@ -383,7 +383,7 @@ class RandomForestGraphs(object):
epoch=([0] if epoch is None else epoch),
**tree_kwargs))
- return control_flow_ops.group(*tree_graphs)
+ return control_flow_ops.group(*tree_graphs, name='train')
def inference_graph(self, input_data, data_spec=None):
"""Constructs a TF graph for evaluating a random forest.
@@ -408,7 +408,9 @@ class RandomForestGraphs(object):
data_spec))
with ops.device(self.device_assigner.get_device(0)):
all_predict = array_ops.pack(probabilities)
- return math_ops.reduce_sum(all_predict, 0) / self.params.num_trees
+ return math_ops.div(
+ math_ops.reduce_sum(all_predict, 0), self.params.num_trees,
+ name='probabilities')
def average_size(self):
"""Constructs a TF graph for evaluating the average size of a forest.