aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensor_forest/ops/stats_ops.cc
blob: 5be581aaec4cab342ef3fd49fa0294e5e702ba1c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/shape_inference.h"

namespace tensorflow {
using shape_inference::InferenceContext;

namespace tensorforest {

REGISTER_RESOURCE_HANDLE_OP(FertileStatsResource);

REGISTER_OP("FertileStatsIsInitializedOp")
    .Input("stats_handle: resource")
    .Output("is_initialized: bool")
    .SetShapeFn(tensorflow::shape_inference::ScalarShape)
    .Doc(R"doc(
Checks whether a stats has been initialized.
)doc");

REGISTER_OP("CreateFertileStatsVariable")
    .Attr("params: string")
    .Input("stats_handle: resource")
    .Input("stats_config: string")
    .SetShapeFn(tensorflow::shape_inference::NoOutputs)
    .Doc(R"doc(
Creates a stats model and returns a handle to it.

params: A serialized TensorForestParams proto.
stats_handle: handle to the stats resource to be created.
stats_config: Serialized proto of the stats.
)doc");

REGISTER_OP("FertileStatsSerialize")
    .Attr("params: string")
    .Input("stats_handle: resource")
    .Output("stats_config: string")
    .SetShapeFn(tensorflow::shape_inference::ScalarShape)
    .Doc(R"doc(
Serializes the stats to a proto.

params: A serialized TensorForestParams proto.
stats_handle: The handle to the stats.
stats_config: Serialized proto of the stats.
)doc");

REGISTER_OP("FertileStatsDeserialize")
    .Attr("params: string")
    .Input("stats_handle: resource")
    .Input("stats_config: string")
    .SetShapeFn(tensorflow::shape_inference::NoOutputs)
    .Doc(R"doc(
Deserializes a serialized stats config and replaces current stats.

params: A serialized TensorForestParams proto.
stats_handle: The handle to the stats.
stats_config: Serialized proto of the stats.
)doc");

REGISTER_OP("GrowTreeV4")
    .Attr("params: string")
    .Input("tree_handle: resource")
    .Input("stats_handle: resource")
    .Input("finished_nodes: int32")
    .SetShapeFn(tensorflow::shape_inference::NoOutputs)
    .Doc(R"doc(
Grows the tree for finished nodes and allocates waiting nodes.

params: A serialized TensorForestParams proto.
tree_handle: The handle to the tree.
stats_handle: The handle to the stats.
finished_nodes: A 1-d Tensor of finished node ids from ProcessInput.
)doc");

REGISTER_OP("ProcessInputV4")
    .Attr("random_seed: int")
    .Attr("input_spec: string")
    .Attr("params: string")
    .Input("tree_handle: resource")
    .Input("stats_handle: resource")
    .Input("input_data: float")
    .Input("sparse_input_indices: int64")
    .Input("sparse_input_values: float")
    .Input("sparse_input_shape: int64")
    .Input("input_labels: float")
    .Input("input_weights: float")
    .Input("leaf_ids: int32")
    .Output("finished_nodes: int32")
    .SetShapeFn([](InferenceContext* c) {
      c->set_output(0, c->Vector(c->UnknownDim()));
      return Status::OK();
    })
    .Doc(R"doc(
Add labels to stats after traversing the tree for each example.

Outputs node ids that are finished.

params: A serialized TensorForestParams proto.
tree_handle: The handle to the tree.
stats_handle: The handle to the stats.
input_data: The training batch's features as a 2-d tensor; `input_data[i][j]`
   gives the j-th feature of the i-th input.
sparse_input_indices: The indices tensor from the SparseTensor input.
sparse_input_values: The values tensor from the SparseTensor input.
sparse_input_shape: The shape tensor from the SparseTensor input.
input_labels: The training batch's labels as a 1 or 2-d tensor.
  'input_labels[i][j]' gives the j-th label/target for the i-th input.
input_weights: The training batch's weights as a 1-d tensor.
  'input_weights[i]' gives the weight for the i-th input.
finished_nodes: A 1-d tensor of node ids that have finished and are ready to
  grow.
leaf_ids: `leaf_ids[i]` is the leaf id for input i.
)doc");

REGISTER_OP("FinalizeTree")
    .Attr("params: string")
    .Input("tree_handle: resource")
    .Input("stats_handle: resource")
    .SetShapeFn([](InferenceContext* c) { return Status::OK(); })
    .Doc(R"doc(
Puts the Leaf models inside the tree into their final form.

If drop_final_class is true, the per-class probability prediction of the
last class is not stored in the leaf models.

params: A serialized TensorForestParams proto.
tree_handle: The handle to the tree.
stats_handle: The handle to the stats.
)doc");
}  // namespace tensorforest
}  // namespace tensorflow