aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/bigtable
diff options
context:
space:
mode:
authorGravatar Brennan Saeta <saeta@google.com>2018-07-10 07:47:23 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-10 07:50:23 -0700
commit4125012a146fafbe029b7e06c586a3865e0edb71 (patch)
treeae329eb6279ae7a48c3b4554990f6edfb2634f82 /tensorflow/contrib/bigtable
parent70592a563e7d9bf116c21f87e2a535141e25a362 (diff)
[tf.data / Bigtable] Parallel scan Bigtable tables
In order to stream data from Cloud Bigtable at high speed, it's important to use multiple connections to stream simultaneously from multiple tablet servers. This change adds two new methods to the BigTable object to set up a dataset based on the SampleKeys method, and tf.contrib.data.parallel_interleave. Because the keys returned from SampleKeys is not guaranteed to be deterministic (in fact, it can change over time without any new data added to the table), the resulting datasets are not deterministic. (In order to further boost performance, we enable sloppy interleaving.) When comparing the table.scan_* methods vs the table.parallel_scan_* methods for a test workload (based on ImageNet), we see performance gains of over 15x, and over 10x compared to a reasonably tuned GCS input pipeline. PiperOrigin-RevId: 203945580
Diffstat (limited to 'tensorflow/contrib/bigtable')
-rw-r--r--tensorflow/contrib/bigtable/BUILD24
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.cc68
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h67
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_range_helpers_test.cc107
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc200
-rw-r--r--tensorflow/contrib/bigtable/ops/bigtable_ops.cc10
-rw-r--r--tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py118
-rw-r--r--tensorflow/contrib/bigtable/python/ops/bigtable_api.py288
8 files changed, 840 insertions, 42 deletions
diff --git a/tensorflow/contrib/bigtable/BUILD b/tensorflow/contrib/bigtable/BUILD
index a262cd4a49..71538e0770 100644
--- a/tensorflow/contrib/bigtable/BUILD
+++ b/tensorflow/contrib/bigtable/BUILD
@@ -31,6 +31,7 @@ tf_custom_op_py_library(
srcs_version = "PY2AND3",
deps = [
":bigtable_ops",
+ "//tensorflow/contrib/data/python/ops:interleave_ops",
"//tensorflow/contrib/util:util_py",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform",
@@ -45,6 +46,7 @@ KERNEL_FILES = [
"kernels/bigtable_prefix_key_dataset_op.cc",
"kernels/bigtable_range_key_dataset_op.cc",
"kernels/bigtable_sample_keys_dataset_op.cc",
+ "kernels/bigtable_sample_key_pairs_dataset_op.cc",
"kernels/bigtable_scan_dataset_op.cc",
]
@@ -55,6 +57,7 @@ tf_custom_op_library(
],
deps = [
":bigtable_lib_cc",
+ ":bigtable_range_helpers",
"@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client",
],
)
@@ -76,6 +79,7 @@ tf_kernel_library(
srcs = KERNEL_FILES,
deps = [
":bigtable_lib_cc",
+ ":bigtable_range_helpers",
"//tensorflow/core:framework_headers_lib",
"//third_party/eigen3",
"@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client",
@@ -95,6 +99,15 @@ cc_library(
)
cc_library(
+ name = "bigtable_range_helpers",
+ srcs = ["kernels/bigtable_range_helpers.cc"],
+ hdrs = ["kernels/bigtable_range_helpers.h"],
+ deps = [
+ "//tensorflow/core:framework_headers_lib",
+ ],
+)
+
+cc_library(
name = "bigtable_test_client",
srcs = ["kernels/test_kernels/bigtable_test_client.cc"],
hdrs = ["kernels/test_kernels/bigtable_test_client.h"],
@@ -118,6 +131,17 @@ tf_cc_test(
],
)
+tf_cc_test(
+ name = "bigtable_range_helpers_test",
+ size = "small",
+ srcs = ["kernels/bigtable_range_helpers_test.cc"],
+ deps = [
+ ":bigtable_range_helpers",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
tf_gen_op_wrapper_py(
name = "bigtable_test_ops",
deps = [":bigtable_test_ops_op_lib"],
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.cc b/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.cc
new file mode 100644
index 0000000000..51965f6214
--- /dev/null
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.cc
@@ -0,0 +1,68 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h"
+
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+namespace {
+
+string MakePrefixEndKey(const string& prefix) {
+ string end = prefix;
+ while (true) {
+ if (end.empty()) {
+ return end;
+ }
+ ++end[end.size() - 1];
+ if (end[end.size() - 1] == 0) {
+ // Handle wraparound case.
+ end = end.substr(0, end.size() - 1);
+ } else {
+ return end;
+ }
+ }
+}
+
+} // namespace
+
+/* static */ MultiModeKeyRange MultiModeKeyRange::FromPrefix(string prefix) {
+ string end = MakePrefixEndKey(prefix);
+ VLOG(1) << "Creating MultiModeKeyRange from Prefix: " << prefix
+ << ", with end key: " << end;
+ return MultiModeKeyRange(std::move(prefix), std::move(end));
+}
+
+/* static */ MultiModeKeyRange MultiModeKeyRange::FromRange(string begin,
+ string end) {
+ return MultiModeKeyRange(std::move(begin), std::move(end));
+}
+
+const string& MultiModeKeyRange::begin_key() const { return begin_; }
+
+const string& MultiModeKeyRange::end_key() const { return end_; }
+
+bool MultiModeKeyRange::contains_key(StringPiece key) const {
+ if (StringPiece(begin_) > key) {
+ return false;
+ }
+ if (StringPiece(end_) <= key && !end_.empty()) {
+ return false;
+ }
+ return true;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h b/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h
new file mode 100644
index 0000000000..44c628e366
--- /dev/null
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h
@@ -0,0 +1,67 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_RANGE_HELPERS_H_
+#define TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_RANGE_HELPERS_H_
+
+#include <string>
+
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+// Represents a continuous range of keys defined by either a prefix or a range.
+//
+// Ranges are represented as "half-open", where the beginning key is included
+// in the range, and the end_key is the first excluded key after the range.
+//
+// The range of keys can be specified either by a key prefix, or by an explicit
+// begin key and end key. All methods on this class are valid no matter which
+// way the range was specified.
+//
+// Example:
+// MultiModeKeyRange range = MultiModeKeyRange::FromPrefix("myPrefix");
+// if (range.contains_key("myPrefixedKey")) {
+// LOG(INFO) << "range from " << range.begin_key() << " to "
+// << range.end_key() << "contains \"myPrefixedKey\"";
+// }
+// if (!range.contains_key("randomKey")) {
+// LOG(INFO) << "range does not contain \"randomKey\"";
+// }
+// range = MultiModeKeyRange::FromRange("a_start_key", "z_end_key");
+class MultiModeKeyRange {
+ public:
+ static MultiModeKeyRange FromPrefix(string prefix);
+ static MultiModeKeyRange FromRange(string begin, string end);
+
+ // The first valid key in the range.
+ const string& begin_key() const;
+ // The first invalid key after the valid range.
+ const string& end_key() const;
+ // Returns true if the provided key is a part of the range, false otherwise.
+ bool contains_key(StringPiece key) const;
+
+ private:
+ MultiModeKeyRange(string begin, string end)
+ : begin_(std::move(begin)), end_(std::move(end)) {}
+
+ const string begin_;
+ const string end_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_RANGE_HELPERS_H_
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers_test.cc b/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers_test.cc
new file mode 100644
index 0000000000..1bfc547271
--- /dev/null
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers_test.cc
@@ -0,0 +1,107 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+TEST(MultiModeKeyRangeTest, SimplePrefix) {
+ MultiModeKeyRange r = MultiModeKeyRange::FromPrefix("prefix");
+ EXPECT_EQ("prefix", r.begin_key());
+ EXPECT_EQ("prefiy", r.end_key());
+ EXPECT_TRUE(r.contains_key("prefixed_key"));
+ EXPECT_FALSE(r.contains_key("not-prefixed-key"));
+ EXPECT_FALSE(r.contains_key("prefi"));
+ EXPECT_FALSE(r.contains_key("prefiy"));
+ EXPECT_FALSE(r.contains_key("early"));
+ EXPECT_FALSE(r.contains_key(""));
+}
+
+TEST(MultiModeKeyRangeTest, Range) {
+ MultiModeKeyRange r = MultiModeKeyRange::FromRange("a", "b");
+ EXPECT_EQ("a", r.begin_key());
+ EXPECT_EQ("b", r.end_key());
+ EXPECT_TRUE(r.contains_key("a"));
+ EXPECT_TRUE(r.contains_key("ab"));
+ EXPECT_FALSE(r.contains_key("b"));
+ EXPECT_FALSE(r.contains_key("bc"));
+ EXPECT_FALSE(r.contains_key("A"));
+ EXPECT_FALSE(r.contains_key("B"));
+ EXPECT_FALSE(r.contains_key(""));
+}
+
+TEST(MultiModeKeyRangeTest, InvertedRange) {
+ MultiModeKeyRange r = MultiModeKeyRange::FromRange("b", "a");
+ EXPECT_FALSE(r.contains_key("a"));
+ EXPECT_FALSE(r.contains_key("b"));
+ EXPECT_FALSE(r.contains_key(""));
+}
+
+TEST(MultiModeKeyRangeTest, EmptyPrefix) {
+ MultiModeKeyRange r = MultiModeKeyRange::FromPrefix("");
+ EXPECT_EQ("", r.begin_key());
+ EXPECT_EQ("", r.end_key());
+ EXPECT_TRUE(r.contains_key(""));
+ EXPECT_TRUE(r.contains_key("a"));
+ EXPECT_TRUE(r.contains_key("z"));
+ EXPECT_TRUE(r.contains_key("A"));
+ EXPECT_TRUE(r.contains_key("ZZZZZZ"));
+}
+
+TEST(MultiModeKeyRangeTest, HalfRange) {
+ MultiModeKeyRange r = MultiModeKeyRange::FromRange("start", "");
+ EXPECT_EQ("start", r.begin_key());
+ EXPECT_EQ("", r.end_key());
+ EXPECT_TRUE(r.contains_key("start"));
+ EXPECT_TRUE(r.contains_key("starting"));
+ EXPECT_TRUE(r.contains_key("z-end"));
+ EXPECT_FALSE(r.contains_key(""));
+ EXPECT_FALSE(r.contains_key("early"));
+}
+
+TEST(MultiModeKeyRangeTest, PrefixWrapAround) {
+ string prefix = "abc\xff";
+ MultiModeKeyRange r = MultiModeKeyRange::FromPrefix(prefix);
+ EXPECT_EQ(prefix, r.begin_key());
+ EXPECT_EQ("abd", r.end_key());
+
+ EXPECT_TRUE(r.contains_key("abc\xff\x07"));
+ EXPECT_TRUE(r.contains_key("abc\xff\x15"));
+ EXPECT_TRUE(r.contains_key("abc\xff\x61"));
+ EXPECT_TRUE(r.contains_key("abc\xff\xff"));
+ EXPECT_FALSE(r.contains_key("abc\0"));
+ EXPECT_FALSE(r.contains_key("abd"));
+}
+
+TEST(MultiModeKeyRangeTest, PrefixSignedWrapAround) {
+ string prefix = "abc\x7f";
+ MultiModeKeyRange r = MultiModeKeyRange::FromPrefix(prefix);
+ EXPECT_EQ(prefix, r.begin_key());
+ EXPECT_EQ("abc\x80", r.end_key());
+
+ EXPECT_TRUE(r.contains_key("abc\x7f\x07"));
+ EXPECT_TRUE(r.contains_key("abc\x7f\x15"));
+ EXPECT_TRUE(r.contains_key("abc\x7f\x61"));
+ EXPECT_TRUE(r.contains_key("abc\x7f\xff"));
+ EXPECT_FALSE(r.contains_key("abc\0"));
+ EXPECT_FALSE(r.contains_key("abc\x01"));
+ EXPECT_FALSE(r.contains_key("abd"));
+ EXPECT_FALSE(r.contains_key("ab\x80"));
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc
new file mode 100644
index 0000000000..a1a63a975a
--- /dev/null
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc
@@ -0,0 +1,200 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
+#include "tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+namespace {
+
+class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel {
+ public:
+ using DatasetOpKernel::DatasetOpKernel;
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
+ string prefix;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "prefix", &prefix));
+
+ string start_key;
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument<string>(ctx, "start_key", &start_key));
+ string end_key;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "end_key", &end_key));
+
+ BigtableTableResource* resource;
+ OP_REQUIRES_OK(ctx,
+ LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
+
+ OP_REQUIRES(ctx, prefix.empty() || start_key.empty(),
+ errors::InvalidArgument(
+ "Only one of prefix and start_key can be provided"));
+ if (!prefix.empty()) {
+ OP_REQUIRES(ctx, end_key.empty(),
+ errors::InvalidArgument(
+ "If prefix is specified, end_key must be empty."));
+ }
+
+ *output = new Dataset(ctx, resource, std::move(prefix),
+ std::move(start_key), std::move(end_key));
+ }
+
+ private:
+ class Dataset : public GraphDatasetBase {
+ public:
+ explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table,
+ string prefix, string start_key, string end_key)
+ : GraphDatasetBase(ctx),
+ table_(table),
+ key_range_(MakeMultiModeKeyRange(
+ std::move(prefix), std::move(start_key), std::move(end_key))) {
+ table_->Ref();
+ }
+
+ ~Dataset() override { table_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(new Iterator(
+ {this, strings::StrCat(prefix, "::BigtableSampleKeyPairsDataset")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ static DataTypeVector* dtypes =
+ new DataTypeVector({DT_STRING, DT_STRING});
+ return *dtypes;
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ static std::vector<PartialTensorShape>* shapes =
+ new std::vector<PartialTensorShape>({{}, {}});
+ return *shapes;
+ }
+
+ string DebugString() const override {
+ return "BigtableSampleKeyPairsDatasetOp::Dataset";
+ }
+
+ private:
+ static MultiModeKeyRange MakeMultiModeKeyRange(string prefix,
+ string start_key,
+ string end_key) {
+ if (!start_key.empty()) {
+ return MultiModeKeyRange::FromRange(std::move(start_key),
+ std::move(end_key));
+ }
+ return MultiModeKeyRange::FromPrefix(std::move(prefix));
+ }
+
+ BigtableTableResource& table() const { return *table_; }
+
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params) {}
+
+ // Computes split points (`keys_`) to use when scanning the table.
+ //
+ // Initialize first retrieves the sample keys from the table (`row_keys`),
+ // as these often form good split points within the table. We then iterate
+ // over them, and copy them to `keys_` if they fall within the requested
+ // range to scan (`dataset()->key_range_`). Because the requested range
+ // might start between elements of the sampled keys list, care is taken to
+ // ensure we don't accidentally miss any subsets of the requested range by
+ // including `begin_key()` and `end_key()` as appropriate.
+ Status Initialize(IteratorContext* ctx) override {
+ grpc::Status status;
+ std::vector<google::cloud::bigtable::RowKeySample> row_keys =
+ dataset()->table().table().SampleRows(status);
+ if (!status.ok()) {
+ return GrpcStatusToTfStatus(status);
+ }
+
+ for (size_t i = 0; i < row_keys.size(); ++i) {
+ string row_key(row_keys[i].row_key);
+ if (dataset()->key_range_.contains_key(row_key)) {
+ // First key: check to see if we need to add the begin_key.
+ if (keys_.empty() && dataset()->key_range_.begin_key() != row_key) {
+ keys_.push_back(dataset()->key_range_.begin_key());
+ }
+ keys_.push_back(std::move(row_key));
+ } else if (!keys_.empty()) {
+ // If !keys_.empty(), then we have found at least one element of
+ // `row_keys` that is within our requested range
+ // (`dataset()->key_range_`). Because `row_keys` is sorted, if we
+ // have found an element that's not within our key range, then we
+ // are after our requested range (ranges are contiguous) and can end
+ // iteration early.
+ break;
+ }
+ }
+
+ // Handle the case where we skip over the selected range entirely.
+ if (keys_.empty()) {
+ keys_.push_back(dataset()->key_range_.begin_key());
+ }
+
+ // Last key: check to see if we need to add the end_key.
+ if (keys_.back() != dataset()->key_range_.end_key()) {
+ keys_.push_back(dataset()->key_range_.end_key());
+ }
+ return Status::OK();
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ if (index_ > keys_.size() - 2) {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+
+ *end_of_sequence = false;
+ out_tensors->emplace_back(ctx->allocator({}), DT_STRING,
+ TensorShape({}));
+ out_tensors->back().scalar<string>()() = keys_[index_];
+
+ out_tensors->emplace_back(ctx->allocator({}), DT_STRING,
+ TensorShape({}));
+ out_tensors->back().scalar<string>()() = keys_[index_ + 1];
+ ++index_;
+
+ return Status::OK();
+ }
+
+ private:
+ mutex mu_;
+ size_t index_ GUARDED_BY(mu_) = 0;
+ // Note: we store the keys_ on the iterator instead of the dataset
+ // because we want to re-sample the row keys in case there have been
+ // tablet rebalancing operations since the dataset was created.
+ //
+ // Note: keys_ is readonly after Initialize, and thus does not need a
+ // guarding lock.
+ std::vector<string> keys_;
+ };
+
+ BigtableTableResource* const table_;
+ const MultiModeKeyRange key_range_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("BigtableSampleKeyPairsDataset").Device(DEVICE_CPU),
+ BigtableSampleKeyPairsDatasetOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/ops/bigtable_ops.cc b/tensorflow/contrib/bigtable/ops/bigtable_ops.cc
index 36a392f2a4..416b719e30 100644
--- a/tensorflow/contrib/bigtable/ops/bigtable_ops.cc
+++ b/tensorflow/contrib/bigtable/ops/bigtable_ops.cc
@@ -79,6 +79,16 @@ REGISTER_OP("BigtableSampleKeysDataset")
// stateful to inhibit constant folding.
.SetShapeFn(shape_inference::ScalarShape);
+REGISTER_OP("BigtableSampleKeyPairsDataset")
+ .Input("table: resource")
+ .Input("prefix: string")
+ .Input("start_key: string")
+ .Input("end_key: string")
+ .Output("handle: variant")
+ .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ // stateful to inhibit constant folding.
+ .SetShapeFn(shape_inference::ScalarShape);
+
// TODO(saeta): Support continuing despite bad data (e.g. empty string, or
// skip incomplete row.)
REGISTER_OP("BigtableScanDataset")
diff --git a/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py b/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py
index 028c861ca3..2f20064619 100644
--- a/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py
+++ b/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
from tensorflow.contrib import bigtable
from tensorflow.contrib.bigtable.ops import gen_bigtable_ops
from tensorflow.contrib.bigtable.ops import gen_bigtable_test_ops
+from tensorflow.contrib.bigtable.python.ops import bigtable_api
from tensorflow.contrib.util import loader
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
@@ -32,6 +33,10 @@ _bigtable_so = loader.load_op_library(
resource_loader.get_path_to_datafile("_bigtable_test.so"))
+def _ListOfTuplesOfStringsToBytes(values):
+ return [(compat.as_bytes(i[0]), compat.as_bytes(i[1])) for i in values]
+
+
class BigtableOpsTest(test.TestCase):
COMMON_ROW_KEYS = ["r1", "r2", "r3"]
COMMON_VALUES = ["v1", "v2", "v3"]
@@ -100,12 +105,18 @@ class BigtableOpsTest(test.TestCase):
def testScanPrefixListCol(self):
self.runScanTest(self._table.scan_prefix("r", cf1=["c1"]))
+ def testScanPrefixTupleCol(self):
+ self.runScanTest(self._table.scan_prefix("r", columns=("cf1", "c1")))
+
def testScanRangeStringCol(self):
self.runScanTest(self._table.scan_range("r1", "r4", cf1="c1"))
def testScanRangeListCol(self):
self.runScanTest(self._table.scan_range("r1", "r4", cf1=["c1"]))
+ def testScanRangeTupleCol(self):
+ self.runScanTest(self._table.scan_range("r1", "r4", columns=("cf1", "c1")))
+
def testLookup(self):
ds = self._table.keys_by_prefix_dataset("r")
ds = ds.apply(self._table.lookup_columns(cf1="c1"))
@@ -149,6 +160,113 @@ class BigtableOpsTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(n)
+ def runSampleKeyPairsTest(self, ds, expected_key_pairs):
+ itr = ds.make_initializable_iterator()
+ n = itr.get_next()
+ with self.test_session() as sess:
+ self._writeCommonValues(sess)
+ sess.run(itr.initializer)
+ for i, elems in enumerate(expected_key_pairs):
+ output = sess.run(n)
+ self.assertEqual(
+ compat.as_bytes(elems[0]), compat.as_bytes(output[0]),
+ "Unequal key pair (first element) at step %d; want: %s, got %s" %
+ (i, compat.as_bytes(elems[0]), compat.as_bytes(output[0])))
+ self.assertEqual(
+ compat.as_bytes(elems[1]), compat.as_bytes(output[1]),
+ "Unequal key pair (second element) at step %d; want: %s, got %s" %
+ (i, compat.as_bytes(elems[1]), compat.as_bytes(output[1])))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(n)
+
+ def testSampleKeyPairsSimplePrefix(self):
+ ds = bigtable_api._BigtableSampleKeyPairsDataset(
+ self._table, prefix="r", start="", end="")
+ expected_key_pairs = [("r", "r1"), ("r1", "r3"), ("r3", "s")]
+ self.runSampleKeyPairsTest(ds, expected_key_pairs)
+
+ def testSampleKeyPairsSimpleRange(self):
+ ds = bigtable_api._BigtableSampleKeyPairsDataset(
+ self._table, prefix="", start="r1", end="r3")
+ expected_key_pairs = [("r1", "r3")]
+ self.runSampleKeyPairsTest(ds, expected_key_pairs)
+
+ def testSampleKeyPairsSkipRangePrefix(self):
+ ds = bigtable_api._BigtableSampleKeyPairsDataset(
+ self._table, prefix="r2", start="", end="")
+ expected_key_pairs = [("r2", "r3")]
+ self.runSampleKeyPairsTest(ds, expected_key_pairs)
+
+ def testSampleKeyPairsSkipRangeRange(self):
+ ds = bigtable_api._BigtableSampleKeyPairsDataset(
+ self._table, prefix="", start="r2", end="r3")
+ expected_key_pairs = [("r2", "r3")]
+ self.runSampleKeyPairsTest(ds, expected_key_pairs)
+
+ def testSampleKeyPairsOffsetRanges(self):
+ ds = bigtable_api._BigtableSampleKeyPairsDataset(
+ self._table, prefix="", start="r2", end="r4")
+ expected_key_pairs = [("r2", "r3"), ("r3", "r4")]
+ self.runSampleKeyPairsTest(ds, expected_key_pairs)
+
+ def testSampleKeyPairEverything(self):
+ ds = bigtable_api._BigtableSampleKeyPairsDataset(
+ self._table, prefix="", start="", end="")
+ expected_key_pairs = [("", "r1"), ("r1", "r3"), ("r3", "")]
+ self.runSampleKeyPairsTest(ds, expected_key_pairs)
+
+ def testSampleKeyPairsPrefixAndStartKey(self):
+ ds = bigtable_api._BigtableSampleKeyPairsDataset(
+ self._table, prefix="r", start="r1", end="")
+ itr = ds.make_initializable_iterator()
+ with self.test_session() as sess:
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(itr.initializer)
+
+ def testSampleKeyPairsPrefixAndEndKey(self):
+ ds = bigtable_api._BigtableSampleKeyPairsDataset(
+ self._table, prefix="r", start="", end="r3")
+ itr = ds.make_initializable_iterator()
+ with self.test_session() as sess:
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(itr.initializer)
+
+ def testParallelScanPrefix(self):
+ ds = self._table.parallel_scan_prefix(prefix="r", cf1="c1")
+ itr = ds.make_initializable_iterator()
+ n = itr.get_next()
+ with self.test_session() as sess:
+ self._writeCommonValues(sess)
+ sess.run(itr.initializer)
+ expected_values = list(zip(self.COMMON_ROW_KEYS, self.COMMON_VALUES))
+ actual_values = []
+ for _ in range(len(expected_values)):
+ output = sess.run(n)
+ actual_values.append(output)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(n)
+ self.assertItemsEqual(
+ _ListOfTuplesOfStringsToBytes(expected_values),
+ _ListOfTuplesOfStringsToBytes(actual_values))
+
+ def testParallelScanRange(self):
+ ds = self._table.parallel_scan_range(start="r1", end="r4", cf1="c1")
+ itr = ds.make_initializable_iterator()
+ n = itr.get_next()
+ with self.test_session() as sess:
+ self._writeCommonValues(sess)
+ sess.run(itr.initializer)
+ expected_values = list(zip(self.COMMON_ROW_KEYS, self.COMMON_VALUES))
+ actual_values = []
+ for _ in range(len(expected_values)):
+ output = sess.run(n)
+ actual_values.append(output)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(n)
+ self.assertItemsEqual(
+ _ListOfTuplesOfStringsToBytes(expected_values),
+ _ListOfTuplesOfStringsToBytes(actual_values))
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py
index a7ec3a1142..9f73b7223c 100644
--- a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py
+++ b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py
@@ -28,8 +28,10 @@ from __future__ import division
from __future__ import print_function
from six import iteritems
+from six import string_types
from tensorflow.contrib.bigtable.ops import gen_bigtable_ops
+from tensorflow.contrib.data.python.ops import interleave_ops
from tensorflow.contrib.util import loader
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
@@ -251,9 +253,11 @@ class BigTable(object):
Note: only the latest value of a cell will be retrieved.
Args:
- prefix: The prefix all row keys muat match to be retrieved for prefix-
+ prefix: The prefix all row keys must match to be retrieved for prefix-
based scans.
- probability: Probabilistically sample rows.
+ probability: (Optional.) A float between 0 (exclusive) and 1 (inclusive).
+ A non-1 value indicates to probabilistically sample rows with the
+ provided probability.
columns: The columns to read. Note: most commonly, they are expressed as
kwargs. Use the columns value if you are using column families that are
reserved. The value of columns and kwargs are merged. Columns is a list
@@ -268,26 +272,8 @@ class BigTable(object):
Raises:
ValueError: If the configured probability is unexpected.
"""
- if probability is None:
- probability = 1.0
- if isinstance(probability, float) and (probability <= 0.0 or
- probability > 1.0):
- raise ValueError("probability must be in the range (0, 1].")
-
- normalized = columns
- if normalized is None:
- normalized = []
- if isinstance(normalized, tuple):
- normalized = list(normalized)
- for key, value in iteritems(kwargs):
- if key == "name":
- continue
- if isinstance(value, str):
- normalized.append((key, value))
- continue
- for col in value:
- normalized.append((key, col))
-
+ probability = _normalize_probability(probability)
+ normalized = _normalize_columns(columns, kwargs)
return _BigtableScanDataset(self, prefix, "", "", normalized, probability)
def scan_range(self, start, end, probability=None, columns=None, **kwargs):
@@ -314,7 +300,9 @@ class BigTable(object):
Args:
start: The start of the range when scanning by range.
end: (Optional.) The end of the range when scanning by range.
- probability: Probabilistically sample rows.
+ probability: (Optional.) A float between 0 (exclusive) and 1 (inclusive).
+ A non-1 value indicates to probabilistically sample rows with the
+ provided probability.
columns: The columns to read. Note: most commonly, they are expressed as
kwargs. Use the columns value if you are using column families that are
reserved. The value of columns and kwargs are merged. Columns is a list
@@ -329,27 +317,129 @@ class BigTable(object):
Raises:
ValueError: If the configured probability is unexpected.
"""
- if probability is None:
- probability = 1.0
- if isinstance(probability, float) and (probability <= 0.0 or
- probability > 1.0):
- raise ValueError("probability must be in the range (0, 1].")
+ probability = _normalize_probability(probability)
+ normalized = _normalize_columns(columns, kwargs)
+ return _BigtableScanDataset(self, "", start, end, normalized, probability)
- normalized = columns
- if normalized is None:
- normalized = []
- if isinstance(normalized, tuple):
- normalized = list(normalized)
- for key, value in iteritems(kwargs):
- if key == "name":
- continue
- if isinstance(value, str):
- normalized.append((key, value))
- continue
- for col in value:
- normalized.append((key, col))
+ def parallel_scan_prefix(self,
+ prefix,
+ num_parallel_scans=None,
+ probability=None,
+ columns=None,
+ **kwargs):
+ """Retrieves row (including values) from the Bigtable service at high speed.
- return _BigtableScanDataset(self, "", start, end, normalized, probability)
+ Rows with row-key prefixed by `prefix` will be retrieved. This method is
+ similar to `scan_prefix`, but by constrast performs multiple sub-scans in
+ parallel in order to achieve higher performance.
+
+ Note: The dataset produced by this method is not deterministic!
+
+ Specifying the columns to retrieve for each row is done by either using
+ kwargs or in the columns parameter. To retrieve values of the columns "c1",
+ and "c2" from the column family "cfa", and the value of the column "c3"
+ from column family "cfb", the following datasets (`ds1`, and `ds2`) are
+ equivalent:
+
+ ```
+ table = # ...
+ ds1 = table.parallel_scan_prefix("row_prefix", columns=[("cfa", "c1"),
+ ("cfa", "c2"),
+ ("cfb", "c3")])
+ ds2 = table.parallel_scan_prefix("row_prefix", cfa=["c1", "c2"], cfb="c3")
+ ```
+
+ Note: only the latest value of a cell will be retrieved.
+
+ Args:
+ prefix: The prefix all row keys must match to be retrieved for prefix-
+ based scans.
+ num_parallel_scans: (Optional.) The number of concurrent scans against the
+ Cloud Bigtable instance.
+ probability: (Optional.) A float between 0 (exclusive) and 1 (inclusive).
+ A non-1 value indicates to probabilistically sample rows with the
+ provided probability.
+ columns: The columns to read. Note: most commonly, they are expressed as
+ kwargs. Use the columns value if you are using column families that are
+ reserved. The value of columns and kwargs are merged. Columns is a list
+ of tuples of strings ("column_family", "column_qualifier").
+ **kwargs: The column families and columns to read. Keys are treated as
+ column_families, and values can be either lists of strings, or strings
+ that are treated as the column qualifier (column name).
+
+ Returns:
+ A @{tf.data.Dataset} returning the row keys and the cell contents.
+
+ Raises:
+ ValueError: If the configured probability is unexpected.
+ """
+ probability = _normalize_probability(probability)
+ normalized = _normalize_columns(columns, kwargs)
+ ds = _BigtableSampleKeyPairsDataset(self, prefix, "", "")
+ return self._make_parallel_scan_dataset(ds, num_parallel_scans, probability,
+ normalized)
+
+ def parallel_scan_range(self,
+ start,
+ end,
+ num_parallel_scans=None,
+ probability=None,
+ columns=None,
+ **kwargs):
+ """Retrieves rows (including values) from the Bigtable service.
+
+ Rows with row-keys between `start` and `end` will be retrieved. This method
+ is similar to `scan_range`, but by constrast performs multiple sub-scans in
+ parallel in order to achieve higher performance.
+
+ Note: The dataset produced by this method is not deterministic!
+
+ Specifying the columns to retrieve for each row is done by either using
+ kwargs or in the columns parameter. To retrieve values of the columns "c1",
+ and "c2" from the column family "cfa", and the value of the column "c3"
+ from column family "cfb", the following datasets (`ds1`, and `ds2`) are
+ equivalent:
+
+ ```
+ table = # ...
+ ds1 = table.parallel_scan_range("row_start",
+ "row_end",
+ columns=[("cfa", "c1"),
+ ("cfa", "c2"),
+ ("cfb", "c3")])
+ ds2 = table.parallel_scan_range("row_start", "row_end",
+ cfa=["c1", "c2"], cfb="c3")
+ ```
+
+ Note: only the latest value of a cell will be retrieved.
+
+ Args:
+ start: The start of the range when scanning by range.
+ end: (Optional.) The end of the range when scanning by range.
+ num_parallel_scans: (Optional.) The number of concurrent scans against the
+ Cloud Bigtable instance.
+ probability: (Optional.) A float between 0 (exclusive) and 1 (inclusive).
+ A non-1 value indicates to probabilistically sample rows with the
+ provided probability.
+ columns: The columns to read. Note: most commonly, they are expressed as
+ kwargs. Use the columns value if you are using column families that are
+ reserved. The value of columns and kwargs are merged. Columns is a list
+ of tuples of strings ("column_family", "column_qualifier").
+ **kwargs: The column families and columns to read. Keys are treated as
+ column_families, and values can be either lists of strings, or strings
+ that are treated as the column qualifier (column name).
+
+ Returns:
+ A @{tf.data.Dataset} returning the row keys and the cell contents.
+
+ Raises:
+ ValueError: If the configured probability is unexpected.
+ """
+ probability = _normalize_probability(probability)
+ normalized = _normalize_columns(columns, kwargs)
+ ds = _BigtableSampleKeyPairsDataset(self, "", start, end)
+ return self._make_parallel_scan_dataset(ds, num_parallel_scans, probability,
+ normalized)
def write(self, dataset, column_families, columns, timestamp=None):
"""Writes a dataset to the table.
@@ -396,6 +486,89 @@ class BigTable(object):
columns,
timestamp)
+ def _make_parallel_scan_dataset(self, ds, num_parallel_scans,
+ normalized_probability, normalized_columns):
+ """Builds a parallel dataset from a given range.
+
+ Args:
+ ds: A `_BigtableSampleKeyPairsDataset` returning ranges of keys to use.
+ num_parallel_scans: The number of concurrent parallel scans to use.
+ normalized_probability: A number between 0 and 1 for the keep probability.
+ normalized_columns: The column families and column qualifiers to retrieve.
+
+ Returns:
+ A @{tf.data.Dataset} representing the result of the parallel scan.
+ """
+ if num_parallel_scans is None:
+ num_parallel_scans = 50
+
+ ds = ds.shuffle(buffer_size=10000) # TODO(saeta): Make configurable.
+
+ def _interleave_fn(start, end):
+ return _BigtableScanDataset(
+ self,
+ prefix="",
+ start=start,
+ end=end,
+ normalized=normalized_columns,
+ probability=normalized_probability)
+
+ # Note prefetch_input_elements must be set in order to avoid rpc timeouts.
+ ds = ds.apply(
+ interleave_ops.parallel_interleave(
+ _interleave_fn,
+ cycle_length=num_parallel_scans,
+ sloppy=True,
+ prefetch_input_elements=1))
+ return ds
+
+
+def _normalize_probability(probability):
+ if probability is None:
+ probability = 1.0
+ if isinstance(probability, float) and (probability <= 0.0 or
+ probability > 1.0):
+ raise ValueError("probability must be in the range (0, 1].")
+ return probability
+
+
+def _normalize_columns(columns, provided_kwargs):
+ """Converts arguments (columns, and kwargs dict) to C++ representation.
+
+ Args:
+ columns: a datastructure containing the column families and qualifier to
+ retrieve. Valid types include (1) None, (2) list of tuples, (3) a tuple of
+ strings.
+ provided_kwargs: a dictionary containing the column families and qualifiers
+ to retrieve
+
+ Returns:
+ A list of pairs of column family+qualifier to retrieve.
+
+ Raises:
+ ValueError: If there are no cells to retrieve or the columns are in an
+ incorrect format.
+ """
+ normalized = columns
+ if normalized is None:
+ normalized = []
+ if isinstance(normalized, tuple):
+ if len(normalized) == 2:
+ normalized = [normalized]
+ else:
+ raise ValueError("columns was a tuple of inappropriate length")
+ for key, value in iteritems(provided_kwargs):
+ if key == "name":
+ continue
+ if isinstance(value, string_types):
+ normalized.append((key, value))
+ continue
+ for col in value:
+ normalized.append((key, col))
+ if not normalized:
+ raise ValueError("At least one column + column family must be specified.")
+ return normalized
+
class _BigtableKeyDataset(dataset_ops.Dataset):
"""_BigtableKeyDataset is an abstract class representing the keys of a table.
@@ -535,3 +708,34 @@ class _BigtableScanDataset(dataset_ops.Dataset):
column_families=self._column_families,
columns=self._columns,
probability=self._probability)
+
+
+class _BigtableSampleKeyPairsDataset(dataset_ops.Dataset):
+ """_BigtableKeyRangeDataset returns key pairs from the Bigtable.
+ """
+
+ def __init__(self, table, prefix, start, end):
+ self._table = table
+ self._prefix = prefix
+ self._start = start
+ self._end = end
+
+ @property
+ def output_classes(self):
+ return (ops.Tensor, ops.Tensor)
+
+ @property
+ def output_shapes(self):
+ return (tensor_shape.TensorShape([]), tensor_shape.TensorShape([]))
+
+ @property
+ def output_types(self):
+ return (dtypes.string, dtypes.string)
+
+ def _as_variant_tensor(self):
+ # pylint: disable=protected-access
+ return gen_bigtable_ops.bigtable_sample_key_pairs_dataset(
+ table=self._table._resource,
+ prefix=self._prefix,
+ start_key=self._start,
+ end_key=self._end)