aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesBucketize.pbtxt34
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesCreateQuantileStreamResource.pbtxt29
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesMakeQuantileSummaries.pbtxt40
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceAddSummaries.pbtxt22
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceFlush.pbtxt31
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceGetBucketBoundaries.pbtxt27
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceHandleOp.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_IsBoostedTreesQuantileStreamResourceInitialized.pbtxt20
-rw-r--r--tensorflow/core/kernels/boosted_trees/BUILD16
-rw-r--r--tensorflow/core/kernels/boosted_trees/quantile_ops.cc453
-rw-r--r--tensorflow/core/kernels/boosted_trees/quantiles/BUILD4
-rw-r--r--tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h96
-rw-r--r--tensorflow/core/ops/boosted_trees_ops.cc125
-rw-r--r--tensorflow/python/kernel_tests/boosted_trees/BUILD13
-rw-r--r--tensorflow/python/kernel_tests/boosted_trees/quantile_ops_test.py140
-rw-r--r--tensorflow/python/ops/boosted_trees_ops.py6
16 files changed, 1059 insertions, 2 deletions
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesBucketize.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesBucketize.pbtxt
new file mode 100644
index 0000000000..cdaeb5091c
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesBucketize.pbtxt
@@ -0,0 +1,34 @@
+op {
+ graph_op_name: "BoostedTreesBucketize"
+ visibility: HIDDEN
+ in_arg {
+ name: "float_values"
+ description: <<END
+float; List of Rank 2 Tensor each containing float values for a single feature.
+END
+ }
+ in_arg {
+ name: "bucket_boundaries"
+ description: <<END
+float; List of Rank 1 Tensors each containing the bucket boundaries for a single
+feature.
+END
+ }
+ out_arg {
+ name: "buckets"
+ description: <<END
+int; List of Rank 2 Tensors each containing the bucketized values for a single feature.
+END
+ }
+ attr {
+ name: "num_features"
+ description: <<END
+inferred int; number of features.
+END
+ }
+ summary: "Bucketize each feature based on bucket boundaries."
+ description: <<END
+An op that returns a list of float tensors, where each tensor represents the
+bucketized values for a single feature.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesCreateQuantileStreamResource.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesCreateQuantileStreamResource.pbtxt
new file mode 100644
index 0000000000..20da1295f6
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesCreateQuantileStreamResource.pbtxt
@@ -0,0 +1,29 @@
+op {
+ graph_op_name: "BoostedTreesCreateQuantileStreamResource"
+ visibility: HIDDEN
+ in_arg {
+ name: "quantile_stream_resource_handle"
+ description: <<END
+resource; Handle to quantile stream resource.
+END
+ }
+ in_arg {
+ name: "epsilon"
+ description: <<END
+float; The required approximation error of the stream resource.
+END
+ }
+ in_arg {
+ name: "num_streams"
+ description: <<END
+int; The number of streams managed by the resource that shares the same epsilon.
+END
+ }
+ attr {
+ name: "max_elements"
+ description : <<END
+int; The maximum number of data points that can be fed to the stream.
+END
+ }
+ summary: "Create the Resource for Quantile Streams."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesMakeQuantileSummaries.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesMakeQuantileSummaries.pbtxt
new file mode 100644
index 0000000000..ca111af312
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesMakeQuantileSummaries.pbtxt
@@ -0,0 +1,40 @@
+op {
+ graph_op_name: "BoostedTreesMakeQuantileSummaries"
+ visibility: HIDDEN
+ in_arg {
+ name: "float_values"
+ description: <<END
+float; List of Rank 2 Tensors each containing values for a single feature.
+END
+ }
+ in_arg {
+ name: "example_weights"
+ description: <<END
+float; Rank 1 Tensor with weights per instance.
+END
+ }
+ in_arg {
+ name: "epsilon"
+ description: <<END
+float; The required maximum approximation error.
+END
+ }
+ out_arg {
+ name: "summaries"
+ description: <<END
+float; List of Rank 2 Tensors each containing the quantile summary (value, weight,
+min_rank, max_rank) of a single feature.
+END
+ }
+ attr {
+ name: "num_features"
+ description: <<END
+int; Inferred from the size of float_values.
+The number of float features.
+END
+ }
+ summary: "Makes the summary of quantiles for the batch."
+ description: <<END
+An op that takes a list of tensors and outputs the quantile summaries for each tensor.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceAddSummaries.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceAddSummaries.pbtxt
new file mode 100644
index 0000000000..bbeecbf32b
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceAddSummaries.pbtxt
@@ -0,0 +1,22 @@
+op {
+ graph_op_name: "BoostedTreesQuantileStreamResourceAddSummaries"
+ visibility: HIDDEN
+ in_arg {
+ name: "quantile_stream_resource_handle"
+ description: <<END
+resource handle referring to a QuantileStreamResource.
+END
+ }
+ in_arg {
+ name: "summaries"
+ description: <<END
+string; List of Rank 2 Tensor each containing the summaries for a single feature.
+END
+ }
+ summary: "Add the quantile summaries to each quantile stream resource."
+ description: <<END
+An op that adds a list of quantile summaries to a quantile stream resource. Each
+summary Tensor is rank 2, containing summaries (value, weight, min_rank, max_rank)
+for a single feature.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceFlush.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceFlush.pbtxt
new file mode 100644
index 0000000000..2fd94efa10
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceFlush.pbtxt
@@ -0,0 +1,31 @@
+op {
+ graph_op_name: "BoostedTreesQuantileStreamResourceFlush"
+ visibility: HIDDEN
+ in_arg {
+ name: "quantile_stream_resource_handle"
+ description: <<END
+resource handle referring to a QuantileStreamResource.
+END
+ }
+ in_arg {
+ name: "num_buckets",
+ description: <<END
+int; approximate number of buckets unless using generate_quantiles.
+END
+ }
+ attr {
+ name: "generate_quantiles"
+ description: <<END
+bool; If True, the output will be the num_quantiles for each stream where the ith
+entry is the ith quantile of the input with an approximation error of epsilon.
+Duplicate values may be present.
+If False, the output will be the points in the histogram that we got which roughly
+translates to 1/epsilon boundaries and without any duplicates.
+Default to False.
+END
+ }
+ summary: "Flush the summaries for a quantile stream resource."
+ description: <<END
+An op that flushes the summaries for a quantile stream resource.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceGetBucketBoundaries.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceGetBucketBoundaries.pbtxt
new file mode 100644
index 0000000000..206672802f
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceGetBucketBoundaries.pbtxt
@@ -0,0 +1,27 @@
+op {
+ graph_op_name: "BoostedTreesQuantileStreamResourceGetBucketBoundaries"
+ visibility: HIDDEN
+ in_arg {
+ name: "quantile_stream_resource_handle"
+ description: <<END
+resource handle referring to a QuantileStreamResource.
+END
+ }
+ out_arg {
+ name: "bucket_boundaries"
+ description: <<END
+float; List of Rank 1 Tensors each containing the bucket boundaries for a feature.
+END
+ }
+ attr {
+ name: "num_features"
+ description: <<END
+inferred int; number of features to get bucket boundaries for.
+END
+ }
+ summary: "Generate the bucket boundaries for each feature based on accumulated summaries."
+ description: <<END
+An op that returns a list of float tensors for a quantile stream resource. Each
+tensor is Rank 1 containing bucket boundaries for a single feature.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceHandleOp.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceHandleOp.pbtxt
new file mode 100644
index 0000000000..cb7786c051
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceHandleOp.pbtxt
@@ -0,0 +1,5 @@
+op {
+ graph_op_name: "BoostedTreesQuantileStreamResourceHandleOp"
+ visibility: HIDDEN
+ summary: "Creates a handle to a BoostedTreesQuantileStreamResource."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_IsBoostedTreesQuantileStreamResourceInitialized.pbtxt b/tensorflow/core/api_def/base_api/api_def_IsBoostedTreesQuantileStreamResourceInitialized.pbtxt
new file mode 100644
index 0000000000..758eeb96f0
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_IsBoostedTreesQuantileStreamResourceInitialized.pbtxt
@@ -0,0 +1,20 @@
+op {
+ graph_op_name: "IsBoostedTreesQuantileStreamResourceInitialized"
+ visibility: HIDDEN
+ in_arg {
+ name: "quantile_stream_resource_handle"
+ description: <<END
+resource; The reference to quantile stream resource handle.
+END
+ }
+ out_arg {
+ name: "is_initialized"
+ description: <<END
+bool; True if the resource is initialized, False otherwise.
+END
+ }
+ summary: "Checks whether a quantile stream has been initialized."
+ description: <<END
+An Op that checks if quantile stream resource is initialized.
+END
+}
diff --git a/tensorflow/core/kernels/boosted_trees/BUILD b/tensorflow/core/kernels/boosted_trees/BUILD
index 4910021c63..4e8bfa02fc 100644
--- a/tensorflow/core/kernels/boosted_trees/BUILD
+++ b/tensorflow/core/kernels/boosted_trees/BUILD
@@ -15,7 +15,9 @@ load(
tf_proto_library(
name = "boosted_trees_proto",
- srcs = ["boosted_trees.proto"],
+ srcs = [
+ "boosted_trees.proto",
+ ],
cc_api_version = 2,
visibility = ["//visibility:public"],
)
@@ -87,9 +89,21 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "quantile_ops",
+ srcs = ["quantile_ops.cc"],
+ deps = [
+ "//tensorflow/core:boosted_trees_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/kernels/boosted_trees/quantiles:weighted_quantiles",
+ ],
+)
+
+tf_kernel_library(
name = "boosted_trees_ops",
deps = [
":prediction_ops",
+ ":quantile_ops",
":resource_ops",
":stats_ops",
":training_ops",
diff --git a/tensorflow/core/kernels/boosted_trees/quantile_ops.cc b/tensorflow/core/kernels/boosted_trees/quantile_ops.cc
new file mode 100644
index 0000000000..d1840941c1
--- /dev/null
+++ b/tensorflow/core/kernels/boosted_trees/quantile_ops.cc
@@ -0,0 +1,453 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#include <algorithm>
+#include <iterator>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h"
+#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream.h"
+#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/work_sharder.h"
+
+namespace tensorflow {
+
+const char* const kExampleWeightsName = "example_weights";
+const char* const kMaxElementsName = "max_elements";
+const char* const kGenerateQuantiles = "generate_quantiles";
+const char* const kNumBucketsName = "num_buckets";
+const char* const kEpsilonName = "epsilon";
+const char* const kBucketBoundariesName = "bucket_boundaries";
+const char* const kBucketsName = "buckets";
+const char* const kSummariesName = "summaries";
+const char* const kNumStreamsName = "num_streams";
+const char* const kNumFeaturesName = "num_features";
+const char* const kFloatFeaturesName = "float_values";
+const char* const kResourceHandleName = "quantile_stream_resource_handle";
+
+using QuantileStreamResource = BoostedTreesQuantileStreamResource;
+using QuantileStream =
+ boosted_trees::quantiles::WeightedQuantilesStream<float, float>;
+using QuantileSummary =
+ boosted_trees::quantiles::WeightedQuantilesSummary<float, float>;
+using QuantileSummaryEntry =
+ boosted_trees::quantiles::WeightedQuantilesSummary<float,
+ float>::SummaryEntry;
+
+// Generates quantiles on a finalized QuantileStream.
+std::vector<float> GenerateBoundaries(const QuantileStream& stream,
+ const int64 num_boundaries) {
+ std::vector<float> boundaries = stream.GenerateBoundaries(num_boundaries);
+
+ // Uniquify elements as we may get dupes.
+ auto end_it = std::unique(boundaries.begin(), boundaries.end());
+ boundaries.resize(std::distance(boundaries.begin(), end_it));
+ return boundaries;
+}
+
+// Generates quantiles on a finalized QuantileStream.
+std::vector<float> GenerateQuantiles(const QuantileStream& stream,
+ const int64 num_quantiles) {
+ // Do not de-dup boundaries. Exactly num_quantiles+1 boundary values
+ // will be returned.
+ std::vector<float> boundaries = stream.GenerateQuantiles(num_quantiles - 1);
+ CHECK_EQ(boundaries.size(), num_quantiles);
+ return boundaries;
+}
+
+std::vector<float> GetBuckets(const int32 feature,
+ const OpInputList& buckets_list) {
+ const auto& buckets = buckets_list[feature].flat<float>();
+ std::vector<float> buckets_vector(buckets.data(),
+ buckets.data() + buckets.size());
+ return buckets_vector;
+}
+
+REGISTER_RESOURCE_HANDLE_KERNEL(BoostedTreesQuantileStreamResource);
+
+REGISTER_KERNEL_BUILDER(
+ Name("IsBoostedTreesQuantileStreamResourceInitialized").Device(DEVICE_CPU),
+ IsResourceInitialized<BoostedTreesQuantileStreamResource>);
+
+class BoostedTreesCreateQuantileStreamResourceOp : public OpKernel {
+ public:
+ explicit BoostedTreesCreateQuantileStreamResourceOp(
+ OpKernelConstruction* const context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr(kMaxElementsName, &max_elements_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ // Only create one, if one does not exist already. Report status for all
+ // other exceptions. If one already exists, it unrefs the new one.
+ // An epsilon value of zero could cause perfoamance issues and is therefore,
+ // disallowed.
+ const Tensor* epsilon_t;
+ OP_REQUIRES_OK(context, context->input(kEpsilonName, &epsilon_t));
+ float epsilon = epsilon_t->scalar<float>()();
+ OP_REQUIRES(
+ context, epsilon > 0,
+ errors::InvalidArgument("An epsilon value of zero is not allowed."));
+
+ const Tensor* num_streams_t;
+ OP_REQUIRES_OK(context, context->input(kNumStreamsName, &num_streams_t));
+ int64 num_streams = num_streams_t->scalar<int64>()();
+
+ auto result =
+ new QuantileStreamResource(epsilon, max_elements_, num_streams);
+ auto status = CreateResource(context, HandleFromInput(context, 0), result);
+ if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) {
+ OP_REQUIRES(context, false, status);
+ }
+ }
+
+ private:
+ // An upper bound on the number of entries that the summaries might have
+ // for a feature.
+ int64 max_elements_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("BoostedTreesCreateQuantileStreamResource").Device(DEVICE_CPU),
+ BoostedTreesCreateQuantileStreamResourceOp);
+
+class BoostedTreesMakeQuantileSummariesOp : public OpKernel {
+ public:
+ explicit BoostedTreesMakeQuantileSummariesOp(
+ OpKernelConstruction* const context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr(kNumFeaturesName, &num_features_));
+ }
+
+ void Compute(OpKernelContext* const context) override {
+ // Read float features list;
+ OpInputList float_features_list;
+ OP_REQUIRES_OK(
+ context, context->input_list(kFloatFeaturesName, &float_features_list));
+
+ // Parse example weights and get batch size.
+ const Tensor* example_weights_t;
+ OP_REQUIRES_OK(context,
+ context->input(kExampleWeightsName, &example_weights_t));
+ auto example_weights = example_weights_t->flat<float>();
+ const int64 batch_size = example_weights.size();
+ const Tensor* epsilon_t;
+ OP_REQUIRES_OK(context, context->input(kEpsilonName, &epsilon_t));
+ float epsilon = epsilon_t->scalar<float>()();
+
+ OpOutputList summaries_output_list;
+ OP_REQUIRES_OK(
+ context, context->output_list(kSummariesName, &summaries_output_list));
+
+ auto do_quantile_summary_gen = [&](const int64 begin, const int64 end) {
+ // Iterating features.
+ for (int64 index = begin; index < end; index++) {
+ const auto feature_values = float_features_list[index].flat<float>();
+ QuantileStream stream(epsilon, batch_size + 1);
+ // Run quantile summary generation.
+ for (int64 j = 0; j < batch_size; j++) {
+ stream.PushEntry(feature_values(j), example_weights(j));
+ }
+ stream.Finalize();
+ const auto summary_entry_list = stream.GetFinalSummary().GetEntryList();
+ Tensor* output_t;
+ OP_REQUIRES_OK(
+ context,
+ summaries_output_list.allocate(
+ index,
+ TensorShape({static_cast<int64>(summary_entry_list.size()), 4}),
+ &output_t));
+ auto output = output_t->matrix<float>();
+ for (auto row = 0; row < summary_entry_list.size(); row++) {
+ const auto& entry = summary_entry_list[row];
+ output(row, 0) = entry.value;
+ output(row, 1) = entry.weight;
+ output(row, 2) = entry.min_rank;
+ output(row, 3) = entry.max_rank;
+ }
+ }
+ };
+ // TODO(tanzheny): comment on the magic number.
+ const int64 kCostPerUnit = 500 * batch_size;
+ const DeviceBase::CpuWorkerThreads& worker_threads =
+ *context->device()->tensorflow_cpu_worker_threads();
+ Shard(worker_threads.num_threads, worker_threads.workers, num_features_,
+ kCostPerUnit, do_quantile_summary_gen);
+ }
+
+ private:
+ int64 num_features_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("BoostedTreesMakeQuantileSummaries").Device(DEVICE_CPU),
+ BoostedTreesMakeQuantileSummariesOp);
+
+class BoostedTreesQuantileStreamResourceAddSummariesOp : public OpKernel {
+ public:
+ explicit BoostedTreesQuantileStreamResourceAddSummariesOp(
+ OpKernelConstruction* const context)
+ : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ ResourceHandle handle;
+ OP_REQUIRES_OK(context,
+ HandleFromInput(context, kResourceHandleName, &handle));
+ QuantileStreamResource* stream_resource;
+ // Create a reference to the underlying resource using the handle.
+ OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource));
+ // Remove the reference at the end of this scope.
+ mutex_lock l(*stream_resource->mutex());
+ core::ScopedUnref unref_me(stream_resource);
+
+ OpInputList summaries_list;
+ OP_REQUIRES_OK(context,
+ context->input_list(kSummariesName, &summaries_list));
+ int32 num_streams = stream_resource->num_streams();
+ CHECK_EQ(static_cast<int>(num_streams), summaries_list.size());
+
+ auto do_quantile_add_summary = [&](const int64 begin, const int64 end) {
+ // Iterating all features.
+ for (int64 feature_idx = begin; feature_idx < end; ++feature_idx) {
+ const Tensor& summaries = summaries_list[feature_idx];
+ const auto summary_values = summaries.matrix<float>();
+ const auto& tensor_shape = summaries.shape();
+ const int64 entries_size = tensor_shape.dim_size(0);
+ CHECK_EQ(tensor_shape.dim_size(1), 4);
+ std::vector<QuantileSummaryEntry> summary_entries;
+ summary_entries.reserve(entries_size);
+ for (int64 i = 0; i < entries_size; i++) {
+ float value = summary_values(i, 0);
+ float weight = summary_values(i, 1);
+ float min_rank = summary_values(i, 2);
+ float max_rank = summary_values(i, 3);
+ QuantileSummaryEntry entry(value, weight, min_rank, max_rank);
+ summary_entries.push_back(entry);
+ }
+ stream_resource->stream(feature_idx)->PushSummary(summary_entries);
+ }
+ };
+
+ // TODO(tanzheny): comment on the magic number.
+ const int64 kCostPerUnit = 500 * num_streams;
+ const DeviceBase::CpuWorkerThreads& worker_threads =
+ *context->device()->tensorflow_cpu_worker_threads();
+ Shard(worker_threads.num_threads, worker_threads.workers, num_streams,
+ kCostPerUnit, do_quantile_add_summary);
+ }
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("BoostedTreesQuantileStreamResourceAddSummaries").Device(DEVICE_CPU),
+ BoostedTreesQuantileStreamResourceAddSummariesOp);
+
+class BoostedTreesQuantileStreamResourceFlushOp : public OpKernel {
+ public:
+ explicit BoostedTreesQuantileStreamResourceFlushOp(
+ OpKernelConstruction* const context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context,
+ context->GetAttr(kGenerateQuantiles, &generate_quantiles_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ ResourceHandle handle;
+ OP_REQUIRES_OK(context,
+ HandleFromInput(context, kResourceHandleName, &handle));
+ QuantileStreamResource* stream_resource;
+ // Create a reference to the underlying resource using the handle.
+ OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource));
+ // Remove the reference at the end of this scope.
+ mutex_lock l(*stream_resource->mutex());
+ core::ScopedUnref unref_me(stream_resource);
+
+ const Tensor* num_buckets_t;
+ OP_REQUIRES_OK(context, context->input(kNumBucketsName, &num_buckets_t));
+ const int64 num_buckets = num_buckets_t->scalar<int64>()();
+ const int64 num_streams = stream_resource->num_streams();
+
+ auto do_quantile_flush = [&](const int64 begin, const int64 end) {
+ // Iterating over all streams.
+ for (int64 stream_idx = begin; stream_idx < end; ++stream_idx) {
+ QuantileStream* stream = stream_resource->stream(stream_idx);
+ stream->Finalize();
+ stream_resource->set_boundaries(
+ generate_quantiles_ ? GenerateQuantiles(*stream, num_buckets)
+ : GenerateBoundaries(*stream, num_buckets),
+ stream_idx);
+ }
+ };
+
+ // TODO(tanzheny): comment on the magic number.
+ const int64 kCostPerUnit = 500 * num_streams;
+ const DeviceBase::CpuWorkerThreads& worker_threads =
+ *context->device()->tensorflow_cpu_worker_threads();
+ Shard(worker_threads.num_threads, worker_threads.workers, num_streams,
+ kCostPerUnit, do_quantile_flush);
+
+ stream_resource->set_buckets_ready(true);
+ }
+
+ private:
+ bool generate_quantiles_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("BoostedTreesQuantileStreamResourceFlush").Device(DEVICE_CPU),
+ BoostedTreesQuantileStreamResourceFlushOp);
+
+class BoostedTreesQuantileStreamResourceGetBucketBoundariesOp
+ : public OpKernel {
+ public:
+ explicit BoostedTreesQuantileStreamResourceGetBucketBoundariesOp(
+ OpKernelConstruction* const context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr(kNumFeaturesName, &num_features_));
+ }
+
+ void Compute(OpKernelContext* const context) override {
+ ResourceHandle handle;
+ OP_REQUIRES_OK(context,
+ HandleFromInput(context, kResourceHandleName, &handle));
+ QuantileStreamResource* stream_resource;
+ // Create a reference to the underlying resource using the handle.
+ OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource));
+ // Remove the reference at the end of this scope.
+ mutex_lock l(*stream_resource->mutex());
+ core::ScopedUnref unref_me(stream_resource);
+
+ const int64 num_streams = stream_resource->num_streams();
+ CHECK_EQ(num_features_, num_streams);
+ OpOutputList bucket_boundaries_list;
+ OP_REQUIRES_OK(context, context->output_list(kBucketBoundariesName,
+ &bucket_boundaries_list));
+
+ auto do_quantile_get_buckets = [&](const int64 begin, const int64 end) {
+ // Iterating over all streams.
+ for (int64 stream_idx = begin; stream_idx < end; stream_idx++) {
+ const auto& boundaries = stream_resource->boundaries(stream_idx);
+ Tensor* bucket_boundaries_t = nullptr;
+ OP_REQUIRES_OK(context,
+ bucket_boundaries_list.allocate(
+ stream_idx, {static_cast<int64>(boundaries.size())},
+ &bucket_boundaries_t));
+ auto* quantiles_flat = bucket_boundaries_t->flat<float>().data();
+ memcpy(quantiles_flat, boundaries.data(),
+ sizeof(float) * boundaries.size());
+ }
+ };
+
+ // TODO(tanzheny): comment on the magic number.
+ const int64 kCostPerUnit = 500 * num_streams;
+ const DeviceBase::CpuWorkerThreads& worker_threads =
+ *context->device()->tensorflow_cpu_worker_threads();
+ Shard(worker_threads.num_threads, worker_threads.workers, num_streams,
+ kCostPerUnit, do_quantile_get_buckets);
+ }
+
+ private:
+ int64 num_features_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("BoostedTreesQuantileStreamResourceGetBucketBoundaries")
+ .Device(DEVICE_CPU),
+ BoostedTreesQuantileStreamResourceGetBucketBoundariesOp);
+
+// Given the calculated quantiles thresholds and input data, this operation
+// converts the input features into the buckets (categorical values), depending
+// on which quantile they fall into.
+class BoostedTreesBucketizeOp : public OpKernel {
+ public:
+ explicit BoostedTreesBucketizeOp(OpKernelConstruction* const context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr(kNumFeaturesName, &num_features_));
+ }
+
+ void Compute(OpKernelContext* const context) override {
+ // Read float features list;
+ OpInputList float_features_list;
+ OP_REQUIRES_OK(
+ context, context->input_list(kFloatFeaturesName, &float_features_list));
+ OpInputList bucket_boundaries_list;
+ OP_REQUIRES_OK(context, context->input_list(kBucketBoundariesName,
+ &bucket_boundaries_list));
+ OP_REQUIRES(context,
+ tensorflow::TensorShapeUtils::IsVector(
+ bucket_boundaries_list[0].shape()),
+ errors::InvalidArgument(
+ strings::Printf("Buckets should be flat vectors.")));
+ OpOutputList buckets_list;
+ OP_REQUIRES_OK(context, context->output_list(kBucketsName, &buckets_list));
+
+ auto do_quantile_get_quantiles = [&](const int64 begin, const int64 end) {
+ // Iterating over all resources
+ for (int64 feature_idx = begin; feature_idx < end; feature_idx++) {
+ const Tensor& values_tensor = float_features_list[feature_idx];
+ const int64 num_values = values_tensor.dim_size(0);
+
+ Tensor* output_t = nullptr;
+ OP_REQUIRES_OK(
+ context, buckets_list.allocate(
+ feature_idx, TensorShape({num_values, 1}), &output_t));
+ auto output = output_t->matrix<int32>();
+
+ const std::vector<float>& bucket_boundaries_vector =
+ GetBuckets(feature_idx, bucket_boundaries_list);
+ CHECK(!bucket_boundaries_vector.empty())
+ << "Got empty buckets for feature " << feature_idx;
+ auto flat_values = values_tensor.flat<float>();
+ for (int64 instance = 0; instance < num_values; instance++) {
+ const float value = flat_values(instance);
+ auto bucket_iter =
+ std::lower_bound(bucket_boundaries_vector.begin(),
+ bucket_boundaries_vector.end(), value);
+ if (bucket_iter == bucket_boundaries_vector.end()) {
+ --bucket_iter;
+ }
+ const int32 bucket = static_cast<int32>(
+ bucket_iter - bucket_boundaries_vector.begin());
+ // Bucket id.
+ output(instance, 0) = bucket;
+ }
+ }
+ };
+
+ // TODO(tanzheny): comment on the magic number.
+ const int64 kCostPerUnit = 500 * num_features_;
+ const DeviceBase::CpuWorkerThreads& worker_threads =
+ *context->device()->tensorflow_cpu_worker_threads();
+ Shard(worker_threads.num_threads, worker_threads.workers, num_features_,
+ kCostPerUnit, do_quantile_get_quantiles);
+ }
+
+ private:
+ int64 num_features_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("BoostedTreesBucketize").Device(DEVICE_CPU),
+ BoostedTreesBucketizeOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/boosted_trees/quantiles/BUILD b/tensorflow/core/kernels/boosted_trees/quantiles/BUILD
index 3163c63949..12d9473776 100644
--- a/tensorflow/core/kernels/boosted_trees/quantiles/BUILD
+++ b/tensorflow/core/kernels/boosted_trees/quantiles/BUILD
@@ -1,5 +1,5 @@
# Description:
-# This directory contains common utilities used in boosted_trees.
+# This directory contains common quantile utilities used in boosted_trees.
package(
default_visibility = ["//tensorflow:internal"],
)
@@ -16,6 +16,7 @@ cc_library(
name = "weighted_quantiles",
srcs = [],
hdrs = [
+ "quantile_stream_resource.h",
"weighted_quantiles_buffer.h",
"weighted_quantiles_stream.h",
"weighted_quantiles_summary.h",
@@ -23,6 +24,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:framework_headers_lib",
+ "//third_party/eigen3",
],
)
diff --git a/tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h b/tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h
new file mode 100644
index 0000000000..1c31724272
--- /dev/null
+++ b/tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h
@@ -0,0 +1,96 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#ifndef TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_QUANTILE_STREAM_RESOURCE_H_
+#define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_QUANTILE_STREAM_RESOURCE_H_
+
+#include <vector>
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace tensorflow {
+
+using QuantileStream =
+ boosted_trees::quantiles::WeightedQuantilesStream<float, float>;
+
+// Quantile Stream Resource for a list of streams sharing the same number of
+// quantiles, maximum elements, and epsilon.
+class BoostedTreesQuantileStreamResource : public ResourceBase {
+ public:
+ BoostedTreesQuantileStreamResource(const float epsilon,
+ const int64 max_elements,
+ const int64 num_streams)
+ : are_buckets_ready_(false),
+ epsilon_(epsilon),
+ num_streams_(num_streams),
+ max_elements_(max_elements) {
+ streams_.reserve(num_streams_);
+ boundaries_.reserve(num_streams_);
+ for (int64 idx = 0; idx < num_streams; ++idx) {
+ streams_.push_back(QuantileStream(epsilon, max_elements));
+ boundaries_.push_back(std::vector<float>());
+ }
+ }
+
+ string DebugString() override { return "QuantileStreamResource"; }
+
+ tensorflow::mutex* mutex() { return &mu_; }
+
+ QuantileStream* stream(const int64 index) { return &streams_[index]; }
+
+ const std::vector<float>& boundaries(const int64 index) {
+ return boundaries_[index];
+ }
+
+ void set_boundaries(const std::vector<float>& boundaries, const int64 index) {
+ boundaries_[index] = boundaries;
+ }
+
+ float epsilon() const { return epsilon_; }
+ int64 num_streams() const { return num_streams_; }
+
+ bool are_buckets_ready() const { return are_buckets_ready_; }
+ void set_buckets_ready(const bool are_buckets_ready) {
+ are_buckets_ready_ = are_buckets_ready;
+ }
+
+ private:
+ ~BoostedTreesQuantileStreamResource() override {}
+
+ // Mutex for the whole resource.
+ tensorflow::mutex mu_;
+
+ // Quantile streams.
+ std::vector<QuantileStream> streams_;
+
+ // Stores the boundaries. Same size as streams_.
+ std::vector<std::vector<float>> boundaries_;
+
+ // Whether boundaries are created. Initially boundaries are empty until
+ // set_boundaries are called.
+ bool are_buckets_ready_;
+
+ const float epsilon_;
+ const int64 num_streams_;
+ // An upper-bound for the number of elements.
+ int64 max_elements_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(BoostedTreesQuantileStreamResource);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_QUANTILE_STREAM_RESOURCE_H_
diff --git a/tensorflow/core/ops/boosted_trees_ops.cc b/tensorflow/core/ops/boosted_trees_ops.cc
index 01452b3e85..7c4184bff4 100644
--- a/tensorflow/core/ops/boosted_trees_ops.cc
+++ b/tensorflow/core/ops/boosted_trees_ops.cc
@@ -22,6 +22,10 @@ limitations under the License.
namespace tensorflow {
+using shape_inference::DimensionHandle;
+using shape_inference::InferenceContext;
+using shape_inference::ShapeHandle;
+
REGISTER_RESOURCE_HANDLE_OP(BoostedTreesEnsembleResource);
REGISTER_OP("IsBoostedTreesEnsembleInitialized")
@@ -354,4 +358,125 @@ REGISTER_OP("BoostedTreesCenterBias")
return Status::OK();
});
+REGISTER_RESOURCE_HANDLE_OP(BoostedTreesQuantileStreamResource);
+
+REGISTER_OP("IsBoostedTreesQuantileStreamResourceInitialized")
+ .Input("quantile_stream_resource_handle: resource")
+ .Output("is_initialized: bool")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused_input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+ c->set_output(0, c->Scalar());
+ return Status::OK();
+ });
+
+REGISTER_OP("BoostedTreesCreateQuantileStreamResource")
+ .Attr("max_elements: int = 1099511627776") // 1 << 40
+ .Input("quantile_stream_resource_handle: resource")
+ .Input("epsilon: float")
+ .Input("num_streams: int64")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused_input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input));
+ return Status::OK();
+ });
+
+REGISTER_OP("BoostedTreesMakeQuantileSummaries")
+ .Attr("num_features: int >= 0")
+ .Input("float_values: num_features * float")
+ .Input("example_weights: float")
+ .Input("epsilon: float")
+ .Output("summaries: num_features * float")
+ .SetShapeFn([](InferenceContext* c) {
+ int num_features;
+ TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
+ ShapeHandle example_weights_shape;
+ TF_RETURN_IF_ERROR(
+ c->WithRank(c->input(num_features), 1, &example_weights_shape));
+ for (int i = 0; i < num_features; ++i) {
+ ShapeHandle feature_shape;
+ DimensionHandle unused_dim;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &feature_shape));
+ TF_RETURN_IF_ERROR(c->Merge(c->Dim(feature_shape, 0),
+ c->Dim(example_weights_shape, 0),
+ &unused_dim));
+ // the columns are value, weight, min_rank, max_rank.
+ c->set_output(i, c->MakeShape({c->UnknownDim(), 4}));
+ }
+ // epsilon must be a scalar.
+ ShapeHandle unused_input;
+ TF_RETURN_IF_ERROR(
+ c->WithRank(c->input(num_features + 1), 0, &unused_input));
+ return Status::OK();
+ });
+
+REGISTER_OP("BoostedTreesQuantileStreamResourceAddSummaries")
+ .Attr("num_features: int >= 0")
+ .Input("quantile_stream_resource_handle: resource")
+ .Input("summaries: num_features * float")
+ .SetShapeFn([](InferenceContext* c) {
+ int num_features;
+ TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
+ // resource handle must be a scalar.
+ shape_inference::ShapeHandle unused_input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+ // each summary must be rank 2.
+ for (int i = 1; i < num_features + 1; i++) {
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &unused_input));
+ }
+ return Status::OK();
+ });
+
+REGISTER_OP("BoostedTreesQuantileStreamResourceFlush")
+ .Attr("generate_quantiles: bool = False")
+ .Input("quantile_stream_resource_handle: resource")
+ .Input("num_buckets: int64")
+ .SetShapeFn([](InferenceContext* c) {
+ // All the inputs are scalars.
+ shape_inference::ShapeHandle unused_input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
+ return Status::OK();
+ });
+
+REGISTER_OP("BoostedTreesQuantileStreamResourceGetBucketBoundaries")
+ .Attr("num_features: int >= 0")
+ .Input("quantile_stream_resource_handle: resource")
+ .Output("bucket_boundaries: num_features * float")
+ .SetShapeFn([](InferenceContext* c) {
+ int num_features;
+ TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
+ shape_inference::ShapeHandle unused_input;
+ // resource handle must be a scalar.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+ for (int i = 0; i < num_features; i++) {
+ c->set_output(i, c->Vector(c->UnknownDim()));
+ }
+ return Status::OK();
+ });
+
+REGISTER_OP("BoostedTreesBucketize")
+ .Attr("num_features: int >= 0")
+ .Input("float_values: num_features * float")
+ .Input("bucket_boundaries: num_features * float")
+ .Output("buckets: num_features * int32")
+ .SetShapeFn([](InferenceContext* c) {
+ int num_features;
+ TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
+ ShapeHandle feature_shape;
+ DimensionHandle unused_dim;
+ for (int i = 0; i < num_features; i++) {
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &feature_shape));
+ TF_RETURN_IF_ERROR(c->Merge(c->Dim(feature_shape, 0),
+ c->Dim(c->input(0), 0), &unused_dim));
+ }
+ // Bucketized result should have same dimension as input.
+ for (int i = 0; i < num_features; i++) {
+ c->set_output(i, c->MakeShape({c->Dim(c->input(i), 0), 1}));
+ }
+ return Status::OK();
+ });
+
} // namespace tensorflow
diff --git a/tensorflow/python/kernel_tests/boosted_trees/BUILD b/tensorflow/python/kernel_tests/boosted_trees/BUILD
index 4f92ab0795..20446781f0 100644
--- a/tensorflow/python/kernel_tests/boosted_trees/BUILD
+++ b/tensorflow/python/kernel_tests/boosted_trees/BUILD
@@ -74,3 +74,16 @@ tf_py_test(
"//tensorflow/python:resources",
],
)
+
+tf_py_test(
+ name = "quantile_ops_test",
+ size = "small",
+ srcs = ["quantile_ops_test.py"],
+ additional_deps = [
+ "//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_py",
+ "//tensorflow/python:boosted_trees_ops",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:resources",
+ ],
+)
diff --git a/tensorflow/python/kernel_tests/boosted_trees/quantile_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/quantile_ops_test.py
new file mode 100644
index 0000000000..c71b8df4ad
--- /dev/null
+++ b/tensorflow/python/kernel_tests/boosted_trees/quantile_ops_test.py
@@ -0,0 +1,140 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Test for checking quantile related ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import boosted_trees_ops
+from tensorflow.python.ops import resources
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_handle_op as resource_handle_op
+from tensorflow.python.ops.gen_boosted_trees_ops import is_boosted_trees_quantile_stream_resource_initialized as resource_initialized
+from tensorflow.python.platform import googletest
+
+
+class QuantileOpsTest(test_util.TensorFlowTestCase):
+
+ def create_resource(self, name, eps, max_elements, num_streams=1):
+ quantile_accumulator_handle = resource_handle_op(
+ container="", shared_name=name, name=name)
+ create_op = boosted_trees_ops.create_quantile_stream_resource(
+ quantile_accumulator_handle,
+ epsilon=eps,
+ max_elements=max_elements,
+ num_streams=num_streams)
+ is_initialized_op = resource_initialized(quantile_accumulator_handle)
+ resources.register_resource(quantile_accumulator_handle, create_op,
+ is_initialized_op)
+ return quantile_accumulator_handle
+
+ def setUp(self):
+ """Sets up the quantile ops test as follows.
+
+ Create a batch of 6 examples having 2 features
+ The data looks like this
+ | Instance | instance weights | Feature 0 | Feature 1
+ | 0 | 10 | 1.2 | 2.3
+ | 1 | 1 | 12.1 | 1.2
+ | 2 | 1 | 0.3 | 1.1
+ | 3 | 1 | 0.5 | 2.6
+ | 4 | 1 | 0.6 | 3.2
+ | 5 | 1 | 2.2 | 0.8
+ """
+
+ self._feature_0 = constant_op.constant(
+ [[1.2], [12.1], [0.3], [0.5], [0.6], [2.2]], dtype=dtypes.float32)
+ self._feature_1 = constant_op.constant(
+ [[2.3], [1.2], [1.1], [2.6], [3.2], [0.8]], dtype=dtypes.float32)
+ self._feature_0_boundaries = constant_op.constant(
+ [0.3, 0.6, 1.2, 12.1], dtype=dtypes.float32)
+ self._feature_1_boundaries = constant_op.constant(
+ [0.8, 1.2, 2.3, 3.2], dtype=dtypes.float32)
+ self._feature_0_quantiles = constant_op.constant(
+ [[2], [3], [0], [1], [1], [3]], dtype=dtypes.int32)
+ self._feature_1_quantiles = constant_op.constant(
+ [[2], [1], [1], [3], [3], [0]], dtype=dtypes.int32)
+
+ self._example_weights = constant_op.constant(
+ [10, 1, 1, 1, 1, 1], dtype=dtypes.float32)
+
+ self.eps = 0.01
+ self.max_elements = 1 << 16
+ self.num_quantiles = constant_op.constant(3, dtype=dtypes.int64)
+
+ def testBasicQuantileBucketsSingleResource(self):
+ with self.test_session() as sess:
+ quantile_accumulator_handle = self.create_resource("floats", self.eps,
+ self.max_elements, 2)
+ resources.initialize_resources(resources.shared_resources()).run()
+ summaries = boosted_trees_ops.make_quantile_summaries(
+ [self._feature_0, self._feature_1], self._example_weights,
+ epsilon=self.eps)
+ summary_op = boosted_trees_ops.quantile_add_summaries(
+ quantile_accumulator_handle, summaries)
+ flush_op = boosted_trees_ops.quantile_flush(
+ quantile_accumulator_handle, self.num_quantiles)
+ buckets = boosted_trees_ops.get_bucket_boundaries(
+ quantile_accumulator_handle, num_features=2)
+ quantiles = boosted_trees_ops.boosted_trees_bucketize(
+ [self._feature_0, self._feature_1], buckets)
+ sess.run(summary_op)
+ sess.run(flush_op)
+ self.assertAllClose(self._feature_0_boundaries, buckets[0].eval())
+ self.assertAllClose(self._feature_1_boundaries, buckets[1].eval())
+
+ self.assertAllClose(self._feature_0_quantiles, quantiles[0].eval())
+ self.assertAllClose(self._feature_1_quantiles, quantiles[1].eval())
+
+ def testBasicQuantileBucketsMultipleResources(self):
+ with self.test_session() as sess:
+ quantile_accumulator_handle_0 = self.create_resource("float_0", self.eps,
+ self.max_elements)
+ quantile_accumulator_handle_1 = self.create_resource("float_1", self.eps,
+ self.max_elements)
+ resources.initialize_resources(resources.shared_resources()).run()
+ summaries = boosted_trees_ops.make_quantile_summaries(
+ [self._feature_0, self._feature_1], self._example_weights,
+ epsilon=self.eps)
+ summary_op_0 = boosted_trees_ops.quantile_add_summaries(
+ quantile_accumulator_handle_0,
+ [summaries[0]])
+ summary_op_1 = boosted_trees_ops.quantile_add_summaries(
+ quantile_accumulator_handle_1,
+ [summaries[1]])
+ flush_op_0 = boosted_trees_ops.quantile_flush(
+ quantile_accumulator_handle_0, self.num_quantiles)
+ flush_op_1 = boosted_trees_ops.quantile_flush(
+ quantile_accumulator_handle_1, self.num_quantiles)
+ bucket_0 = boosted_trees_ops.get_bucket_boundaries(
+ quantile_accumulator_handle_0, num_features=1)
+ bucket_1 = boosted_trees_ops.get_bucket_boundaries(
+ quantile_accumulator_handle_1, num_features=1)
+ quantiles = boosted_trees_ops.boosted_trees_bucketize(
+ [self._feature_0, self._feature_1], bucket_0 + bucket_1)
+ sess.run([summary_op_0, summary_op_1])
+ sess.run([flush_op_0, flush_op_1])
+ self.assertAllClose(self._feature_0_boundaries, bucket_0[0].eval())
+ self.assertAllClose(self._feature_1_boundaries, bucket_1[0].eval())
+
+ self.assertAllClose(self._feature_0_quantiles, quantiles[0].eval())
+ self.assertAllClose(self._feature_1_quantiles, quantiles[1].eval())
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/ops/boosted_trees_ops.py b/tensorflow/python/ops/boosted_trees_ops.py
index f7cbfe0312..720f9f4d41 100644
--- a/tensorflow/python/ops/boosted_trees_ops.py
+++ b/tensorflow/python/ops/boosted_trees_ops.py
@@ -24,11 +24,17 @@ from tensorflow.python.ops import resources
# Re-exporting ops used by other modules.
# pylint: disable=unused-import
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_bucketize
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_calculate_best_gains_per_feature as calculate_best_gains_per_feature
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_center_bias as center_bias
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_create_quantile_stream_resource as create_quantile_stream_resource
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_example_debug_outputs as example_debug_outputs
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_make_quantile_summaries as make_quantile_summaries
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_make_stats_summary as make_stats_summary
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_predict as predict
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_add_summaries as quantile_add_summaries
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_flush as quantile_flush
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_get_bucket_boundaries as get_bucket_boundaries
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_training_predict as training_predict
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_update_ensemble as update_ensemble
# pylint: enable=unused-import