diff options
Diffstat (limited to 'tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc')
-rw-r--r-- | tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc | 104 |
1 files changed, 104 insertions, 0 deletions
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc new file mode 100644 index 0000000000..e960719614 --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc @@ -0,0 +1,104 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class BigtablePrefixKeyDatasetOp : public DatasetOpKernel { + public: + using DatasetOpKernel::DatasetOpKernel; + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + string prefix; + OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "prefix", &prefix)); + + BigtableTableResource* resource; + OP_REQUIRES_OK(ctx, + LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); + + *output = new Dataset(ctx, resource, std::move(prefix)); + } + + private: + class Dataset : public GraphDatasetBase { + public: + explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table, + string prefix) + : GraphDatasetBase(ctx), table_(table), prefix_(std::move(prefix)) { + table_->Ref(); + } + + ~Dataset() override { table_->Unref(); } + + std::unique_ptr<IteratorBase> MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr<IteratorBase>(new Iterator( + {this, strings::StrCat(prefix, "::BigtablePrefixKeyDataset")})); + } + + const DataTypeVector& output_dtypes() const override { + static DataTypeVector* dtypes = new DataTypeVector({DT_STRING}); + return *dtypes; + } + + const std::vector<PartialTensorShape>& output_shapes() const override { + static std::vector<PartialTensorShape>* shapes = + new std::vector<PartialTensorShape>({{}}); + return *shapes; + } + + string DebugString() const override { + return "BigtablePrefixKeyDatasetOp::Dataset"; + } + + BigtableTableResource* table() const { return table_; } + + private: + class Iterator : public BigtableReaderDatasetIterator<Dataset> { + public: + explicit Iterator(const Params& params) + : BigtableReaderDatasetIterator<Dataset>(params) {} + + ::google::cloud::bigtable::RowRange MakeRowRange() override { + return ::google::cloud::bigtable::RowRange::Prefix(dataset()->prefix_); + } + ::google::cloud::bigtable::Filter MakeFilter() override { + return ::google::cloud::bigtable::Filter::Chain( + ::google::cloud::bigtable::Filter::CellsRowLimit(1), + ::google::cloud::bigtable::Filter::StripValueTransformer()); + } + Status ParseRow(IteratorContext* ctx, + const ::google::cloud::bigtable::Row& row, + std::vector<Tensor>* out_tensors) override { + Tensor output_tensor(ctx->allocator({}), DT_STRING, {}); + output_tensor.scalar<string>()() = string(row.row_key()); + out_tensors->emplace_back(std::move(output_tensor)); + return Status::OK(); + } + }; + + BigtableTableResource* const table_; + const string prefix_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("BigtablePrefixKeyDataset").Device(DEVICE_CPU), + BigtablePrefixKeyDatasetOp); + +} // namespace +} // namespace tensorflow |