diff options
author | Brennan Saeta <saeta@google.com> | 2018-07-05 10:38:42 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-05 10:41:55 -0700 |
commit | 14ef3b67a32279643d468274adb3a3019b0dd9c2 (patch) | |
tree | 27b5cf2e7a8a77afdbccb0156ce400283d71c596 /tensorflow/contrib/bigtable | |
parent | baf41a35baddd6b637a7b9febf6acfa55c573168 (diff) |
[tf.data / Bigtable] Support sampling row keys.
When writing high performance input pipelines, you often need to read from multiple servers in parallel. Being able to sample row keys from a table allows one to easily construct high performance parallel input pipelines from Cloud Bigtable.
PiperOrigin-RevId: 203389017
Diffstat (limited to 'tensorflow/contrib/bigtable')
5 files changed, 169 insertions, 0 deletions
diff --git a/tensorflow/contrib/bigtable/BUILD b/tensorflow/contrib/bigtable/BUILD index 037645aa6e..7c4f26d791 100644 --- a/tensorflow/contrib/bigtable/BUILD +++ b/tensorflow/contrib/bigtable/BUILD @@ -44,6 +44,7 @@ KERNEL_FILES = [ "kernels/bigtable_lookup_dataset_op.cc", "kernels/bigtable_prefix_key_dataset_op.cc", "kernels/bigtable_range_key_dataset_op.cc", + "kernels/bigtable_sample_keys_dataset_op.cc", "kernels/bigtable_scan_dataset_op.cc", ] diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc new file mode 100644 index 0000000000..a5a47cfe2d --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc @@ -0,0 +1,113 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class BigtableSampleKeysDatasetOp : public DatasetOpKernel { + public: + using DatasetOpKernel::DatasetOpKernel; + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + BigtableTableResource* resource; + OP_REQUIRES_OK(ctx, + LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); + *output = new Dataset(ctx, resource); + } + + private: + class Dataset : public GraphDatasetBase { + public: + explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table) + : GraphDatasetBase(ctx), table_(table) { + table_->Ref(); + } + + ~Dataset() override { table_->Unref(); } + + std::unique_ptr<IteratorBase> MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr<IteratorBase>(new Iterator( + {this, strings::StrCat(prefix, "::BigtableSampleKeysDataset")})); + } + + const DataTypeVector& output_dtypes() const override { + static DataTypeVector* dtypes = new DataTypeVector({DT_STRING}); + return *dtypes; + } + + const std::vector<PartialTensorShape>& output_shapes() const override { + static std::vector<PartialTensorShape>* shapes = + new std::vector<PartialTensorShape>({{}}); + return *shapes; + } + + string DebugString() const override { + return "BigtableRangeKeyDatasetOp::Dataset"; + } + + BigtableTableResource* table() const { return table_; } + + private: + class Iterator : public DatasetIterator<Dataset> { + public: + explicit Iterator(const Params& params) + : DatasetIterator<Dataset>(params) {} + + Status Initialize(IteratorContext* ctx) override { + ::grpc::Status status; + row_keys_ = dataset()->table()->table().SampleRows(status); + if (!status.ok()) { + row_keys_.clear(); + return GrpcStatusToTfStatus(status); + } + return Status::OK(); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector<Tensor>* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + if (index_ < row_keys_.size()) { + out_tensors->emplace_back(ctx->allocator({}), DT_STRING, + TensorShape({})); + out_tensors->back().scalar<string>()() = + string(row_keys_[index_].row_key); + *end_of_sequence = false; + index_++; + } else { + *end_of_sequence = true; + } + return Status::OK(); + } + + private: + mutex mu_; + size_t index_ = 0; + std::vector<::google::cloud::bigtable::RowKeySample> row_keys_; + }; + + BigtableTableResource* const table_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("BigtableSampleKeysDataset").Device(DEVICE_CPU), + BigtableSampleKeysDatasetOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/ops/bigtable_ops.cc b/tensorflow/contrib/bigtable/ops/bigtable_ops.cc index c7ff012ec8..179963457c 100644 --- a/tensorflow/contrib/bigtable/ops/bigtable_ops.cc +++ b/tensorflow/contrib/bigtable/ops/bigtable_ops.cc @@ -71,6 +71,13 @@ REGISTER_OP("BigtableRangeKeyDataset") // stateful to inhibit constant folding. .SetShapeFn(shape_inference::ScalarShape); +REGISTER_OP("BigtableSampleKeysDataset") + .Input("table: resource") + .Output("handle: variant") + .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + // stateful to inhibit constant folding. + .SetShapeFn(shape_inference::ScalarShape); + // TODO(saeta): Support continuing despite bad data (e.g. empty string, or // skip incomplete row.) REGISTER_OP("BigtableScanDataset") diff --git a/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py b/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py index d33a66f2df..028c861ca3 100644 --- a/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py +++ b/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py @@ -23,6 +23,7 @@ from tensorflow.contrib.bigtable.ops import gen_bigtable_ops from tensorflow.contrib.bigtable.ops import gen_bigtable_test_ops from tensorflow.contrib.util import loader from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import errors from tensorflow.python.platform import resource_loader from tensorflow.python.platform import test from tensorflow.python.util import compat @@ -127,6 +128,27 @@ class BigtableOpsTest(test.TestCase): "Unequal values at step %d: want: %s, got: %s" % (i, compat.as_bytes(elem[1]), compat.as_bytes(output[1]))) + def testSampleKeys(self): + ds = self._table.sample_keys() + itr = ds.make_initializable_iterator() + n = itr.get_next() + expected_key = self.COMMON_ROW_KEYS[0] + with self.test_session() as sess: + self._writeCommonValues(sess) + sess.run(itr.initializer) + output = sess.run(n) + self.assertEqual( + compat.as_bytes(self.COMMON_ROW_KEYS[0]), compat.as_bytes(output), + "Unequal keys: want: %s, got: %s" % (compat.as_bytes( + self.COMMON_ROW_KEYS[0]), compat.as_bytes(output))) + output = sess.run(n) + self.assertEqual( + compat.as_bytes(self.COMMON_ROW_KEYS[2]), compat.as_bytes(output), + "Unequal keys: want: %s, got: %s" % (compat.as_bytes( + self.COMMON_ROW_KEYS[2]), compat.as_bytes(output))) + with self.assertRaises(errors.OutOfRangeError): + sess.run(n) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py index 39c58ba665..acf4d34e9d 100644 --- a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py +++ b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py @@ -205,6 +205,18 @@ class BigTable(object): """ return _BigtablePrefixKeyDataset(self, prefix) + def sample_keys(self): + """Retrieves a sampling of row keys from the Bigtable table. + + This dataset is most often used in conjunction with + @{tf.contrib.data.parallel_interleave} to construct a set of ranges for + scanning in parallel. + + Returns: + A @{tf.data.Dataset} returning string row keys. + """ + return _BigtableSampleKeysDataset(self) + def scan_prefix(self, prefix, probability=None, columns=None, **kwargs): """Retrieves row (including values) from the Bigtable service. @@ -429,6 +441,20 @@ class _BigtableRangeKeyDataset(_BigtableKeyDataset): end_key=self._end) +class _BigtableSampleKeysDataset(_BigtableKeyDataset): + """_BigtableSampleKeysDataset represents a sampling of row keys. + """ + + # TODO(saeta): Expose the data size offsets into the keys. + + def __init__(self, table): + super(_BigtableSampleKeysDataset, self).__init__(table) + + def _as_variant_tensor(self): + return gen_bigtable_ops.bigtable_sample_keys_dataset( + table=self._table._resource) # pylint: disable=protected_access + + class _BigtableLookupDataset(dataset_ops.Dataset): """_BigtableLookupDataset represents a dataset that retrieves values for keys. """ |