aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/boosted_trees/quantile_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/boosted_trees/quantile_ops_test.py')
-rw-r--r--tensorflow/python/kernel_tests/boosted_trees/quantile_ops_test.py140
1 files changed, 140 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/boosted_trees/quantile_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/quantile_ops_test.py
new file mode 100644
index 0000000000..c71b8df4ad
--- /dev/null
+++ b/tensorflow/python/kernel_tests/boosted_trees/quantile_ops_test.py
@@ -0,0 +1,140 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Test for checking quantile related ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import boosted_trees_ops
+from tensorflow.python.ops import resources
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_handle_op as resource_handle_op
+from tensorflow.python.ops.gen_boosted_trees_ops import is_boosted_trees_quantile_stream_resource_initialized as resource_initialized
+from tensorflow.python.platform import googletest
+
+
+class QuantileOpsTest(test_util.TensorFlowTestCase):
+
+ def create_resource(self, name, eps, max_elements, num_streams=1):
+ quantile_accumulator_handle = resource_handle_op(
+ container="", shared_name=name, name=name)
+ create_op = boosted_trees_ops.create_quantile_stream_resource(
+ quantile_accumulator_handle,
+ epsilon=eps,
+ max_elements=max_elements,
+ num_streams=num_streams)
+ is_initialized_op = resource_initialized(quantile_accumulator_handle)
+ resources.register_resource(quantile_accumulator_handle, create_op,
+ is_initialized_op)
+ return quantile_accumulator_handle
+
+ def setUp(self):
+ """Sets up the quantile ops test as follows.
+
+ Create a batch of 6 examples having 2 features
+ The data looks like this
+ | Instance | instance weights | Feature 0 | Feature 1
+ | 0 | 10 | 1.2 | 2.3
+ | 1 | 1 | 12.1 | 1.2
+ | 2 | 1 | 0.3 | 1.1
+ | 3 | 1 | 0.5 | 2.6
+ | 4 | 1 | 0.6 | 3.2
+ | 5 | 1 | 2.2 | 0.8
+ """
+
+ self._feature_0 = constant_op.constant(
+ [[1.2], [12.1], [0.3], [0.5], [0.6], [2.2]], dtype=dtypes.float32)
+ self._feature_1 = constant_op.constant(
+ [[2.3], [1.2], [1.1], [2.6], [3.2], [0.8]], dtype=dtypes.float32)
+ self._feature_0_boundaries = constant_op.constant(
+ [0.3, 0.6, 1.2, 12.1], dtype=dtypes.float32)
+ self._feature_1_boundaries = constant_op.constant(
+ [0.8, 1.2, 2.3, 3.2], dtype=dtypes.float32)
+ self._feature_0_quantiles = constant_op.constant(
+ [[2], [3], [0], [1], [1], [3]], dtype=dtypes.int32)
+ self._feature_1_quantiles = constant_op.constant(
+ [[2], [1], [1], [3], [3], [0]], dtype=dtypes.int32)
+
+ self._example_weights = constant_op.constant(
+ [10, 1, 1, 1, 1, 1], dtype=dtypes.float32)
+
+ self.eps = 0.01
+ self.max_elements = 1 << 16
+ self.num_quantiles = constant_op.constant(3, dtype=dtypes.int64)
+
+ def testBasicQuantileBucketsSingleResource(self):
+ with self.test_session() as sess:
+ quantile_accumulator_handle = self.create_resource("floats", self.eps,
+ self.max_elements, 2)
+ resources.initialize_resources(resources.shared_resources()).run()
+ summaries = boosted_trees_ops.make_quantile_summaries(
+ [self._feature_0, self._feature_1], self._example_weights,
+ epsilon=self.eps)
+ summary_op = boosted_trees_ops.quantile_add_summaries(
+ quantile_accumulator_handle, summaries)
+ flush_op = boosted_trees_ops.quantile_flush(
+ quantile_accumulator_handle, self.num_quantiles)
+ buckets = boosted_trees_ops.get_bucket_boundaries(
+ quantile_accumulator_handle, num_features=2)
+ quantiles = boosted_trees_ops.boosted_trees_bucketize(
+ [self._feature_0, self._feature_1], buckets)
+ sess.run(summary_op)
+ sess.run(flush_op)
+ self.assertAllClose(self._feature_0_boundaries, buckets[0].eval())
+ self.assertAllClose(self._feature_1_boundaries, buckets[1].eval())
+
+ self.assertAllClose(self._feature_0_quantiles, quantiles[0].eval())
+ self.assertAllClose(self._feature_1_quantiles, quantiles[1].eval())
+
+ def testBasicQuantileBucketsMultipleResources(self):
+ with self.test_session() as sess:
+ quantile_accumulator_handle_0 = self.create_resource("float_0", self.eps,
+ self.max_elements)
+ quantile_accumulator_handle_1 = self.create_resource("float_1", self.eps,
+ self.max_elements)
+ resources.initialize_resources(resources.shared_resources()).run()
+ summaries = boosted_trees_ops.make_quantile_summaries(
+ [self._feature_0, self._feature_1], self._example_weights,
+ epsilon=self.eps)
+ summary_op_0 = boosted_trees_ops.quantile_add_summaries(
+ quantile_accumulator_handle_0,
+ [summaries[0]])
+ summary_op_1 = boosted_trees_ops.quantile_add_summaries(
+ quantile_accumulator_handle_1,
+ [summaries[1]])
+ flush_op_0 = boosted_trees_ops.quantile_flush(
+ quantile_accumulator_handle_0, self.num_quantiles)
+ flush_op_1 = boosted_trees_ops.quantile_flush(
+ quantile_accumulator_handle_1, self.num_quantiles)
+ bucket_0 = boosted_trees_ops.get_bucket_boundaries(
+ quantile_accumulator_handle_0, num_features=1)
+ bucket_1 = boosted_trees_ops.get_bucket_boundaries(
+ quantile_accumulator_handle_1, num_features=1)
+ quantiles = boosted_trees_ops.boosted_trees_bucketize(
+ [self._feature_0, self._feature_1], bucket_0 + bucket_1)
+ sess.run([summary_op_0, summary_op_1])
+ sess.run([flush_op_0, flush_op_1])
+ self.assertAllClose(self._feature_0_boundaries, bucket_0[0].eval())
+ self.assertAllClose(self._feature_1_boundaries, bucket_1[0].eval())
+
+ self.assertAllClose(self._feature_0_quantiles, quantiles[0].eval())
+ self.assertAllClose(self._feature_1_quantiles, quantiles[1].eval())
+
+
+if __name__ == "__main__":
+ googletest.main()