aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensor_forest/python/ops
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-06-30 14:40:02 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-06-30 15:48:03 -0700
commitb70103502b41df370906e8988b6593e55caf69cf (patch)
tree3455ed439430bb6c0e739bb974a52a99a7bc6626 /tensorflow/contrib/tensor_forest/python/ops
parentd3067c338425bdf97fa782d834399b89bce18309 (diff)
Improvements to tensor_forest, including support for sparse and categorical inputs.
Add tf.learn.Estimator for random forests. Change: 126352221
Diffstat (limited to 'tensorflow/contrib/tensor_forest/python/ops')
-rw-r--r--tensorflow/contrib/tensor_forest/python/ops/inference_ops.py24
-rw-r--r--tensorflow/contrib/tensor_forest/python/ops/training_ops.py28
2 files changed, 29 insertions, 23 deletions
diff --git a/tensorflow/contrib/tensor_forest/python/ops/inference_ops.py b/tensorflow/contrib/tensor_forest/python/ops/inference_ops.py
index 6f4e6fff40..88f8112ed4 100644
--- a/tensorflow/contrib/tensor_forest/python/ops/inference_ops.py
+++ b/tensorflow/contrib/tensor_forest/python/ops/inference_ops.py
@@ -1,3 +1,4 @@
+# pylint: disable=g-bad-file-header
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -17,13 +18,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import os
import threading
-import tensorflow as tf
-
+from tensorflow.python.framework import load_library
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.platform import resource_loader
+from tensorflow.python.platform import tf_logging as logging
+
INFERENCE_OPS_FILE = '_inference_ops.so'
@@ -38,7 +40,11 @@ ops.NoGradient('TreePredictions')
def TreePredictions(op):
"""Shape function for TreePredictions Op."""
num_points = op.inputs[0].get_shape()[0].value
- num_classes = op.inputs[3].get_shape()[1].value
+ sparse_shape = op.inputs[3].get_shape()
+ if sparse_shape.ndims == 2:
+ num_points = sparse_shape[0].value
+ num_classes = op.inputs[7].get_shape()[1].value
+
# The output of TreePredictions is
# [node_pcw(evaluate_tree(x), c) for c in classes for x in input_data].
return [tensor_shape.TensorShape([num_points, num_classes - 1])]
@@ -49,16 +55,14 @@ def TreePredictions(op):
# there's not yet any guarantee that the shared object exists.
# In which case, "import tensorflow" will always crash, even for users that
# never use contrib.
-def Load(library_base_dir=''):
+def Load():
"""Load the inference ops library and return the loaded module."""
with _ops_lock:
global _inference_ops
if not _inference_ops:
- data_files_path = os.path.join(library_base_dir,
- tf.resource_loader.get_data_files_path())
- tf.logging.info('data path: %s', data_files_path)
- _inference_ops = tf.load_op_library(os.path.join(
- data_files_path, INFERENCE_OPS_FILE))
+ ops_path = resource_loader.get_path_to_datafile(INFERENCE_OPS_FILE)
+ logging.info('data path: %s', ops_path)
+ _inference_ops = load_library.load_op_library(ops_path)
assert _inference_ops, 'Could not load inference_ops.so'
return _inference_ops
diff --git a/tensorflow/contrib/tensor_forest/python/ops/training_ops.py b/tensorflow/contrib/tensor_forest/python/ops/training_ops.py
index 7a108baf42..d25d5ce50b 100644
--- a/tensorflow/contrib/tensor_forest/python/ops/training_ops.py
+++ b/tensorflow/contrib/tensor_forest/python/ops/training_ops.py
@@ -1,3 +1,4 @@
+# pylint: disable=g-bad-file-header
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -17,13 +18,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import os
import threading
-import tensorflow as tf
-
+from tensorflow.python.framework import load_library
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.platform import resource_loader
+from tensorflow.python.platform import tf_logging as logging
TRAINING_OPS_FILE = '_training_ops.so'
@@ -45,7 +46,10 @@ def _CountExtremelyRandomStatsShape(op):
"""Shape function for CountExtremelyRandomStats Op."""
regression = op.get_attr('regression')
num_points = op.inputs[0].get_shape()[0].value
- num_nodes = op.inputs[2].get_shape()[0].value
+ sparse_shape = op.inputs[3].get_shape()
+ if sparse_shape.ndims == 2:
+ num_points = sparse_shape[0].value
+ num_nodes = op.inputs[6].get_shape()[0].value
num_classes = op.get_attr('num_classes')
# The output of TraverseTree is [leaf_node_index(x) for x in input_data].
return [tensor_shape.TensorShape([num_nodes, num_classes]), # node sums
@@ -66,7 +70,7 @@ def _CountExtremelyRandomStatsShape(op):
@ops.RegisterShape('SampleInputs')
def _SampleInputsShape(op):
"""Shape function for SampleInputs Op."""
- num_splits = op.inputs[3].get_shape()[1].value
+ num_splits = op.inputs[6].get_shape()[1].value
return [[None], [None, num_splits], [None, num_splits]]
@@ -85,7 +89,7 @@ def _GrowTreeShape(unused_op):
@ops.RegisterShape('FinishedNodes')
def _FinishedNodesShape(unused_op):
"""Shape function for FinishedNodes Op."""
- return [[None]]
+ return [[None], [None]]
@ops.RegisterShape('ScatterAddNdim')
@@ -97,7 +101,7 @@ def _ScatterAddNdimShape(unused_op):
@ops.RegisterShape('UpdateFertileSlots')
def _UpdateFertileSlotsShape(unused_op):
"""Shape function for UpdateFertileSlots Op."""
- return [[None, 2], [None], [None], [None], [None]]
+ return [[None, 2], [None], [None]]
# Workaround for the fact that importing tensorflow imports contrib
@@ -105,16 +109,14 @@ def _UpdateFertileSlotsShape(unused_op):
# there's not yet any guarantee that the shared object exists.
# In which case, "import tensorflow" will always crash, even for users that
# never use contrib.
-def Load(library_base_dir=''):
+def Load():
"""Load training ops library and return the loaded module."""
with _ops_lock:
global _training_ops
if not _training_ops:
- data_files_path = os.path.join(library_base_dir,
- tf.resource_loader.get_data_files_path())
- tf.logging.info('data path: %s', data_files_path)
- _training_ops = tf.load_op_library(os.path.join(
- data_files_path, TRAINING_OPS_FILE))
+ ops_path = resource_loader.get_path_to_datafile(TRAINING_OPS_FILE)
+ logging.info('data path: %s', ops_path)
+ _training_ops = load_library.load_op_library(ops_path)
assert _training_ops, 'Could not load _training_ops.so'
return _training_ops