From 1583f18633925dfe414f4e122eb7dafa9c07d56e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 21 Dec 2016 11:11:54 -0800 Subject: Adding BigQuery Reader OP. This changelist adds the necessary C++ and Python code to register a new op to partition and read BigQuery tables. Change: 142681535 --- tensorflow/core/BUILD | 12 + tensorflow/core/kernels/cloud/BUILD | 19 ++ .../core/kernels/cloud/bigquery_reader_ops.cc | 193 +++++++++++++++ .../core/kernels/cloud/bigquery_table_accessor.cc | 79 ++++-- .../core/kernels/cloud/bigquery_table_accessor.h | 45 ++-- .../kernels/cloud/bigquery_table_accessor_test.cc | 38 ++- tensorflow/core/ops/cloud_ops.cc | 88 +++++++ tensorflow/python/BUILD | 30 +++ tensorflow/python/__init__.py | 2 + tensorflow/python/ops/cloud/__init__.py | 22 ++ tensorflow/python/ops/cloud/bigquery_reader_ops.py | 150 +++++++++++ .../python/ops/cloud/bigquery_reader_ops_test.py | 274 +++++++++++++++++++++ 12 files changed, 900 insertions(+), 52 deletions(-) create mode 100644 tensorflow/core/kernels/cloud/bigquery_reader_ops.cc create mode 100644 tensorflow/core/ops/cloud_ops.cc create mode 100644 tensorflow/python/ops/cloud/__init__.py create mode 100644 tensorflow/python/ops/cloud/bigquery_reader_ops.py create mode 100644 tensorflow/python/ops/cloud/bigquery_reader_ops_test.py diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 95b111f23b..fd45006c95 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -450,12 +450,23 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "cloud_ops_op_lib", + srcs = ["ops/cloud_ops.cc"], + copts = tf_copts(), + linkstatic = 1, + visibility = ["//visibility:public"], + deps = [":framework"], + alwayslink = 1, +) + cc_library( name = "ops", visibility = ["//visibility:public"], deps = [ ":array_ops_op_lib", ":candidate_sampling_ops_op_lib", + ":cloud_ops_op_lib", ":control_flow_ops_op_lib", ":ctc_ops_op_lib", ":data_flow_ops_op_lib", @@ -602,6 +613,7 @@ cc_library( "//tensorflow/core/kernels:string", "//tensorflow/core/kernels:training_ops", "//tensorflow/core/kernels:word2vec_kernels", + "//tensorflow/core/kernels/cloud:bigquery_reader_ops", ] + if_not_windows([ "//tensorflow/core/kernels:fact_op", "//tensorflow/core/kernels:array_not_windows", diff --git a/tensorflow/core/kernels/cloud/BUILD b/tensorflow/core/kernels/cloud/BUILD index dfb4772b97..710cb5aa14 100644 --- a/tensorflow/core/kernels/cloud/BUILD +++ b/tensorflow/core/kernels/cloud/BUILD @@ -9,6 +9,7 @@ licenses(["notice"]) # Apache 2.0 load( "//tensorflow:tensorflow.bzl", + "tf_kernel_library", "tf_cc_test", ) @@ -30,6 +31,24 @@ filegroup( visibility = ["//tensorflow:__subpackages__"], ) +tf_kernel_library( + name = "bigquery_reader_ops", + srcs = [ + "bigquery_reader_ops.cc", + ], + visibility = ["//visibility:public"], + deps = [ + ":bigquery_table_accessor", + ":bigquery_table_partition_proto_cc", + "//tensorflow/core:cloud_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/kernels:reader_base", + ], +) + cc_library( name = "bigquery_table_accessor", srcs = [ diff --git a/tensorflow/core/kernels/cloud/bigquery_reader_ops.cc b/tensorflow/core/kernels/cloud/bigquery_reader_ops.cc new file mode 100644 index 0000000000..a3b026e2a1 --- /dev/null +++ b/tensorflow/core/kernels/cloud/bigquery_reader_ops.cc @@ -0,0 +1,193 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/core/example/example.pb.h" +#include "tensorflow/core/framework/reader_op_kernel.h" +#include "tensorflow/core/kernels/cloud/bigquery_table_accessor.h" +#include "tensorflow/core/kernels/cloud/bigquery_table_partition.pb.h" +#include "tensorflow/core/kernels/reader_base.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/math/math_util.h" +#include "tensorflow/core/lib/strings/numbers.h" + +namespace tensorflow { +namespace { + +constexpr int64 kDefaultRowBufferSize = 1000; // Number of rows to buffer. + +// This is a helper function for reading table attributes from context. +Status GetTableAttrs(OpKernelConstruction* context, string* project_id, + string* dataset_id, string* table_id, + int64* timestamp_millis, std::vector* columns, + string* test_end_point) { + TF_RETURN_IF_ERROR(context->GetAttr("project_id", project_id)); + TF_RETURN_IF_ERROR(context->GetAttr("dataset_id", dataset_id)); + TF_RETURN_IF_ERROR(context->GetAttr("table_id", table_id)); + TF_RETURN_IF_ERROR(context->GetAttr("timestamp_millis", timestamp_millis)); + TF_RETURN_IF_ERROR(context->GetAttr("columns", columns)); + TF_RETURN_IF_ERROR(context->GetAttr("test_end_point", test_end_point)); + return Status::OK(); +} + +} // namespace + +// Note that overriden methods with names ending in "Locked" are called by +// ReaderBase while a mutex is held. +// See comments for ReaderBase. +class BigQueryReader : public ReaderBase { + public: + explicit BigQueryReader(BigQueryTableAccessor* bigquery_table_accessor, + const string& node_name) + : ReaderBase(strings::StrCat("BigQueryReader '", node_name, "'")), + bigquery_table_accessor_(CHECK_NOTNULL(bigquery_table_accessor)) {} + + Status OnWorkStartedLocked() override { + BigQueryTablePartition partition; + if (!partition.ParseFromString(current_work())) { + return errors::InvalidArgument( + "Could not parse work as as valid partition."); + } + TF_RETURN_IF_ERROR(bigquery_table_accessor_->SetPartition(partition)); + return Status::OK(); + } + + Status ReadLocked(string* key, string* value, bool* produced, + bool* at_end) override { + *at_end = false; + *produced = false; + if (bigquery_table_accessor_->Done()) { + *at_end = true; + return Status::OK(); + } + + Example example; + int64 row_id; + TF_RETURN_IF_ERROR(bigquery_table_accessor_->ReadRow(&row_id, &example)); + + *key = std::to_string(row_id); + *value = example.SerializeAsString(); + *produced = true; + return Status::OK(); + } + + private: + // Not owned. + BigQueryTableAccessor* bigquery_table_accessor_; +}; + +class BigQueryReaderOp : public ReaderOpKernel { + public: + explicit BigQueryReaderOp(OpKernelConstruction* context) + : ReaderOpKernel(context) { + string table_id; + string project_id; + string dataset_id; + int64 timestamp_millis; + std::vector columns; + string test_end_point; + + OP_REQUIRES_OK(context, + GetTableAttrs(context, &project_id, &dataset_id, &table_id, + ×tamp_millis, &columns, &test_end_point)); + OP_REQUIRES_OK(context, + BigQueryTableAccessor::New( + project_id, dataset_id, table_id, timestamp_millis, + kDefaultRowBufferSize, test_end_point, columns, + BigQueryTablePartition(), &bigquery_table_accessor_)); + + SetReaderFactory([this]() { + return new BigQueryReader(bigquery_table_accessor_.get(), name()); + }); + } + + private: + std::unique_ptr bigquery_table_accessor_; +}; + +REGISTER_KERNEL_BUILDER(Name("BigQueryReader").Device(DEVICE_CPU), + BigQueryReaderOp); + +class GenerateBigQueryReaderPartitionsOp : public OpKernel { + public: + explicit GenerateBigQueryReaderPartitionsOp(OpKernelConstruction* context) + : OpKernel(context) { + string project_id; + string dataset_id; + string table_id; + int64 timestamp_millis; + std::vector columns; + string test_end_point; + + OP_REQUIRES_OK(context, + GetTableAttrs(context, &project_id, &dataset_id, &table_id, + ×tamp_millis, &columns, &test_end_point)); + OP_REQUIRES_OK(context, + BigQueryTableAccessor::New( + project_id, dataset_id, table_id, timestamp_millis, + kDefaultRowBufferSize, test_end_point, columns, + BigQueryTablePartition(), &bigquery_table_accessor_)); + OP_REQUIRES_OK(context, InitializeNumberOfPartitions(context)); + OP_REQUIRES_OK(context, InitializeTotalNumberOfRows()); + } + + void Compute(OpKernelContext* context) override { + const int64 partition_size = tensorflow::MathUtil::CeilOfRatio( + total_num_rows_, num_partitions_); + Tensor* output_tensor = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, TensorShape({num_partitions_}), + &output_tensor)); + + auto output = output_tensor->template flat(); + for (int64 i = 0; i < num_partitions_; ++i) { + BigQueryTablePartition partition; + partition.set_start_index(i * partition_size); + partition.set_end_index( + std::min(total_num_rows_, (i + 1) * partition_size) - 1); + output(i) = partition.SerializeAsString(); + } + } + + private: + Status InitializeTotalNumberOfRows() { + total_num_rows_ = bigquery_table_accessor_->total_num_rows(); + if (total_num_rows_ <= 0) { + return errors::FailedPrecondition("Invalid total number of rows."); + } + return Status::OK(); + } + + Status InitializeNumberOfPartitions(OpKernelConstruction* context) { + TF_RETURN_IF_ERROR(context->GetAttr("num_partitions", &num_partitions_)); + if (num_partitions_ <= 0) { + return errors::FailedPrecondition("Invalid number of partitions."); + } + return Status::OK(); + } + + int64 num_partitions_; + int64 total_num_rows_; + std::unique_ptr bigquery_table_accessor_; +}; + +REGISTER_KERNEL_BUILDER( + Name("GenerateBigQueryReaderPartitions").Device(DEVICE_CPU), + GenerateBigQueryReaderPartitionsOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cloud/bigquery_table_accessor.cc b/tensorflow/core/kernels/cloud/bigquery_table_accessor.cc index 293d47d975..3e9adfa372 100644 --- a/tensorflow/core/kernels/cloud/bigquery_table_accessor.cc +++ b/tensorflow/core/kernels/cloud/bigquery_table_accessor.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - #include "tensorflow/core/kernels/cloud/bigquery_table_accessor.h" #include "tensorflow/core/example/feature.pb.h" @@ -23,6 +22,15 @@ namespace tensorflow { namespace { constexpr size_t kBufferSize = 1024 * 1024; // In bytes. +const string kBigQueryEndPoint = "https://www.googleapis.com/bigquery/v2"; + +bool IsPartitionEmpty(const BigQueryTablePartition& partition) { + if (partition.end_index() != -1 && + partition.end_index() < partition.start_index()) { + return true; + } + return false; +} Status ParseJson(StringPiece json, Json::Value* result) { Json::Reader reader; @@ -92,17 +100,18 @@ Status ParseColumnType(const string& type, Status BigQueryTableAccessor::New( const string& project_id, const string& dataset_id, const string& table_id, - int64 timestamp_millis, int64 row_buffer_size, - const std::set& columns, const BigQueryTablePartition& partition, + int64 timestamp_millis, int64 row_buffer_size, const string& end_point, + const std::vector& columns, const BigQueryTablePartition& partition, std::unique_ptr* accessor) { return New(project_id, dataset_id, table_id, timestamp_millis, - row_buffer_size, columns, partition, nullptr, nullptr, accessor); + row_buffer_size, end_point, columns, partition, nullptr, nullptr, + accessor); } Status BigQueryTableAccessor::New( const string& project_id, const string& dataset_id, const string& table_id, - int64 timestamp_millis, int64 row_buffer_size, - const std::set& columns, const BigQueryTablePartition& partition, + int64 timestamp_millis, int64 row_buffer_size, const string& end_point, + const std::vector& columns, const BigQueryTablePartition& partition, std::unique_ptr auth_provider, std::unique_ptr http_request_factory, std::unique_ptr* accessor) { @@ -110,14 +119,16 @@ Status BigQueryTableAccessor::New( return errors::InvalidArgument( "Cannot use zero or negative timestamp to query a table."); } + const string& big_query_end_point = + end_point.empty() ? kBigQueryEndPoint : end_point; if (auth_provider == nullptr && http_request_factory == nullptr) { - accessor->reset(new BigQueryTableAccessor(project_id, dataset_id, table_id, - timestamp_millis, row_buffer_size, - columns, partition)); + accessor->reset(new BigQueryTableAccessor( + project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, + big_query_end_point, columns, partition)); } else { accessor->reset(new BigQueryTableAccessor( project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, - columns, partition, std::move(auth_provider), + big_query_end_point, columns, partition, std::move(auth_provider), std::move(http_request_factory))); } return (*accessor)->ReadSchema(); @@ -125,11 +136,11 @@ Status BigQueryTableAccessor::New( BigQueryTableAccessor::BigQueryTableAccessor( const string& project_id, const string& dataset_id, const string& table_id, - int64 timestamp_millis, int64 row_buffer_size, - const std::set& columns, const BigQueryTablePartition& partition) + int64 timestamp_millis, int64 row_buffer_size, const string& end_point, + const std::vector& columns, const BigQueryTablePartition& partition) : BigQueryTableAccessor( project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, - columns, partition, + end_point, columns, partition, std::unique_ptr(new GoogleAuthProvider()), std::unique_ptr(new HttpRequest::Factory())) { row_buffer_.resize(row_buffer_size); @@ -137,15 +148,16 @@ BigQueryTableAccessor::BigQueryTableAccessor( BigQueryTableAccessor::BigQueryTableAccessor( const string& project_id, const string& dataset_id, const string& table_id, - int64 timestamp_millis, int64 row_buffer_size, - const std::set& columns, const BigQueryTablePartition& partition, + int64 timestamp_millis, int64 row_buffer_size, const string& end_point, + const std::vector& columns, const BigQueryTablePartition& partition, std::unique_ptr auth_provider, std::unique_ptr http_request_factory) : project_id_(project_id), dataset_id_(dataset_id), table_id_(table_id), timestamp_millis_(timestamp_millis), - columns_(columns), + columns_(columns.begin(), columns.end()), + bigquery_end_point_(end_point), partition_(partition), auth_provider_(std::move(auth_provider)), http_request_factory_(std::move(http_request_factory)) { @@ -153,10 +165,14 @@ BigQueryTableAccessor::BigQueryTableAccessor( Reset(); } -void BigQueryTableAccessor::SetPartition( +Status BigQueryTableAccessor::SetPartition( const BigQueryTablePartition& partition) { + if (partition.start_index() < 0) { + return errors::InvalidArgument("Start index cannot be negative."); + } partition_ = partition; Reset(); + return Status::OK(); } void BigQueryTableAccessor::Reset() { @@ -172,7 +188,8 @@ Status BigQueryTableAccessor::ReadRow(int64* row_id, Example* example) { // If the next row is already fetched and cached, return the row from the // buffer. Otherwise, fill up the row buffer from BigQuery and return a row. - if (next_row_in_buffer_ != -1 && next_row_in_buffer_ < row_buffer_.size()) { + if (next_row_in_buffer_ != -1 && + next_row_in_buffer_ < ComputeMaxResultsArg()) { *row_id = first_buffered_row_index_ + next_row_in_buffer_; *example = row_buffer_[next_row_in_buffer_]; next_row_in_buffer_++; @@ -190,12 +207,12 @@ Status BigQueryTableAccessor::ReadRow(int64* row_id, Example* example) { // we use the page token (which returns rows faster). if (!next_page_token_.empty()) { TF_RETURN_IF_ERROR(request->SetUri(strings::StrCat( - BigQueryUriPrefix(), "data?maxResults=", row_buffer_.size(), + BigQueryUriPrefix(), "data?maxResults=", ComputeMaxResultsArg(), "&pageToken=", request->EscapeString(next_page_token_)))); first_buffered_row_index_ += row_buffer_.size(); } else { TF_RETURN_IF_ERROR(request->SetUri(strings::StrCat( - BigQueryUriPrefix(), "data?maxResults=", row_buffer_.size(), + BigQueryUriPrefix(), "data?maxResults=", ComputeMaxResultsArg(), "&startIndex=", first_buffered_row_index_))); } TF_RETURN_IF_ERROR(request->AddAuthBearerHeader(auth_token)); @@ -222,6 +239,18 @@ Status BigQueryTableAccessor::ReadRow(int64* row_id, Example* example) { return Status::OK(); } +int64 BigQueryTableAccessor::ComputeMaxResultsArg() { + if (partition_.end_index() == -1) { + return row_buffer_.size(); + } + if (IsPartitionEmpty(partition_)) { + return 0; + } + return std::min(static_cast(row_buffer_.size()), + static_cast(partition_.end_index() - + partition_.start_index() + 1)); +} + Status BigQueryTableAccessor::ParseColumnValues( const Json::Value& value, const SchemaNode& root_schema_node, Example* example) { @@ -364,21 +393,17 @@ Status BigQueryTableAccessor::AppendValueToExample( string BigQueryTableAccessor::BigQueryTableAccessor::BigQueryUriPrefix() { HttpRequest request; - return strings::StrCat("https://www.googleapis.com/bigquery/v2/projects/", + return strings::StrCat(bigquery_end_point_, "/projects/", request.EscapeString(project_id_), "/datasets/", request.EscapeString(dataset_id_), "/tables/", request.EscapeString(table_id_), "/"); } -string BigQueryTableAccessor::FullTableName() { - return strings::StrCat(project_id_, ":", dataset_id_, ".", table_id_, "@", - timestamp_millis_); -} - bool BigQueryTableAccessor::Done() { return (total_num_rows_ <= first_buffered_row_index_ + next_row_in_buffer_) || + IsPartitionEmpty(partition_) || (partition_.end_index() != -1 && - partition_.end_index() <= + partition_.end_index() < first_buffered_row_index_ + next_row_in_buffer_); } diff --git a/tensorflow/core/kernels/cloud/bigquery_table_accessor.h b/tensorflow/core/kernels/cloud/bigquery_table_accessor.h index fafda9cdd6..33d1905b8a 100644 --- a/tensorflow/core/kernels/cloud/bigquery_table_accessor.h +++ b/tensorflow/core/kernels/cloud/bigquery_table_accessor.h @@ -55,16 +55,23 @@ class BigQueryTableAccessor { }; /// \brief Creates a new BigQueryTableAccessor object. + // + // We do not allow relative (negative or zero) snapshot times here since we + // want to have a consistent snapshot of the table for the lifetime of this + // object. + // Use end_point if you want to connect to a different end point than the + // official BigQuery end point. Otherwise send an empty string. static Status New(const string& project_id, const string& dataset_id, const string& table_id, int64 timestamp_millis, - int64 row_buffer_size, const std::set& columns, + int64 row_buffer_size, const string& end_point, + const std::vector& columns, const BigQueryTablePartition& partition, std::unique_ptr* accessor); /// \brief Starts reading a new partition. - void SetPartition(const BigQueryTablePartition& partition); + Status SetPartition(const BigQueryTablePartition& partition); - /// \brief Returns false if there are more rows available in the current + /// \brief Returns true if there are more rows available in the current /// partition. bool Done(); @@ -74,9 +81,11 @@ class BigQueryTableAccessor { /// in the BigQuery service. Status ReadRow(int64* row_id, Example* example); - /// \brief Returns total number of rows. + /// \brief Returns total number of rows in the table. int64 total_num_rows() { return total_num_rows_; } + virtual ~BigQueryTableAccessor() {} + private: friend class BigQueryTableAccessorTest; @@ -95,7 +104,8 @@ class BigQueryTableAccessor { /// these two variables. static Status New(const string& project_id, const string& dataset_id, const string& table_id, int64 timestamp_millis, - int64 row_buffer_size, const std::set& columns, + int64 row_buffer_size, const string& end_point, + const std::vector& columns, const BigQueryTablePartition& partition, std::unique_ptr auth_provider, std::unique_ptr http_request_factory, @@ -104,14 +114,16 @@ class BigQueryTableAccessor { /// \brief Constructs an object for a given table and partition. BigQueryTableAccessor(const string& project_id, const string& dataset_id, const string& table_id, int64 timestamp_millis, - int64 row_buffer_size, const std::set& columns, + int64 row_buffer_size, const string& end_point, + const std::vector& columns, const BigQueryTablePartition& partition); /// Used for unit testing. BigQueryTableAccessor( const string& project_id, const string& dataset_id, const string& table_id, int64 timestamp_millis, int64 row_buffer_size, - const std::set& columns, const BigQueryTablePartition& partition, + const string& end_point, const std::vector& columns, + const BigQueryTablePartition& partition, std::unique_ptr auth_provider, std::unique_ptr http_request_factory); @@ -132,7 +144,7 @@ class BigQueryTableAccessor { Status AppendValueToExample(const string& column_name, const Json::Value& column_value, const BigQueryTableAccessor::ColumnType type, - Example* ex); + Example* example); /// \brief Resets internal counters for reading a partition. void Reset(); @@ -140,25 +152,28 @@ class BigQueryTableAccessor { /// \brief Helper function that returns BigQuery http endpoint prefix. string BigQueryUriPrefix(); + /// \brief Computes the maxResults arg to send to BigQuery. + int64 ComputeMaxResultsArg(); + /// \brief Returns full name of the underlying table name. - string FullTableName(); + string FullTableName() { + return strings::StrCat(project_id_, ":", dataset_id_, ".", table_id_, "@", + timestamp_millis_); + } const string project_id_; const string dataset_id_; const string table_id_; // Snapshot timestamp. - // - // Indicates a snapshot of the table in milliseconds since the epoch. - // - // We do not allow relative (negative or zero) times here since we want to - // have a consistent snapshot of the table for the lifetime of this object. - // For more details, see 'Table Decorators' in BigQuery documentation. const int64 timestamp_millis_; // Columns that should be read. Empty means all columns. const std::set columns_; + // HTTP address of BigQuery end point to use. + const string bigquery_end_point_; + // Describes the portion of the table that we are currently accessing. BigQueryTablePartition partition_; diff --git a/tensorflow/core/kernels/cloud/bigquery_table_accessor_test.cc b/tensorflow/core/kernels/cloud/bigquery_table_accessor_test.cc index 306cc5a4e1..57a4b89251 100644 --- a/tensorflow/core/kernels/cloud/bigquery_table_accessor_test.cc +++ b/tensorflow/core/kernels/cloud/bigquery_table_accessor_test.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/core/platform/test.h" namespace tensorflow { - namespace { constexpr char kTestProject[] = "test-project"; @@ -69,10 +68,10 @@ class BigQueryTableAccessorTest : public ::testing::Test { Status CreateTableAccessor(const string& project_id, const string& dataset_id, const string& table_id, int64 timestamp_millis, int64 row_buffer_size, - const std::set& columns, + const std::vector& columns, const BigQueryTablePartition& partition) { return BigQueryTableAccessor::New( - project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, + project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, "", columns, partition, std::unique_ptr(new FakeAuthProvider), std::unique_ptr( new FakeHttpRequestFactory(&requests_)), @@ -197,7 +196,7 @@ TEST_F(BigQueryTableAccessorTest, ReadOneRowTest) { kTestRow)); BigQueryTablePartition partition; partition.set_start_index(2); - partition.set_end_index(3); + partition.set_end_index(2); TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1, {}, partition)); int64 row_id; @@ -227,7 +226,7 @@ TEST_F(BigQueryTableAccessorTest, ReadOneRowPartialTest) { kTestRow)); BigQueryTablePartition partition; partition.set_start_index(2); - partition.set_end_index(3); + partition.set_end_index(2); TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1, {"bool_field", "rec_field.float_field"}, partition)); @@ -258,7 +257,7 @@ TEST_F(BigQueryTableAccessorTest, ReadOneRowWithNullsTest) { kTestRowWithNulls)); BigQueryTablePartition partition; partition.set_start_index(2); - partition.set_end_index(3); + partition.set_end_index(2); TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1, {}, partition)); int64 row_id; @@ -288,7 +287,7 @@ TEST_F(BigQueryTableAccessorTest, BrokenRowTest) { kBrokenTestRow)); BigQueryTablePartition partition; partition.set_start_index(2); - partition.set_end_index(3); + partition.set_end_index(2); TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1, {}, partition)); int64 row_id; @@ -357,7 +356,7 @@ TEST_F(BigQueryTableAccessorTest, SwitchingPartitionsTest) { kSampleSchema)); requests_.emplace_back(new FakeHttpRequest( "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" - "datasets/test-dataset/tables/test-table/data?maxResults=2&startIndex=0\n" + "datasets/test-dataset/tables/test-table/data?maxResults=1&startIndex=0\n" "Auth Token: fake_token\n", kTestTwoRows)); requests_.emplace_back(new FakeHttpRequest( @@ -374,7 +373,7 @@ TEST_F(BigQueryTableAccessorTest, SwitchingPartitionsTest) { BigQueryTablePartition partition; partition.set_start_index(0); - partition.set_end_index(1); + partition.set_end_index(0); TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 2, {}, partition)); @@ -396,7 +395,7 @@ TEST_F(BigQueryTableAccessorTest, SwitchingPartitionsTest) { 1234); partition.set_start_index(0); - partition.set_end_index(2); + partition.set_end_index(1); accessor_->SetPartition(partition); TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example)); EXPECT_EQ(0, row_id); @@ -410,4 +409,23 @@ TEST_F(BigQueryTableAccessorTest, SwitchingPartitionsTest) { 2222); } +TEST_F(BigQueryTableAccessorTest, EmptyPartitionTest) { + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/\n" + "Auth Token: fake_token\n", + kSampleSchema)); + + BigQueryTablePartition partition; + partition.set_start_index(3); + partition.set_end_index(2); + TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1, + {}, partition)); + EXPECT_TRUE(accessor_->Done()); + + int64 row_id; + Example example; + EXPECT_TRUE(errors::IsOutOfRange(accessor_->ReadRow(&row_id, &example))); +} + } // namespace tensorflow diff --git a/tensorflow/core/ops/cloud_ops.cc b/tensorflow/core/ops/cloud_ops.cc new file mode 100644 index 0000000000..89f31a46ab --- /dev/null +++ b/tensorflow/core/ops/cloud_ops.cc @@ -0,0 +1,88 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/* This file registers all cloud ops. */ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +namespace tensorflow { + +using shape_inference::InferenceContext; + +REGISTER_OP("BigQueryReader") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Attr("project_id: string") + .Attr("dataset_id: string") + .Attr("table_id: string") + .Attr("columns: list(string)") + .Attr("timestamp_millis: int") + .Attr("test_end_point: string = ''") + .Output("reader_handle: Ref(string)") + .SetIsStateful() + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->Vector(2)); + return Status::OK(); + }) + .Doc(R"doc( +A Reader that outputs rows from a BigQuery table as tensorflow Examples. + +container: If non-empty, this reader is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this reader is named in the given bucket + with this shared_name. Otherwise, the node name is used instead. +project_id: GCP project ID. +dataset_id: BigQuery Dataset ID. +table_id: Table to read. +columns: List of columns to read. Leave empty to read all columns. +timestamp_millis: Table snapshot timestamp in millis since epoch. Relative +(negative or zero) snapshot times are not allowed. For more details, see +'Table Decorators' in BigQuery docs. +test_end_point: Do not use. For testing purposes only. +reader_handle: The handle to reference the Reader. +)doc"); + +REGISTER_OP("GenerateBigQueryReaderPartitions") + .Attr("project_id: string") + .Attr("dataset_id: string") + .Attr("table_id: string") + .Attr("columns: list(string)") + .Attr("timestamp_millis: int") + .Attr("num_partitions: int") + .Attr("test_end_point: string = ''") + .Output("partitions: string") + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->Vector(InferenceContext::kUnknownDim)); + return Status::OK(); + }) + .Doc(R"doc( +Generates serialized partition messages suitable for batch reads. + +This op should not be used directly by clients. Instead, the +bigquery_reader_ops.py file defines a clean interface to the reader. + +project_id: GCP project ID. +dataset_id: BigQuery Dataset ID. +table_id: Table to read. +columns: List of columns to read. Leave empty to read all columns. +timestamp_millis: Table snapshot timestamp in millis since epoch. Relative +(negative or zero) snapshot times are not allowed. For more details, see +'Table Decorators' in BigQuery docs. +num_partitions: Number of partitions to split the table into. +test_end_point: Do not use. For testing purposes only. +partitions: Serialized table partitions. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index c4162b6e22..d5f6ecc78c 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -35,6 +35,7 @@ py_library( ":check_ops", ":client", ":client_testlib", + ":cloud_ops", ":confusion_matrix", ":control_flow_ops", ":errors", @@ -789,6 +790,11 @@ tf_gen_op_wrapper_private_py( require_shape_functions = True, ) +tf_gen_op_wrapper_private_py( + name = "cloud_ops_gen", + require_shape_functions = True, +) + tf_gen_op_wrapper_private_py( name = "control_flow_ops_gen", require_shape_functions = True, @@ -1431,6 +1437,19 @@ py_library( ], ) +py_library( + name = "cloud_ops", + srcs = [ + "ops/cloud/__init__.py", + "ops/cloud/bigquery_reader_ops.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":cloud_ops_gen", + ":framework", + ], +) + py_library( name = "script_ops", srcs = ["ops/script_ops.py"], @@ -2028,6 +2047,17 @@ cuda_py_test( ], ) +tf_py_test( + name = "bigquery_reader_ops_test", + size = "small", + srcs = ["ops/cloud/bigquery_reader_ops_test.py"], + additional_deps = [ + ":cloud_ops", + "//tensorflow:tensorflow_py", + "//tensorflow/python:util", + ], +) + py_library( name = "training", srcs = glob( diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py index 70ea38ffb2..6626a80149 100644 --- a/tensorflow/python/__init__.py +++ b/tensorflow/python/__init__.py @@ -83,6 +83,7 @@ from tensorflow.python.ops.standard_ops import * # Bring in subpackages. from tensorflow.python.layers import layers +from tensorflow.python.ops import cloud from tensorflow.python.ops import metrics from tensorflow.python.ops import nn from tensorflow.python.ops import sdca_ops as sdca @@ -213,6 +214,7 @@ _allowed_symbols.extend([ _allowed_symbols.extend([ 'app', 'compat', + 'cloud', 'errors', 'flags', 'gfile', diff --git a/tensorflow/python/ops/cloud/__init__.py b/tensorflow/python/ops/cloud/__init__.py new file mode 100644 index 0000000000..536d911a48 --- /dev/null +++ b/tensorflow/python/ops/cloud/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Import cloud ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=wildcard-import +from tensorflow.python.ops.cloud.bigquery_reader_ops import * diff --git a/tensorflow/python/ops/cloud/bigquery_reader_ops.py b/tensorflow/python/ops/cloud/bigquery_reader_ops.py new file mode 100644 index 0000000000..dbdc3fd7a2 --- /dev/null +++ b/tensorflow/python/ops/cloud/bigquery_reader_ops.py @@ -0,0 +1,150 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""BigQuery reading support for TensorFlow.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_cloud_ops +from tensorflow.python.ops import io_ops + + +class BigQueryReader(io_ops.ReaderBase): + """A Reader that outputs keys and tf.Example values from a BigQuery table. + + Example use: + ```python + # Assume a BigQuery has the following schema, + # name STRING, + # age INT, + # state STRING + + # Create the parse_examples list of features. + features = dict( + name=tf.FixedLenFeature([1], tf.string), + age=tf.FixedLenFeature([1], tf.int32), + state=tf.FixedLenFeature([1], dtype=tf.string, default_value="UNK")) + + # Create a Reader. + reader = bigquery_reader_ops.BigQueryReader(project_id=PROJECT, + dataset_id=DATASET, + table_id=TABLE, + timestamp_millis=TIME, + num_partitions=NUM_PARTITIONS, + features=features) + + # Populate a queue with the BigQuery Table partitions. + queue = tf.training.string_input_producer(reader.partitions()) + + # Read and parse examples. + row_id, examples_serialized = reader.read(queue) + examples = tf.parse_example(examples_serialized, features=features) + + # Process the Tensors examples["name"], examples["age"], etc... + ``` + + Note that to create a reader a snapshot timestamp is necessary. This + will enable the reader to look at a consistent snapshot of the table. + For more information, see 'Table Decorators' in BigQuery docs. + + See ReaderBase for supported methods. + """ + + def __init__(self, + project_id, + dataset_id, + table_id, + timestamp_millis, + num_partitions, + features=None, + columns=None, + test_end_point=None, + name=None): + """Creates a BigQueryReader. + + Args: + project_id: GCP project ID. + dataset_id: BigQuery dataset ID. + table_id: BigQuery table ID. + timestamp_millis: timestamp to snapshot the table in milliseconds since + the epoch. Relative (negative or zero) snapshot times are not allowed. + For more details, see 'Table Decorators' in BigQuery docs. + num_partitions: Number of non-overlapping partitions to read from. + features: parse_example compatible dict from keys to `VarLenFeature` and + `FixedLenFeature` objects. Keys are read as columns from the db. + columns: list of columns to read, can be set iff features is None. + test_end_point: Used only for testing purposes (optional). + name: a name for the operation (optional). + + Raises: + TypeError: - If features is neither None nor a dict or + - If columns is is neither None nor a list or + - If both features and columns are None or set. + """ + if (features is None) == (columns is None): + raise TypeError("exactly one of features and columns must be set.") + + if features is not None: + if not isinstance(features, dict): + raise TypeError("features must be a dict.") + self._columns = list(features.keys()) + elif columns is not None: + if not isinstance(columns, list): + raise TypeError("columns must be a list.") + self._columns = columns + + self._project_id = project_id + self._dataset_id = dataset_id + self._table_id = table_id + self._timestamp_millis = timestamp_millis + self._num_partitions = num_partitions + self._test_end_point = test_end_point + + reader = gen_cloud_ops.big_query_reader( + name=name, + project_id=self._project_id, + dataset_id=self._dataset_id, + table_id=self._table_id, + timestamp_millis=self._timestamp_millis, + columns=self._columns, + test_end_point=self._test_end_point) + super(BigQueryReader, self).__init__(reader) + + def partitions(self, name=None): + """Returns serialized BigQueryTablePartition messages. + + These messages represent a non-overlapping division of a table for a + bulk read. + + Args: + name: a name for the operation (optional). + + Returns: + `1-D` string `Tensor` of serialized `BigQueryTablePartition` messages. + """ + return gen_cloud_ops.generate_big_query_reader_partitions( + name=name, + project_id=self._project_id, + dataset_id=self._dataset_id, + table_id=self._table_id, + timestamp_millis=self._timestamp_millis, + num_partitions=self._num_partitions, + test_end_point=self._test_end_point, + columns=self._columns) + + +ops.NotDifferentiable("BigQueryReader") diff --git a/tensorflow/python/ops/cloud/bigquery_reader_ops_test.py b/tensorflow/python/ops/cloud/bigquery_reader_ops_test.py new file mode 100644 index 0000000000..8e59985ff4 --- /dev/null +++ b/tensorflow/python/ops/cloud/bigquery_reader_ops_test.py @@ -0,0 +1,274 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for BigQueryReader Op.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import re +import threading + +from six.moves import SimpleHTTPServer +from six.moves import socketserver +import tensorflow as tf + +from tensorflow.core.example import example_pb2 +from tensorflow.core.framework import types_pb2 +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import compat + +_PROJECT = "test-project" +_DATASET = "test-dataset" +_TABLE = "test-table" +# List representation of the test rows in the 'test-table' in BigQuery. +# The schema for each row is: [int64, string, float]. +# The values for rows are generated such that some columns have null values. The +# general formula here is: +# - The int64 column is present in every row. +# - The string column is only avaiable in even rows. +# - The float column is only available in every third row. +_ROWS = [[0, "s_0", 0.1], [1, None, None], [2, "s_2", None], [3, None, 3.1], + [4, "s_4", None], [5, None, None], [6, "s_6", 6.1], [7, None, None], + [8, "s_8", None], [9, None, 9.1]] +# Schema for 'test-table'. +# The schema currently has three columns: int64, string, and float +_SCHEMA = { + "kind": "bigquery#table", + "id": "test-project:test-dataset.test-table", + "schema": { + "fields": [{ + "name": "int64_col", + "type": "INTEGER", + "mode": "NULLABLE" + }, { + "name": "string_col", + "type": "STRING", + "mode": "NULLABLE" + }, { + "name": "float_col", + "type": "FLOAT", + "mode": "NULLABLE" + }] + } +} + + +def _ConvertRowToExampleProto(row): + """Converts the input row to an Example proto. + + Args: + row: Input Row instance. + + Returns: + An Example proto initialized with row values. + """ + + example = example_pb2.Example() + example.features.feature["int64_col"].int64_list.value.append(row[0]) + if row[1] is not None: + example.features.feature["string_col"].bytes_list.value.append(compat.as_bytes(row[1])) + if row[2] is not None: + example.features.feature["float_col"].float_list.value.append(row[2]) + return example + + +class FakeBigQueryServer(threading.Thread): + """Fake http server to return schema and data for sample table.""" + + def __init__(self, address, port): + """Creates a FakeBigQueryServer. + + Args: + address: Server address + port: Server port. Pass 0 to automatically pick an empty port. + """ + threading.Thread.__init__(self) + self.handler = BigQueryRequestHandler + self.httpd = socketserver.TCPServer((address, port), self.handler) + + def run(self): + self.httpd.serve_forever() + + def shutdown(self): + self.httpd.shutdown() + self.httpd.socket.close() + + +class BigQueryRequestHandler(SimpleHTTPServer.SimpleHTTPRequestHandler): + """Responds to BigQuery HTTP requests. + + Attributes: + num_rows: num_rows in the underlying table served by this class. + """ + + num_rows = 0 + + def do_GET(self): + if "data?maxResults=" not in self.path: + # This is a schema request. + _SCHEMA["numRows"] = self.num_rows + response = json.dumps(_SCHEMA) + else: + # This is a data request. + # + # Extract max results and start index. + max_results = int(re.findall(r"maxResults=(\d+)", self.path)[0]) + start_index = int(re.findall(r"startIndex=(\d+)", self.path)[0]) + + # Send the rows as JSON. + rows = [] + for row in _ROWS[start_index:start_index + max_results]: + row_json = { + "f": [{ + "v": str(row[0]) + }, { + "v": str(row[1]) if row[1] is not None else None + }, { + "v": str(row[2]) if row[2] is not None else None + }] + } + rows.append(row_json) + response = json.dumps({ + "kind": "bigquery#table", + "id": "test-project:test-dataset.test-table", + "rows": rows + }) + self.send_response(200) + self.end_headers() + self.wfile.write(compat.as_bytes(response)) + + +def _SetUpQueue(reader): + """Sets up a queue for a reader.""" + queue = tf.FIFOQueue(8, [types_pb2.DT_STRING], shapes=()) + key, value = reader.read(queue) + queue.enqueue_many(reader.partitions()).run() + queue.close().run() + return key, value + + +class BigQueryReaderOpsTest(tf.test.TestCase): + + def setUp(self): + super(BigQueryReaderOpsTest, self).setUp() + self.server = FakeBigQueryServer("127.0.0.1", 0) + self.server.start() + logging.info("server address is %s:%s", self.server.httpd.server_address[0], + self.server.httpd.server_address[1]) + + def tearDown(self): + self.server.shutdown() + super(BigQueryReaderOpsTest, self).tearDown() + + def _ReadAndCheckRowsUsingFeatures(self, num_rows): + self.server.handler.num_rows = num_rows + + with self.test_session() as sess: + feature_configs = { + "int64_col": + tf.FixedLenFeature( + [1], dtype=tf.int64), + "string_col": + tf.FixedLenFeature( + [1], dtype=tf.string, default_value="s_default"), + } + reader = tf.cloud.BigQueryReader( + project_id=_PROJECT, + dataset_id=_DATASET, + table_id=_TABLE, + num_partitions=4, + features=feature_configs, + timestamp_millis=1, + test_end_point=("%s:%s" % (self.server.httpd.server_address[0], + self.server.httpd.server_address[1]))) + + key, value = _SetUpQueue(reader) + + seen_rows = [] + features = tf.parse_example(tf.reshape(value, [1]), feature_configs) + for _ in range(num_rows): + int_value, str_value = sess.run( + [features["int64_col"], features["string_col"]]) + + # Parse values returned from the session. + self.assertEqual(int_value.shape, (1, 1)) + self.assertEqual(str_value.shape, (1, 1)) + int64_col = int_value[0][0] + string_col = str_value[0][0] + seen_rows.append(int64_col) + + # Compare. + expected_row = _ROWS[int64_col] + self.assertEqual(int64_col, expected_row[0]) + self.assertEqual(compat.as_str(string_col), ("s_%d" % int64_col) if expected_row[1] + else "s_default") + + self.assertItemsEqual(seen_rows, range(num_rows)) + + with self.assertRaisesOpError("is closed and has insufficient elements " + "\\(requested 1, current size 0\\)"): + sess.run([key, value]) + + def testReadingSingleRowUsingFeatures(self): + self._ReadAndCheckRowsUsingFeatures(1) + + def testReadingMultipleRowsUsingFeatures(self): + self._ReadAndCheckRowsUsingFeatures(10) + + def testReadingMultipleRowsUsingColumns(self): + num_rows = 10 + self.server.handler.num_rows = num_rows + + with self.test_session() as sess: + reader = tf.cloud.BigQueryReader( + project_id=_PROJECT, + dataset_id=_DATASET, + table_id=_TABLE, + num_partitions=4, + columns=["int64_col", "float_col", "string_col"], + timestamp_millis=1, + test_end_point=("%s:%s" % (self.server.httpd.server_address[0], + self.server.httpd.server_address[1]))) + key, value = _SetUpQueue(reader) + seen_rows = [] + for row_index in range(num_rows): + returned_row_id, example_proto = sess.run([key, value]) + example = example_pb2.Example() + example.ParseFromString(example_proto) + self.assertIn("int64_col", example.features.feature) + feature = example.features.feature["int64_col"] + self.assertEqual(len(feature.int64_list.value), 1) + int64_col = feature.int64_list.value[0] + seen_rows.append(int64_col) + + # Create our expected Example. + expected_example = example_pb2.Example() + expected_example = _ConvertRowToExampleProto(_ROWS[int64_col]) + + # Compare. + self.assertProtoEquals(example, expected_example) + self.assertEqual(row_index, int(returned_row_id)) + + self.assertItemsEqual(seen_rows, range(num_rows)) + + with self.assertRaisesOpError("is closed and has insufficient elements " + "\\(requested 1, current size 0\\)"): + sess.run([key, value]) + + +if __name__ == "__main__": + tf.test.main() -- cgit v1.2.3