diff options
author | 2016-07-13 18:12:25 -0800 | |
---|---|---|
committer | 2016-07-13 19:18:10 -0700 | |
commit | 4ddfb7812e7a1b22b16baa1ab6f1c319e1289bc5 (patch) | |
tree | d0043d5b601c7c95f3dd623e55b043c3f6f0ad35 /tensorflow/contrib/tensor_forest/python/tensor_forest.py | |
parent | 6c7681fbbcc3c244f3e406abc4ea1287fd717752 (diff) |
Give names to training and inference ops in tensor_forest, which helps with integration with frameworks that identify them by name.
Change: 127387066
Diffstat (limited to 'tensorflow/contrib/tensor_forest/python/tensor_forest.py')
-rw-r--r-- | tensorflow/contrib/tensor_forest/python/tensor_forest.py | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest.py b/tensorflow/contrib/tensor_forest/python/tensor_forest.py index e3a4be1e9a..03ad655a05 100644 --- a/tensorflow/contrib/tensor_forest/python/tensor_forest.py +++ b/tensorflow/contrib/tensor_forest/python/tensor_forest.py @@ -383,7 +383,7 @@ class RandomForestGraphs(object): epoch=([0] if epoch is None else epoch), **tree_kwargs)) - return control_flow_ops.group(*tree_graphs) + return control_flow_ops.group(*tree_graphs, name='train') def inference_graph(self, input_data, data_spec=None): """Constructs a TF graph for evaluating a random forest. @@ -408,7 +408,9 @@ class RandomForestGraphs(object): data_spec)) with ops.device(self.device_assigner.get_device(0)): all_predict = array_ops.pack(probabilities) - return math_ops.reduce_sum(all_predict, 0) / self.params.num_trees + return math_ops.div( + math_ops.reduce_sum(all_predict, 0), self.params.num_trees, + name='probabilities') def average_size(self): """Constructs a TF graph for evaluating the average size of a forest. |