aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensor_forest
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-30 10:43:03 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-30 12:33:54 -0800
commit4463d105a8a4a83642b9709ba79310e8f4ddf577 (patch)
tree240e9a0a9a6b9ad956c704776a33126ba00cbfe8 /tensorflow/contrib/tensor_forest
parent8f0e7207774279f4fe50f4d6c4fbd576e2941463 (diff)
Cleanup: Ran clang-format on all *.{cc,h} files in tensorflow/contrib/.../*.{hh,c}.
PiperOrigin-RevId: 183855242
Diffstat (limited to 'tensorflow/contrib/tensor_forest')
-rw-r--r--tensorflow/contrib/tensor_forest/hybrid/core/ops/hard_routing_function_op.cc27
-rw-r--r--tensorflow/contrib/tensor_forest/hybrid/core/ops/k_feature_gradient_op.cc56
-rw-r--r--tensorflow/contrib/tensor_forest/hybrid/core/ops/k_feature_routing_function_op.cc49
-rw-r--r--tensorflow/contrib/tensor_forest/hybrid/core/ops/routing_function_op.cc29
-rw-r--r--tensorflow/contrib/tensor_forest/hybrid/core/ops/stochastic_hard_routing_function_op.cc38
-rw-r--r--tensorflow/contrib/tensor_forest/hybrid/core/ops/stochastic_hard_routing_gradient_op.cc18
-rw-r--r--tensorflow/contrib/tensor_forest/hybrid/core/ops/unpack_path_op.cc10
-rw-r--r--tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.cc11
-rw-r--r--tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.h13
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/reinterpret_string_to_float_op.cc15
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/scatter_add_ndim_op.cc19
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/tree_utils.cc78
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/tree_utils.h25
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/tree_utils_test.cc128
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.cc13
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h13
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h1
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator_test.cc5
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h9
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.cc36
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h25
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/grow_stats_test.cc47
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc8
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/input_data.h4
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/input_target.h4
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.cc1
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators_test.cc24
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/params.h1
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/params_test.cc2
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc7
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h4
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.cc22
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/test_utils.h5
33 files changed, 310 insertions, 437 deletions
diff --git a/tensorflow/contrib/tensor_forest/hybrid/core/ops/hard_routing_function_op.cc b/tensorflow/contrib/tensor_forest/hybrid/core/ops/hard_routing_function_op.cc
index 76cfb4c9ca..cf0db788a4 100644
--- a/tensorflow/contrib/tensor_forest/hybrid/core/ops/hard_routing_function_op.cc
+++ b/tensorflow/contrib/tensor_forest/hybrid/core/ops/hard_routing_function_op.cc
@@ -99,18 +99,17 @@ class HardRoutingFunction : public OpKernel {
const Tensor& tree_biases_tensor = context->input(2);
if (input_data.shape().dim_size(0) > 0) {
- OP_REQUIRES(context, input_data.shape().dims() == 2,
- errors::InvalidArgument(
- "input_data should be two-dimensional"));
+ OP_REQUIRES(
+ context, input_data.shape().dims() == 2,
+ errors::InvalidArgument("input_data should be two-dimensional"));
}
// Check tensor bounds.
if (!CheckTensorBounds(context, input_data)) return;
- const int32 num_data = static_cast<int32>(
- input_data.shape().dim_size(0));
- const int32 num_features = static_cast<int32>(
- input_data.shape().dim_size(1));
+ const int32 num_data = static_cast<int32>(input_data.shape().dim_size(0));
+ const int32 num_features =
+ static_cast<int32>(input_data.shape().dim_size(1));
Tensor* output_probability = nullptr;
TensorShape output_probability_shape;
@@ -125,9 +124,8 @@ class HardRoutingFunction : public OpKernel {
OP_REQUIRES_OK(context,
context->allocate_output(0, output_probability_shape,
&output_probability));
- OP_REQUIRES_OK(context,
- context->allocate_output(1, output_path_shape,
- &output_path));
+ OP_REQUIRES_OK(
+ context, context->allocate_output(1, output_path_shape, &output_path));
auto out_probability = output_probability->tensor<float, 2>();
auto out_path = output_path->tensor<int32, 2>();
@@ -144,12 +142,11 @@ class HardRoutingFunction : public OpKernel {
out_probability(i, 0) = 1.0;
out_path(i, 0) = 0;
for (int j = 0; j < tree_depth_ - 1; j++) {
- float left_prob = LeftProbability(point,
- tree_parameters_tensor.Slice(j, j+1),
- tree_biases(j),
- num_features);
+ float left_prob =
+ LeftProbability(point, tree_parameters_tensor.Slice(j, j + 1),
+ tree_biases(j), num_features);
- int32 left_child = 2*node + 1;
+ int32 left_child = 2 * node + 1;
int32 right_child = left_child + 1;
float dot_product = 0.0;
diff --git a/tensorflow/contrib/tensor_forest/hybrid/core/ops/k_feature_gradient_op.cc b/tensorflow/contrib/tensor_forest/hybrid/core/ops/k_feature_gradient_op.cc
index 28f50f1a32..f64155fa55 100644
--- a/tensorflow/contrib/tensor_forest/hybrid/core/ops/k_feature_gradient_op.cc
+++ b/tensorflow/contrib/tensor_forest/hybrid/core/ops/k_feature_gradient_op.cc
@@ -85,12 +85,9 @@ REGISTER_OP("KFeatureGradient")
class KFeatureGradient : public OpKernel {
public:
- explicit KFeatureGradient(OpKernelConstruction* context)
- : OpKernel(context) {
- OP_REQUIRES_OK(context, context->GetAttr("layer_num",
- &layer_num_));
- OP_REQUIRES_OK(context, context->GetAttr("random_seed",
- &random_seed_));
+ explicit KFeatureGradient(OpKernelConstruction* context) : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("layer_num", &layer_num_));
+ OP_REQUIRES_OK(context, context->GetAttr("random_seed", &random_seed_));
}
void Compute(OpKernelContext* context) override {
@@ -101,14 +98,14 @@ class KFeatureGradient : public OpKernel {
const Tensor& routing_tensor = context->input(3);
// Extract dimensions from input tensors.
- const int32 num_data = static_cast<int32>(
- input_data_tensor.shape().dim_size(0));
- const int32 num_features = static_cast<int32>(
- input_data_tensor.shape().dim_size(1));
- const int32 num_nodes = static_cast<int32>(
- tree_parameters_tensor.shape().dim_size(0));
- const int32 num_features_per_node = static_cast<int32>(
- tree_parameters_tensor.shape().dim_size(1));
+ const int32 num_data =
+ static_cast<int32>(input_data_tensor.shape().dim_size(0));
+ const int32 num_features =
+ static_cast<int32>(input_data_tensor.shape().dim_size(1));
+ const int32 num_nodes =
+ static_cast<int32>(tree_parameters_tensor.shape().dim_size(0));
+ const int32 num_features_per_node =
+ static_cast<int32>(tree_parameters_tensor.shape().dim_size(1));
// Construct output tensors.
Tensor* out_routes = nullptr;
@@ -127,12 +124,12 @@ class KFeatureGradient : public OpKernel {
out_weights_shape.AddDim(num_nodes);
out_weights_shape.AddDim(num_features_per_node);
- OP_REQUIRES_OK(context, context->allocate_output(
- 0, out_routes_shape, &out_routes));
- OP_REQUIRES_OK(context, context->allocate_output(
- 1, out_data_shape, &out_data));
- OP_REQUIRES_OK(context, context->allocate_output(
- 2, out_weights_shape, &out_weights));
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, out_routes_shape, &out_routes));
+ OP_REQUIRES_OK(context,
+ context->allocate_output(1, out_data_shape, &out_data));
+ OP_REQUIRES_OK(
+ context, context->allocate_output(2, out_weights_shape, &out_weights));
tensorforest::Initialize(*out_data, 0.0f);
@@ -148,18 +145,13 @@ class KFeatureGradient : public OpKernel {
std::vector<int32> feature_set;
for (int i = 0; i < num_data; i++) {
- const Tensor point = input_data_tensor.Slice(i, i+1);
+ const Tensor point = input_data_tensor.Slice(i, i + 1);
feature_set.clear();
// Traverse the tree from the bottom up.
for (int j = num_nodes - 1; j >= 0; j--) {
- tensorforest::GetFeatureSet(
- layer_num_,
- j,
- random_seed_,
- num_features,
- num_features_per_node,
- &feature_set);
+ tensorforest::GetFeatureSet(layer_num_, j, random_seed_, num_features,
+ num_features_per_node, &feature_set);
// Compute routing gradient.
// j is a leaf node.
@@ -170,12 +162,8 @@ class KFeatureGradient : public OpKernel {
int32 right_child = left_child + 1;
float left_prob = LeftProbabilityK(
- point,
- feature_set,
- tree_parameters_tensor.Slice(j, j+1),
- tree_biases(j),
- num_features,
- num_features_per_node);
+ point, feature_set, tree_parameters_tensor.Slice(j, j + 1),
+ tree_biases(j), num_features, num_features_per_node);
float right_prob = 1.0f - left_prob;
diff --git a/tensorflow/contrib/tensor_forest/hybrid/core/ops/k_feature_routing_function_op.cc b/tensorflow/contrib/tensor_forest/hybrid/core/ops/k_feature_routing_function_op.cc
index 9bc42eb61f..e7cafb144d 100644
--- a/tensorflow/contrib/tensor_forest/hybrid/core/ops/k_feature_routing_function_op.cc
+++ b/tensorflow/contrib/tensor_forest/hybrid/core/ops/k_feature_routing_function_op.cc
@@ -43,7 +43,6 @@ using shape_inference::ShapeHandle;
using tensorforest::CheckTensorBounds;
using tensorforest::LeftProbabilityK;
-
// The term 'routing function' is synonymous with 'the probability
// that an instance is routed to each leaf node.' It is defined in
// 'Deep Neural Decision Forests' by Kontschieder et al.
@@ -96,10 +95,8 @@ class KFeatureRoutingFunction : public OpKernel {
OP_REQUIRES_OK(context, context->GetAttr("max_nodes", &max_nodes_));
OP_REQUIRES_OK(context, context->GetAttr("num_features_per_node",
&num_features_per_node_));
- OP_REQUIRES_OK(context, context->GetAttr("layer_num",
- &layer_num_));
- OP_REQUIRES_OK(context, context->GetAttr("random_seed",
- &random_seed_));
+ OP_REQUIRES_OK(context, context->GetAttr("layer_num", &layer_num_));
+ OP_REQUIRES_OK(context, context->GetAttr("random_seed", &random_seed_));
}
void Compute(OpKernelContext* context) override {
@@ -108,27 +105,25 @@ class KFeatureRoutingFunction : public OpKernel {
const Tensor& tree_biases_tensor = context->input(2);
if (input_data.shape().dim_size(0) > 0) {
- OP_REQUIRES(context, input_data.shape().dims() == 2,
- errors::InvalidArgument(
- "input_data should be two-dimensional"));
+ OP_REQUIRES(
+ context, input_data.shape().dims() == 2,
+ errors::InvalidArgument("input_data should be two-dimensional"));
}
// Check tensor bounds.
if (!CheckTensorBounds(context, input_data)) return;
- const int32 num_data = static_cast<int32>(
- input_data.shape().dim_size(0));
- const int32 num_features = static_cast<int32>(
- input_data.shape().dim_size(1));
+ const int32 num_data = static_cast<int32>(input_data.shape().dim_size(0));
+ const int32 num_features =
+ static_cast<int32>(input_data.shape().dim_size(1));
Tensor* output_probabilities = nullptr;
TensorShape output_shape;
output_shape.AddDim(num_data);
output_shape.AddDim(max_nodes_);
- OP_REQUIRES_OK(context,
- context->allocate_output(0, output_shape,
- &output_probabilities));
+ OP_REQUIRES_OK(context, context->allocate_output(0, output_shape,
+ &output_probabilities));
auto out_probs = output_probabilities->tensor<float, 2>();
const auto tree_biases = tree_biases_tensor.tensor<float, 1>();
@@ -136,30 +131,22 @@ class KFeatureRoutingFunction : public OpKernel {
// Iteratively compute the probability of reaching each leaf.
std::vector<int32> feature_set;
for (int i = 0; i < num_data; i++) {
- const Tensor point = input_data.Slice(i, i+1);
+ const Tensor point = input_data.Slice(i, i + 1);
out_probs(i, 0) = 1.0f;
for (int j = 0; j < max_nodes_ / 2; j++) {
feature_set.clear();
- tensorforest::GetFeatureSet(
- layer_num_,
- i,
- random_seed_,
- num_features,
- num_features_per_node_,
- &feature_set);
-
- int32 left_child = 2*j + 1;
+ tensorforest::GetFeatureSet(layer_num_, i, random_seed_, num_features,
+ num_features_per_node_, &feature_set);
+
+ int32 left_child = 2 * j + 1;
int32 right_child = left_child + 1;
float prob = out_probs(i, j);
- float left_prob = LeftProbabilityK(point,
- feature_set,
- tree_parameters_tensor.Slice(j, j+1),
- tree_biases(j),
- num_features,
- num_features_per_node_);
+ float left_prob = LeftProbabilityK(
+ point, feature_set, tree_parameters_tensor.Slice(j, j + 1),
+ tree_biases(j), num_features, num_features_per_node_);
out_probs(i, left_child) = prob * left_prob;
out_probs(i, right_child) = prob * (1.0f - left_prob);
diff --git a/tensorflow/contrib/tensor_forest/hybrid/core/ops/routing_function_op.cc b/tensorflow/contrib/tensor_forest/hybrid/core/ops/routing_function_op.cc
index 4027e732b3..0c2eaabe8f 100644
--- a/tensorflow/contrib/tensor_forest/hybrid/core/ops/routing_function_op.cc
+++ b/tensorflow/contrib/tensor_forest/hybrid/core/ops/routing_function_op.cc
@@ -90,46 +90,43 @@ class RoutingFunction : public OpKernel {
const Tensor& tree_biases_tensor = context->input(2);
if (input_data.shape().dim_size(0) > 0) {
- OP_REQUIRES(context, input_data.shape().dims() == 2,
- errors::InvalidArgument(
- "input_data should be two-dimensional"));
+ OP_REQUIRES(
+ context, input_data.shape().dims() == 2,
+ errors::InvalidArgument("input_data should be two-dimensional"));
}
// Check tensor bounds.
if (!CheckTensorBounds(context, input_data)) return;
- const int32 num_data = static_cast<int32>(
- input_data.shape().dim_size(0));
- const int32 num_features = static_cast<int32>(
- input_data.shape().dim_size(1));
+ const int32 num_data = static_cast<int32>(input_data.shape().dim_size(0));
+ const int32 num_features =
+ static_cast<int32>(input_data.shape().dim_size(1));
Tensor* output_probabilities = nullptr;
TensorShape output_shape;
output_shape.AddDim(num_data);
output_shape.AddDim(max_nodes_);
- OP_REQUIRES_OK(context,
- context->allocate_output(0, output_shape,
- &output_probabilities));
+ OP_REQUIRES_OK(context, context->allocate_output(0, output_shape,
+ &output_probabilities));
auto out_probs = output_probabilities->tensor<float, 2>();
const auto tree_biases = tree_biases_tensor.tensor<float, 1>();
// Iteratively compute the probability of reaching each leaf.
for (int i = 0; i < num_data; i++) {
- const Tensor point = input_data.Slice(i, i+1);
+ const Tensor point = input_data.Slice(i, i + 1);
out_probs(i, 0) = 1.0;
for (int j = 0; j < max_nodes_ / 2; j++) {
- int32 left_child = 2*j + 1;
+ int32 left_child = 2 * j + 1;
int32 right_child = left_child + 1;
float prob = out_probs(i, j);
- float left_prob = LeftProbability(point,
- tree_parameters_tensor.Slice(j, j+1),
- tree_biases(j),
- num_features);
+ float left_prob =
+ LeftProbability(point, tree_parameters_tensor.Slice(j, j + 1),
+ tree_biases(j), num_features);
out_probs(i, left_child) = prob * left_prob;
out_probs(i, right_child) = prob * (1.0 - left_prob);
diff --git a/tensorflow/contrib/tensor_forest/hybrid/core/ops/stochastic_hard_routing_function_op.cc b/tensorflow/contrib/tensor_forest/hybrid/core/ops/stochastic_hard_routing_function_op.cc
index 66aa293dc1..c9df09bfda 100644
--- a/tensorflow/contrib/tensor_forest/hybrid/core/ops/stochastic_hard_routing_function_op.cc
+++ b/tensorflow/contrib/tensor_forest/hybrid/core/ops/stochastic_hard_routing_function_op.cc
@@ -96,10 +96,9 @@ class StochasticHardRoutingFunction : public OpKernel {
explicit StochasticHardRoutingFunction(OpKernelConstruction* context)
: OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("tree_depth", &tree_depth_));
- OP_REQUIRES_OK(context, context->GetAttr("random_seed",
- &random_seed_));
+ OP_REQUIRES_OK(context, context->GetAttr("random_seed", &random_seed_));
single_rand_ = std::unique_ptr<random::PhiloxRandom>(
- new random::PhiloxRandom(random_seed_));
+ new random::PhiloxRandom(random_seed_));
rng_ = std::unique_ptr<random::SimplePhilox>(
new random::SimplePhilox(single_rand_.get()));
}
@@ -111,20 +110,19 @@ class StochasticHardRoutingFunction : public OpKernel {
const Tensor& tree_biases_tensor = context->input(2);
if (input_data.shape().dim_size(0) > 0) {
- OP_REQUIRES(context, input_data.shape().dims() == 2,
- errors::InvalidArgument(
- "input_data should be two-dimensional"));
+ OP_REQUIRES(
+ context, input_data.shape().dims() == 2,
+ errors::InvalidArgument("input_data should be two-dimensional"));
}
// Check tensor bounds.
if (!CheckTensorBounds(context, input_data)) return;
- const int32 num_data = static_cast<int32>(
- input_data.shape().dim_size(0));
- const int32 num_features = static_cast<int32>(
- input_data.shape().dim_size(1));
- const int32 num_nodes = static_cast<int32>(
- tree_parameters_tensor.shape().dim_size(0));
+ const int32 num_data = static_cast<int32>(input_data.shape().dim_size(0));
+ const int32 num_features =
+ static_cast<int32>(input_data.shape().dim_size(1));
+ const int32 num_nodes =
+ static_cast<int32>(tree_parameters_tensor.shape().dim_size(0));
Tensor* output_probability = nullptr;
TensorShape output_probability_shape;
@@ -139,9 +137,8 @@ class StochasticHardRoutingFunction : public OpKernel {
OP_REQUIRES_OK(context,
context->allocate_output(0, output_probability_shape,
&output_probability));
- OP_REQUIRES_OK(context,
- context->allocate_output(1, output_path_shape,
- &output_path));
+ OP_REQUIRES_OK(
+ context, context->allocate_output(1, output_path_shape, &output_path));
auto out_probability = output_probability->tensor<float, 2>();
auto out_path = output_path->tensor<int32, 2>();
@@ -150,19 +147,18 @@ class StochasticHardRoutingFunction : public OpKernel {
// Stochastically traverse the tree to a leaf.
for (int i = 0; i < num_data; i++) {
- const Tensor point = input_data.Slice(i, i+1);
+ const Tensor point = input_data.Slice(i, i + 1);
int32 node = 0;
out_probability(i, 0) = 1.0;
out_path(i, 0) = 0;
for (int j = 0; j < tree_depth_ - 1; j++) {
- int32 left_child = 2*node + 1;
+ int32 left_child = 2 * node + 1;
int32 right_child = left_child + 1;
- float left_prob = LeftProbability(point,
- tree_parameters_tensor.Slice(j, j+1),
- tree_biases(j),
- num_features);
+ float left_prob =
+ LeftProbability(point, tree_parameters_tensor.Slice(j, j + 1),
+ tree_biases(j), num_features);
if (left_prob < rng_->RandFloat()) {
CHECK_LT(i, num_data);
diff --git a/tensorflow/contrib/tensor_forest/hybrid/core/ops/stochastic_hard_routing_gradient_op.cc b/tensorflow/contrib/tensor_forest/hybrid/core/ops/stochastic_hard_routing_gradient_op.cc
index 0b5afe464f..b0d8b832b5 100644
--- a/tensorflow/contrib/tensor_forest/hybrid/core/ops/stochastic_hard_routing_gradient_op.cc
+++ b/tensorflow/contrib/tensor_forest/hybrid/core/ops/stochastic_hard_routing_gradient_op.cc
@@ -149,14 +149,14 @@ class StochasticHardRoutingGradient : public OpKernel {
TensorShape output_bias_shape;
output_bias_shape.AddDim(num_data);
- OP_REQUIRES_OK(context, context->allocate_output(
- 0, output_routing_shape, &output_routing));
- OP_REQUIRES_OK(context, context->allocate_output(
- 1, output_data_shape, &output_data));
- OP_REQUIRES_OK(context, context->allocate_output(
- 2, output_parameters_shape, &output_parameters));
- OP_REQUIRES_OK(context, context->allocate_output(
- 3, output_bias_shape, &output_bias));
+ OP_REQUIRES_OK(context, context->allocate_output(0, output_routing_shape,
+ &output_routing));
+ OP_REQUIRES_OK(
+ context, context->allocate_output(1, output_data_shape, &output_data));
+ OP_REQUIRES_OK(context, context->allocate_output(2, output_parameters_shape,
+ &output_parameters));
+ OP_REQUIRES_OK(
+ context, context->allocate_output(3, output_bias_shape, &output_bias));
tensorforest::Initialize(*output_routing, 0.0);
tensorforest::Initialize(*output_data, 0.0);
@@ -178,7 +178,7 @@ class StochasticHardRoutingGradient : public OpKernel {
const Tensor point = input_data.Slice(i, i + 1);
// Traverses the tree from the bottom up.
- for (int j = tree_depth_-1; j > -1; j--) {
+ for (int j = tree_depth_ - 1; j > -1; j--) {
int32 node = path(i, j);
CHECK_LT(node, num_nodes);
diff --git a/tensorflow/contrib/tensor_forest/hybrid/core/ops/unpack_path_op.cc b/tensorflow/contrib/tensor_forest/hybrid/core/ops/unpack_path_op.cc
index cacad03e27..25825a78a1 100644
--- a/tensorflow/contrib/tensor_forest/hybrid/core/ops/unpack_path_op.cc
+++ b/tensorflow/contrib/tensor_forest/hybrid/core/ops/unpack_path_op.cc
@@ -64,8 +64,7 @@ REGISTER_OP("UnpackPath")
class UnpackPath : public OpKernel {
public:
- explicit UnpackPath(OpKernelConstruction* context)
- : OpKernel(context) {}
+ explicit UnpackPath(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
VLOG(1) << "unpack start";
@@ -73,8 +72,8 @@ class UnpackPath : public OpKernel {
const Tensor& path_values_tensor = context->input(1);
const int32 num_data = static_cast<int32>(path_tensor.shape().dim_size(0));
- const int32 tree_depth = static_cast<int32>(
- path_tensor.shape().dim_size(1));
+ const int32 tree_depth =
+ static_cast<int32>(path_tensor.shape().dim_size(1));
const int32 num_nodes = MathUtil::IPow(2, tree_depth) - 1;
@@ -107,7 +106,6 @@ class UnpackPath : public OpKernel {
}
};
-REGISTER_KERNEL_BUILDER(Name("UnpackPath").Device(DEVICE_CPU),
- UnpackPath);
+REGISTER_KERNEL_BUILDER(Name("UnpackPath").Device(DEVICE_CPU), UnpackPath);
} // namespace tensorflow
diff --git a/tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.cc b/tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.cc
index c091a73c4e..34388fe1aa 100644
--- a/tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.cc
+++ b/tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.cc
@@ -25,9 +25,7 @@ namespace tensorforest {
using tensorflow::Tensor;
-float LeftProbability(const Tensor& point,
- const Tensor& weight,
- float bias,
+float LeftProbability(const Tensor& point, const Tensor& weight, float bias,
int num_features) {
const auto p = point.unaligned_flat<float>();
const auto w = weight.unaligned_flat<float>();
@@ -41,11 +39,8 @@ float LeftProbability(const Tensor& point,
return 1.0 / (1.0 + exp(-dot_product + bias));
}
-float LeftProbabilityK(const Tensor& point,
- std::vector<int32> feature_set,
- const Tensor& weight,
- float bias,
- int num_features,
+float LeftProbabilityK(const Tensor& point, std::vector<int32> feature_set,
+ const Tensor& weight, float bias, int num_features,
int k) {
const auto p = point.unaligned_flat<float>();
const auto w = weight.unaligned_flat<float>();
diff --git a/tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.h b/tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.h
index c5902184f9..69a0143a4e 100644
--- a/tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.h
+++ b/tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.h
@@ -24,16 +24,11 @@ namespace tensorflow {
namespace tensorforest {
// Returns the probability that the point falls to the left.
-float LeftProbability(const Tensor& point,
- const Tensor& weight,
- float bias,
+float LeftProbability(const Tensor& point, const Tensor& weight, float bias,
int num_features);
-float LeftProbabilityK(const Tensor& point,
- std::vector<int32> feature_set,
- const Tensor& weight,
- float bias,
- int num_features,
+float LeftProbabilityK(const Tensor& point, std::vector<int32> feature_set,
+ const Tensor& weight, float bias, int num_features,
int k);
// Returns a random set of num_features_to_pick features in the
@@ -49,5 +44,3 @@ void GetFeatureSet(int32 tree_num, int32 node_num, int32 random_seed,
} // namespace tensorflow
#endif // LEARNING_LIB_TENSOR_FOREST_HYBRID_CORE_OPS_UTILS_H_
-
-
diff --git a/tensorflow/contrib/tensor_forest/kernels/reinterpret_string_to_float_op.cc b/tensorflow/contrib/tensor_forest/kernels/reinterpret_string_to_float_op.cc
index 47b49a379c..b21a917977 100644
--- a/tensorflow/contrib/tensor_forest/kernels/reinterpret_string_to_float_op.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/reinterpret_string_to_float_op.cc
@@ -30,15 +30,13 @@ namespace tensorflow {
using tensorforest::CheckTensorBounds;
-
float Convert(const string& in) {
const std::size_t intval = std::hash<string>()(in);
return static_cast<float>(intval);
}
-
-void Evaluate(const Tensor& input_data, Tensor output_data,
- int32 start, int32 end) {
+void Evaluate(const Tensor& input_data, Tensor output_data, int32 start,
+ int32 end) {
auto out_data = output_data.unaligned_flat<float>();
const auto in_data = input_data.unaligned_flat<string>();
@@ -59,9 +57,8 @@ class ReinterpretStringToFloat : public OpKernel {
if (!CheckTensorBounds(context, input_data)) return;
Tensor* output_data = nullptr;
- OP_REQUIRES_OK(context,
- context->allocate_output(0, input_data.shape(),
- &output_data));
+ OP_REQUIRES_OK(
+ context, context->allocate_output(0, input_data.shape(), &output_data));
// Evaluate input data in parallel.
const int32 num_data = static_cast<int32>(input_data.NumElements());
@@ -73,8 +70,8 @@ class ReinterpretStringToFloat : public OpKernel {
auto work = [&input_data, output_data, num_data](int64 start, int64 end) {
CHECK(start <= end);
CHECK(end <= num_data);
- Evaluate(input_data, *output_data,
- static_cast<int32>(start), static_cast<int32>(end));
+ Evaluate(input_data, *output_data, static_cast<int32>(start),
+ static_cast<int32>(end));
};
Shard(num_threads, worker_threads->workers, num_data, 100, work);
}
diff --git a/tensorflow/contrib/tensor_forest/kernels/scatter_add_ndim_op.cc b/tensorflow/contrib/tensor_forest/kernels/scatter_add_ndim_op.cc
index dd2a98b08c..60740c2be3 100644
--- a/tensorflow/contrib/tensor_forest/kernels/scatter_add_ndim_op.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/scatter_add_ndim_op.cc
@@ -22,7 +22,6 @@
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/platform/logging.h"
-
namespace tensorflow {
using tensorforest::CheckTensorBounds;
@@ -38,20 +37,19 @@ class ScatterAddNdim : public OpKernel {
if (indices_tensor.shape().dim_size(0) > 0) {
OP_REQUIRES(context, indices_tensor.shape().dims() == 2,
- errors::InvalidArgument(
- "indices should be two-dimensional"));
+ errors::InvalidArgument("indices should be two-dimensional"));
const int32 delta_dims = deltas_tensor.shape().dims();
OP_REQUIRES(
context,
indices_tensor.shape().dim_size(1) + delta_dims ==
- input_tensor.shape().dims() + 1,
+ input_tensor.shape().dims() + 1,
errors::InvalidArgument(
"Number of indices dimensions should be the same as input "
"rank."));
OP_REQUIRES(
context,
indices_tensor.shape().dim_size(0) ==
- deltas_tensor.shape().dim_size(0),
+ deltas_tensor.shape().dim_size(0),
errors::InvalidArgument(
"Number of updates should be same as number of indices."));
} else {
@@ -68,8 +66,8 @@ class ScatterAddNdim : public OpKernel {
const auto indices = indices_tensor.tensor<int32, 2>();
const auto deltas = deltas_tensor.unaligned_flat<float>();
- const int32 num_dims = static_cast<int32>(
- indices_tensor.shape().dim_size(1));
+ const int32 num_dims =
+ static_cast<int32>(indices_tensor.shape().dim_size(1));
// Figure out if indices don't specify a complete position in the
// input tensor.
@@ -80,10 +78,9 @@ class ScatterAddNdim : public OpKernel {
// Calculate index multipliers.
std::vector<int32> multipliers;
- OP_REQUIRES(
- context, input.size() < std::numeric_limits<int32>::max(),
- errors::InvalidArgument(
- "Input must contain less than 2^31 total elements"));
+ OP_REQUIRES(context, input.size() < std::numeric_limits<int32>::max(),
+ errors::InvalidArgument(
+ "Input must contain less than 2^31 total elements"));
int32 last_size = static_cast<int32>(input.size());
for (int32 j = 0; j < num_dims; j++) {
diff --git a/tensorflow/contrib/tensor_forest/kernels/tree_utils.cc b/tensorflow/contrib/tensor_forest/kernels/tree_utils.cc
index 94e12cea5a..44997ec5d6 100644
--- a/tensorflow/contrib/tensor_forest/kernels/tree_utils.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/tree_utils.cc
@@ -65,8 +65,8 @@ void GetTwoBest(int max, const std::function<float(int)>& score_fn,
float ClassificationSplitScore(
const Eigen::Tensor<float, 1, Eigen::RowMajor>& splits,
- const Eigen::Tensor<float, 1, Eigen::RowMajor>& rights,
- int32 num_classes, int i) {
+ const Eigen::Tensor<float, 1, Eigen::RowMajor>& rights, int32 num_classes,
+ int i) {
Eigen::array<int, 1> offsets;
// Class counts are stored with the total in [0], so the length of each
// count vector is num_classes + 1.
@@ -74,7 +74,7 @@ float ClassificationSplitScore(
Eigen::array<int, 1> extents;
extents[0] = num_classes;
return WeightedGiniImpurity(splits.slice(offsets, extents)) +
- WeightedGiniImpurity(rights.slice(offsets, extents));
+ WeightedGiniImpurity(rights.slice(offsets, extents));
}
void GetTwoBestClassification(const Tensor& total_counts,
@@ -90,29 +90,28 @@ void GetTwoBestClassification(const Tensor& total_counts,
// in seg faults, so we have to go with flat views of these tensors. However,
// it is still pretty efficient because we put off evaluation until the
// score is actually returned.
- const auto tc = total_counts.Slice(
- accumulator, accumulator + 1).unaligned_flat<float>();
+ const auto tc =
+ total_counts.Slice(accumulator, accumulator + 1).unaligned_flat<float>();
// TODO(gilberth): See if we can delay evaluation here by templating the
// arguments to ClassificationSplitScore.
- const Eigen::Tensor<float, 1, Eigen::RowMajor> splits = split_counts.Slice(
- accumulator, accumulator + 1).unaligned_flat<float>();
+ const Eigen::Tensor<float, 1, Eigen::RowMajor> splits =
+ split_counts.Slice(accumulator, accumulator + 1).unaligned_flat<float>();
Eigen::array<int, 1> bcast;
bcast[0] = num_splits;
const Eigen::Tensor<float, 1, Eigen::RowMajor> rights =
tc.broadcast(bcast) - splits;
- std::function<float(int)> score_fn = std::bind(
- ClassificationSplitScore, splits, rights, num_classes,
- std::placeholders::_1);
+ std::function<float(int)> score_fn =
+ std::bind(ClassificationSplitScore, splits, rights, num_classes,
+ std::placeholders::_1);
GetTwoBest(num_splits, score_fn, best_score, best_index, second_best_score,
second_best_index);
}
-int32 BestFeatureClassification(
- const Tensor& total_counts, const Tensor& split_counts,
- int32 accumulator) {
+int32 BestFeatureClassification(const Tensor& total_counts,
+ const Tensor& split_counts, int32 accumulator) {
float best_score;
float second_best_score;
int best_feature_index;
@@ -130,8 +129,7 @@ float RegressionSplitScore(
const Eigen::Tensor<float, 1, Eigen::RowMajor>& splits_square,
const Eigen::Tensor<float, 1, Eigen::RowMajor>& right_sums,
const Eigen::Tensor<float, 1, Eigen::RowMajor>& right_squares,
- int32 accumulator,
- int32 num_regression_dims, int i) {
+ int32 accumulator, int32 num_regression_dims, int i) {
Eigen::array<int, 1> offsets = {i * num_regression_dims + 1};
Eigen::array<int, 1> extents = {num_regression_dims - 1};
float left_count = splits_count_accessor(accumulator, i, 0);
@@ -141,15 +139,15 @@ float RegressionSplitScore(
// Guard against divide-by-zero.
if (left_count > 0) {
- score += WeightedVariance(
- splits_sum.slice(offsets, extents),
- splits_square.slice(offsets, extents), left_count);
+ score +=
+ WeightedVariance(splits_sum.slice(offsets, extents),
+ splits_square.slice(offsets, extents), left_count);
}
if (right_count > 0) {
- score += WeightedVariance(right_sums.slice(offsets, extents),
- right_squares.slice(offsets, extents),
- right_count);
+ score +=
+ WeightedVariance(right_sums.slice(offsets, extents),
+ right_squares.slice(offsets, extents), right_count);
}
return score;
}
@@ -159,20 +157,20 @@ void GetTwoBestRegression(const Tensor& total_sums, const Tensor& total_squares,
int32 accumulator, float* best_score, int* best_index,
float* second_best_score, int* second_best_index) {
const int32 num_splits = static_cast<int32>(split_sums.shape().dim_size(1));
- const int32 num_regression_dims = static_cast<int32>(
- split_sums.shape().dim_size(2));
+ const int32 num_regression_dims =
+ static_cast<int32>(split_sums.shape().dim_size(2));
// Ideally, Eigen::Tensor::chip would be best to use here but it results
// in seg faults, so we have to go with flat views of these tensors. However,
// it is still pretty efficient because we put off evaluation until the
// score is actually returned.
- const auto tc_sum = total_sums.Slice(
- accumulator, accumulator + 1).unaligned_flat<float>();
- const auto tc_square = total_squares.Slice(
- accumulator, accumulator + 1).unaligned_flat<float>();
- const auto splits_sum = split_sums.Slice(
- accumulator, accumulator + 1).unaligned_flat<float>();
- const auto splits_square = split_squares.Slice(
- accumulator, accumulator + 1).unaligned_flat<float>();
+ const auto tc_sum =
+ total_sums.Slice(accumulator, accumulator + 1).unaligned_flat<float>();
+ const auto tc_square =
+ total_squares.Slice(accumulator, accumulator + 1).unaligned_flat<float>();
+ const auto splits_sum =
+ split_sums.Slice(accumulator, accumulator + 1).unaligned_flat<float>();
+ const auto splits_square =
+ split_squares.Slice(accumulator, accumulator + 1).unaligned_flat<float>();
// Eigen is infuriating to work with, usually resulting in all kinds of
// unhelpful compiler errors when trying something that seems sane. This
// helps us do a simple thing like access the first element (the counts)
@@ -193,10 +191,10 @@ void GetTwoBestRegression(const Tensor& total_sums, const Tensor& total_squares,
best_score, best_index, second_best_score, second_best_index);
}
-int32 BestFeatureRegression(
- const Tensor& total_sums, const Tensor& total_squares,
- const Tensor& split_sums, const Tensor& split_squares,
- int32 accumulator) {
+int32 BestFeatureRegression(const Tensor& total_sums,
+ const Tensor& total_squares,
+ const Tensor& split_sums,
+ const Tensor& split_squares, int32 accumulator) {
float best_score;
float second_best_score;
int best_feature_index;
@@ -207,10 +205,11 @@ int32 BestFeatureRegression(
return best_feature_index;
}
-bool BestSplitDominatesRegression(
- const Tensor& total_sums, const Tensor& total_squares,
- const Tensor& split_sums, const Tensor& split_squares,
- int32 accumulator) {
+bool BestSplitDominatesRegression(const Tensor& total_sums,
+ const Tensor& total_squares,
+ const Tensor& split_sums,
+ const Tensor& split_squares,
+ int32 accumulator) {
// TODO(thomaswc): Implement this, probably as part of v3.
return false;
}
@@ -599,7 +598,6 @@ bool Decide(float value, float bias, DataColumnTypes type) {
}
}
-
void GetParentWeightedMean(float leaf_sum, const float* leaf_data,
float parent_sum, const float* parent_data,
float valid_leaf_threshold, int num_outputs,
diff --git a/tensorflow/contrib/tensor_forest/kernels/tree_utils.h b/tensorflow/contrib/tensor_forest/kernels/tree_utils.h
index dad9df4898..edbac67006 100644
--- a/tensorflow/contrib/tensor_forest/kernels/tree_utils.h
+++ b/tensorflow/contrib/tensor_forest/kernels/tree_utils.h
@@ -45,13 +45,10 @@ const int32 LEAF_NODE = -1;
const int32 FREE_NODE = -2;
// Used to indicate column types, e.g. categorical vs. float
-enum DataColumnTypes {
- kDataFloat = 0,
- kDataCategorical = 1
-};
+enum DataColumnTypes { kDataFloat = 0, kDataCategorical = 1 };
// Calculates the sum of a tensor.
-template<typename T>
+template <typename T>
T Sum(Tensor counts) {
Eigen::Tensor<T, 0, Eigen::RowMajor> count_sum =
counts.unaligned_flat<T>().sum();
@@ -97,7 +94,7 @@ float WeightedGiniImpurity(const T& counts) {
return RawWeightedGiniImpurity(smoothed);
}
-template<typename T1, typename T2>
+template <typename T1, typename T2>
float WeightedVariance(const T1& sums, const T2& squares, float count) {
const auto e_x = sums / count;
const auto e_x2 = squares / count;
@@ -120,10 +117,11 @@ int32 BestFeatureRegression(const Tensor& total_sums,
// Returns true if the best split's variance is sufficiently smaller than
// that of the next best split.
-bool BestSplitDominatesRegression(
- const Tensor& total_sums, const Tensor& total_squares,
- const Tensor& split_sums, const Tensor& split_squares,
- int32 accumulator);
+bool BestSplitDominatesRegression(const Tensor& total_sums,
+ const Tensor& total_squares,
+ const Tensor& split_sums,
+ const Tensor& split_squares,
+ int32 accumulator);
// Performs booststrap_samples bootstrap samples of the best split's class
// counts and the second best splits's class counts, and returns true if at
@@ -178,10 +176,8 @@ bool DecideNode(const GetFeatureFnType& get_dense,
// isn't present in sparse_input_indices. sparse_input_indices is assumed
// to be sorted.
template <typename T1, typename T2>
-float FindSparseValue(
- const T1& sparse_input_indices,
- const T2& sparse_input_values,
- int32 i, int32 j) {
+float FindSparseValue(const T1& sparse_input_indices,
+ const T2& sparse_input_values, int32 i, int32 j) {
int32 low = 0;
int32 high = sparse_input_values.dimension(0);
while (low < high) {
@@ -273,7 +269,6 @@ int32 GetNumSparseFeatures(const T1& indices, int32 input_index,
// categorical data, it is value != bias.
bool Decide(float value, float bias, DataColumnTypes type = kDataFloat);
-
// Returns true if all the splits are initialized. Since they get initialized
// in order, we can simply infer this from the last split.
// This should only be called for a single allocator's candidate features
diff --git a/tensorflow/contrib/tensor_forest/kernels/tree_utils_test.cc b/tensorflow/contrib/tensor_forest/kernels/tree_utils_test.cc
index 7485a695df..0855354550 100644
--- a/tensorflow/contrib/tensor_forest/kernels/tree_utils_test.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/tree_utils_test.cc
@@ -44,11 +44,13 @@ TEST(TestWeightedVariance, Basic) {
Tensor squares = test::AsTensor<float>({29, 12}, {2});
EXPECT_FLOAT_EQ(WeightedVariance(sums.unaligned_flat<float>(),
- squares.unaligned_flat<float>(), 3), 2.0);
+ squares.unaligned_flat<float>(), 3),
+ 2.0);
Tensor zero = test::AsTensor<float>({0}, {1});
EXPECT_FLOAT_EQ(WeightedVariance(zero.unaligned_flat<float>(),
- zero.unaligned_flat<float>(), 1), 0);
+ zero.unaligned_flat<float>(), 1),
+ 0);
}
TEST(TestInitialize, Basic) {
@@ -94,17 +96,16 @@ TEST(BestFeatureClassification, Basic) {
const int32 num_accumulators = 4;
const int32 num_splits = 3;
const int32 num_classes = 4;
- Tensor totals = test::AsTensor<float>({1, 5, 6, 7,
- 0, 0, 0, 0,
- 30, 10, 10, 10, // this one
- -1, -1, -1, -1},
- {num_accumulators, num_classes});
- Tensor splits = test::AsTensor<float>(
- {1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4,
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 30, 10, 10, 10, 10, 0, 0, 10, 19, 5, 6, 8, // this one
- -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
- {num_accumulators, num_splits, num_classes});
+ Tensor totals = test::AsTensor<float>(
+ {1, 5, 6, 7, 0, 0, 0, 0, 30, 10, 10, 10, // this one
+ -1, -1, -1, -1},
+ {num_accumulators, num_classes});
+ Tensor splits =
+ test::AsTensor<float>({1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 30, 10,
+ 10, 10, 10, 0, 0, 10, 19, 5, 6, 8, // this one
+ -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
+ {num_accumulators, num_splits, num_classes});
EXPECT_EQ(BestFeatureClassification(totals, splits, 2), 1);
}
@@ -114,17 +115,16 @@ TEST(BestFeatureClassification, NoWinner) {
const int32 num_splits = 3;
const int32 num_classes = 4;
// When counts are all the same, the most reasonable thing to do is pick 0.
- Tensor totals = test::AsTensor<float>({1, 5, 6, 7,
- 0, 0, 0, 0,
- 18, 6, 6, 6, // this one
- -1, -1, -1, -1},
- {num_accumulators, num_classes});
- Tensor splits = test::AsTensor<float>(
- {1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4,
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 9, 3, 3, 3, 9, 3, 3, 3, 9, 3, 3, 3, // this one
- -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
- {num_accumulators, num_splits, num_classes});
+ Tensor totals =
+ test::AsTensor<float>({1, 5, 6, 7, 0, 0, 0, 0, 18, 6, 6, 6, // this one
+ -1, -1, -1, -1},
+ {num_accumulators, num_classes});
+ Tensor splits =
+ test::AsTensor<float>({1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 3,
+ 3, 3, 9, 3, 3, 3, 9, 3, 3, 3, // this one
+ -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
+ {num_accumulators, num_splits, num_classes});
EXPECT_EQ(BestFeatureClassification(totals, splits, 2), 0);
}
@@ -133,36 +133,34 @@ TEST(BestFeatureRegression, Basic) {
const int32 num_accumulators = 4;
const int32 num_splits = 3;
const int32 num_classes = 4;
- Tensor total_sums = test::AsTensor<float>(
- {1, 5, 6, 7,
- 0, 0, 0, 0,
- 10, 8, 6, 9, // this one
- -1, -1, -1, -1},
- {num_accumulators, num_classes});
+ Tensor total_sums =
+ test::AsTensor<float>({1, 5, 6, 7, 0, 0, 0, 0, 10, 8, 6, 9, // this one
+ -1, -1, -1, -1},
+ {num_accumulators, num_classes});
Tensor total_squares = test::AsTensor<float>(
- {1, 5, 6, 7,
- 0, 0, 0, 0,
- 100, 50, 40, 45, // this one
+ {1, 5, 6, 7, 0, 0, 0, 0, 100, 50, 40, 45, // this one
-1, -1, -1, -1},
{num_accumulators, num_classes});
- Tensor split_sums = test::AsTensor<float>(
- {1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4,
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 10, 8, 6, 9, 9, 8, 5, 9, 0, 0, 0, 0, // this one
- -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
- {num_accumulators, num_splits, num_classes});
+ Tensor split_sums =
+ test::AsTensor<float>({1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 8,
+ 6, 9, 9, 8, 5, 9, 0, 0, 0, 0, // this one
+ -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
+ {num_accumulators, num_splits, num_classes});
// lower the variance by lowering one of the squares just a little.
- Tensor split_squares = test::AsTensor<float>(
- {1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4,
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 100, 50, 40, 45, 100, 50, 40, 43, 0, 0, 0, 0, // this one
- -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
- {num_accumulators, num_splits, num_classes});
+ Tensor split_squares =
+ test::AsTensor<float>(
+ {1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 100, 50, 40, 45, 100, 50, 40, 43, 0, 0, 0, 0, // this one
+ -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
+ {num_accumulators, num_splits, num_classes});
EXPECT_EQ(BestFeatureRegression(total_sums, total_squares, split_sums,
- split_squares, 2), 1);
+ split_squares, 2),
+ 1);
}
TEST(BestFeatureRegression, NoWinner) {
@@ -170,37 +168,33 @@ TEST(BestFeatureRegression, NoWinner) {
const int32 num_splits = 3;
const int32 num_classes = 4;
// when counts are all the same, the most reasonable thing to do is pick 0.
- Tensor total_sums = test::AsTensor<float>(
- {1, 5, 6, 7,
- 0, 0, 0, 0,
- 10, 8, 6, 9, // this one
- -1, -1, -1, -1},
- {num_accumulators, num_classes});
+ Tensor total_sums =
+ test::AsTensor<float>({1, 5, 6, 7, 0, 0, 0, 0, 10, 8, 6, 9, // this one
+ -1, -1, -1, -1},
+ {num_accumulators, num_classes});
Tensor total_squares = test::AsTensor<float>(
- {1, 5, 6, 7,
- 0, 0, 0, 0,
- 100, 50, 40, 45, // this one
+ {1, 5, 6, 7, 0, 0, 0, 0, 100, 50, 40, 45, // this one
-1, -1, -1, -1},
{num_accumulators, num_classes});
- Tensor split_sums = test::AsTensor<float>(
- {1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4,
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 10, 8, 6, 9, 10, 8, 6, 9, 10, 8, 6, 9, // this one
- -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
- {num_accumulators, num_splits, num_classes});
+ Tensor split_sums =
+ test::AsTensor<float>({1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 8,
+ 6, 9, 10, 8, 6, 9, 10, 8, 6, 9, // this one
+ -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
+ {num_accumulators, num_splits, num_classes});
Tensor split_squares = test::AsTensor<float>(
- {1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4,
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 100, 50, 40, 45, 100, 50, 40, 45, 100, 50, 40, 45, // this one
- -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
+ {1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 100, 50, 40, 45, 100, 50, 40, 45, 100, 50, 40, 45, // this one
+ -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{num_accumulators, num_splits, num_classes});
EXPECT_EQ(BestFeatureRegression(total_sums, total_squares, split_sums,
- split_squares, 2), 0);
+ split_squares, 2),
+ 0);
}
} // namespace tensorforest
} // namespace tensorflow
-
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.cc b/tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.cc
index 81e2a1b2a1..f4a7058ddb 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.cc
@@ -14,8 +14,8 @@
// =============================================================================
#include "tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.h"
-#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
namespace tensorflow {
@@ -58,8 +58,7 @@ CandidateGraphRunner::CandidateGraphRunner(
// Features don't change, store them in a tensor.
const auto& oblique = split.inequality_left_child_test().oblique();
const int32 feat_size = oblique.features_size();
- features_.reset(
- new Tensor(tensorflow::DT_INT32, TensorShape({feat_size})));
+ features_.reset(new Tensor(tensorflow::DT_INT32, TensorShape({feat_size})));
auto feat = features_->flat<int32>();
int i = 0;
for (const auto& id : oblique.features()) {
@@ -67,10 +66,10 @@ CandidateGraphRunner::CandidateGraphRunner(
}
}
-void CandidateGraphRunner::RunOp(
- const string& name, const TensorNameValueList& inputs,
- const std::vector<string>& output_tensor_names,
- std::vector<Tensor>* outputs) {
+void CandidateGraphRunner::RunOp(const string& name,
+ const TensorNameValueList& inputs,
+ const std::vector<string>& output_tensor_names,
+ std::vector<Tensor>* outputs) {
std::vector<string> op_name;
if (name != kNoOp) {
op_name.push_back(name);
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h b/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h
index cced26b903..328af28725 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h
@@ -26,7 +26,6 @@
namespace tensorflow {
namespace tensorforest {
-
// Keep a tree ensemble in memory for efficient evaluation and mutation.
class DecisionTreeResource : public ResourceBase {
public:
@@ -35,15 +34,12 @@ class DecisionTreeResource : public ResourceBase {
string DebugString() override {
return strings::StrCat("DecisionTree[size=",
- decision_tree_->decision_tree().nodes_size(),
- "]");
+ decision_tree_->decision_tree().nodes_size(), "]");
}
void MaybeInitialize();
- const decision_trees::Model& decision_tree() const {
- return *decision_tree_;
- }
+ const decision_trees::Model& decision_tree() const { return *decision_tree_; }
decision_trees::Model* mutable_decision_tree() {
return decision_tree_.get();
@@ -59,9 +55,7 @@ class DecisionTreeResource : public ResourceBase {
// Resets the resource and frees the proto.
// Caller needs to hold the mutex lock while calling this.
- void Reset() {
- decision_tree_.reset(new decision_trees::Model());
- }
+ void Reset() { decision_tree_.reset(new decision_trees::Model()); }
mutex* get_mutex() { return &mu_; }
@@ -84,7 +78,6 @@ class DecisionTreeResource : public ResourceBase {
std::vector<std::unique_ptr<DecisionNodeEvaluator>> node_evaluators_;
};
-
} // namespace tensorforest
} // namespace tensorflow
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h
index 85ce7b825b..bf2b2aaa3c 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h
@@ -22,7 +22,6 @@
namespace tensorflow {
namespace tensorforest {
-
// Base class for evaluators of decision nodes that effectively copy proto
// contents into C++ structures for faster execution.
class DecisionNodeEvaluator {
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator_test.cc b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator_test.cc
index 5c49b87443..af5cf72a3c 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator_test.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator_test.cc
@@ -20,11 +20,11 @@
namespace tensorflow {
namespace {
+using tensorflow::decision_trees::InequalityTest;
+using tensorflow::decision_trees::MatchingValuesTest;
using tensorflow::tensorforest::InequalityDecisionNodeEvaluator;
using tensorflow::tensorforest::MatchingValuesDecisionNodeEvaluator;
using tensorflow::tensorforest::ObliqueInequalityDecisionNodeEvaluator;
-using tensorflow::decision_trees::InequalityTest;
-using tensorflow::decision_trees::MatchingValuesTest;
TEST(InequalityDecisionNodeEvaluatorTest, TestLessOrEqual) {
InequalityTest test;
@@ -124,4 +124,3 @@ TEST(ObliqueDecisionNodeEvaluatorTest, Basic) {
} // namespace
} // namespace tensorflow
-
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h b/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h
index 0d6712e9e5..eea0be27ca 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h
@@ -40,9 +40,7 @@ class FertileStatsResource : public ResourceBase {
model_op_ = LeafModelOperatorFactory::CreateLeafModelOperator(params_);
}
- string DebugString() override {
- return "FertileStats";
- }
+ string DebugString() override { return "FertileStats"; }
void ExtractFromProto(const FertileStats& stats);
@@ -50,8 +48,7 @@ class FertileStatsResource : public ResourceBase {
// Resets the resource and frees the proto.
// Caller needs to hold the mutex lock while calling this.
- void Reset() {
- }
+ void Reset() {}
// Reset the stats for a node, but leave the leaf_stats intact.
void ResetSplitStats(int32 node_id, int32 depth) {
@@ -84,7 +81,6 @@ class FertileStatsResource : public ResourceBase {
// was found.
bool BestSplit(int32 node_id, SplitCandidate* best, int32* depth);
-
private:
mutex mu_;
std::shared_ptr<LeafModelOperator> model_op_;
@@ -94,7 +90,6 @@ class FertileStatsResource : public ResourceBase {
void AllocateNode(int32 node_id, int32 depth);
};
-
} // namespace tensorforest
} // namespace tensorflow
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.cc b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.cc
index 3ce630e3a9..da600d34ea 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.cc
@@ -20,7 +20,6 @@
#include "tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.h"
#include "tensorflow/core/lib/random/distribution_sampler.h"
-
namespace tensorflow {
namespace tensorforest {
@@ -454,14 +453,14 @@ void DenseClassificationGrowStats::PackToProto(FertileSlot* slot) const {
class_stats->add_value()->set_float_value(total_counts_[i]);
}
- for (int split_num = 0; split_num < num_splits(); ++split_num) {
+ for (int split_num = 0; split_num < num_splits(); ++split_num) {
auto* cand = slot->add_candidates();
*cand->mutable_split() = splits_[split_num];
auto* left_stats = cand->mutable_left_stats()
->mutable_classification()
->mutable_dense_counts();
for (int i = 0; i < num_outputs_; ++i) {
- left_stats->add_value()->set_float_value(left_count(split_num, i));
+ left_stats->add_value()->set_float_value(left_count(split_num, i));
}
}
}
@@ -546,7 +545,7 @@ void SparseClassificationGrowStats::PackToProto(FertileSlot* slot) const {
(*class_stats)[entry.first] = val;
}
- for (int split_num = 0; split_num < num_splits(); ++split_num) {
+ for (int split_num = 0; split_num < num_splits(); ++split_num) {
auto* cand = slot->add_candidates();
*cand->mutable_split() = splits_[split_num];
auto* left_stats = cand->mutable_left_stats()
@@ -561,8 +560,8 @@ void SparseClassificationGrowStats::PackToProto(FertileSlot* slot) const {
}
}
-float SparseClassificationGrowStats::GiniScore(
- int split, float* left_sum, float* right_sum) const {
+float SparseClassificationGrowStats::GiniScore(int split, float* left_sum,
+ float* right_sum) const {
float left_square = 0, right_square = 0;
*left_sum = 0;
*right_sum = 0;
@@ -844,12 +843,11 @@ void LeastSquaresRegressionGrowStats::PackToProto(FertileSlot* slot) const {
total_squares->add_value()->set_float_value(total_sum_squares_[i]);
}
- for (int split_num = 0; split_num < num_splits(); ++split_num) {
+ for (int split_num = 0; split_num < num_splits(); ++split_num) {
auto* cand = slot->add_candidates();
*cand->mutable_split() = splits_[split_num];
- auto* sums = cand->mutable_left_stats()
- ->mutable_regression()
- ->mutable_mean_output();
+ auto* sums =
+ cand->mutable_left_stats()->mutable_regression()->mutable_mean_output();
auto* squares = cand->mutable_left_stats()
->mutable_regression()
->mutable_mean_output_squares();
@@ -891,20 +889,17 @@ float LeastSquaresRegressionGrowStats::SplitVariance(int split) const {
float total_variance = 0;
for (int i = 0; i < params_.num_outputs(); ++i) {
// Left side
- const float le_x =
- left_sum(split, i) / left_counts_[split];
+ const float le_x = left_sum(split, i) / left_counts_[split];
- const float le_x2 =
- left_square(split, i) / left_counts_[split];
+ const float le_x2 = left_square(split, i) / left_counts_[split];
total_variance += le_x2 - le_x * le_x;
// Right side
const float re_x = (total_sum_[i] - left_sum(split, i)) /
(weight_sum_ - left_counts_[split]);
- const float re_x2 =
- (total_sum_squares_[i] - left_square(split, i)) /
- (weight_sum_ - left_counts_[split]);
+ const float re_x2 = (total_sum_squares_[i] - left_square(split, i)) /
+ (weight_sum_ - left_counts_[split]);
total_variance += re_x2 - re_x * re_x;
}
return total_variance;
@@ -937,8 +932,7 @@ bool LeastSquaresRegressionGrowStats::BestSplit(SplitCandidate* best) const {
left->set_weight_sum(left_counts_[best_index]);
auto* left_output_sum = left_reg_stats->mutable_mean_output();
for (int i = 0; i < num_outputs; ++i) {
- left_output_sum->add_value()->set_float_value(
- left_sum(best_index, i));
+ left_output_sum->add_value()->set_float_value(left_sum(best_index, i));
}
// Right
@@ -947,8 +941,8 @@ bool LeastSquaresRegressionGrowStats::BestSplit(SplitCandidate* best) const {
right->set_weight_sum(weight_sum_ - left_counts_[best_index]);
auto* right_output_sum = right_reg_stats->mutable_mean_output();
for (int i = 0; i < num_outputs; ++i) {
- right_output_sum->add_value()->set_float_value(
- total_sum_[i] - left_sum(best_index, i));
+ right_output_sum->add_value()->set_float_value(total_sum_[i] -
+ left_sum(best_index, i));
}
return true;
}
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h
index 02c0fc687f..04e6b0a735 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h
@@ -73,21 +73,15 @@ class GrowStats {
const InputTarget* target, int example) {}
void RemoveSplit(int split_num);
- int num_splits() const {
- return splits_.size();
- }
+ int num_splits() const { return splits_.size(); }
- float weight_sum() const {
- return weight_sum_;
- }
+ float weight_sum() const { return weight_sum_; }
virtual bool IsInitialized() const {
return weight_sum_ > 0 || splits_.size() == num_splits_to_consider_;
}
- int32 depth() const {
- return depth_;
- }
+ int32 depth() const { return depth_; }
protected:
GrowStats(const TensorForestParams& params, int32 depth);
@@ -206,8 +200,8 @@ class ClassificationStats : public GrowStats {
virtual float left_count(int split, int class_num) const = 0;
virtual float right_count(int split, int class_num) const = 0;
- virtual void ClassificationAddLeftExample(
- int split, int64 int_label, float weight) = 0;
+ virtual void ClassificationAddLeftExample(int split, int64 int_label,
+ float weight) = 0;
virtual void ClassificationAddRightExample(int split, int64 int_label,
float weight) {
// Does nothing by default, but sub-classes can override.
@@ -375,9 +369,7 @@ class SparseClassificationGrowStats : public ClassificationStats {
SparseClassificationGrowStats(const TensorForestParams& params, int32 depth)
: ClassificationStats(params, depth) {}
- void Initialize() override {
- Clear();
- }
+ void Initialize() override { Clear(); }
void ExtractFromProto(const FertileSlot& slot) override;
void PackToProto(FertileSlot* slot) const override;
@@ -562,9 +554,9 @@ class LeastSquaresRegressionGrowStats : public GrowStats {
}
void RemoveSplitStats(int split_num) override {
left_sums_.erase(left_sums_.begin() + num_outputs_ * split_num,
- left_sums_.begin() + num_outputs_ * (split_num + 1));
+ left_sums_.begin() + num_outputs_ * (split_num + 1));
left_squares_.erase(left_squares_.begin() + num_outputs_ * split_num,
- left_squares_.begin() + num_outputs_ * (split_num + 1));
+ left_squares_.begin() + num_outputs_ * (split_num + 1));
left_counts_.erase(left_counts_.begin() + split_num,
left_counts_.begin() + (split_num + 1));
}
@@ -605,7 +597,6 @@ class LeastSquaresRegressionGrowStats : public GrowStats {
std::vector<int64> left_counts_;
};
-
} // namespace tensorforest
} // namespace tensorflow
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats_test.cc b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats_test.cc
index ceb58d2ead..26e989928e 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats_test.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats_test.cc
@@ -24,21 +24,21 @@
namespace tensorflow {
namespace {
-using tensorflow::tensorforest::GrowStats;
-using tensorflow::tensorforest::TestableInputTarget;
-using tensorflow::tensorforest::FertileSlot;
+using tensorflow::decision_trees::BinaryNode;
+using tensorflow::decision_trees::FeatureId;
+using tensorflow::decision_trees::InequalityTest;
using tensorflow::tensorforest::DenseClassificationGrowStats;
-using tensorflow::tensorforest::SparseClassificationGrowStats;
+using tensorflow::tensorforest::FertileSlot;
using tensorflow::tensorforest::FixedSizeClassStats;
using tensorflow::tensorforest::FixedSizeSparseClassificationGrowStats;
+using tensorflow::tensorforest::GrowStats;
using tensorflow::tensorforest::LeastSquaresRegressionGrowStats;
-using tensorflow::tensorforest::TensorForestParams;
+using tensorflow::tensorforest::SparseClassificationGrowStats;
using tensorflow::tensorforest::SPLIT_FINISH_BASIC;
using tensorflow::tensorforest::SPLIT_FINISH_DOMINATE_HOEFFDING;
using tensorflow::tensorforest::SPLIT_PRUNE_HOEFFDING;
-using tensorflow::decision_trees::BinaryNode;
-using tensorflow::decision_trees::InequalityTest;
-using tensorflow::decision_trees::FeatureId;
+using tensorflow::tensorforest::TensorForestParams;
+using tensorflow::tensorforest::TestableInputTarget;
BinaryNode MakeSplit(const string& feat, float val) {
BinaryNode split;
@@ -52,8 +52,7 @@ BinaryNode MakeSplit(const string& feat, float val) {
return split;
}
-void RunBatch(GrowStats* stats,
- const TestableInputTarget* target) {
+void RunBatch(GrowStats* stats, const TestableInputTarget* target) {
std::unique_ptr<tensorflow::tensorforest::TensorDataSet> dataset(
new tensorflow::tensorforest::TestableDataSet(
{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, 2));
@@ -102,18 +101,10 @@ class TestableRunningStats : public DenseClassificationGrowStats {
TestableRunningStats(const TensorForestParams& params, int32 depth)
: DenseClassificationGrowStats(params, depth) {}
- float test_left_sum(int split) {
- return get_left_gini()->sum(split);
- }
- float test_left_square(int split) {
- return get_left_gini()->square(split);
- }
- float test_right_sum(int split) {
- return get_right_gini()->sum(split);
- }
- float test_right_square(int split) {
- return get_right_gini()->square(split);
- }
+ float test_left_sum(int split) { return get_left_gini()->sum(split); }
+ float test_left_square(int split) { return get_left_gini()->square(split); }
+ float test_right_sum(int split) { return get_right_gini()->sum(split); }
+ float test_right_square(int split) { return get_right_gini()->square(split); }
};
TEST(GrowStatsDenseClassificationTest, BasicRunningStats) {
@@ -166,9 +157,7 @@ class TestableFinishEarly : public DenseClassificationGrowStats {
int num_times_called_;
protected:
- void CheckFinishEarlyHoeffding() override {
- ++num_times_called_;
- }
+ void CheckFinishEarlyHoeffding() override { ++num_times_called_; }
};
TEST(GrowStatsDenseClassificationTest, TestFinishEarly) {
@@ -212,7 +201,6 @@ TEST(GrowStatsDenseClassificationTest, TestFinishEarly) {
ASSERT_EQ(stat->num_times_called_, 9);
}
-
TEST(GrowStatsDenseClassificationTest, TestCheckPruneHoeffding) {
TensorForestParams params;
params.set_num_outputs(2);
@@ -224,7 +212,8 @@ TEST(GrowStatsDenseClassificationTest, TestCheckPruneHoeffding) {
finish->set_type(SPLIT_FINISH_BASIC);
finish->mutable_check_every_steps()->set_constant_value(100);
params.mutable_pruning_type()->set_type(SPLIT_PRUNE_HOEFFDING);
- params.mutable_pruning_type()->mutable_prune_every_samples()
+ params.mutable_pruning_type()
+ ->mutable_prune_every_samples()
->set_constant_value(1);
// On each iteration, we add two examples, one of class 0 and one
@@ -234,8 +223,8 @@ TEST(GrowStatsDenseClassificationTest, TestCheckPruneHoeffding) {
std::vector<float> weights = {1, 1};
TestableInputTarget target(labels, weights, 1);
std::unique_ptr<tensorflow::tensorforest::TensorDataSet> dataset(
- new tensorflow::tensorforest::TestableDataSet(
- {-1.0, -1.0, 1.0, -1.0}, 2));
+ new tensorflow::tensorforest::TestableDataSet({-1.0, -1.0, 1.0, -1.0},
+ 2));
DenseClassificationGrowStats stats(params, 1);
stats.Initialize();
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc
index bf0fb92450..d43884481a 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc
@@ -109,10 +109,10 @@ void TensorDataSet::set_input_tensors(const Tensor& dense,
dense_data_.reset(new DenseStorageType(dense.tensor<float, 2>()));
}
if (sparse_indices.shape().dims() == 2) {
- sparse_indices_.reset(new SparseIndicesStorageType(
- sparse_indices.tensor<int64, 2>()));
- sparse_values_.reset(new SparseValuesStorageType(
- sparse_values.tensor<float, 1>()));
+ sparse_indices_.reset(
+ new SparseIndicesStorageType(sparse_indices.tensor<int64, 2>()));
+ sparse_values_.reset(
+ new SparseValuesStorageType(sparse_values.tensor<float, 1>()));
sparse_batch_size_ = sparse_shape.tensor<int64, 1>()(0);
}
original_dense_tensor_ = dense;
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h
index eafad6b591..c544a8c75e 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h
@@ -93,9 +93,7 @@ class TensorDataSet {
// an int32 you can avoid the atoi32.
virtual float GetExampleValue(int example, int32 feature_id) const;
- int num_features() {
- return available_features_.size();
- }
+ int num_features() { return available_features_.size(); }
const Tensor& original_tensor() const { return original_dense_tensor_; }
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/input_target.h b/tensorflow/contrib/tensor_forest/kernels/v4/input_target.h
index 44ec09c50e..d4402b6055 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/input_target.h
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/input_target.h
@@ -79,9 +79,7 @@ class TensorInputTarget : public StoredInputTarget<SingleDimStorageType> {
return (*target_)(example_index * num_targets_ + target_index);
}
- const Tensor& original_tensor() const {
- return original_tensor_;
- }
+ const Tensor& original_tensor() const { return original_tensor_; }
protected:
Tensor original_tensor_;
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.cc b/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.cc
index d43c068e46..83614a2531 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.cc
@@ -160,6 +160,5 @@ void RegressionLeafModelOperator::ExportModel(
}
}
-
} // namespace tensorforest
} // namespace tensorflow
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators_test.cc b/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators_test.cc
index ffd92c01f9..ab4191809b 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators_test.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators_test.cc
@@ -26,19 +26,19 @@ namespace {
using tensorflow::decision_trees::Leaf;
using tensorflow::tensorforest::DenseClassificationLeafModelOperator;
using tensorflow::tensorforest::LeafModelOperator;
-using tensorflow::tensorforest::SparseClassificationLeafModelOperator;
-using tensorflow::tensorforest::SparseOrDenseClassificationLeafModelOperator;
using tensorflow::tensorforest::LeafStat;
using tensorflow::tensorforest::RegressionLeafModelOperator;
-using tensorflow::tensorforest::TestableInputTarget;
+using tensorflow::tensorforest::SparseClassificationLeafModelOperator;
+using tensorflow::tensorforest::SparseOrDenseClassificationLeafModelOperator;
using tensorflow::tensorforest::TensorForestParams;
+using tensorflow::tensorforest::TestableInputTarget;
const int32 kNumClasses = 3;
constexpr char kRegressionStatProto[] =
- "weight_sum: 3 "
- "regression { "
- "mean_output { "
+ "weight_sum: 3 "
+ "regression { "
+ "mean_output { "
"value { "
" float_value: 27 "
"} "
@@ -48,8 +48,8 @@ constexpr char kRegressionStatProto[] =
"value { "
" float_value: 10 "
"} "
- "} "
- "mean_output_squares { "
+ "} "
+ "mean_output_squares { "
"value {"
" float_value: 245"
"}"
@@ -59,8 +59,8 @@ constexpr char kRegressionStatProto[] =
"value {"
" float_value: 46"
"}"
- "}"
-"}";
+ "}"
+ "}";
void TestClassificationNormalUse(const std::unique_ptr<LeafModelOperator>& op) {
Leaf l;
@@ -83,7 +83,6 @@ void TestClassificationNormalUse(const std::unique_ptr<LeafModelOperator>& op) {
EXPECT_FLOAT_EQ(op->GetOutputValue(l, 1), 3.4);
}
-
TEST(DenseLeafModelOperatorsTest, NormalUse) {
TensorForestParams params;
params.set_num_outputs(kNumClasses);
@@ -182,7 +181,7 @@ TEST(SparseLeafModelOperatorsTest, InitWithExisting) {
std::unique_ptr<Leaf> leaf(new Leaf);
- op->ExportModel( *stat, leaf.get());
+ op->ExportModel(*stat, leaf.get());
// Make sure it was initialized correctly.
EXPECT_FLOAT_EQ(op->GetOutputValue(*leaf, 0), 1.1);
@@ -194,7 +193,6 @@ TEST(SparseLeafModelOperatorsTest, InitWithExisting) {
EXPECT_EQ(leaf->sparse_vector().sparse_value().size(), kNumClasses);
}
-
TEST(RegressionLeafModelOperatorsTest, NormalUse) {
TensorForestParams params;
params.set_num_outputs(kNumClasses);
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/params.h b/tensorflow/contrib/tensor_forest/kernels/v4/params.h
index b0ed949424..7583e3d040 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/params.h
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/params.h
@@ -24,7 +24,6 @@ namespace tensorforest {
// Return the value of the given depth-dependent parameter given a leaf's depth.
float ResolveParam(const DepthDependentParam& param, int32 depth);
-
} // namespace tensorforest
} // namespace tensorflow
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/params_test.cc b/tensorflow/contrib/tensor_forest/kernels/v4/params_test.cc
index 801881af13..4010a71006 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/params_test.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/params_test.cc
@@ -71,5 +71,3 @@ TEST(ParamsTest, TestThreshold) {
}
} // namespace
-
-
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc
index cdb1d80a4b..b7b60d0ab8 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc
@@ -52,8 +52,8 @@ std::unique_ptr<GrowStats> SplitCollectionOperator::CreateGrowStats(
new SparseClassificationGrowStats(params_, depth));
case STATS_LEAST_SQUARES_REGRESSION:
- return std::unique_ptr<GrowStats>(new LeastSquaresRegressionGrowStats(
- params_, depth));
+ return std::unique_ptr<GrowStats>(
+ new LeastSquaresRegressionGrowStats(params_, depth));
case STATS_FIXED_SIZE_SPARSE_GINI:
return std::unique_ptr<GrowStats>(
@@ -136,8 +136,7 @@ void SplitCollectionOperator::CreateAndInitializeCandidateWithExample(
stats_.at(node_id)->AddSplit(split, input_data, target, example);
}
-bool SplitCollectionOperator::BestSplit(int32 node_id,
- SplitCandidate* best,
+bool SplitCollectionOperator::BestSplit(int32 node_id, SplitCandidate* best,
int32* depth) const {
auto* slot = stats_.at(node_id).get();
*depth = slot->depth();
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h
index ad52f89fad..c606ff98c6 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h
@@ -71,9 +71,7 @@ class SplitCollectionOperator {
}
// Perform any necessary cleanup for any tracked state for the slot.
- virtual void ClearSlot(int32 node_id) {
- stats_.erase(node_id);
- }
+ virtual void ClearSlot(int32 node_id) { stats_.erase(node_id); }
// Return true if slot is fully initialized.
virtual bool IsInitialized(int32 node_id) const;
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.cc b/tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.cc
index 0bec198e97..c749fbe69e 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.cc
@@ -32,9 +32,9 @@ namespace tensorforest {
// smoothed_sum = stats.sum() + #_classes
float GiniImpurity(const LeafStat& stats, int32 num_classes) {
const float smoothed_sum = num_classes + stats.weight_sum();
- return 1.0 - (
- (stats.classification().gini().square()
- + 2 * stats.weight_sum() + num_classes) / (smoothed_sum * smoothed_sum));
+ return 1.0 - ((stats.classification().gini().square() +
+ 2 * stats.weight_sum() + num_classes) /
+ (smoothed_sum * smoothed_sum));
}
float WeightedGiniImpurity(const LeafStat& stats, int32 num_classes) {
@@ -46,21 +46,20 @@ void UpdateGini(LeafStat* stats, float old_val, float weight) {
// Equivalent to stats->square() - old_val * old_val + new_val * new_val,
// (for new_val = old_val + weight), but more numerically stable.
stats->mutable_classification()->mutable_gini()->set_square(
- stats->classification().gini().square()
- + weight * weight + 2 * old_val * weight);
+ stats->classification().gini().square() + weight * weight +
+ 2 * old_val * weight);
}
-
float Variance(const LeafStat& stats, int output) {
if (stats.weight_sum() == 0) {
return 0;
}
const float e_x =
- stats.regression().mean_output().value(output).float_value()
- / stats.weight_sum();
+ stats.regression().mean_output().value(output).float_value() /
+ stats.weight_sum();
const auto e_x2 =
- stats.regression().mean_output_squares().value(output).float_value()
- / stats.weight_sum();
+ stats.regression().mean_output_squares().value(output).float_value() /
+ stats.weight_sum();
return e_x2 - e_x * e_x;
}
@@ -75,8 +74,7 @@ float TotalVariance(const LeafStat& stats) {
float SmoothedGini(float sum, float square, int num_classes) {
// See comments for GiniImpurity above.
const float smoothed_sum = num_classes + sum;
- return 1.0 -
- (square + 2 * sum + num_classes) / (smoothed_sum * smoothed_sum);
+ return 1.0 - (square + 2 * sum + num_classes) / (smoothed_sum * smoothed_sum);
}
float WeightedSmoothedGini(float sum, float square, int num_classes) {
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/test_utils.h b/tensorflow/contrib/tensor_forest/kernels/v4/test_utils.h
index 289c81e9d5..38deb3e3cd 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/test_utils.h
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/test_utils.h
@@ -27,9 +27,7 @@ class TestableInputTarget : public StoredInputTarget<std::vector<float>> {
: StoredInputTarget(new std::vector<float>(t), new std::vector<float>(w),
num_t) {}
- int NumItems() const {
- return target_->size();
- }
+ int NumItems() const { return target_->size(); }
int32 GetTargetAsClassIndex(int example_index,
int target_index) const override {
@@ -51,7 +49,6 @@ class TestableInputTarget : public StoredInputTarget<std::vector<float>> {
}
};
-
class TestableDataSet : public TensorDataSet {
public:
TestableDataSet(const std::vector<float>& data, int num_features)