aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/BUILD3
-rw-r--r--tensorflow/contrib/boosted_trees/BUILD2
-rw-r--r--tensorflow/contrib/boosted_trees/lib/BUILD161
-rw-r--r--tensorflow/contrib/boosted_trees/lib/quantiles/BUILD66
-rw-r--r--tensorflow/contrib/boosted_trees/lib/utils/batch_features.cc152
-rw-r--r--tensorflow/contrib/boosted_trees/lib/utils/batch_features.h79
-rw-r--r--tensorflow/contrib/boosted_trees/lib/utils/batch_features_test.cc213
-rw-r--r--tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.cc138
-rw-r--r--tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h72
-rw-r--r--tensorflow/contrib/boosted_trees/lib/utils/dropout_utils_test.cc331
-rw-r--r--tensorflow/contrib/boosted_trees/lib/utils/example.h50
-rw-r--r--tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.cc83
-rw-r--r--tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.h172
-rw-r--r--tensorflow/contrib/boosted_trees/lib/utils/examples_iterable_test.cc182
-rw-r--r--tensorflow/contrib/boosted_trees/lib/utils/macros.h26
-rw-r--r--tensorflow/contrib/boosted_trees/lib/utils/optional_value.h47
-rw-r--r--tensorflow/contrib/boosted_trees/lib/utils/parallel_for.cc51
-rw-r--r--tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h33
-rw-r--r--tensorflow/contrib/boosted_trees/lib/utils/random.h39
-rw-r--r--tensorflow/contrib/boosted_trees/lib/utils/random_test.cc56
-rw-r--r--tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.cc122
-rw-r--r--tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.h128
-rw-r--r--tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable_test.cc100
-rw-r--r--tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.cc103
-rw-r--r--tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.h60
-rw-r--r--tensorflow/contrib/boosted_trees/proto/BUILD32
-rw-r--r--tensorflow/contrib/boosted_trees/proto/learner.proto136
-rw-r--r--tensorflow/contrib/boosted_trees/proto/tree_config.proto109
28 files changed, 2680 insertions, 66 deletions
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 56d9c598ff..d8d8cbc56d 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -160,6 +160,9 @@ filegroup(
"//tensorflow/contrib:all_files",
"//tensorflow/contrib/android:all_files",
"//tensorflow/contrib/bayesflow:all_files",
+ "//tensorflow/contrib/boosted_trees:all_files",
+ "//tensorflow/contrib/boosted_trees/lib:all_files",
+ "//tensorflow/contrib/boosted_trees/proto:all_files",
"//tensorflow/contrib/cloud:all_files",
"//tensorflow/contrib/cloud/kernels:all_files",
"//tensorflow/contrib/compiler:all_files",
diff --git a/tensorflow/contrib/boosted_trees/BUILD b/tensorflow/contrib/boosted_trees/BUILD
index 7823aa621f..c1600bdabd 100644
--- a/tensorflow/contrib/boosted_trees/BUILD
+++ b/tensorflow/contrib/boosted_trees/BUILD
@@ -17,3 +17,5 @@ filegroup(
),
visibility = ["//tensorflow:__subpackages__"],
)
+
+package_group(name = "friends")
diff --git a/tensorflow/contrib/boosted_trees/lib/BUILD b/tensorflow/contrib/boosted_trees/lib/BUILD
new file mode 100644
index 0000000000..5111576bf0
--- /dev/null
+++ b/tensorflow/contrib/boosted_trees/lib/BUILD
@@ -0,0 +1,161 @@
+# Description:
+# This directory contains common utilities used in boosted_trees.
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+package(
+ default_visibility = [
+ "//tensorflow/contrib/boosted_trees:__subpackages__",
+ "//tensorflow/contrib/boosted_trees:friends",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+cc_library(
+ name = "weighted_quantiles",
+ srcs = [],
+ hdrs = [
+ "quantiles/weighted_quantiles_buffer.h",
+ "quantiles/weighted_quantiles_stream.h",
+ "quantiles/weighted_quantiles_summary.h",
+ ],
+ deps = [
+ "//tensorflow/core:framework_headers_lib",
+ ],
+)
+
+cc_test(
+ name = "weighted_quantiles_buffer_test",
+ size = "small",
+ srcs = ["quantiles/weighted_quantiles_buffer_test.cc"],
+ deps = [
+ ":weighted_quantiles",
+ "//tensorflow/core",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_test(
+ name = "weighted_quantiles_summary_test",
+ size = "small",
+ srcs = ["quantiles/weighted_quantiles_summary_test.cc"],
+ deps = [
+ ":weighted_quantiles",
+ "//tensorflow/core",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_test(
+ name = "weighted_quantiles_stream_test",
+ size = "small",
+ srcs = ["quantiles/weighted_quantiles_stream_test.cc"],
+ deps = [
+ ":weighted_quantiles",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
+ name = "utils",
+ srcs = [
+ "utils/batch_features.cc",
+ "utils/dropout_utils.cc",
+ "utils/examples_iterable.cc",
+ "utils/parallel_for.cc",
+ "utils/sparse_column_iterable.cc",
+ "utils/tensor_utils.cc",
+ ],
+ hdrs = [
+ "utils/batch_features.h",
+ "utils/dropout_utils.h",
+ "utils/example.h",
+ "utils/examples_iterable.h",
+ "utils/macros.h",
+ "utils/optional_value.h",
+ "utils/parallel_for.h",
+ "utils/random.h",
+ "utils/sparse_column_iterable.h",
+ "utils/tensor_utils.h",
+ ],
+ deps = [
+ "//tensorflow/contrib/boosted_trees/proto:learner_proto_cc",
+ "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:protos_all_cc",
+ "//third_party/eigen3",
+ ],
+)
+
+cc_test(
+ name = "sparse_column_iterable_test",
+ size = "small",
+ srcs = ["utils/sparse_column_iterable_test.cc"],
+ deps = [
+ ":utils",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:tensor_testutil",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_test(
+ name = "examples_iterable_test",
+ size = "small",
+ srcs = ["utils/examples_iterable_test.cc"],
+ deps = [
+ ":utils",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:tensor_testutil",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_test(
+ name = "batch_features_test",
+ size = "small",
+ srcs = ["utils/batch_features_test.cc"],
+ deps = [
+ ":utils",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:tensor_testutil",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_test(
+ name = "dropout_utils_test",
+ size = "small",
+ srcs = ["utils/dropout_utils_test.cc"],
+ deps = [
+ ":utils",
+ "//tensorflow/contrib/boosted_trees/proto:learner_proto_cc",
+ "//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:tensor_testutil",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/BUILD b/tensorflow/contrib/boosted_trees/lib/quantiles/BUILD
deleted file mode 100644
index dc24286883..0000000000
--- a/tensorflow/contrib/boosted_trees/lib/quantiles/BUILD
+++ /dev/null
@@ -1,66 +0,0 @@
-# Description:
-# This directory contains a runtime O(nlog(log(n))) and
-# memory O(log^2(n)) efficient approximate quantiles
-# implementation allowing streaming and distributed
-# computation of quantiles on weighted data points.
-licenses(["notice"]) # Apache 2.0
-
-exports_files(["LICENSE"])
-
-package(default_visibility = [
- "//visibility:public",
-])
-
-cc_library(
- name = "weighted_quantiles",
- srcs = [],
- hdrs = [
- "weighted_quantiles_buffer.h",
- "weighted_quantiles_stream.h",
- "weighted_quantiles_summary.h",
- ],
- deps = [
- "//tensorflow/core:framework_headers_lib",
- ],
-)
-
-cc_test(
- name = "weighted_quantiles_buffer_test",
- size = "small",
- srcs = ["weighted_quantiles_buffer_test.cc"],
- deps = [
- ":weighted_quantiles",
- "//tensorflow/core",
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:test",
- "//tensorflow/core:test_main",
- ],
-)
-
-cc_test(
- name = "weighted_quantiles_summary_test",
- size = "small",
- srcs = ["weighted_quantiles_summary_test.cc"],
- deps = [
- ":weighted_quantiles",
- "//tensorflow/core",
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:test",
- "//tensorflow/core:test_main",
- ],
-)
-
-cc_test(
- name = "weighted_quantiles_stream_test",
- size = "small",
- srcs = ["weighted_quantiles_stream_test.cc"],
- deps = [
- ":weighted_quantiles",
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:test",
- "//tensorflow/core:test_main",
- ],
-)
diff --git a/tensorflow/contrib/boosted_trees/lib/utils/batch_features.cc b/tensorflow/contrib/boosted_trees/lib/utils/batch_features.cc
new file mode 100644
index 0000000000..12b377dda7
--- /dev/null
+++ b/tensorflow/contrib/boosted_trees/lib/utils/batch_features.cc
@@ -0,0 +1,152 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "tensorflow/contrib/boosted_trees/lib/utils/batch_features.h"
+#include "tensorflow/contrib/boosted_trees/lib/utils/macros.h"
+#include "tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.h"
+
+namespace tensorflow {
+namespace boosted_trees {
+namespace utils {
+
+Status BatchFeatures::Initialize(
+ std::vector<Tensor> dense_float_features_list,
+ std::vector<Tensor> sparse_float_feature_indices_list,
+ std::vector<Tensor> sparse_float_feature_values_list,
+ std::vector<Tensor> sparse_float_feature_shapes_list,
+ std::vector<Tensor> sparse_int_feature_indices_list,
+ std::vector<Tensor> sparse_int_feature_values_list,
+ std::vector<Tensor> sparse_int_feature_shapes_list) {
+ // Validate number of feature columns.
+ auto num_dense_float_features = dense_float_features_list.size();
+ auto num_sparse_float_features = sparse_float_feature_indices_list.size();
+ auto num_sparse_int_features = sparse_int_feature_indices_list.size();
+ QCHECK(num_dense_float_features + num_sparse_float_features +
+ num_sparse_int_features >
+ 0)
+ << "Must have at least one feature column.";
+
+ // Read dense float features.
+ dense_float_feature_columns_.reserve(num_dense_float_features);
+ for (uint32 dense_feat_idx = 0; dense_feat_idx < num_dense_float_features;
+ ++dense_feat_idx) {
+ auto dense_float_feature = dense_float_features_list[dense_feat_idx];
+ TF_CHECK_AND_RETURN_IF_ERROR(
+ TensorShapeUtils::IsMatrix(dense_float_feature.shape()),
+ errors::InvalidArgument("Dense float feature must be a matrix."));
+ TF_CHECK_AND_RETURN_IF_ERROR(
+ dense_float_feature.dim_size(0) == batch_size_,
+ errors::InvalidArgument(
+ "Dense float vector must have batch_size rows: ", batch_size_,
+ " vs. ", dense_float_feature.dim_size(0)));
+ TF_CHECK_AND_RETURN_IF_ERROR(
+ dense_float_feature.dim_size(1) == 1,
+ errors::InvalidArgument(
+ "Dense float features may not be multi-valent: dim_size(1) = ",
+ dense_float_feature.dim_size(1)));
+ dense_float_feature_columns_.emplace_back(dense_float_feature);
+ }
+
+ // Read sparse float features.
+ sparse_float_feature_columns_.reserve(num_sparse_float_features);
+ TF_CHECK_AND_RETURN_IF_ERROR(
+ sparse_float_feature_values_list.size() == num_sparse_float_features &&
+ sparse_float_feature_shapes_list.size() == num_sparse_float_features,
+ errors::InvalidArgument("Inconsistent number of sparse float features."));
+ for (uint32 sparse_feat_idx = 0; sparse_feat_idx < num_sparse_float_features;
+ ++sparse_feat_idx) {
+ auto sparse_float_feature_indices =
+ sparse_float_feature_indices_list[sparse_feat_idx];
+ auto sparse_float_feature_values =
+ sparse_float_feature_values_list[sparse_feat_idx];
+ auto sparse_float_feature_shape =
+ sparse_float_feature_shapes_list[sparse_feat_idx];
+ TF_CHECK_AND_RETURN_IF_ERROR(
+ TensorShapeUtils::IsMatrix(sparse_float_feature_indices.shape()),
+ errors::InvalidArgument(
+ "Sparse float feature indices must be a matrix."));
+ TF_CHECK_AND_RETURN_IF_ERROR(
+ TensorShapeUtils::IsVector(sparse_float_feature_values.shape()),
+ errors::InvalidArgument(
+ "Sparse float feature values must be a vector."));
+ TF_CHECK_AND_RETURN_IF_ERROR(
+ TensorShapeUtils::IsVector(sparse_float_feature_shape.shape()),
+ errors::InvalidArgument(
+ "Sparse float feature shape must be a vector."));
+ auto shape_flat = sparse_float_feature_shape.flat<int64>();
+ TF_CHECK_AND_RETURN_IF_ERROR(
+ shape_flat.size() == 2,
+ errors::InvalidArgument(
+ "Sparse float feature column must be two-dimensional."));
+ TF_CHECK_AND_RETURN_IF_ERROR(
+ shape_flat(0) == batch_size_,
+ errors::InvalidArgument(
+ "Sparse float feature shape incompatible with batch size."));
+ TF_CHECK_AND_RETURN_IF_ERROR(
+ shape_flat(1) <= 1,
+ errors::InvalidArgument(
+ "Sparse float features may not be multi-valent."));
+ auto tensor_shape = TensorShape({shape_flat(0), shape_flat(1)});
+ auto order_dims = sparse::SparseTensor::VarDimArray({0, 1});
+ sparse_float_feature_columns_.emplace_back(sparse_float_feature_indices,
+ sparse_float_feature_values,
+ tensor_shape, order_dims);
+ }
+
+ // Read sparse int features.
+ sparse_int_feature_columns_.reserve(num_sparse_int_features);
+ TF_CHECK_AND_RETURN_IF_ERROR(
+ sparse_int_feature_values_list.size() == num_sparse_int_features &&
+ sparse_int_feature_shapes_list.size() == num_sparse_int_features,
+ errors::InvalidArgument("Inconsistent number of sparse int features."));
+ for (uint32 sparse_feat_idx = 0; sparse_feat_idx < num_sparse_int_features;
+ ++sparse_feat_idx) {
+ auto sparse_int_feature_indices =
+ sparse_int_feature_indices_list[sparse_feat_idx];
+ auto sparse_int_feature_values =
+ sparse_int_feature_values_list[sparse_feat_idx];
+ auto sparse_int_feature_shape =
+ sparse_int_feature_shapes_list[sparse_feat_idx];
+ TF_CHECK_AND_RETURN_IF_ERROR(
+ TensorShapeUtils::IsMatrix(sparse_int_feature_indices.shape()),
+ errors::InvalidArgument(
+ "Sparse int feature indices must be a matrix."));
+ TF_CHECK_AND_RETURN_IF_ERROR(
+ TensorShapeUtils::IsVector(sparse_int_feature_values.shape()),
+ errors::InvalidArgument("Sparse int feature values must be a vector."));
+ TF_CHECK_AND_RETURN_IF_ERROR(
+ TensorShapeUtils::IsVector(sparse_int_feature_shape.shape()),
+ errors::InvalidArgument("Sparse int feature shape must be a vector."));
+ auto shape_flat = sparse_int_feature_shape.flat<int64>();
+ TF_CHECK_AND_RETURN_IF_ERROR(
+ shape_flat.size() == 2,
+ errors::InvalidArgument(
+ "Sparse int feature column must be two-dimensional."));
+ TF_CHECK_AND_RETURN_IF_ERROR(
+ shape_flat(0) == batch_size_,
+ errors::InvalidArgument(
+ "Sparse int feature shape incompatible with batch size."));
+ auto tensor_shape = TensorShape({shape_flat(0), shape_flat(1)});
+ auto order_dims = sparse::SparseTensor::VarDimArray({0, 1});
+ sparse_int_feature_columns_.emplace_back(sparse_int_feature_indices,
+ sparse_int_feature_values,
+ tensor_shape, order_dims);
+ }
+ return Status::OK();
+}
+
+} // namespace utils
+} // namespace boosted_trees
+} // namespace tensorflow
diff --git a/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h b/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h
new file mode 100644
index 0000000000..bb11dc9a07
--- /dev/null
+++ b/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h
@@ -0,0 +1,79 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_BATCH_FEATURES_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_BATCH_FEATURES_H_
+
+#include <vector>
+#include "tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/util/sparse/sparse_tensor.h"
+
+namespace tensorflow {
+namespace boosted_trees {
+namespace utils {
+
+class BatchFeatures {
+ public:
+ // Constructs batch features with a fixed batch size.
+ explicit BatchFeatures(int64 batch_size) : batch_size_(batch_size) {}
+
+ // Disallow copy and assign.
+ BatchFeatures(const BatchFeatures& other) = delete;
+ BatchFeatures& operator=(const BatchFeatures& other) = delete;
+
+ // Method to initialize batch features from op kernel context.
+ Status Initialize(std::vector<Tensor> dense_float_features_list,
+ std::vector<Tensor> sparse_float_feature_indices_list,
+ std::vector<Tensor> sparse_float_feature_values_list,
+ std::vector<Tensor> sparse_float_feature_shapes_list,
+ std::vector<Tensor> sparse_int_feature_indices_list,
+ std::vector<Tensor> sparse_int_feature_values_list,
+ std::vector<Tensor> sparse_int_feature_shapes_list);
+
+ // Creates an example iterable for the requested slice.
+ ExamplesIterable examples_iterable(int64 example_start,
+ int64 example_end) const {
+ QCHECK(example_start >= 0 && example_end >= 0);
+ QCHECK(example_start < batch_size_ && example_end <= batch_size_);
+ return ExamplesIterable(
+ dense_float_feature_columns_, sparse_float_feature_columns_,
+ sparse_int_feature_columns_, example_start, example_end);
+ }
+
+ // Returns the fixed batch size.
+ int64 batch_size() const { return batch_size_; }
+
+ private:
+ // Total number of examples in the batch.
+ const int64 batch_size_;
+
+ // Dense float feature columns.
+ std::vector<Tensor> dense_float_feature_columns_;
+
+ // Sparse float feature columns.
+ std::vector<sparse::SparseTensor> sparse_float_feature_columns_;
+
+ // Sparse int feature columns.
+ std::vector<sparse::SparseTensor> sparse_int_feature_columns_;
+};
+
+} // namespace utils
+} // namespace boosted_trees
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_BATCH_FEATURES_H_
diff --git a/tensorflow/contrib/boosted_trees/lib/utils/batch_features_test.cc b/tensorflow/contrib/boosted_trees/lib/utils/batch_features_test.cc
new file mode 100644
index 0000000000..7f523d527a
--- /dev/null
+++ b/tensorflow/contrib/boosted_trees/lib/utils/batch_features_test.cc
@@ -0,0 +1,213 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "tensorflow/contrib/boosted_trees/lib/utils/batch_features.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace boosted_trees {
+namespace utils {
+namespace {
+
+using test::AsTensor;
+using errors::InvalidArgument;
+
+class BatchFeaturesTest : public ::testing::Test {};
+
+TEST_F(BatchFeaturesTest, InvalidNumFeatures) {
+ BatchFeatures batch_features(8);
+ EXPECT_DEATH(({ batch_features.Initialize({}, {}, {}, {}, {}, {}, {}); })
+ .IgnoreError(),
+ "Must have at least one feature column.");
+}
+
+TEST_F(BatchFeaturesTest, DenseFloatFeatures_WrongShape) {
+ BatchFeatures batch_features(8);
+ auto dense_vec = AsTensor<float>({3.0f, 7.0f});
+ auto expected_error =
+ InvalidArgument("Dense float feature must be a matrix.");
+ EXPECT_EQ(expected_error,
+ batch_features.Initialize({dense_vec}, {}, {}, {}, {}, {}, {}));
+}
+
+TEST_F(BatchFeaturesTest, DenseFloatFeatures_WrongBatchDimension) {
+ BatchFeatures batch_features(8);
+ auto dense_vec = AsTensor<float>({3.0f, 7.0f}, {2, 1});
+ auto expected_error =
+ InvalidArgument("Dense float vector must have batch_size rows: 8 vs. 2");
+ EXPECT_EQ(expected_error,
+ batch_features.Initialize({dense_vec}, {}, {}, {}, {}, {}, {}));
+}
+
+TEST_F(BatchFeaturesTest, DenseFloatFeatures_Multivalent) {
+ BatchFeatures batch_features(1);
+ auto dense_vec = AsTensor<float>({3.0f, 7.0f}, {1, 2});
+ auto expected_error = InvalidArgument(
+ "Dense float features may not be multi-valent: dim_size(1) = 2");
+ EXPECT_EQ(expected_error,
+ batch_features.Initialize({dense_vec}, {}, {}, {}, {}, {}, {}));
+}
+
+TEST_F(BatchFeaturesTest, SparseFloatFeatures_WrongShapeIndices) {
+ BatchFeatures batch_features(2);
+ auto sparse_float_feature_indices = AsTensor<int64>({0, 0, 1, 0});
+ auto sparse_float_feature_values = AsTensor<float>({3.0f, 7.0f});
+ auto sparse_float_feature_shape = AsTensor<int64>({2, 1});
+ auto expected_error =
+ InvalidArgument("Sparse float feature indices must be a matrix.");
+ EXPECT_EQ(expected_error, batch_features.Initialize(
+ {}, {sparse_float_feature_indices},
+ {sparse_float_feature_values},
+ {sparse_float_feature_shape}, {}, {}, {}));
+}
+
+TEST_F(BatchFeaturesTest, SparseFloatFeatures_WrongShapeValues) {
+ BatchFeatures batch_features(2);
+ auto sparse_float_feature_indices = AsTensor<int64>({0, 0, 1, 0}, {2, 2});
+ auto sparse_float_feature_values = AsTensor<float>({3.0f, 7.0f}, {1, 2});
+ auto sparse_float_feature_shape = AsTensor<int64>({2, 1});
+ auto expected_error =
+ InvalidArgument("Sparse float feature values must be a vector.");
+ EXPECT_EQ(expected_error, batch_features.Initialize(
+ {}, {sparse_float_feature_indices},
+ {sparse_float_feature_values},
+ {sparse_float_feature_shape}, {}, {}, {}));
+}
+
+TEST_F(BatchFeaturesTest, SparseFloatFeatures_WrongShapeShape) {
+ BatchFeatures batch_features(2);
+ auto sparse_float_feature_indices = AsTensor<int64>({0, 0, 1, 0}, {2, 2});
+ auto sparse_float_feature_values = AsTensor<float>({3.0f, 7.0f});
+ auto sparse_float_feature_shape = AsTensor<int64>({2, 1}, {1, 2});
+ auto expected_error =
+ InvalidArgument("Sparse float feature shape must be a vector.");
+ EXPECT_EQ(expected_error, batch_features.Initialize(
+ {}, {sparse_float_feature_indices},
+ {sparse_float_feature_values},
+ {sparse_float_feature_shape}, {}, {}, {}));
+}
+
+TEST_F(BatchFeaturesTest, SparseFloatFeatures_WrongSizeShape) {
+ BatchFeatures batch_features(2);
+ auto sparse_float_feature_indices = AsTensor<int64>({0, 0, 1, 0}, {2, 2});
+ auto sparse_float_feature_values = AsTensor<float>({3.0f, 7.0f});
+ auto sparse_float_feature_shape = AsTensor<int64>({2, 1, 9});
+ auto expected_error =
+ InvalidArgument("Sparse float feature column must be two-dimensional.");
+ EXPECT_EQ(expected_error, batch_features.Initialize(
+ {}, {sparse_float_feature_indices},
+ {sparse_float_feature_values},
+ {sparse_float_feature_shape}, {}, {}, {}));
+}
+
+TEST_F(BatchFeaturesTest, SparseFloatFeatures_IncompatibleShape) {
+ BatchFeatures batch_features(2);
+ auto sparse_float_feature_indices = AsTensor<int64>({0, 0, 1, 0}, {2, 2});
+ auto sparse_float_feature_values = AsTensor<float>({3.0f, 7.0f});
+ auto sparse_float_feature_shape = AsTensor<int64>({8, 1});
+ auto expected_error = InvalidArgument(
+ "Sparse float feature shape incompatible with batch size.");
+ EXPECT_EQ(expected_error, batch_features.Initialize(
+ {}, {sparse_float_feature_indices},
+ {sparse_float_feature_values},
+ {sparse_float_feature_shape}, {}, {}, {}));
+}
+
+TEST_F(BatchFeaturesTest, SparseFloatFeatures_Multivalent) {
+ BatchFeatures batch_features(2);
+ auto sparse_float_feature_indices = AsTensor<int64>({0, 0, 1, 0}, {2, 2});
+ auto sparse_float_feature_values = AsTensor<float>({3.0f, 7.0f});
+ auto sparse_float_feature_shape = AsTensor<int64>({2, 2});
+ auto expected_error =
+ InvalidArgument("Sparse float features may not be multi-valent.");
+ EXPECT_EQ(expected_error, batch_features.Initialize(
+ {}, {sparse_float_feature_indices},
+ {sparse_float_feature_values},
+ {sparse_float_feature_shape}, {}, {}, {}));
+}
+
+TEST_F(BatchFeaturesTest, SparseIntFeatures_WrongShapeIndices) {
+ BatchFeatures batch_features(2);
+ auto sparse_int_feature_indices = AsTensor<int64>({0, 0, 1, 0});
+ auto sparse_int_feature_values = AsTensor<int64>({3, 7});
+ auto sparse_int_feature_shape = AsTensor<int64>({2, 1});
+ auto expected_error =
+ InvalidArgument("Sparse int feature indices must be a matrix.");
+ EXPECT_EQ(expected_error,
+ batch_features.Initialize(
+ {}, {}, {}, {}, {sparse_int_feature_indices},
+ {sparse_int_feature_values}, {sparse_int_feature_shape}));
+}
+
+TEST_F(BatchFeaturesTest, SparseIntFeatures_WrongShapeValues) {
+ BatchFeatures batch_features(2);
+ auto sparse_int_feature_indices = AsTensor<int64>({0, 0, 1, 0}, {2, 2});
+ auto sparse_int_feature_values = AsTensor<int64>({3, 7}, {1, 2});
+ auto sparse_int_feature_shape = AsTensor<int64>({2, 1});
+ auto expected_error =
+ InvalidArgument("Sparse int feature values must be a vector.");
+ EXPECT_EQ(expected_error,
+ batch_features.Initialize(
+ {}, {}, {}, {}, {sparse_int_feature_indices},
+ {sparse_int_feature_values}, {sparse_int_feature_shape}));
+}
+
+TEST_F(BatchFeaturesTest, SparseIntFeatures_WrongShapeShape) {
+ BatchFeatures batch_features(2);
+ auto sparse_int_feature_indices = AsTensor<int64>({0, 0, 1, 0}, {2, 2});
+ auto sparse_int_feature_values = AsTensor<int64>({3, 7});
+ auto sparse_int_feature_shape = AsTensor<int64>({2, 1}, {1, 2});
+ auto expected_error =
+ InvalidArgument("Sparse int feature shape must be a vector.");
+ EXPECT_EQ(expected_error,
+ batch_features.Initialize(
+ {}, {}, {}, {}, {sparse_int_feature_indices},
+ {sparse_int_feature_values}, {sparse_int_feature_shape}));
+}
+
+TEST_F(BatchFeaturesTest, SparseIntFeatures_WrongSizeShape) {
+ BatchFeatures batch_features(2);
+ auto sparse_int_feature_indices = AsTensor<int64>({0, 0, 1, 0}, {2, 2});
+ auto sparse_int_feature_values = AsTensor<int64>({3, 7});
+ auto sparse_int_feature_shape = AsTensor<int64>({2, 1, 9});
+ auto expected_error =
+ InvalidArgument("Sparse int feature column must be two-dimensional.");
+ EXPECT_EQ(expected_error,
+ batch_features.Initialize(
+ {}, {}, {}, {}, {sparse_int_feature_indices},
+ {sparse_int_feature_values}, {sparse_int_feature_shape}));
+}
+
+TEST_F(BatchFeaturesTest, SparseIntFeatures_IncompatibleShape) {
+ BatchFeatures batch_features(2);
+ auto sparse_int_feature_indices = AsTensor<int64>({0, 0, 1, 0}, {2, 2});
+ auto sparse_int_feature_values = AsTensor<int64>({3, 7});
+ auto sparse_int_feature_shape = AsTensor<int64>({8, 1});
+ auto expected_error =
+ InvalidArgument("Sparse int feature shape incompatible with batch size.");
+ EXPECT_EQ(expected_error,
+ batch_features.Initialize(
+ {}, {}, {}, {}, {sparse_int_feature_indices},
+ {sparse_int_feature_values}, {sparse_int_feature_shape}));
+}
+
+} // namespace
+} // namespace utils
+} // namespace boosted_trees
+} // namespace tensorflow
diff --git a/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.cc b/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.cc
new file mode 100644
index 0000000000..7e98fcf789
--- /dev/null
+++ b/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.cc
@@ -0,0 +1,138 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#include "tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h"
+
+#include <iterator>
+#include <numeric>
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/random/philox_random.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/platform/logging.h"
+
+using tensorflow::boosted_trees::learner::LearningRateDropoutDrivenConfig;
+using tensorflow::random::PhiloxRandom;
+using tensorflow::random::SimplePhilox;
+using tensorflow::Status;
+
+namespace tensorflow {
+namespace boosted_trees {
+namespace utils {
+
+Status DropoutUtils::DropOutTrees(const uint64 seed,
+ const LearningRateDropoutDrivenConfig& config,
+ const std::vector<float>& weights,
+ std::vector<int32>* dropped_trees,
+ std::vector<float>* original_weights) {
+ // Verify params.
+ if (dropped_trees == nullptr) {
+ return errors::Internal("Dropped trees is nullptr.");
+ }
+ if (original_weights == nullptr) {
+ return errors::InvalidArgument("Original weights is nullptr.");
+ }
+ const float dropout_probability = config.dropout_probability();
+ if (dropout_probability < 0 || dropout_probability > 1) {
+ return errors::InvalidArgument(
+ "Dropout probability must be in [0,1] range");
+ }
+ const float learning_rate = config.learning_rate();
+ if (learning_rate <= 0) {
+ return errors::InvalidArgument("Learning rate must be in (0,1] range.");
+ }
+ const float probability_of_skipping_dropout =
+ config.probability_of_skipping_dropout();
+ if (probability_of_skipping_dropout < 0 ||
+ probability_of_skipping_dropout > 1) {
+ return errors::InvalidArgument(
+ "Probability of skiping dropout must be in [0,1] range");
+ }
+ const auto num_trees = weights.size();
+
+ dropped_trees->clear();
+ original_weights->clear();
+
+ // If dropout is no op, return.
+ if (dropout_probability == 0 || probability_of_skipping_dropout == 1.0) {
+ return Status::OK();
+ }
+
+ // Roll the dice for each tree.
+ PhiloxRandom philox(seed);
+ SimplePhilox rng(&philox);
+
+ std::vector<int32> trees_to_keep;
+
+ // What is the probability of skipping dropout altogether.
+ if (probability_of_skipping_dropout != 0) {
+ // First roll the dice - do we do dropout
+ double roll = rng.RandDouble();
+ if (roll < probability_of_skipping_dropout) {
+ // don't do dropout
+ return Status::OK();
+ }
+ }
+
+ for (int32 i = 0; i < num_trees; ++i) {
+ double roll = rng.RandDouble();
+ if (roll >= dropout_probability) {
+ trees_to_keep.push_back(i);
+ } else {
+ dropped_trees->push_back(i);
+ }
+ }
+
+ // Sort the dropped trees indices.
+ std::sort(dropped_trees->begin(), dropped_trees->end());
+ for (const int32 dropped_tree : *dropped_trees) {
+ original_weights->push_back(weights[dropped_tree]);
+ }
+
+ return Status::OK();
+}
+
+void DropoutUtils::GetTreesWeightsForAddingTrees(
+ const std::vector<int32>& dropped_trees,
+ const std::vector<float>& dropped_trees_original_weights,
+ const int32 num_trees_to_add, std::vector<float>* current_weights,
+ std::vector<int32>* num_updates) {
+ CHECK(num_updates->size() == current_weights->size());
+ // combined weight of trees that were dropped out
+ const float dropped_sum =
+ std::accumulate(dropped_trees_original_weights.begin(),
+ dropped_trees_original_weights.end(), 0.0);
+
+ const int num_dropped = dropped_trees.size();
+
+ // Allocate additional weight for the new tree
+ const float total_new_trees_weight = dropped_sum / (num_dropped + 1);
+ for (int i = 0; i < num_trees_to_add; ++i) {
+ current_weights->push_back(total_new_trees_weight / num_trees_to_add);
+ num_updates->push_back(1);
+ }
+
+ for (int32 i = 0; i < dropped_trees.size(); ++i) {
+ const int32 dropped = dropped_trees[i];
+ const float original_weight = dropped_trees_original_weights[i];
+ const float new_weight = original_weight * num_dropped / (num_dropped + 1);
+ (*current_weights)[dropped] = new_weight;
+ // Update the number of updates per tree.
+ ++(*num_updates)[dropped];
+ }
+}
+
+} // namespace utils
+} // namespace boosted_trees
+} // namespace tensorflow
diff --git a/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h b/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h
new file mode 100644
index 0000000000..6d6cc594c6
--- /dev/null
+++ b/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h
@@ -0,0 +1,72 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_DROPOUT_UTILS_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_DROPOUT_UTILS_H_
+
+#include <vector>
+
+#include "tensorflow/contrib/boosted_trees/proto/learner.pb.h" // NOLINT
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace boosted_trees {
+namespace utils {
+
+// Utils for deciding on what trees to be/were dropped when building a new tree.
+class DropoutUtils {
+ public:
+ // This method determines what trees should be dropped and returns their
+ // indices and the weights they had when this method ran.
+ // seed: random seed to be used
+ // config: dropout config, that defines the probability of dropout etc
+ // number_of_trees_to_consider: how many trees are currently in the ensemble
+ // weights: weights of those trees
+ // Returns sorted vector of indices of trees to be dropped and their original
+ // weights.
+ static tensorflow::Status DropOutTrees(
+ const uint64 seed, const learner::LearningRateDropoutDrivenConfig& config,
+ const std::vector<float>& weights, std::vector<int32>* dropped_trees,
+ std::vector<float>* original_weights);
+
+ // Recalculates the weights of the trees when the new trees are added to
+ // ensemble.
+ // dropped_trees: ids of trees that were dropped when trees to add were built.
+ // dropped_trees_original_weights: the weight dropped trees had during dropout
+ // num_trees_to_add: how many trees are being added to the ensemble.
+ // Returns
+ // current_weights: updated vector of the tree weights. Weights of dropped
+ // trees are updated. Note that the size of returned vector will be
+ // total_num_trees + num_trees_to_add (the last elements are the weights of
+ // the new trees to be
+ // added).
+ // num_updates: updated vector with increased number of updates for dropped
+ // trees.
+ static void GetTreesWeightsForAddingTrees(
+ const std::vector<int32>& dropped_trees,
+ const std::vector<float>& dropped_trees_original_weights,
+ const int32 num_trees_to_add,
+ // Current weights and num_updates will be updated as a result of this
+ // func
+ std::vector<float>* current_weights,
+ // How many weight assignements have been done for each tree already.
+ std::vector<int32>* num_updates);
+};
+
+} // namespace utils
+} // namespace boosted_trees
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_DROPOUT_UTILS_H_
diff --git a/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils_test.cc b/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils_test.cc
new file mode 100644
index 0000000000..8bc1dbfdf2
--- /dev/null
+++ b/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils_test.cc
@@ -0,0 +1,331 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#include "tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h"
+
+#include <sys/types.h>
+#include <algorithm>
+#include <cstdlib>
+#include <ctime>
+#include <functional>
+#include <iterator>
+#include <unordered_set>
+#include <utility>
+
+#include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h" // NOLINT
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+using tensorflow::boosted_trees::learner::LearningRateDropoutDrivenConfig;
+using tensorflow::boosted_trees::trees::DecisionTreeEnsembleConfig;
+using std::unordered_set;
+
+namespace tensorflow {
+namespace boosted_trees {
+namespace utils {
+namespace {
+
+const uint32 kSeed = 123;
+const int32 kNumTrees = 1000;
+
+class DropoutUtilsTest : public ::testing::Test {
+ public:
+ void SetUp() override {
+ // Fill an weights.
+ for (int i = 0; i < kNumTrees; ++i) {
+ weights_.push_back(1.1 + 0.4 * i);
+ }
+ }
+
+ protected:
+ std::vector<float> weights_;
+};
+
+TEST_F(DropoutUtilsTest, DropoutProbabilityTest) {
+ std::vector<int32> dropped_trees;
+ std::vector<float> original_weights;
+
+ // Do not drop any trees
+ {
+ LearningRateDropoutDrivenConfig config;
+ config.set_dropout_probability(0.0);
+ config.set_learning_rate(1.0);
+
+ TF_EXPECT_OK(DropoutUtils::DropOutTrees(kSeed, config, weights_,
+ &dropped_trees, &original_weights));
+
+ // Nothing changed
+ EXPECT_TRUE(dropped_trees.empty());
+ EXPECT_TRUE(original_weights.empty());
+ }
+ // Drop out all trees
+ {
+ LearningRateDropoutDrivenConfig config;
+ config.set_dropout_probability(1.0);
+ config.set_learning_rate(1.0);
+
+ TF_EXPECT_OK(DropoutUtils::DropOutTrees(kSeed, config, weights_,
+ &dropped_trees, &original_weights));
+
+ // No trees left
+ EXPECT_EQ(kNumTrees, dropped_trees.size());
+ EXPECT_EQ(kNumTrees, original_weights.size());
+ EXPECT_EQ(original_weights, weights_);
+ }
+ // 50% probability of dropping a tree
+ {
+ const int32 kNumRuns = 1000;
+ LearningRateDropoutDrivenConfig config;
+ config.set_dropout_probability(0.5);
+ config.set_learning_rate(1.0);
+
+ int32 total_num_trees = 0;
+ for (int i = 0; i < kNumRuns; ++i) {
+ // draw random seeds
+ uint random_generator_seed = static_cast<uint>(std::clock());
+ uint32 seed = rand_r(&random_generator_seed) % 100 + i;
+ TF_EXPECT_OK(DropoutUtils::DropOutTrees(
+ seed, config, weights_, &dropped_trees, &original_weights));
+
+ // We would expect 400-600 trees left
+ EXPECT_NEAR(500, kNumTrees - dropped_trees.size(), 100);
+ total_num_trees += kNumTrees - dropped_trees.size();
+
+ // Trees dropped are unique
+ unordered_set<int32> ids;
+ for (const auto& tree : dropped_trees) {
+ ids.insert(tree);
+ }
+ EXPECT_EQ(ids.size(), dropped_trees.size());
+ }
+ EXPECT_NEAR(500, total_num_trees / kNumRuns, 5);
+ }
+}
+
+TEST_F(DropoutUtilsTest, DropoutSeedTest) {
+ // Different seeds remove different trees
+ {
+ LearningRateDropoutDrivenConfig config;
+ config.set_dropout_probability(0.5);
+ config.set_learning_rate(1.0);
+
+ std::vector<int32> dropped_trees_1;
+ std::vector<float> original_weights_1;
+ std::vector<int32> dropped_trees_2;
+ std::vector<float> original_weights_2;
+
+ DecisionTreeEnsembleConfig new_ensemble_1;
+ DecisionTreeEnsembleConfig new_ensemble_2;
+
+ TF_EXPECT_OK(DropoutUtils::DropOutTrees(
+ kSeed + 1, config, weights_, &dropped_trees_1, &original_weights_1));
+ TF_EXPECT_OK(DropoutUtils::DropOutTrees(
+ kSeed + 2, config, weights_, &dropped_trees_2, &original_weights_2));
+
+ EXPECT_FALSE(dropped_trees_1 == dropped_trees_2);
+ EXPECT_FALSE(original_weights_1 == original_weights_2);
+ }
+ // The same seed produces the same result
+ {
+ LearningRateDropoutDrivenConfig config;
+ config.set_dropout_probability(0.5);
+ config.set_learning_rate(1.0);
+
+ std::vector<int32> dropped_trees_1;
+ std::vector<float> original_weights_1;
+ std::vector<int32> dropped_trees_2;
+ std::vector<float> original_weights_2;
+
+ DecisionTreeEnsembleConfig new_ensemble_1;
+ DecisionTreeEnsembleConfig new_ensemble_2;
+
+ TF_EXPECT_OK(DropoutUtils::DropOutTrees(
+ kSeed, config, weights_, &dropped_trees_1, &original_weights_1));
+ TF_EXPECT_OK(DropoutUtils::DropOutTrees(
+ kSeed, config, weights_, &dropped_trees_2, &original_weights_2));
+
+ EXPECT_TRUE(dropped_trees_1 == dropped_trees_2);
+ EXPECT_TRUE(original_weights_1 == original_weights_2);
+ }
+}
+
+TEST_F(DropoutUtilsTest, InvalidConfigTest) {
+ std::vector<int32> dropped_trees;
+ std::vector<float> original_weights;
+ // Negative prob
+ {
+ LearningRateDropoutDrivenConfig config;
+ config.set_dropout_probability(-1.34);
+
+ EXPECT_FALSE(DropoutUtils::DropOutTrees(kSeed, config, weights_,
+ &dropped_trees, &original_weights)
+ .ok());
+ }
+ // Larger than 1 prob of dropping a tree.
+ {
+ LearningRateDropoutDrivenConfig config;
+ config.set_dropout_probability(1.34);
+
+ EXPECT_FALSE(DropoutUtils::DropOutTrees(kSeed, config, weights_,
+ &dropped_trees, &original_weights)
+ .ok());
+ }
+ // Negative probability of skipping dropout.
+ {
+ LearningRateDropoutDrivenConfig config;
+ config.set_dropout_probability(0.5);
+ config.set_probability_of_skipping_dropout(-10);
+
+ DecisionTreeEnsembleConfig new_ensemble;
+ EXPECT_FALSE(DropoutUtils::DropOutTrees(kSeed, config, weights_,
+ &dropped_trees, &original_weights)
+ .ok());
+ }
+ // Larger than 1 probability of skipping dropout.
+ {
+ LearningRateDropoutDrivenConfig config;
+ config.set_dropout_probability(0.5);
+ config.set_probability_of_skipping_dropout(1.2);
+
+ DecisionTreeEnsembleConfig new_ensemble;
+ EXPECT_FALSE(DropoutUtils::DropOutTrees(kSeed, config, weights_,
+ &dropped_trees, &original_weights)
+ .ok());
+ }
+}
+namespace {
+
+void ExpectVecsEquiv(const std::vector<float>& vec1,
+ const std::vector<float>& vec2) {
+ EXPECT_EQ(vec1.size(), vec2.size());
+ for (int i = 0; i < vec1.size(); ++i) {
+ EXPECT_NEAR(vec1[i], vec2[i], 1e-3);
+ }
+}
+
+std::vector<float> GetWeightsByIndex(const std::vector<float>& weights,
+ const std::vector<int>& indices) {
+ std::vector<float> res;
+ for (const int index : indices) {
+ res.push_back(weights[index]);
+ }
+ return res;
+}
+
+void MergeLastElements(const int32 last_n, std::vector<float>* weights) {
+ float sum = 0.0;
+ for (int i = 0; i < last_n; ++i) {
+ sum += weights->back();
+ weights->pop_back();
+ }
+ weights->push_back(sum);
+}
+
+} // namespace
+
+TEST_F(DropoutUtilsTest, GetTreesWeightsForAddingTreesTest) {
+ // Adding trees should give the same res in any order
+ {
+ std::vector<float> weights = {1.0, 1.0, 1.0, 1.0, 1.0};
+ std::vector<int32> dropped_1 = {0, 3};
+
+ std::vector<int32> dropped_2 = {0};
+
+ std::vector<float> res_1;
+ std::vector<float> res_2;
+ // Do one order
+ {
+ std::vector<float> current_weights = weights;
+ std::vector<int32> num_updates =
+ std::vector<int32>(current_weights.size(), 1);
+ DropoutUtils::GetTreesWeightsForAddingTrees(
+ dropped_1, GetWeightsByIndex(current_weights, dropped_1), 1,
+ &current_weights, &num_updates);
+ DropoutUtils::GetTreesWeightsForAddingTrees(
+ dropped_2, GetWeightsByIndex(current_weights, dropped_2), 1,
+ &current_weights, &num_updates);
+ res_1 = current_weights;
+ }
+ // Do another order
+ {
+ std::vector<float> current_weights = weights;
+ std::vector<int32> num_updates =
+ std::vector<int32>(current_weights.size(), 1);
+
+ DropoutUtils::GetTreesWeightsForAddingTrees(
+ dropped_2, GetWeightsByIndex(current_weights, dropped_2), 1,
+ &current_weights, &num_updates);
+ DropoutUtils::GetTreesWeightsForAddingTrees(
+ dropped_1, GetWeightsByIndex(current_weights, dropped_1), 1,
+ &current_weights, &num_updates);
+ res_2 = current_weights;
+ }
+ // The vectors are the same, but the last two elements have the same sum.
+ EXPECT_EQ(res_1.size(), 7);
+ EXPECT_EQ(res_2.size(), 7);
+
+ MergeLastElements(2, &res_1);
+ MergeLastElements(2, &res_2);
+
+ EXPECT_EQ(res_1, res_2);
+ }
+ // Now when the weights are not all 1s
+ {
+ std::vector<float> weights = {1.1, 2.1, 3.1, 4.1, 5.1};
+ std::vector<int32> dropped_1 = {0, 3};
+
+ std::vector<int32> dropped_2 = {0};
+
+ std::vector<float> res_1;
+ std::vector<float> res_2;
+ // Do one order
+ {
+ std::vector<float> current_weights = weights;
+ std::vector<int32> num_updates =
+ std::vector<int32>(current_weights.size(), 1);
+ DropoutUtils::GetTreesWeightsForAddingTrees(
+ dropped_1, GetWeightsByIndex(current_weights, dropped_1), 1,
+ &current_weights, &num_updates);
+ DropoutUtils::GetTreesWeightsForAddingTrees(
+ dropped_2, GetWeightsByIndex(current_weights, dropped_2), 1,
+ &current_weights, &num_updates);
+ res_1 = current_weights;
+ }
+ // Do another order
+ {
+ std::vector<float> current_weights = weights;
+ std::vector<int32> num_updates =
+ std::vector<int32>(current_weights.size(), 1);
+ DropoutUtils::GetTreesWeightsForAddingTrees(
+ dropped_2, GetWeightsByIndex(current_weights, dropped_2), 1,
+ &current_weights, &num_updates);
+ DropoutUtils::GetTreesWeightsForAddingTrees(
+ dropped_1, GetWeightsByIndex(current_weights, dropped_1), 1,
+ &current_weights, &num_updates);
+ res_2 = current_weights;
+ }
+ EXPECT_EQ(res_1.size(), 7);
+ EXPECT_EQ(res_2.size(), 7);
+
+ // The vectors are the same, but the last two elements have the same sum.
+ MergeLastElements(2, &res_1);
+ MergeLastElements(2, &res_2);
+
+ ExpectVecsEquiv(res_1, res_2);
+ }
+}
+
+} // namespace
+} // namespace utils
+} // namespace boosted_trees
+} // namespace tensorflow
diff --git a/tensorflow/contrib/boosted_trees/lib/utils/example.h b/tensorflow/contrib/boosted_trees/lib/utils/example.h
new file mode 100644
index 0000000000..4681eb06aa
--- /dev/null
+++ b/tensorflow/contrib/boosted_trees/lib/utils/example.h
@@ -0,0 +1,50 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLE_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLE_H_
+
+#include <unordered_set>
+#include <vector>
+#include "tensorflow/contrib/boosted_trees/lib/utils/optional_value.h"
+
+namespace tensorflow {
+namespace boosted_trees {
+namespace utils {
+
+// Holds data for one example and enables lookup by feature column.
+struct Example {
+ // Default constructor creates an empty example.
+ Example() : example_idx(-1) {}
+
+ // Example index.
+ int64 example_idx;
+
+ // Dense and sparse float features indexed by feature column.
+ // TODO(salehay): figure out a design to support multivalent float features.
+ std::vector<float> dense_float_features;
+ std::vector<OptionalValue<float>> sparse_float_features;
+
+ // Sparse integer features indexed by feature column.
+ // Note that all integer features are assumed to be categorical, i.e. will
+ // never be compared by order. Also these features can be multivalent.
+ std::vector<std::unordered_set<int64>> sparse_int_features;
+};
+
+} // namespace utils
+} // namespace boosted_trees
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLE_H_
diff --git a/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.cc b/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.cc
new file mode 100644
index 0000000000..c73dc8e15d
--- /dev/null
+++ b/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.cc
@@ -0,0 +1,83 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#include "tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.h"
+
+namespace tensorflow {
+namespace boosted_trees {
+namespace utils {
+
+using Iterator = ExamplesIterable::Iterator;
+
+ExamplesIterable::ExamplesIterable(
+ const std::vector<Tensor>& dense_float_feature_columns,
+ const std::vector<sparse::SparseTensor>& sparse_float_feature_columns,
+ const std::vector<sparse::SparseTensor>& sparse_int_feature_columns,
+ int64 example_start, int64 example_end)
+ : example_start_(example_start), example_end_(example_end) {
+ // Create dense float column values.
+ dense_float_column_values_.reserve(dense_float_feature_columns.size());
+ for (auto& dense_float_column : dense_float_feature_columns) {
+ dense_float_column_values_.emplace_back(
+ dense_float_column.template matrix<float>());
+ }
+
+ // Create sparse float column iterables and values.
+ sparse_float_column_iterables_.reserve(sparse_float_feature_columns.size());
+ sparse_float_column_values_.reserve(sparse_float_feature_columns.size());
+ for (auto& sparse_float_column : sparse_float_feature_columns) {
+ sparse_float_column_iterables_.emplace_back(
+ sparse_float_column.indices().template matrix<int64>(), example_start,
+ example_end);
+ sparse_float_column_values_.emplace_back(
+ sparse_float_column.values().template vec<float>());
+ }
+
+ // Create sparse int column iterables and values.
+ sparse_int_column_iterables_.reserve(sparse_int_feature_columns.size());
+ sparse_int_column_values_.reserve(sparse_int_feature_columns.size());
+ for (auto& sparse_int_column : sparse_int_feature_columns) {
+ sparse_int_column_iterables_.emplace_back(
+ sparse_int_column.indices().template matrix<int64>(), example_start,
+ example_end);
+ sparse_int_column_values_.emplace_back(
+ sparse_int_column.values().template vec<int64>());
+ }
+}
+
+Iterator::Iterator(ExamplesIterable* iter, int64 example_idx)
+ : iter_(iter), example_idx_(example_idx) {
+ // Create sparse iterators.
+ sparse_float_column_iterators_.reserve(
+ iter->sparse_float_column_iterables_.size());
+ for (auto& iterable : iter->sparse_float_column_iterables_) {
+ sparse_float_column_iterators_.emplace_back(iterable.begin());
+ }
+ sparse_int_column_iterators_.reserve(
+ iter->sparse_int_column_iterables_.size());
+ for (auto& iterable : iter->sparse_int_column_iterables_) {
+ sparse_int_column_iterators_.emplace_back(iterable.begin());
+ }
+
+ // Pre-size example features.
+ example_.dense_float_features.resize(
+ iter_->dense_float_column_values_.size());
+ example_.sparse_float_features.resize(
+ iter_->sparse_float_column_values_.size());
+ example_.sparse_int_features.resize(iter_->sparse_int_column_values_.size());
+}
+
+} // namespace utils
+} // namespace boosted_trees
+} // namespace tensorflow
diff --git a/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.h b/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.h
new file mode 100644
index 0000000000..67efb82a22
--- /dev/null
+++ b/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.h
@@ -0,0 +1,172 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLES_ITERABLE_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLES_ITERABLE_H_
+
+#include <vector>
+
+#include "tensorflow/contrib/boosted_trees/lib/utils/example.h"
+#include "tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/util/sparse/sparse_tensor.h"
+
+namespace tensorflow {
+namespace boosted_trees {
+namespace utils {
+
+// Enables row-wise iteration through examples from feature columns.
+class ExamplesIterable {
+ public:
+ // Constructs an iterable given the desired examples slice and corresponding
+ // feature columns.
+ ExamplesIterable(
+ const std::vector<Tensor>& dense_float_feature_columns,
+ const std::vector<sparse::SparseTensor>& sparse_float_feature_columns,
+ const std::vector<sparse::SparseTensor>& sparse_int_feature_columns,
+ int64 example_start, int64 example_end);
+
+ // Helper class to iterate through examples.
+ class Iterator {
+ public:
+ Iterator(ExamplesIterable* iter, int64 example_idx);
+
+ Iterator& operator++() {
+ // Advance to next example.
+ ++example_idx_;
+
+ // Update sparse column iterables.
+ for (auto& it : sparse_float_column_iterators_) {
+ ++it;
+ }
+ for (auto& it : sparse_int_column_iterators_) {
+ ++it;
+ }
+ return (*this);
+ }
+
+ Iterator operator++(int) {
+ Iterator tmp(*this);
+ ++(*this);
+ return tmp;
+ }
+
+ bool operator!=(const Iterator& other) const {
+ QCHECK_EQ(iter_, other.iter_);
+ return (example_idx_ != other.example_idx_);
+ }
+
+ bool operator==(const Iterator& other) const {
+ QCHECK_EQ(iter_, other.iter_);
+ return (example_idx_ == other.example_idx_);
+ }
+
+ const Example& operator*() {
+ // Set example index based on iterator.
+ example_.example_idx = example_idx_;
+
+ // Get dense float values per column.
+ auto& dense_float_features = example_.dense_float_features;
+ for (size_t dense_float_idx = 0;
+ dense_float_idx < dense_float_features.size(); ++dense_float_idx) {
+ dense_float_features[dense_float_idx] =
+ iter_->dense_float_column_values_[dense_float_idx](example_idx_, 0);
+ }
+
+ // Get sparse float values per column.
+ auto& sparse_float_features = example_.sparse_float_features;
+ for (size_t sparse_float_idx = 0;
+ sparse_float_idx < sparse_float_features.size();
+ ++sparse_float_idx) {
+ const auto& row_range =
+ (*sparse_float_column_iterators_[sparse_float_idx]);
+ DCHECK_EQ(example_idx_, row_range.example_idx);
+ if (row_range.start < row_range.end) {
+ DCHECK_EQ(1, row_range.end - row_range.start);
+ sparse_float_features[sparse_float_idx] = OptionalValue<float>(
+ iter_->sparse_float_column_values_[sparse_float_idx](
+ row_range.start));
+ } else {
+ sparse_float_features[sparse_float_idx] = OptionalValue<float>();
+ }
+ }
+
+ // Get sparse int values per column.
+ auto& sparse_int_features = example_.sparse_int_features;
+ for (size_t sparse_int_idx = 0;
+ sparse_int_idx < sparse_int_features.size(); ++sparse_int_idx) {
+ const auto& row_range = (*sparse_int_column_iterators_[sparse_int_idx]);
+ DCHECK_EQ(example_idx_, row_range.example_idx);
+ sparse_int_features[sparse_int_idx].clear();
+ if (row_range.start < row_range.end) {
+ sparse_int_features[sparse_int_idx].reserve(row_range.end -
+ row_range.start);
+ for (int64 row_idx = row_range.start; row_idx < row_range.end;
+ ++row_idx) {
+ sparse_int_features[sparse_int_idx].insert(
+ iter_->sparse_int_column_values_[sparse_int_idx](row_idx));
+ }
+ }
+ }
+
+ return example_;
+ }
+
+ private:
+ // Examples iterable (not owned).
+ const ExamplesIterable* iter_;
+
+ // Example index.
+ int64 example_idx_;
+
+ // Sparse float column iterators.
+ std::vector<SparseColumnIterable::Iterator> sparse_float_column_iterators_;
+
+ // Sparse int column iterators.
+ std::vector<SparseColumnIterable::Iterator> sparse_int_column_iterators_;
+
+ // Example placeholder.
+ Example example_;
+ };
+
+ Iterator begin() { return Iterator(this, example_start_); }
+ Iterator end() { return Iterator(this, example_end_); }
+
+ private:
+ // Example slice spec.
+ const int64 example_start_;
+ const int64 example_end_;
+
+ // Dense float column values.
+ std::vector<TTypes<float>::ConstMatrix> dense_float_column_values_;
+
+ // Sparse float column iterables.
+ std::vector<SparseColumnIterable> sparse_float_column_iterables_;
+
+ // Sparse float column values.
+ std::vector<TTypes<float>::ConstVec> sparse_float_column_values_;
+
+ // Sparse int column iterables.
+ std::vector<SparseColumnIterable> sparse_int_column_iterables_;
+
+ // Sparse int column values.
+ std::vector<TTypes<int64>::ConstVec> sparse_int_column_values_;
+};
+
+} // namespace utils
+} // namespace boosted_trees
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLES_ITERABLE_H_
diff --git a/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable_test.cc b/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable_test.cc
new file mode 100644
index 0000000000..d12618217a
--- /dev/null
+++ b/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable_test.cc
@@ -0,0 +1,182 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#include "tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace boosted_trees {
+namespace utils {
+namespace {
+
+class ExamplesIterableTest : public ::testing::Test {};
+
+TEST_F(ExamplesIterableTest, Iterate) {
+ // Create a batch of 8 examples having one dense float, two sparse float and
+ // two sparse int features.
+ // The data looks like the following:
+ // Instance | DenseF1 | SparseF1 | SparseF2 | SparseI1 | SparseI2 |
+ // 0 | 7 | -3 | | 1, 8 | |
+ // 1 | -2 | | 4 | 0 | 7 |
+ // 2 | 8 | 0 | | | 13 |
+ // 3 | 1 | 5 | 7 | 2, 0 | 4 |
+ // 4 | 0 | 0 | | | 0 |
+ // 5 | -4 | | 9 | | |
+ // 6 | 7 | | | | |
+ // 7 | -2 | | -4 | 5 | |
+ auto dense_float_tensor = test::AsTensor<float>(
+ {7.0f, -2.0f, 8.0f, 1.0f, 0.0f, -4.0f, 7.0f, -2.0f}, {8, 1});
+ auto sparse_float_indices1 =
+ test::AsTensor<int64>({0, 0, 2, 0, 3, 0, 4, 0}, {4, 2});
+ auto sparse_float_values1 = test::AsTensor<float>({-3.0f, 0.0f, 5.0f, 0.0f});
+ auto sparse_float_shape1 = TensorShape({8, 1});
+ sparse::SparseTensor sparse_float_tensor1(
+ sparse_float_indices1, sparse_float_values1, sparse_float_shape1);
+ auto sparse_float_indices2 =
+ test::AsTensor<int64>({1, 0, 3, 0, 5, 0, 7, 0}, {4, 2});
+ auto sparse_float_values2 = test::AsTensor<float>({4.0f, 7.0f, 9.0f, -4.0f});
+ auto sparse_float_shape2 = TensorShape({8, 1});
+ sparse::SparseTensor sparse_float_tensor2(
+ sparse_float_indices2, sparse_float_values2, sparse_float_shape2);
+ auto sparse_int_indices1 =
+ test::AsTensor<int64>({0, 0, 0, 1, 1, 0, 3, 0, 3, 1, 7, 0}, {6, 2});
+ auto sparse_int_values1 = test::AsTensor<int64>({1, 8, 0, 2, 0, 5});
+ auto sparse_int_shape1 = TensorShape({8, 2});
+ sparse::SparseTensor sparse_int_tensor1(
+ sparse_int_indices1, sparse_int_values1, sparse_int_shape1);
+ auto sparse_int_indices2 =
+ test::AsTensor<int64>({1, 0, 2, 0, 3, 0, 4, 0}, {4, 2});
+ auto sparse_int_values2 = test::AsTensor<int64>({7, 13, 4, 0});
+ auto sparse_int_shape2 = TensorShape({8, 1});
+ sparse::SparseTensor sparse_int_tensor2(
+ sparse_int_indices2, sparse_int_values2, sparse_int_shape2);
+
+ auto validate_example_features = [](int64 example_idx,
+ const Example& example) {
+ EXPECT_EQ(1, example.dense_float_features.size());
+ EXPECT_EQ(2, example.sparse_float_features.size());
+
+ switch (example_idx) {
+ case 0: {
+ EXPECT_EQ(0, example.example_idx);
+ EXPECT_EQ(7.0f, example.dense_float_features[0]);
+ EXPECT_TRUE(example.sparse_float_features[0].has_value());
+ EXPECT_EQ(-3.0f, example.sparse_float_features[0].get_value());
+ EXPECT_FALSE(example.sparse_float_features[1].has_value());
+ EXPECT_EQ(2, example.sparse_int_features[0].size());
+ EXPECT_EQ(1, example.sparse_int_features[0].count(1));
+ EXPECT_EQ(1, example.sparse_int_features[0].count(8));
+ EXPECT_EQ(0, example.sparse_int_features[1].size());
+ } break;
+ case 1: {
+ EXPECT_EQ(1, example.example_idx);
+ EXPECT_EQ(-2.0f, example.dense_float_features[0]);
+ EXPECT_FALSE(example.sparse_float_features[0].has_value());
+ EXPECT_TRUE(example.sparse_float_features[1].has_value());
+ EXPECT_EQ(4.0f, example.sparse_float_features[1].get_value());
+ EXPECT_EQ(1, example.sparse_int_features[0].size());
+ EXPECT_EQ(1, example.sparse_int_features[0].count(0));
+ EXPECT_EQ(1, example.sparse_int_features[1].size());
+ EXPECT_EQ(1, example.sparse_int_features[1].count(7));
+ } break;
+ case 2: {
+ EXPECT_EQ(2, example.example_idx);
+ EXPECT_EQ(8.0f, example.dense_float_features[0]);
+ EXPECT_TRUE(example.sparse_float_features[0].has_value());
+ EXPECT_EQ(0.0f, example.sparse_float_features[0].get_value());
+ EXPECT_FALSE(example.sparse_float_features[1].has_value());
+ EXPECT_EQ(0, example.sparse_int_features[0].size());
+ EXPECT_EQ(1, example.sparse_int_features[1].size());
+ EXPECT_EQ(1, example.sparse_int_features[1].count(13));
+ } break;
+ case 3: {
+ EXPECT_EQ(3, example.example_idx);
+ EXPECT_EQ(1.0f, example.dense_float_features[0]);
+ EXPECT_TRUE(example.sparse_float_features[0].has_value());
+ EXPECT_EQ(5.0f, example.sparse_float_features[0].get_value());
+ EXPECT_TRUE(example.sparse_float_features[1].has_value());
+ EXPECT_EQ(7.0f, example.sparse_float_features[1].get_value());
+ EXPECT_EQ(2, example.sparse_int_features[0].size());
+ EXPECT_EQ(1, example.sparse_int_features[0].count(2));
+ EXPECT_EQ(1, example.sparse_int_features[0].count(0));
+ EXPECT_EQ(1, example.sparse_int_features[1].size());
+ EXPECT_EQ(1, example.sparse_int_features[1].count(4));
+ } break;
+ case 4: {
+ EXPECT_EQ(4, example.example_idx);
+ EXPECT_EQ(0.0f, example.dense_float_features[0]);
+ EXPECT_TRUE(example.sparse_float_features[0].has_value());
+ EXPECT_EQ(0.0f, example.sparse_float_features[0].get_value());
+ EXPECT_FALSE(example.sparse_float_features[1].has_value());
+ EXPECT_EQ(0, example.sparse_int_features[0].size());
+ EXPECT_EQ(1, example.sparse_int_features[1].size());
+ EXPECT_EQ(1, example.sparse_int_features[1].count(0));
+ } break;
+ case 5: {
+ EXPECT_EQ(5, example.example_idx);
+ EXPECT_EQ(-4.0f, example.dense_float_features[0]);
+ EXPECT_FALSE(example.sparse_float_features[0].has_value());
+ EXPECT_TRUE(example.sparse_float_features[1].has_value());
+ EXPECT_EQ(9.0f, example.sparse_float_features[1].get_value());
+ EXPECT_EQ(0, example.sparse_int_features[0].size());
+ } break;
+ case 6: {
+ EXPECT_EQ(6, example.example_idx);
+ EXPECT_EQ(7.0f, example.dense_float_features[0]);
+ EXPECT_FALSE(example.sparse_float_features[0].has_value());
+ EXPECT_FALSE(example.sparse_float_features[1].has_value());
+ EXPECT_EQ(0, example.sparse_int_features[0].size());
+ } break;
+ case 7: {
+ EXPECT_EQ(7, example.example_idx);
+ EXPECT_EQ(-2.0f, example.dense_float_features[0]);
+ EXPECT_FALSE(example.sparse_float_features[0].has_value());
+ EXPECT_TRUE(example.sparse_float_features[1].has_value());
+ EXPECT_EQ(-4.0f, example.sparse_float_features[1].get_value());
+ EXPECT_EQ(1, example.sparse_int_features[0].size());
+ EXPECT_EQ(1, example.sparse_int_features[0].count(5));
+ } break;
+ default: { QCHECK(false) << "Invalid example index."; } break;
+ }
+ };
+
+ // Iterate through all examples sequentially.
+ ExamplesIterable full_iterable(
+ {dense_float_tensor}, {sparse_float_tensor1, sparse_float_tensor2},
+ {sparse_int_tensor1, sparse_int_tensor2}, 0, 8);
+ int64 example_idx = 0;
+ for (const auto& example : full_iterable) {
+ validate_example_features(example_idx, example);
+ ++example_idx;
+ }
+ EXPECT_EQ(8, example_idx);
+
+ // Iterate through slice (2, 6) of examples.
+ ExamplesIterable slice_iterable(
+ {dense_float_tensor}, {sparse_float_tensor1, sparse_float_tensor2},
+ {sparse_int_tensor1, sparse_int_tensor2}, 2, 6);
+ example_idx = 2;
+ for (const auto& example : slice_iterable) {
+ validate_example_features(example_idx, example);
+ ++example_idx;
+ }
+ EXPECT_EQ(6, example_idx);
+}
+
+} // namespace
+} // namespace utils
+} // namespace boosted_trees
+} // namespace tensorflow
diff --git a/tensorflow/contrib/boosted_trees/lib/utils/macros.h b/tensorflow/contrib/boosted_trees/lib/utils/macros.h
new file mode 100644
index 0000000000..28ea0a4dc1
--- /dev/null
+++ b/tensorflow/contrib/boosted_trees/lib/utils/macros.h
@@ -0,0 +1,26 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_MACROS_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_MACROS_H_
+
+#include "tensorflow/core/platform/macros.h"
+
+#define TF_CHECK_AND_RETURN_IF_ERROR(EXP, STATUS) \
+ if (!TF_PREDICT_TRUE(EXP)) { \
+ return (STATUS); \
+ }
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_MACROS_H_
diff --git a/tensorflow/contrib/boosted_trees/lib/utils/optional_value.h b/tensorflow/contrib/boosted_trees/lib/utils/optional_value.h
new file mode 100644
index 0000000000..c141fe059d
--- /dev/null
+++ b/tensorflow/contrib/boosted_trees/lib/utils/optional_value.h
@@ -0,0 +1,47 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_OPTIONAL_VALUE_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_OPTIONAL_VALUE_H_
+
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+namespace boosted_trees {
+namespace utils {
+
+// Utility class holding an optional value.
+template <typename T>
+class OptionalValue {
+ public:
+ OptionalValue() : value_(), has_value_(false) {}
+ explicit OptionalValue(T value) : value_(value), has_value_(true) {}
+
+ bool has_value() const { return has_value_; }
+ const T& get_value() const {
+ QCHECK(has_value());
+ return value_;
+ }
+
+ private:
+ T value_;
+ bool has_value_;
+};
+
+} // namespace utils
+} // namespace boosted_trees
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_OPTIONAL_VALUE_H_
diff --git a/tensorflow/contrib/boosted_trees/lib/utils/parallel_for.cc b/tensorflow/contrib/boosted_trees/lib/utils/parallel_for.cc
new file mode 100644
index 0000000000..b00d80b522
--- /dev/null
+++ b/tensorflow/contrib/boosted_trees/lib/utils/parallel_for.cc
@@ -0,0 +1,51 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#include "tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h"
+#include "tensorflow/core/lib/core/blocking_counter.h"
+
+namespace tensorflow {
+namespace boosted_trees {
+namespace utils {
+
+void ParallelFor(int64 batch_size, int64 desired_parallelism,
+ thread::ThreadPool* thread_pool,
+ std::function<void(int64, int64)> do_work) {
+ // Parallelize work over the batch.
+ if (desired_parallelism <= 0) {
+ do_work(0, batch_size);
+ return;
+ }
+ const int num_shards = std::max<int>(
+ 1, std::min(static_cast<int64>(desired_parallelism), batch_size));
+ const int64 block_size = (batch_size + num_shards - 1) / num_shards;
+ CHECK_GT(block_size, 0);
+ const int num_shards_used = (batch_size + block_size - 1) / block_size;
+ BlockingCounter counter(num_shards_used - 1);
+ for (int64 start = block_size; start < batch_size; start += block_size) {
+ auto end = std::min(start + block_size, batch_size);
+ thread_pool->Schedule([&do_work, &counter, start, end]() {
+ do_work(start, end);
+ counter.DecrementCount();
+ });
+ }
+
+ // Execute first shard on main thread.
+ do_work(0, std::min(block_size, batch_size));
+ counter.Wait();
+}
+
+} // namespace utils
+} // namespace boosted_trees
+} // namespace tensorflow
diff --git a/tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h b/tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h
new file mode 100644
index 0000000000..c80431b558
--- /dev/null
+++ b/tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h
@@ -0,0 +1,33 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LIB_UTILS_PARALLEL_FOR_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LIB_UTILS_PARALLEL_FOR_H_
+
+#include "tensorflow/core/lib/core/threadpool.h"
+
+namespace tensorflow {
+namespace boosted_trees {
+namespace utils {
+
+// Executes a parallel for over the batch for the desired parallelism level.
+void ParallelFor(int64 batch_size, int64 desired_parallelism,
+ thread::ThreadPool* thread_pool,
+ std::function<void(int64, int64)> do_work);
+
+} // namespace utils
+} // namespace boosted_trees
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LIB_UTILS_PARALLEL_FOR_H_
diff --git a/tensorflow/contrib/boosted_trees/lib/utils/random.h b/tensorflow/contrib/boosted_trees/lib/utils/random.h
new file mode 100644
index 0000000000..6dd55fcacc
--- /dev/null
+++ b/tensorflow/contrib/boosted_trees/lib/utils/random.h
@@ -0,0 +1,39 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LIB_UTILS_RANDOM_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LIB_UTILS_RANDOM_H_
+
+#include "tensorflow/core/lib/random/simple_philox.h"
+
+namespace tensorflow {
+namespace boosted_trees {
+namespace utils {
+
+// Generates a poisson distributed number with mean 1 for use in bootstrapping.
+inline int32 PoissonBootstrap(random::SimplePhilox* rng) {
+ // Knuth, special cased for lambda = 1.0 for efficiency.
+ static const float lbound = exp(-1.0f);
+ int32 n = 0;
+ for (float r = 1; r > lbound; r *= rng->RandFloat()) {
+ ++n;
+ }
+ return n - 1;
+}
+
+} // namespace utils
+} // namespace boosted_trees
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LIB_UTILS_RANDOM_H_
diff --git a/tensorflow/contrib/boosted_trees/lib/utils/random_test.cc b/tensorflow/contrib/boosted_trees/lib/utils/random_test.cc
new file mode 100644
index 0000000000..51162f410e
--- /dev/null
+++ b/tensorflow/contrib/boosted_trees/lib/utils/random_test.cc
@@ -0,0 +1,56 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#include "tensorflow/contrib/boosted_trees/lib/utils/random.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace boosted_trees {
+namespace utils {
+namespace {
+
+TEST(RandomTest, Poisson) {
+ random::PhiloxRandom philox(77L);
+ random::SimplePhilox rng(&philox);
+ for (int trial = 0; trial < 10; ++trial) {
+ const int32 num_bootstrap = 10000;
+ double sum = 0;
+ double zeros = 0;
+ double ones = 0;
+ for (int i = 0; i < num_bootstrap; ++i) {
+ auto n = PoissonBootstrap(&rng);
+ sum += n;
+ zeros += (n == 0) ? 1 : 0;
+ ones += (n == 1) ? 1 : 0;
+ }
+
+ // Ensure mean is near expected value.
+ const double expected_mean = 1.0; // lambda
+ const double mean_std_error = 1.0 / sqrt(num_bootstrap);
+ double mean = sum / num_bootstrap;
+ EXPECT_NEAR(mean, expected_mean, 3 * mean_std_error);
+
+ // Ensure probability mass for values 0 and 1 are near expected value.
+ const double expected_p = 0.368;
+ const double proportion_std_error =
+ sqrt(expected_p * (1 - expected_p) / num_bootstrap);
+ EXPECT_NEAR(zeros / num_bootstrap, expected_p, 3 * proportion_std_error);
+ EXPECT_NEAR(ones / num_bootstrap, expected_p, 3 * proportion_std_error);
+ }
+}
+
+} // namespace
+} // namespace utils
+} // namespace boosted_trees
+} // namespace tensorflow
diff --git a/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.cc b/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.cc
new file mode 100644
index 0000000000..21df5d13ff
--- /dev/null
+++ b/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.cc
@@ -0,0 +1,122 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.h"
+
+namespace tensorflow {
+namespace boosted_trees {
+namespace utils {
+
+using ExampleRowRange = SparseColumnIterable::ExampleRowRange;
+using Iterator = SparseColumnIterable::Iterator;
+
+namespace {
+
+// Iterator over indices matrix rows.
+class IndicesRowIterator
+ : public std::iterator<std::random_access_iterator_tag, const int64> {
+ public:
+ IndicesRowIterator() : iter_(nullptr), row_idx_(-1) {}
+ IndicesRowIterator(SparseColumnIterable* iter, int row_idx)
+ : iter_(iter), row_idx_(row_idx) {}
+ IndicesRowIterator(const IndicesRowIterator& other)
+ : iter_(other.iter_), row_idx_(other.row_idx_) {}
+
+ IndicesRowIterator& operator=(const IndicesRowIterator& other) {
+ iter_ = other.iter_;
+ row_idx_ = other.row_idx_;
+ return (*this);
+ }
+
+ IndicesRowIterator& operator++() {
+ ++row_idx_;
+ return (*this);
+ }
+
+ IndicesRowIterator operator++(int) {
+ IndicesRowIterator tmp(*this);
+ ++row_idx_;
+ return tmp;
+ }
+
+ reference operator*() { return iter_->ix()(row_idx_, 0); }
+
+ pointer operator->() { return &iter_->ix()(row_idx_, 0); }
+
+ IndicesRowIterator& operator--() {
+ --row_idx_;
+ return (*this);
+ }
+
+ IndicesRowIterator operator--(int) {
+ IndicesRowIterator tmp(*this);
+ --row_idx_;
+ return tmp;
+ }
+
+ IndicesRowIterator& operator+=(const difference_type& step) {
+ row_idx_ += step;
+ return (*this);
+ }
+ IndicesRowIterator& operator-=(const difference_type& step) {
+ row_idx_ -= step;
+ return (*this);
+ }
+
+ IndicesRowIterator operator+(const difference_type& step) const {
+ IndicesRowIterator tmp(*this);
+ tmp += step;
+ return tmp;
+ }
+
+ IndicesRowIterator operator-(const difference_type& step) const {
+ IndicesRowIterator tmp(*this);
+ tmp -= step;
+ return tmp;
+ }
+
+ difference_type operator-(const IndicesRowIterator& other) {
+ return row_idx_ - other.row_idx_;
+ }
+
+ bool operator!=(const IndicesRowIterator& other) const {
+ QCHECK_EQ(iter_, other.iter_);
+ return (row_idx_ != other.row_idx_);
+ }
+
+ bool operator==(const IndicesRowIterator& other) const {
+ QCHECK_EQ(iter_, other.iter_);
+ return (row_idx_ == other.row_idx_);
+ }
+
+ Eigen::Index row_idx() const { return row_idx_; }
+
+ private:
+ SparseColumnIterable* iter_;
+ Eigen::Index row_idx_;
+};
+} // namespace
+
+Iterator::Iterator(SparseColumnIterable* iter, int64 example_idx)
+ : iter_(iter), example_idx_(example_idx), end_(iter->ix_.dimension(0)) {
+ cur_ = next_ = std::lower_bound(IndicesRowIterator(iter, 0),
+ IndicesRowIterator(iter, end_), example_idx_)
+ .row_idx();
+ UpdateNext();
+}
+
+} // namespace utils
+} // namespace boosted_trees
+} // namespace tensorflow
diff --git a/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.h b/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.h
new file mode 100644
index 0000000000..78a5752730
--- /dev/null
+++ b/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.h
@@ -0,0 +1,128 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_SPARSE_COLUMN_ITERABLE_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_SPARSE_COLUMN_ITERABLE_H_
+
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace boosted_trees {
+namespace utils {
+
+// Enables row-wise iteration through examples on sparse feature columns.
+class SparseColumnIterable {
+ public:
+ // Indicates a contiguous range for an example: [start, end).
+ struct ExampleRowRange {
+ int64 example_idx;
+ int64 start;
+ int64 end;
+ };
+
+ // Helper class to iterate through examples and return the corresponding
+ // indices row range. Note that the row range can be empty in case a given
+ // example has no corresponding indices.
+ // An Iterator can be initialized from any example start offset, the
+ // corresponding range indicators will be initialized in log time.
+ class Iterator {
+ public:
+ Iterator(SparseColumnIterable* iter, int64 example_idx);
+
+ Iterator& operator++() {
+ ++example_idx_;
+ if (cur_ < end_ && iter_->ix()(cur_, 0) < example_idx_) {
+ cur_ = next_;
+ UpdateNext();
+ }
+ return (*this);
+ }
+
+ Iterator operator++(int) {
+ Iterator tmp(*this);
+ ++(*this);
+ return tmp;
+ }
+
+ bool operator!=(const Iterator& other) const {
+ QCHECK_EQ(iter_, other.iter_);
+ return (example_idx_ != other.example_idx_);
+ }
+
+ bool operator==(const Iterator& other) const {
+ QCHECK_EQ(iter_, other.iter_);
+ return (example_idx_ == other.example_idx_);
+ }
+
+ const ExampleRowRange& operator*() {
+ range_.example_idx = example_idx_;
+ if (cur_ < end_ && iter_->ix()(cur_, 0) == example_idx_) {
+ range_.start = cur_;
+ range_.end = next_;
+ } else {
+ range_.start = 0;
+ range_.end = 0;
+ }
+ return range_;
+ }
+
+ private:
+ void UpdateNext() {
+ next_ = std::min(next_ + 1, end_);
+ while (next_ < end_ && iter_->ix()(cur_, 0) == iter_->ix()(next_, 0)) {
+ ++next_;
+ }
+ }
+
+ const SparseColumnIterable* iter_;
+ int64 example_idx_;
+ int64 cur_;
+ int64 next_;
+ const int64 end_;
+ ExampleRowRange range_;
+ };
+
+ // Constructs an iterable given the desired examples slice and corresponding
+ // feature columns.
+ SparseColumnIterable(TTypes<int64>::ConstMatrix ix, int64 example_start,
+ int64 example_end)
+ : ix_(ix), example_start_(example_start), example_end_(example_end) {
+ QCHECK(example_start >= 0 && example_end >= 0);
+ }
+
+ Iterator begin() { return Iterator(this, example_start_); }
+ Iterator end() { return Iterator(this, example_end_); }
+
+ const TTypes<int64>::ConstMatrix& ix() const { return ix_; }
+ int64 example_start() const { return example_start_; }
+ int64 example_end() const { return example_end_; }
+
+ private:
+ // Sparse indices matrix.
+ TTypes<int64>::ConstMatrix ix_;
+
+ // Example slice spec.
+ const int64 example_start_;
+ const int64 example_end_;
+};
+
+} // namespace utils
+} // namespace boosted_trees
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_SPARSE_COLUMN_ITERABLE_H_
diff --git a/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable_test.cc b/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable_test.cc
new file mode 100644
index 0000000000..7792bd8c66
--- /dev/null
+++ b/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable_test.cc
@@ -0,0 +1,100 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace boosted_trees {
+namespace utils {
+namespace {
+
+using test::AsTensor;
+using ExampleRowRange = SparseColumnIterable::ExampleRowRange;
+
+class SparseColumnIterableTest : public ::testing::Test {};
+
+TEST_F(SparseColumnIterableTest, Empty) {
+ const auto indices = Tensor(DT_INT64, {0, 2});
+ SparseColumnIterable iterable(indices.template matrix<int64>(), 0, 0);
+ EXPECT_EQ(iterable.begin(), iterable.end());
+}
+
+TEST_F(SparseColumnIterableTest, Iterate) {
+ // 8 examples having 7 sparse features with the third multi-valent.
+ // This can be visualized like the following:
+ // Instance | Sparse |
+ // 0 | x |
+ // 1 | |
+ // 2 | |
+ // 3 | xxx |
+ // 4 | x |
+ // 5 | |
+ // 6 | |
+ // 7 | xx |
+ const auto indices =
+ AsTensor<int64>({0, 0, 3, 0, 3, 1, 3, 2, 4, 0, 7, 0, 7, 1}, {7, 2});
+
+ auto validate_example_range = [](const ExampleRowRange& range) {
+ switch (range.example_idx) {
+ case 0: {
+ EXPECT_EQ(0, range.start);
+ EXPECT_EQ(1, range.end);
+ } break;
+ case 3: {
+ EXPECT_EQ(1, range.start);
+ EXPECT_EQ(4, range.end);
+ } break;
+ case 4: {
+ EXPECT_EQ(4, range.start);
+ EXPECT_EQ(5, range.end);
+ } break;
+ case 7: {
+ EXPECT_EQ(5, range.start);
+ EXPECT_EQ(7, range.end);
+ } break;
+ default: {
+ // Empty examples.
+ EXPECT_GE(range.start, range.end);
+ } break;
+ }
+ };
+
+ // Iterate through all examples sequentially.
+ SparseColumnIterable full_iterable(indices.template matrix<int64>(), 0, 8);
+ int64 expected_example_idx = 0;
+ for (const ExampleRowRange& range : full_iterable) {
+ EXPECT_EQ(expected_example_idx, range.example_idx);
+ validate_example_range(range);
+ ++expected_example_idx;
+ }
+ EXPECT_EQ(8, expected_example_idx);
+
+ // Iterate through slice (2, 6) of examples.
+ SparseColumnIterable slice_iterable(indices.template matrix<int64>(), 2, 6);
+ expected_example_idx = 2;
+ for (const ExampleRowRange& range : slice_iterable) {
+ EXPECT_EQ(expected_example_idx, range.example_idx);
+ validate_example_range(range);
+ ++expected_example_idx;
+ }
+ EXPECT_EQ(6, expected_example_idx);
+}
+
+} // namespace
+} // namespace utils
+} // namespace boosted_trees
+} // namespace tensorflow
diff --git a/tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.cc b/tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.cc
new file mode 100644
index 0000000000..be2f787fd8
--- /dev/null
+++ b/tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.cc
@@ -0,0 +1,103 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.h"
+#include "tensorflow/contrib/boosted_trees/lib/utils/macros.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+
+namespace tensorflow {
+namespace boosted_trees {
+namespace utils {
+
+std::vector<Tensor> TensorUtils::OpInputListToTensorVec(
+ const OpInputList& input_list) {
+ std::vector<Tensor> tensor_vec;
+ tensor_vec.reserve(input_list.size());
+ for (const Tensor& tensor : input_list) {
+ tensor_vec.emplace_back(tensor);
+ }
+ return tensor_vec;
+}
+
+Status TensorUtils::ReadDenseFloatFeatures(OpKernelContext* const context,
+ OpInputList* features_list) {
+ // Constants.
+ constexpr auto kDenseFloatFeaturesName = "dense_float_features";
+
+ // Read dense float features list;
+ TF_RETURN_IF_ERROR(
+ context->input_list(kDenseFloatFeaturesName, features_list));
+ return Status::OK();
+}
+
+Status TensorUtils::ReadSparseFloatFeatures(OpKernelContext* const context,
+ OpInputList* features_indices_list,
+ OpInputList* feature_values_list,
+ OpInputList* feature_shapes_list) {
+ // Constants.
+ constexpr auto kSparseFloatFeatureIndicesName =
+ "sparse_float_feature_indices";
+ constexpr auto kSparseFloatFeatureValuesName = "sparse_float_feature_values";
+ constexpr auto kSparseFloatFeatureShapesName = "sparse_float_feature_shapes";
+
+ // Read sparse float features list;
+ TF_RETURN_IF_ERROR(context->input_list(kSparseFloatFeatureIndicesName,
+ features_indices_list));
+ TF_RETURN_IF_ERROR(
+ context->input_list(kSparseFloatFeatureValuesName, feature_values_list));
+ TF_RETURN_IF_ERROR(
+ context->input_list(kSparseFloatFeatureShapesName, feature_shapes_list));
+ return Status::OK();
+}
+
+Status TensorUtils::ReadSparseIntFeatures(OpKernelContext* const context,
+ OpInputList* features_indices_list,
+ OpInputList* feature_values_list,
+ OpInputList* feature_shapes_list) {
+ // Constants.
+ constexpr auto kSparseIntFeatureIndicesName = "sparse_int_feature_indices";
+ constexpr auto kSparseIntFeatureValuesName = "sparse_int_feature_values";
+ constexpr auto kSparseIntFeatureShapesName = "sparse_int_feature_shapes";
+
+ // Read sparse int features list;
+ TF_RETURN_IF_ERROR(
+ context->input_list(kSparseIntFeatureIndicesName, features_indices_list));
+ TF_RETURN_IF_ERROR(
+ context->input_list(kSparseIntFeatureValuesName, feature_values_list));
+ TF_RETURN_IF_ERROR(
+ context->input_list(kSparseIntFeatureShapesName, feature_shapes_list));
+ return Status::OK();
+}
+
+int64 TensorUtils::InferBatchSize(
+ const OpInputList& dense_float_features_list,
+ const OpInputList& sparse_float_feature_shapes_list,
+ const OpInputList& sparse_int_feature_shapes_list) {
+ if (dense_float_features_list.size() > 0) {
+ return dense_float_features_list[0].dim_size(0);
+ }
+ if (sparse_float_feature_shapes_list.size() > 0) {
+ return sparse_float_feature_shapes_list[0].flat<int64>()(0);
+ }
+ if (sparse_int_feature_shapes_list.size() > 0) {
+ return sparse_int_feature_shapes_list[0].flat<int64>()(0);
+ }
+ QCHECK(false) << "Could not infer batch size due to empty feature set.";
+}
+
+} // namespace utils
+} // namespace boosted_trees
+} // namespace tensorflow
diff --git a/tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.h b/tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.h
new file mode 100644
index 0000000000..58f5e5a0d1
--- /dev/null
+++ b/tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.h
@@ -0,0 +1,60 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_TENSOR_UTILS_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_TENSOR_UTILS_H_
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_types.h"
+
+namespace tensorflow {
+namespace boosted_trees {
+namespace utils {
+
+class TensorUtils {
+ public:
+ // Read an input list into a vector of tensors.
+ static std::vector<Tensor> OpInputListToTensorVec(
+ const OpInputList& input_list);
+
+ // Reads the dense float features input list.
+ static Status ReadDenseFloatFeatures(OpKernelContext* const context,
+ OpInputList* features_list);
+
+ // Reads the sparse float features input list.
+ static Status ReadSparseFloatFeatures(OpKernelContext* const context,
+ OpInputList* features_indices_list,
+ OpInputList* feature_values_list,
+ OpInputList* feature_shapes_list);
+
+ // Reads the sparse int features input list.
+ static Status ReadSparseIntFeatures(OpKernelContext* const context,
+ OpInputList* features_indices_list,
+ OpInputList* feature_values_list,
+ OpInputList* feature_shapes_list);
+
+ // Infers the batch size by looking at the op input features.
+ static int64 InferBatchSize(
+ const OpInputList& dense_float_features_list,
+ const OpInputList& sparse_float_feature_shapes_list,
+ const OpInputList& sparse_int_feature_shapes_list);
+};
+
+} // namespace utils
+} // namespace boosted_trees
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_TENSOR_UTILS_H_
diff --git a/tensorflow/contrib/boosted_trees/proto/BUILD b/tensorflow/contrib/boosted_trees/proto/BUILD
new file mode 100644
index 0000000000..3b6b0339d2
--- /dev/null
+++ b/tensorflow/contrib/boosted_trees/proto/BUILD
@@ -0,0 +1,32 @@
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+tf_proto_library(
+ name = "learner_proto",
+ srcs = [
+ "learner.proto",
+ ],
+ cc_api_version = 2,
+ visibility = ["//visibility:public"],
+)
+
+tf_proto_library(
+ name = "tree_config_proto",
+ srcs = ["tree_config.proto"],
+ cc_api_version = 2,
+ visibility = ["//visibility:public"],
+)
diff --git a/tensorflow/contrib/boosted_trees/proto/learner.proto b/tensorflow/contrib/boosted_trees/proto/learner.proto
new file mode 100644
index 0000000000..06ee223467
--- /dev/null
+++ b/tensorflow/contrib/boosted_trees/proto/learner.proto
@@ -0,0 +1,136 @@
+syntax = "proto3";
+
+option cc_enable_arenas = true;
+
+package tensorflow.boosted_trees.learner;
+
+// Tree regularization config.
+message TreeRegularizationConfig {
+ // Classic L1/L2.
+ float l1 = 1;
+ float l2 = 2;
+
+ // Tree complexity penalizes overall model complexity effectively
+ // limiting how deep the tree can grow in regions with small gain.
+ float tree_complexity = 3;
+}
+
+// Tree constraints config.
+message TreeConstraintsConfig {
+ // Maximum depth of the trees.
+ uint32 max_tree_depth = 1;
+
+ // Min hessian weight per node.
+ float min_node_weight = 2;
+}
+
+// LearningRateConfig describes all supported learning rate tuners.
+message LearningRateConfig {
+ oneof tuner {
+ LearningRateFixedConfig fixed = 1;
+ LearningRateDropoutDrivenConfig dropout = 2;
+ LearningRateLineSearchConfig line_search = 3;
+ }
+}
+
+// Config for a fixed learning rate.
+message LearningRateFixedConfig {
+ float learning_rate = 1;
+}
+
+// Config for a tuned learning rate.
+message LearningRateLineSearchConfig {
+ // Max learning rate. Must be strictly positive.
+ float max_learning_rate = 1;
+
+ // Number of learning rate values to consider between [0, max_learning_rate).
+ int32 num_steps = 2;
+}
+
+// When we have a sequence of trees 1, 2, 3 ... n, these essentially represent
+// weights updates in functional space, and thus we can use averaging of weight
+// updates to achieve better performance. For example, we can say that our final
+// ensemble will be an average of ensembles of tree 1, and ensemble of tree 1
+// and tree 2 etc .. ensemble of all trees.
+// Note that this averaging will apply ONLY DURING PREDICTION. The training
+// stays the same.
+message AveragingConfig {
+ oneof config {
+ float average_last_n_trees = 1;
+ // Between 0 and 1. If set to 1.0, we are averaging ensembles of tree 1,
+ // ensemble of tree 1 and tree 2, etc ensemble of all trees. If set to 0.5,
+ // last half of the trees are averaged etc.
+ float average_last_percent_trees = 2;
+ }
+}
+
+message LearningRateDropoutDrivenConfig {
+ // Probability of dropping each tree in an existing so far ensemble.
+ float dropout_probability = 1;
+
+ // When trees are built after dropout happen, they don't "advance" to the
+ // optimal solution, they just rearrange the path. However you can still
+ // choose to skip dropout periodically, to allow a new tree that "advances"
+ // to be added.
+ // For example, if running for 200 steps with probability of dropout 1/100,
+ // you would expect the dropout to start happening for sure for all iterations
+ // after 100. However you can add probability_of_skipping_dropout of 0.1, this
+ // way iterations 100-200 will include approx 90 iterations of dropout and 10
+ // iterations of normal steps.Set it to 0 if you want just keep building
+ // the refinement trees after dropout kicks in.
+ float probability_of_skipping_dropout = 2;
+
+ // Between 0 and 1.
+ float learning_rate = 3;
+}
+
+message LearnerConfig {
+ enum PruningMode {
+ PRE_PRUNE = 0;
+ POST_PRUNE = 1;
+ }
+
+ enum GrowingMode {
+ WHOLE_TREE = 0;
+ // Layer by layer is only supported by the batch learner.
+ LAYER_BY_LAYER = 1;
+ }
+
+ enum MultiClassStrategy {
+ TREE_PER_CLASS = 0;
+ FULL_HESSIAN = 1;
+ DIAGONAL_HESSIAN = 2;
+ }
+
+ // Number of classes.
+ uint32 num_classes = 1;
+
+ // Fraction of features to consider in each tree sampled randomly
+ // from all available features.
+ oneof feature_fraction {
+ float feature_fraction_per_tree = 2;
+ float feature_fraction_per_level = 3;
+ };
+
+ // Regularization.
+ TreeRegularizationConfig regularization = 4;
+
+ // Constraints.
+ TreeConstraintsConfig constraints = 5;
+
+ // Pruning.
+ PruningMode pruning_mode = 8;
+
+ // Growing Mode.
+ GrowingMode growing_mode = 9;
+
+ // Learning rate.
+ LearningRateConfig learning_rate_tuner = 6;
+
+ // Multi-class strategy.
+ MultiClassStrategy multi_class_strategy = 10;
+
+ // If you want to average the ensembles (for regularization), provide the
+ // config below.
+ AveragingConfig averaging_config = 11;
+}
diff --git a/tensorflow/contrib/boosted_trees/proto/tree_config.proto b/tensorflow/contrib/boosted_trees/proto/tree_config.proto
new file mode 100644
index 0000000000..3daa613b5d
--- /dev/null
+++ b/tensorflow/contrib/boosted_trees/proto/tree_config.proto
@@ -0,0 +1,109 @@
+syntax = "proto3";
+option cc_enable_arenas = true;
+
+package tensorflow.boosted_trees.trees;
+
+// TreeNode describes a node in a tree.
+message TreeNode {
+ oneof node {
+ Leaf leaf = 1;
+ DenseFloatBinarySplit dense_float_binary_split = 2;
+ SparseFloatBinarySplitDefaultLeft sparse_float_binary_split_default_left =
+ 3;
+ SparseFloatBinarySplitDefaultRight sparse_float_binary_split_default_right =
+ 4;
+ CategoricalIdBinarySplit categorical_id_binary_split = 5;
+ }
+ TreeNodeMetadata node_metadata = 777;
+}
+
+// TreeNodeMetadata encodes metadata associated with each node in a tree.
+message TreeNodeMetadata {
+ // The gain associated with this node.
+ float gain = 1;
+
+ // The original leaf node before this node was split.
+ Leaf original_leaf = 2;
+}
+
+// Leaves can either hold dense or sparse information.
+message Leaf {
+ oneof leaf {
+ // See learning/decision_trees/proto/generic_tree_model.proto?l=133
+ // for a description of how vector and sparse_vector might be used.
+ Vector vector = 1;
+ SparseVector sparse_vector = 2;
+ }
+}
+
+message Vector {
+ repeated float value = 1;
+}
+
+message SparseVector {
+ repeated int32 index = 1;
+ repeated float value = 2;
+}
+
+// Split rule for dense float features.
+message DenseFloatBinarySplit {
+ // Float feature column and split threshold describing
+ // the rule feature <= threshold.
+ int32 feature_column = 1;
+ float threshold = 2;
+
+ // Node children indexing into a contiguous
+ // vector of nodes starting from the root.
+ int32 left_id = 3;
+ int32 right_id = 4;
+}
+
+// Split rule for sparse float features defaulting left for missing features.
+message SparseFloatBinarySplitDefaultLeft {
+ DenseFloatBinarySplit split = 1;
+}
+
+// Split rule for sparse float features defaulting right for missing features.
+message SparseFloatBinarySplitDefaultRight {
+ DenseFloatBinarySplit split = 1;
+}
+
+// Split rule for categorical features with a single feature Id.
+message CategoricalIdBinarySplit {
+ // Categorical feature column and Id describing
+ // the rule feature == Id.
+ int32 feature_column = 1;
+ int64 feature_id = 2;
+
+ // Node children indexing into a contiguous
+ // vector of nodes starting from the root.
+ int32 left_id = 3;
+ int32 right_id = 4;
+}
+
+// DecisionTreeConfig describes a list of connected nodes.
+// Node 0 must be the root and can carry any payload including a leaf
+// in the case of representing the bias.
+// Note that each node id is implicitly its index in the list of nodes.
+message DecisionTreeConfig {
+ repeated TreeNode nodes = 1;
+}
+
+message DecisionTreeMetadata {
+ // How many times tree weight was updated (due to reweighting of the final
+ // ensemble, dropout, shrinkage etc).
+ int32 num_tree_weight_updates = 1;
+
+ // Number of layers grown for this tree.
+ int32 num_layers_grown = 2;
+
+ // Whether the tree is finalized in that no more layers can be grown.
+ bool is_finalized = 3;
+}
+
+// DecisionTreeEnsembleConfig describes an ensemble of decision trees.
+message DecisionTreeEnsembleConfig {
+ repeated DecisionTreeConfig trees = 1;
+ repeated float tree_weights = 2;
+ repeated DecisionTreeMetadata tree_metadata = 3;
+}