diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-03-13 16:58:18 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-13 18:09:05 -0700 |
commit | 1b881b7c77bd1e664382785447a170de2b85f688 (patch) | |
tree | 2d3b8c455ef9b93cb55eb03ea91d050db598fa74 /tensorflow/contrib/cloud | |
parent | ee6f27b647fd51b11f9795042c4f6941c77d1c86 (diff) |
First version of BigQuery Reader.
Change: 150016997
Diffstat (limited to 'tensorflow/contrib/cloud')
-rw-r--r-- | tensorflow/contrib/cloud/BUILD | 73 | ||||
-rw-r--r-- | tensorflow/contrib/cloud/__init__.py | 28 | ||||
-rw-r--r-- | tensorflow/contrib/cloud/kernels/BUILD | 94 | ||||
-rw-r--r-- | tensorflow/contrib/cloud/kernels/bigquery_reader_ops.cc | 192 | ||||
-rw-r--r-- | tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc | 410 | ||||
-rw-r--r-- | tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h | 208 | ||||
-rw-r--r-- | tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test.cc | 513 | ||||
-rw-r--r-- | tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test_data.h | 404 | ||||
-rw-r--r-- | tensorflow/contrib/cloud/kernels/bigquery_table_partition.proto | 12 | ||||
-rw-r--r-- | tensorflow/contrib/cloud/ops/bigquery_reader_ops.cc | 88 | ||||
-rw-r--r-- | tensorflow/contrib/cloud/python/ops/bigquery_reader_ops.py | 150 | ||||
-rw-r--r-- | tensorflow/contrib/cloud/python/ops/bigquery_reader_ops_test.py | 287 |
12 files changed, 2459 insertions, 0 deletions
diff --git a/tensorflow/contrib/cloud/BUILD b/tensorflow/contrib/cloud/BUILD new file mode 100644 index 0000000000..840997223f --- /dev/null +++ b/tensorflow/contrib/cloud/BUILD @@ -0,0 +1,73 @@ +# Description: +# BigQueryReader implementation + +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # Apache 2.0 + +load( + "//tensorflow:tensorflow.bzl", + "tf_gen_op_libs", + "tf_gen_op_wrapper_py", + "tf_py_test", +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +tf_gen_op_libs( + op_lib_names = ["bigquery_reader_ops"], + deps = [ + "//tensorflow/core:lib", + ], +) + +tf_gen_op_wrapper_py( + name = "gen_bigquery_reader_ops", + out = "python/ops/gen_bigquery_reader_ops.py", + require_shape_functions = True, + deps = [":bigquery_reader_ops_op_lib"], +) + +py_library( + name = "cloud_py", + srcs = [ + "__init__.py", + "python/ops/bigquery_reader_ops.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":gen_bigquery_reader_ops", + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform", + ], +) + +tf_py_test( + name = "bigquery_reader_ops_test", + size = "small", + srcs = ["python/ops/bigquery_reader_ops_test.py"], + additional_deps = [ + ":bigquery_reader_ops_op_lib", + ":cloud_py", + "//tensorflow/contrib/cloud/kernels:bigquery_reader_ops", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:io_ops", + "//tensorflow/python:parsing_ops", + ], + tags = ["manual"], +) diff --git a/tensorflow/contrib/cloud/__init__.py b/tensorflow/contrib/cloud/__init__.py new file mode 100644 index 0000000000..8870264b95 --- /dev/null +++ b/tensorflow/contrib/cloud/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Module for cloud ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=line-too-long,wildcard-import +from tensorflow.contrib.cloud.python.ops.bigquery_reader_ops import * +# pylint: enable=line-too-long,wildcard-import + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = ['BigQueryReader'] +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/cloud/kernels/BUILD b/tensorflow/contrib/cloud/kernels/BUILD new file mode 100644 index 0000000000..2500f10c74 --- /dev/null +++ b/tensorflow/contrib/cloud/kernels/BUILD @@ -0,0 +1,94 @@ +# Description: +# BigQueryReader implementation + +package( + default_visibility = ["//visibility:private"], +) + +licenses(["notice"]) # Apache 2.0 + +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", + "tf_copts", + "tf_kernel_library", +) + +# For platform specific build config +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_proto_library", +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +tf_kernel_library( + name = "bigquery_reader_ops", + srcs = [ + "bigquery_reader_ops.cc", + ], + visibility = ["//tensorflow:__subpackages__"], + deps = [ + ":bigquery_table_accessor", + ":bigquery_table_partition_proto_cc", + "//tensorflow/contrib/cloud:bigquery_reader_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:reader_base", + ], +) + +cc_library( + name = "bigquery_table_accessor", + srcs = [ + "bigquery_table_accessor.cc", + ], + hdrs = [ + "bigquery_table_accessor.h", + ], + copts = tf_copts(), + linkstatic = 1, + deps = [ + ":bigquery_table_partition_proto_cc", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform/cloud:google_auth_provider", + "//tensorflow/core/platform/cloud:http_request", + ], + alwayslink = 1, +) + +tf_cc_test( + name = "bigquery_table_accessor_test", + size = "small", + srcs = [ + "bigquery_table_accessor_test.cc", + "bigquery_table_accessor_test_data.h", + ], + deps = [ + ":bigquery_table_accessor", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/platform/cloud:http_request_fake", + ], +) + +tf_proto_library( + name = "bigquery_table_partition_proto", + srcs = [ + "bigquery_table_partition.proto", + ], + cc_api_version = 2, +) diff --git a/tensorflow/contrib/cloud/kernels/bigquery_reader_ops.cc b/tensorflow/contrib/cloud/kernels/bigquery_reader_ops.cc new file mode 100644 index 0000000000..02a759eefd --- /dev/null +++ b/tensorflow/contrib/cloud/kernels/bigquery_reader_ops.cc @@ -0,0 +1,192 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <map> +#include <memory> +#include <set> + +#include "tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h" +#include "tensorflow/contrib/cloud/kernels/bigquery_table_partition.pb.h" +#include "tensorflow/core/framework/reader_base.h" +#include "tensorflow/core/framework/reader_op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/math/math_util.h" +#include "tensorflow/core/lib/strings/numbers.h" + +namespace tensorflow { +namespace { + +constexpr int64 kDefaultRowBufferSize = 1000; // Number of rows to buffer. + +// This is a helper function for reading table attributes from context. +Status GetTableAttrs(OpKernelConstruction* context, string* project_id, + string* dataset_id, string* table_id, + int64* timestamp_millis, std::vector<string>* columns, + string* test_end_point) { + TF_RETURN_IF_ERROR(context->GetAttr("project_id", project_id)); + TF_RETURN_IF_ERROR(context->GetAttr("dataset_id", dataset_id)); + TF_RETURN_IF_ERROR(context->GetAttr("table_id", table_id)); + TF_RETURN_IF_ERROR(context->GetAttr("timestamp_millis", timestamp_millis)); + TF_RETURN_IF_ERROR(context->GetAttr("columns", columns)); + TF_RETURN_IF_ERROR(context->GetAttr("test_end_point", test_end_point)); + return Status::OK(); +} + +} // namespace + +// Note that overriden methods with names ending in "Locked" are called by +// ReaderBase while a mutex is held. +// See comments for ReaderBase. +class BigQueryReader : public ReaderBase { + public: + explicit BigQueryReader(BigQueryTableAccessor* bigquery_table_accessor, + const string& node_name) + : ReaderBase(strings::StrCat("BigQueryReader '", node_name, "'")), + bigquery_table_accessor_(CHECK_NOTNULL(bigquery_table_accessor)) {} + + Status OnWorkStartedLocked() override { + BigQueryTablePartition partition; + if (!partition.ParseFromString(current_work())) { + return errors::InvalidArgument( + "Could not parse work as as valid partition."); + } + TF_RETURN_IF_ERROR(bigquery_table_accessor_->SetPartition(partition)); + return Status::OK(); + } + + Status ReadLocked(string* key, string* value, bool* produced, + bool* at_end) override { + *at_end = false; + *produced = false; + if (bigquery_table_accessor_->Done()) { + *at_end = true; + return Status::OK(); + } + + Example example; + int64 row_id; + TF_RETURN_IF_ERROR(bigquery_table_accessor_->ReadRow(&row_id, &example)); + + *key = std::to_string(row_id); + *value = example.SerializeAsString(); + *produced = true; + return Status::OK(); + } + + private: + // Not owned. + BigQueryTableAccessor* bigquery_table_accessor_; +}; + +class BigQueryReaderOp : public ReaderOpKernel { + public: + explicit BigQueryReaderOp(OpKernelConstruction* context) + : ReaderOpKernel(context) { + string table_id; + string project_id; + string dataset_id; + int64 timestamp_millis; + std::vector<string> columns; + string test_end_point; + + OP_REQUIRES_OK(context, + GetTableAttrs(context, &project_id, &dataset_id, &table_id, + ×tamp_millis, &columns, &test_end_point)); + OP_REQUIRES_OK(context, + BigQueryTableAccessor::New( + project_id, dataset_id, table_id, timestamp_millis, + kDefaultRowBufferSize, test_end_point, columns, + BigQueryTablePartition(), &bigquery_table_accessor_)); + + SetReaderFactory([this]() { + return new BigQueryReader(bigquery_table_accessor_.get(), name()); + }); + } + + private: + std::unique_ptr<BigQueryTableAccessor> bigquery_table_accessor_; +}; + +REGISTER_KERNEL_BUILDER(Name("BigQueryReader").Device(DEVICE_CPU), + BigQueryReaderOp); + +class GenerateBigQueryReaderPartitionsOp : public OpKernel { + public: + explicit GenerateBigQueryReaderPartitionsOp(OpKernelConstruction* context) + : OpKernel(context) { + string project_id; + string dataset_id; + string table_id; + int64 timestamp_millis; + std::vector<string> columns; + string test_end_point; + + OP_REQUIRES_OK(context, + GetTableAttrs(context, &project_id, &dataset_id, &table_id, + ×tamp_millis, &columns, &test_end_point)); + OP_REQUIRES_OK(context, + BigQueryTableAccessor::New( + project_id, dataset_id, table_id, timestamp_millis, + kDefaultRowBufferSize, test_end_point, columns, + BigQueryTablePartition(), &bigquery_table_accessor_)); + OP_REQUIRES_OK(context, InitializeNumberOfPartitions(context)); + OP_REQUIRES_OK(context, InitializeTotalNumberOfRows()); + } + + void Compute(OpKernelContext* context) override { + const int64 partition_size = tensorflow::MathUtil::CeilOfRatio<int64>( + total_num_rows_, num_partitions_); + Tensor* output_tensor = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, TensorShape({num_partitions_}), + &output_tensor)); + + auto output = output_tensor->template flat<string>(); + for (int64 i = 0; i < num_partitions_; ++i) { + BigQueryTablePartition partition; + partition.set_start_index(i * partition_size); + partition.set_end_index( + std::min(total_num_rows_, (i + 1) * partition_size) - 1); + output(i) = partition.SerializeAsString(); + } + } + + private: + Status InitializeTotalNumberOfRows() { + total_num_rows_ = bigquery_table_accessor_->total_num_rows(); + if (total_num_rows_ <= 0) { + return errors::FailedPrecondition("Invalid total number of rows."); + } + return Status::OK(); + } + + Status InitializeNumberOfPartitions(OpKernelConstruction* context) { + TF_RETURN_IF_ERROR(context->GetAttr("num_partitions", &num_partitions_)); + if (num_partitions_ <= 0) { + return errors::FailedPrecondition("Invalid number of partitions."); + } + return Status::OK(); + } + + int64 num_partitions_; + int64 total_num_rows_; + std::unique_ptr<BigQueryTableAccessor> bigquery_table_accessor_; +}; + +REGISTER_KERNEL_BUILDER( + Name("GenerateBigQueryReaderPartitions").Device(DEVICE_CPU), + GenerateBigQueryReaderPartitionsOp); + +} // namespace tensorflow diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc new file mode 100644 index 0000000000..5e95db55b6 --- /dev/null +++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc @@ -0,0 +1,410 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h" + +#include "tensorflow/core/example/feature.pb.h" +#include "tensorflow/core/lib/strings/numbers.h" + +namespace tensorflow { + +namespace { + +constexpr size_t kBufferSize = 1024 * 1024; // In bytes. +const string kBigQueryEndPoint = "https://www.googleapis.com/bigquery/v2"; + +bool IsPartitionEmpty(const BigQueryTablePartition& partition) { + if (partition.end_index() != -1 && + partition.end_index() < partition.start_index()) { + return true; + } + return false; +} + +Status ParseJson(StringPiece json, Json::Value* result) { + Json::Reader reader; + if (!reader.parse(json.ToString(), *result)) { + return errors::Internal("Couldn't parse JSON response from BigQuery."); + } + return Status::OK(); +} + +string ColumnTypeToString(BigQueryTableAccessor::ColumnType enum_type) { + switch (enum_type) { + case BigQueryTableAccessor::ColumnType::kRecord: + return "RECORD"; + case BigQueryTableAccessor::ColumnType::kString: + return "STRING"; + case BigQueryTableAccessor::ColumnType::kBytes: + return "BYTES"; + case BigQueryTableAccessor::ColumnType::kInteger: + return "INTEGER"; + case BigQueryTableAccessor::ColumnType::kFloat: + return "FLOAT"; + case BigQueryTableAccessor::ColumnType::kBoolean: + return "BOOLEAN"; + case BigQueryTableAccessor::ColumnType::kTimestamp: + return "TIMESTAMP"; + case BigQueryTableAccessor::ColumnType::kDate: + return "DATE"; + case BigQueryTableAccessor::ColumnType::kTime: + return "TIME"; + case BigQueryTableAccessor::ColumnType::kDatetime: + return "DATETIME"; + case BigQueryTableAccessor::ColumnType::kNone: + return "NONE"; + } +} + +Status ParseColumnType(const string& type, + BigQueryTableAccessor::ColumnType* enum_type) { + if (type == "RECORD") { + *enum_type = BigQueryTableAccessor::ColumnType::kRecord; + } else if (type == "STRING") { + *enum_type = BigQueryTableAccessor::ColumnType::kString; + } else if (type == "BYTES") { + *enum_type = BigQueryTableAccessor::ColumnType::kBytes; + } else if (type == "INTEGER") { + *enum_type = BigQueryTableAccessor::ColumnType::kInteger; + } else if (type == "FLOAT") { + *enum_type = BigQueryTableAccessor::ColumnType::kFloat; + } else if (type == "BOOLEAN") { + *enum_type = BigQueryTableAccessor::ColumnType::kBoolean; + } else if (type == "TIMESTAMP") { + *enum_type = BigQueryTableAccessor::ColumnType::kTimestamp; + } else if (type == "DATE") { + *enum_type = BigQueryTableAccessor::ColumnType::kDate; + } else if (type == "TIME") { + *enum_type = BigQueryTableAccessor::ColumnType::kTime; + } else if (type == "DATETIME") { + *enum_type = BigQueryTableAccessor::ColumnType::kDatetime; + } else { + return errors::Internal( + strings::StrCat("Could not parse column type ", type)); + } + return Status::OK(); +} + +} // namespace + +Status BigQueryTableAccessor::New( + const string& project_id, const string& dataset_id, const string& table_id, + int64 timestamp_millis, int64 row_buffer_size, const string& end_point, + const std::vector<string>& columns, const BigQueryTablePartition& partition, + std::unique_ptr<BigQueryTableAccessor>* accessor) { + return New(project_id, dataset_id, table_id, timestamp_millis, + row_buffer_size, end_point, columns, partition, nullptr, nullptr, + accessor); +} + +Status BigQueryTableAccessor::New( + const string& project_id, const string& dataset_id, const string& table_id, + int64 timestamp_millis, int64 row_buffer_size, const string& end_point, + const std::vector<string>& columns, const BigQueryTablePartition& partition, + std::unique_ptr<AuthProvider> auth_provider, + std::unique_ptr<HttpRequest::Factory> http_request_factory, + std::unique_ptr<BigQueryTableAccessor>* accessor) { + if (timestamp_millis <= 0) { + return errors::InvalidArgument( + "Cannot use zero or negative timestamp to query a table."); + } + const string& big_query_end_point = + end_point.empty() ? kBigQueryEndPoint : end_point; + if (auth_provider == nullptr && http_request_factory == nullptr) { + accessor->reset(new BigQueryTableAccessor( + project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, + big_query_end_point, columns, partition)); + } else { + accessor->reset(new BigQueryTableAccessor( + project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, + big_query_end_point, columns, partition, std::move(auth_provider), + std::move(http_request_factory))); + } + return (*accessor)->ReadSchema(); +} + +BigQueryTableAccessor::BigQueryTableAccessor( + const string& project_id, const string& dataset_id, const string& table_id, + int64 timestamp_millis, int64 row_buffer_size, const string& end_point, + const std::vector<string>& columns, const BigQueryTablePartition& partition) + : BigQueryTableAccessor( + project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, + end_point, columns, partition, + std::unique_ptr<AuthProvider>(new GoogleAuthProvider()), + std::unique_ptr<HttpRequest::Factory>(new HttpRequest::Factory())) { + row_buffer_.resize(row_buffer_size); +} + +BigQueryTableAccessor::BigQueryTableAccessor( + const string& project_id, const string& dataset_id, const string& table_id, + int64 timestamp_millis, int64 row_buffer_size, const string& end_point, + const std::vector<string>& columns, const BigQueryTablePartition& partition, + std::unique_ptr<AuthProvider> auth_provider, + std::unique_ptr<HttpRequest::Factory> http_request_factory) + : project_id_(project_id), + dataset_id_(dataset_id), + table_id_(table_id), + timestamp_millis_(timestamp_millis), + columns_(columns.begin(), columns.end()), + bigquery_end_point_(end_point), + partition_(partition), + auth_provider_(std::move(auth_provider)), + http_request_factory_(std::move(http_request_factory)) { + row_buffer_.resize(row_buffer_size); + Reset(); +} + +Status BigQueryTableAccessor::SetPartition( + const BigQueryTablePartition& partition) { + if (partition.start_index() < 0) { + return errors::InvalidArgument("Start index cannot be negative."); + } + partition_ = partition; + Reset(); + return Status::OK(); +} + +void BigQueryTableAccessor::Reset() { + first_buffered_row_index_ = partition_.start_index(); + next_row_in_buffer_ = -1; + next_page_token_ = ""; +} + +Status BigQueryTableAccessor::ReadRow(int64* row_id, Example* example) { + if (Done()) { + return errors::OutOfRange("Reached end of table ", FullTableName()); + } + + // If the next row is already fetched and cached, return the row from the + // buffer. Otherwise, fill up the row buffer from BigQuery and return a row. + if (next_row_in_buffer_ != -1 && + next_row_in_buffer_ < ComputeMaxResultsArg()) { + *row_id = first_buffered_row_index_ + next_row_in_buffer_; + *example = row_buffer_[next_row_in_buffer_]; + next_row_in_buffer_++; + } else { + string auth_token; + TF_RETURN_IF_ERROR( + AuthProvider::GetToken(auth_provider_.get(), &auth_token)); + + std::unique_ptr<HttpRequest> request(http_request_factory_->Create()); + std::vector<char> output_buffer; + output_buffer.reserve(kBufferSize); + TF_RETURN_IF_ERROR(request->Init()); + + // The first time that we access BigQuery there is no page token. After that + // we use the page token (which returns rows faster). + if (!next_page_token_.empty()) { + TF_RETURN_IF_ERROR(request->SetUri(strings::StrCat( + BigQueryUriPrefix(), "data?maxResults=", ComputeMaxResultsArg(), + "&pageToken=", request->EscapeString(next_page_token_)))); + first_buffered_row_index_ += row_buffer_.size(); + } else { + TF_RETURN_IF_ERROR(request->SetUri(strings::StrCat( + BigQueryUriPrefix(), "data?maxResults=", ComputeMaxResultsArg(), + "&startIndex=", first_buffered_row_index_))); + } + TF_RETURN_IF_ERROR(request->AddAuthBearerHeader(auth_token)); + TF_RETURN_IF_ERROR(request->SetResultBuffer(&output_buffer)); + TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when reading rows from ", + FullTableName()); + + // Parse the returned row. + StringPiece response_piece = + StringPiece(&output_buffer[0], output_buffer.size()); + Json::Value root; + TF_RETURN_IF_ERROR(ParseJson(response_piece, &root)); + for (unsigned int i = 0; i < root["rows"].size(); ++i) { + row_buffer_[i].Clear(); + TF_RETURN_IF_ERROR( + ParseColumnValues(root["rows"][i], schema_root_, &row_buffer_[i])); + } + + next_page_token_ = root["pageToken"].asString(); + *row_id = first_buffered_row_index_; + *example = row_buffer_[0]; + next_row_in_buffer_ = 1; + } + return Status::OK(); +} + +int64 BigQueryTableAccessor::ComputeMaxResultsArg() { + if (partition_.end_index() == -1) { + return row_buffer_.size(); + } + if (IsPartitionEmpty(partition_)) { + return 0; + } + return std::min(static_cast<int64>(row_buffer_.size()), + static_cast<int64>(partition_.end_index() - + partition_.start_index() + 1)); +} + +Status BigQueryTableAccessor::ParseColumnValues( + const Json::Value& value, const SchemaNode& root_schema_node, + Example* example) { + if (value.empty()) { + return Status::OK(); + } + if (value["f"].isNull()) { + return Status::OK(); + } + int value_index = 0; + for (const auto& schema_node : root_schema_node.schema_nodes) { + if (value["f"][value_index].isNull()) { + value_index++; + continue; + } + + if (schema_node.type == ColumnType::kRecord) { + TF_RETURN_IF_ERROR(ParseColumnValues(value["f"][value_index]["v"], + schema_node, example)); + } else { + // Append the column value only if user has requested the column. + if (columns_.empty() || + columns_.find(schema_node.name) != columns_.end()) { + TF_RETURN_IF_ERROR(AppendValueToExample(schema_node.name, + value["f"][value_index]["v"], + schema_node.type, example)); + } + } + value_index++; + } + return Status::OK(); +} + +Status BigQueryTableAccessor::ReadSchema() { + string auth_token; + TF_RETURN_IF_ERROR(AuthProvider::GetToken(auth_provider_.get(), &auth_token)); + + // Send a request to read the schema. + std::unique_ptr<HttpRequest> request(http_request_factory_->Create()); + std::vector<char> output_buffer; + output_buffer.reserve(kBufferSize); + TF_RETURN_IF_ERROR(request->Init()); + TF_RETURN_IF_ERROR(request->SetUri(BigQueryUriPrefix())); + TF_RETURN_IF_ERROR(request->AddAuthBearerHeader(auth_token)); + TF_RETURN_IF_ERROR(request->SetResultBuffer(&output_buffer)); + TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when reading schema for ", + FullTableName()); + + // Parse the schema. + StringPiece response_piece = + StringPiece(&output_buffer[0], output_buffer.size()); + + Json::Value root; + TF_RETURN_IF_ERROR(ParseJson(response_piece, &root)); + const auto& columns = root["schema"]["fields"]; + string column_name_prefix = ""; + schema_root_ = {"", ColumnType::kNone}; + TF_RETURN_IF_ERROR( + ExtractColumnType(columns, column_name_prefix, &schema_root_)); + if (root["numRows"].isNull()) { + return errors::Internal("Number of rows cannot be extracted for table ", + FullTableName()); + } + strings::safe_strto64(root["numRows"].asString().c_str(), &total_num_rows_); + return Status::OK(); +} + +Status BigQueryTableAccessor::ExtractColumnType( + const Json::Value& columns, const string& column_name_prefix, + SchemaNode* root) { + for (auto columns_it = columns.begin(); columns_it != columns.end(); + ++columns_it) { + if ((*columns_it)["mode"].asString() == "REPEATED") { + return errors::Unimplemented(strings::StrCat( + "Tables with repeated columns are not supported: ", FullTableName())); + } + ColumnType type; + const string current_column_name = strings::StrCat( + column_name_prefix, (*columns_it)["name"].asString().c_str()); + TF_RETURN_IF_ERROR( + ParseColumnType((*columns_it)["type"].asString().c_str(), &type)); + root->schema_nodes.emplace_back(current_column_name, type); + if (type == ColumnType::kRecord) { + const auto new_prefix = strings::StrCat(current_column_name, "."); + TF_RETURN_IF_ERROR(ExtractColumnType((*columns_it)["fields"], new_prefix, + &root->schema_nodes.back())); + } + } + return Status::OK(); +} + +Status BigQueryTableAccessor::AppendValueToExample( + const string& column_name, const Json::Value& column_value, + const BigQueryTableAccessor::ColumnType type, Example* example) { + if (column_value.isNull()) { + return Status::OK(); + } + auto& feature = + (*example->mutable_features()->mutable_feature())[column_name]; + + switch (type) { + case BigQueryTableAccessor::ColumnType::kNone: + case BigQueryTableAccessor::ColumnType::kRecord: + return errors::Unimplemented("Cannot append type to an example."); + case BigQueryTableAccessor::ColumnType::kTimestamp: + case BigQueryTableAccessor::ColumnType::kDate: + case BigQueryTableAccessor::ColumnType::kTime: + case BigQueryTableAccessor::ColumnType::kDatetime: + case BigQueryTableAccessor::ColumnType::kString: + case BigQueryTableAccessor::ColumnType::kBytes: + feature.mutable_bytes_list()->add_value(column_value.asString()); + break; + case BigQueryTableAccessor::ColumnType::kBoolean: + feature.mutable_int64_list()->add_value( + column_value.asString() == "false" ? 0 : 1); + break; + case BigQueryTableAccessor::ColumnType::kInteger: + int64 column_value_int64; + if (!strings::safe_strto64(column_value.asString().c_str(), + &column_value_int64)) { + return errors::Internal("Cannot convert value to integer ", + column_value.asString().c_str()); + } + feature.mutable_int64_list()->add_value(column_value_int64); + break; + case BigQueryTableAccessor::ColumnType::kFloat: + // BigQuery float is actually a double. + double column_value_double; + if (!strings::safe_strtod(column_value.asString().c_str(), + &column_value_double)) { + return errors::Internal("Cannot convert value to double: ", + column_value.asString().c_str()); + } + feature.mutable_float_list()->add_value( + static_cast<float>(column_value_double)); + break; + } + return Status::OK(); +} + +string BigQueryTableAccessor::BigQueryTableAccessor::BigQueryUriPrefix() { + HttpRequest request; + return strings::StrCat(bigquery_end_point_, "/projects/", + request.EscapeString(project_id_), "/datasets/", + request.EscapeString(dataset_id_), "/tables/", + request.EscapeString(table_id_), "/"); +} + +bool BigQueryTableAccessor::Done() { + return (total_num_rows_ <= first_buffered_row_index_ + next_row_in_buffer_) || + IsPartitionEmpty(partition_) || + (partition_.end_index() != -1 && + partition_.end_index() < + first_buffered_row_index_ + next_row_in_buffer_); +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h new file mode 100644 index 0000000000..1cd0482186 --- /dev/null +++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h @@ -0,0 +1,208 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_PARTITION_ACCESSOR_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_PARTITION_ACCESSOR_H_ + +#include <map> +#include <memory> +#include <vector> + +#include "tensorflow/contrib/cloud/kernels/bigquery_table_partition.pb.h" +#include "tensorflow/core/example/example.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/cloud/google_auth_provider.h" +#include "tensorflow/core/platform/cloud/http_request.h" + +namespace tensorflow { + +/// This class facilitates accessing BigQuery tables. +/// +/// Notes: +/// - Nested fields are not supported. +/// - BigQuery 'Record's are automatically flattened, +/// - BigQuery float type is a double but is converted to a C++ float in this +/// class. +/// - It is possible for a table snapshot to go out-of-scope in the BigQuery +/// service while accessing the table if a very old timestamp is used. For +/// exact details, see 'Table Decorators' in BigQuery docs. +class BigQueryTableAccessor { + public: + // Column types supported by BigQuery. + enum class ColumnType { + kString = 0, + kBytes, + kInteger, + kFloat, + kBoolean, + kTimestamp, + kDate, + kTime, + kDatetime, + kRecord, + kNone + }; + + /// \brief Creates a new BigQueryTableAccessor object. + // + // We do not allow relative (negative or zero) snapshot times here since we + // want to have a consistent snapshot of the table for the lifetime of this + // object. + // Use end_point if you want to connect to a different end point than the + // official BigQuery end point. Otherwise send an empty string. + static Status New(const string& project_id, const string& dataset_id, + const string& table_id, int64 timestamp_millis, + int64 row_buffer_size, const string& end_point, + const std::vector<string>& columns, + const BigQueryTablePartition& partition, + std::unique_ptr<BigQueryTableAccessor>* accessor); + + /// \brief Starts reading a new partition. + Status SetPartition(const BigQueryTablePartition& partition); + + /// \brief Returns true if there are more rows available in the current + /// partition. + bool Done(); + + /// \brief Returns a single row as example proto. + /// + /// This function will return an error if the table snapshot goes out of scope + /// in the BigQuery service. + Status ReadRow(int64* row_id, Example* example); + + /// \brief Returns total number of rows in the table. + int64 total_num_rows() { return total_num_rows_; } + + virtual ~BigQueryTableAccessor() {} + + private: + friend class BigQueryTableAccessorTest; + + // This struct encapsulates schema nodes for a BigQuery table. + struct SchemaNode { + SchemaNode() {} + SchemaNode(const string& name, ColumnType type) : name(name), type(type) {} + + string name; + ColumnType type; + std::vector<SchemaNode> schema_nodes; + }; + + /// If nullptr is passed for http_request_factory and auth_provider the + /// default production ones are used. This can be used by tests to override + /// these two variables. + static Status New(const string& project_id, const string& dataset_id, + const string& table_id, int64 timestamp_millis, + int64 row_buffer_size, const string& end_point, + const std::vector<string>& columns, + const BigQueryTablePartition& partition, + std::unique_ptr<AuthProvider> auth_provider, + std::unique_ptr<HttpRequest::Factory> http_request_factory, + std::unique_ptr<BigQueryTableAccessor>* accessor); + + /// \brief Constructs an object for a given table and partition. + BigQueryTableAccessor(const string& project_id, const string& dataset_id, + const string& table_id, int64 timestamp_millis, + int64 row_buffer_size, const string& end_point, + const std::vector<string>& columns, + const BigQueryTablePartition& partition); + + /// Used for unit testing. + BigQueryTableAccessor( + const string& project_id, const string& dataset_id, + const string& table_id, int64 timestamp_millis, int64 row_buffer_size, + const string& end_point, const std::vector<string>& columns, + const BigQueryTablePartition& partition, + std::unique_ptr<AuthProvider> auth_provider, + std::unique_ptr<HttpRequest::Factory> http_request_factory); + + /// \brief Parses column values for a given row. + Status ParseColumnValues(const Json::Value& value, + const SchemaNode& root_schema_node, + Example* example); + + /// \brief Reads the table schema and stores it. + Status ReadSchema(); + + /// \brief Extracts column type from a column in schema. + Status ExtractColumnType(const Json::Value& columns, + const string& column_name_prefix, SchemaNode* root); + + /// \brief Appends a single BigQuery column Value to 'example' for a given + /// column. + Status AppendValueToExample(const string& column_name, + const Json::Value& column_value, + const BigQueryTableAccessor::ColumnType type, + Example* example); + + /// \brief Resets internal counters for reading a partition. + void Reset(); + + /// \brief Helper function that returns BigQuery http endpoint prefix. + string BigQueryUriPrefix(); + + /// \brief Computes the maxResults arg to send to BigQuery. + int64 ComputeMaxResultsArg(); + + /// \brief Returns full name of the underlying table name. + string FullTableName() { + return strings::StrCat(project_id_, ":", dataset_id_, ".", table_id_, "@", + timestamp_millis_); + } + + const string project_id_; + const string dataset_id_; + const string table_id_; + + // Snapshot timestamp. + const int64 timestamp_millis_; + + // Columns that should be read. Empty means all columns. + const std::set<string> columns_; + + // HTTP address of BigQuery end point to use. + const string bigquery_end_point_; + + // Describes the portion of the table that we are currently accessing. + BigQueryTablePartition partition_; + + // Total number of rows in the underlying table. + int64 total_num_rows_ = 0; + + // Offset of the first row in the underlying row_buffer_. + int64 first_buffered_row_index_ = 0; + + // Offset of the next row in the row_buffer_. -1 indicates that this index + // is invalid. + int next_row_in_buffer_ = -1; + + // This buffer holds next rows to improve performance. Its size will be + // based on how much buffering was requested. + std::vector<Example> row_buffer_; + + // If next_page is set, it will used to read next batch of data. + string next_page_token_; + + // A tree representing the schema for the underlying table. + SchemaNode schema_root_; + + std::unique_ptr<AuthProvider> auth_provider_; + std::unique_ptr<HttpRequest::Factory> http_request_factory_; + + TF_DISALLOW_COPY_AND_ASSIGN(BigQueryTableAccessor); +}; + +} // namespace tensorflow +#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_PARTITION_ACCESSOR_H_ diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test.cc b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test.cc new file mode 100644 index 0000000000..9fb339864d --- /dev/null +++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test.cc @@ -0,0 +1,513 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h" +#include "tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test_data.h" +#include "tensorflow/core/example/feature.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/platform/cloud/http_request_fake.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +constexpr char kTestProject[] = "test-project"; +constexpr char kTestDataset[] = "test-dataset"; +constexpr char kTestTable[] = "test-table"; + +static bool HasSubstr(const string& base, const string& substr) { + bool ok = StringPiece(base).contains(substr); + EXPECT_TRUE(ok) << base << ", expected substring " << substr; + return ok; +} + +class FakeAuthProvider : public AuthProvider { + public: + Status GetToken(string* token) override { + *token = "fake_token"; + return Status::OK(); + } +}; + +static string DeterministicSerialization(const tensorflow::Example& example) { + const int size = example.ByteSize(); + string result(size, '\0'); + ::tensorflow::protobuf::io::ArrayOutputStream array_stream( + gtl::string_as_array(&result), size); + ::tensorflow::protobuf::io::CodedOutputStream output_stream(&array_stream); + + output_stream.SetSerializationDeterministic(true); + example.SerializeWithCachedSizes(&output_stream); + EXPECT_FALSE(output_stream.HadError()); + EXPECT_EQ(size, output_stream.ByteCount()); + return result; +} + +} // namespace + +class BigQueryTableAccessorTest : public ::testing::Test { + protected: + BigQueryTableAccessor::SchemaNode GetSchema() { + return accessor_->schema_root_; + } + + Status CreateTableAccessor(const string& project_id, const string& dataset_id, + const string& table_id, int64 timestamp_millis, + int64 row_buffer_size, + const std::vector<string>& columns, + const BigQueryTablePartition& partition) { + return BigQueryTableAccessor::New( + project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, "", + columns, partition, std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests_)), + &accessor_); + } + + std::vector<HttpRequest*> requests_; + std::unique_ptr<BigQueryTableAccessor> accessor_; +}; + +TEST_F(BigQueryTableAccessorTest, NegativeTimestamp) { + const auto status = + CreateTableAccessor(kTestProject, kTestDataset, kTestTable, -1, 3, {}, + BigQueryTablePartition()); + EXPECT_TRUE(errors::IsInvalidArgument(status)); +} + +TEST_F(BigQueryTableAccessorTest, ZeroTimestamp) { + const auto status = + CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 0, 3, {}, + BigQueryTablePartition()); + EXPECT_TRUE(errors::IsInvalidArgument(status)); +} + +TEST_F(BigQueryTableAccessorTest, RepeatedFieldNoAllowedTest) { + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/\n" + "Auth Token: fake_token\n", + R"({ + "kind": "bigquery#table", + "etag": "\"4zcX32ezvFoFzxHoG04qJqKZk6c/MTQ1Nzk3NTgwNzE4Mw\"", + "id": "test-project:test-dataset.test-table", + "schema": { + "fields": [ + { + "name": "int_field", + "type": "INTEGER", + "mode": "REPEATED" + }] + }, + "numRows": "10" + })")); + const auto status = + CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 3, {}, + BigQueryTablePartition()); + EXPECT_TRUE(errors::IsUnimplemented(status)); + EXPECT_TRUE(HasSubstr(status.error_message(), + "Tables with repeated columns are not supported")); +} + +TEST_F(BigQueryTableAccessorTest, ValidSchemaTest) { + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/\n" + "Auth Token: fake_token\n", + kSampleSchema)); + TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 3, + {}, BigQueryTablePartition())); + // Validate total number of rows. + EXPECT_EQ(4, accessor_->total_num_rows()); + + // Validate the schema. + const auto schema_root = GetSchema(); + EXPECT_EQ(schema_root.name, ""); + EXPECT_EQ(schema_root.type, BigQueryTableAccessor::ColumnType::kNone); + EXPECT_EQ(9, schema_root.schema_nodes.size()); + + EXPECT_EQ(schema_root.schema_nodes[0].name, "int_field"); + EXPECT_EQ(schema_root.schema_nodes[0].type, + BigQueryTableAccessor::ColumnType::kInteger); + + EXPECT_EQ(schema_root.schema_nodes[1].name, "str_field"); + EXPECT_EQ(schema_root.schema_nodes[1].type, + BigQueryTableAccessor::ColumnType::kString); + + EXPECT_EQ(1, schema_root.schema_nodes[2].schema_nodes.size()); + EXPECT_EQ(schema_root.schema_nodes[2].name, "rec_field"); + EXPECT_EQ(schema_root.schema_nodes[2].type, + BigQueryTableAccessor::ColumnType::kRecord); + + EXPECT_EQ(schema_root.schema_nodes[2].schema_nodes[0].name, + "rec_field.float_field"); + EXPECT_EQ(schema_root.schema_nodes[2].schema_nodes[0].type, + BigQueryTableAccessor::ColumnType::kFloat); + + EXPECT_EQ(schema_root.schema_nodes[3].name, "bool_field"); + EXPECT_EQ(schema_root.schema_nodes[3].type, + BigQueryTableAccessor::ColumnType::kBoolean); + + EXPECT_EQ(schema_root.schema_nodes[4].name, "bytes_field"); + EXPECT_EQ(schema_root.schema_nodes[4].type, + BigQueryTableAccessor::ColumnType::kBytes); + + EXPECT_EQ(schema_root.schema_nodes[5].name, "timestamp_field"); + EXPECT_EQ(schema_root.schema_nodes[5].type, + BigQueryTableAccessor::ColumnType::kTimestamp); + + EXPECT_EQ(schema_root.schema_nodes[6].name, "date_field"); + EXPECT_EQ(schema_root.schema_nodes[6].type, + BigQueryTableAccessor::ColumnType::kDate); + + EXPECT_EQ(schema_root.schema_nodes[7].name, "time_field"); + EXPECT_EQ(schema_root.schema_nodes[7].type, + BigQueryTableAccessor::ColumnType::kTime); + + EXPECT_EQ(schema_root.schema_nodes[8].name, "datetime_field"); + EXPECT_EQ(schema_root.schema_nodes[8].type, + BigQueryTableAccessor::ColumnType::kDatetime); +} + +TEST_F(BigQueryTableAccessorTest, ReadOneRowTest) { + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/\n" + "Auth Token: fake_token\n", + kSampleSchema)); + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/data?maxResults=1&startIndex=2\n" + "Auth Token: fake_token\n", + kTestRow)); + BigQueryTablePartition partition; + partition.set_start_index(2); + partition.set_end_index(2); + TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1, + {}, partition)); + int64 row_id; + Example example; + TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example)); + + // Validate returned result. + Example expected_example; + ASSERT_TRUE(protobuf::TextFormat::ParseFromString(kTestExampleProto, + &expected_example)); + EXPECT_EQ(DeterministicSerialization(expected_example), + DeterministicSerialization(example)); + EXPECT_EQ(row_id, 2); + EXPECT_TRUE(accessor_->Done()); +} + +TEST_F(BigQueryTableAccessorTest, ReadOneRowPartialTest) { + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/\n" + "Auth Token: fake_token\n", + kSampleSchema)); + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/data?maxResults=1&startIndex=2\n" + "Auth Token: fake_token\n", + kTestRow)); + BigQueryTablePartition partition; + partition.set_start_index(2); + partition.set_end_index(2); + TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1, + {"bool_field", "rec_field.float_field"}, + partition)); + int64 row_id; + Example example; + TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example)); + + // Validate returned result. + EXPECT_EQ(row_id, 2); + EXPECT_TRUE(accessor_->Done()); + Example expected_example; + ASSERT_TRUE(protobuf::TextFormat::ParseFromString(kTestPartialExampleProto, + &expected_example)); + EXPECT_EQ(DeterministicSerialization(expected_example), + DeterministicSerialization(example)); +} + +TEST_F(BigQueryTableAccessorTest, ReadOneRowWithNullsTest) { + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/\n" + "Auth Token: fake_token\n", + kSampleSchema)); + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/data?maxResults=1&startIndex=2\n" + "Auth Token: fake_token\n", + kTestRowWithNulls)); + BigQueryTablePartition partition; + partition.set_start_index(2); + partition.set_end_index(2); + TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1, + {}, partition)); + int64 row_id; + Example example; + TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example)); + + // Validate returned result. + Example expected_example; + ASSERT_TRUE(protobuf::TextFormat::ParseFromString(kTestExampleProtoWithNulls, + &expected_example)); + EXPECT_EQ(DeterministicSerialization(expected_example), + DeterministicSerialization(example)); + EXPECT_EQ(row_id, 2); + EXPECT_TRUE(accessor_->Done()); +} + +TEST_F(BigQueryTableAccessorTest, ReadOneRowTwoRecords) { + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/\n" + "Auth Token: fake_token\n", + kSampleSchemaTwoRecords)); + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/data?maxResults=1&startIndex=2\n" + "Auth Token: fake_token\n", + kTestRowWithTwoRecords)); + BigQueryTablePartition partition; + partition.set_start_index(2); + partition.set_end_index(2); + TF_EXPECT_OK(CreateTableAccessor( + kTestProject, kTestDataset, kTestTable, 1, 1, + {"rec_field2.bool_field", "rec_field1.float_field"}, partition)); + + int64 row_id; + Example example; + TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example)); + + // Validate returned result. + Example expected_example; + ASSERT_TRUE(protobuf::TextFormat::ParseFromString( + kTestExampleProtoWithTwoRecords, &expected_example)); + EXPECT_EQ(DeterministicSerialization(expected_example), + DeterministicSerialization(example)); + EXPECT_EQ(row_id, 2); + EXPECT_TRUE(accessor_->Done()); +} + +TEST_F(BigQueryTableAccessorTest, NonExistentColumns) { + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/\n" + "Auth Token: fake_token\n", + kSampleSchemaTwoRecords)); + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/data?maxResults=1&startIndex=2\n" + "Auth Token: fake_token\n", + kTestRowWithTwoRecords)); + BigQueryTablePartition partition; + partition.set_start_index(2); + partition.set_end_index(2); + TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1, + {"bool_field", "float_field"}, partition)); + int64 row_id; + Example example; + TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example)); + + // Validate returned result. + EXPECT_EQ(row_id, 2); + EXPECT_TRUE(accessor_->Done()); +} + +TEST_F(BigQueryTableAccessorTest, EmptyRow) { + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/\n" + "Auth Token: fake_token\n", + kSampleSchemaTwoRecords)); + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/data?maxResults=1&startIndex=2\n" + "Auth Token: fake_token\n", + kTestEmptyRow)); + BigQueryTablePartition partition; + partition.set_start_index(2); + partition.set_end_index(2); + TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1, + {}, partition)); + int64 row_id; + Example example; + TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example)); + + // Validate returned result. + EXPECT_EQ(row_id, 2); + EXPECT_TRUE(accessor_->Done()); +} + +TEST_F(BigQueryTableAccessorTest, BrokenRowTest) { + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/\n" + "Auth Token: fake_token\n", + kSampleSchema)); + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/data?maxResults=1&startIndex=2\n" + "Auth Token: fake_token\n", + kBrokenTestRow)); + BigQueryTablePartition partition; + partition.set_start_index(2); + partition.set_end_index(2); + TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1, + {}, partition)); + int64 row_id; + Example example; + const auto status = accessor_->ReadRow(&row_id, &example); + EXPECT_TRUE(errors::IsInternal(status)); + EXPECT_TRUE( + HasSubstr(status.error_message(), "Cannot convert value to integer")); +} + +TEST_F(BigQueryTableAccessorTest, MultiplePagesTest) { + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/\n" + "Auth Token: fake_token\n", + kSampleSchema)); + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/data?maxResults=2&startIndex=1\n" + "Auth Token: fake_token\n", + kTestTwoRows)); + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/" + "data?maxResults=2&pageToken=next_page\n" + "Auth Token: fake_token\n", + kTestRowWithNulls)); + + BigQueryTablePartition partition; + partition.set_start_index(1); + partition.set_end_index(-1); + TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 2, + {}, partition)); + + int64 row_id; + Example example; + TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example)); + EXPECT_EQ(1, row_id); + EXPECT_FALSE(accessor_->Done()); + EXPECT_EQ( + (example.features().feature()).at("int_field").int64_list().value(0), + 1111); + + TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example)); + EXPECT_EQ(2, row_id); + EXPECT_FALSE(accessor_->Done()); + EXPECT_EQ(example.features().feature().at("int_field").int64_list().value(0), + 2222); + + TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example)); + EXPECT_EQ(3, row_id); + EXPECT_TRUE(accessor_->Done()); + + Example expected_example; + ASSERT_TRUE(protobuf::TextFormat::ParseFromString(kTestExampleProtoWithNulls, + &expected_example)); + EXPECT_EQ(DeterministicSerialization(expected_example), + DeterministicSerialization(example)); + EXPECT_TRUE(errors::IsOutOfRange(accessor_->ReadRow(&row_id, &example))); +} + +TEST_F(BigQueryTableAccessorTest, SwitchingPartitionsTest) { + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/\n" + "Auth Token: fake_token\n", + kSampleSchema)); + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/data?maxResults=1&startIndex=0\n" + "Auth Token: fake_token\n", + kTestTwoRows)); + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/" + "data?maxResults=2&startIndex=3\n" + "Auth Token: fake_token\n", + kTestRowWithNulls)); + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/data?maxResults=2&startIndex=0\n" + "Auth Token: fake_token\n", + kTestTwoRows)); + + BigQueryTablePartition partition; + partition.set_start_index(0); + partition.set_end_index(0); + TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 2, + {}, partition)); + + int64 row_id; + Example example; + TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example)); + EXPECT_EQ(0, row_id); + EXPECT_TRUE(accessor_->Done()); + EXPECT_EQ(example.features().feature().at("int_field").int64_list().value(0), + 1111); + + partition.set_start_index(3); + partition.set_end_index(-1); + accessor_->SetPartition(partition); + TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example)); + EXPECT_EQ(3, row_id); + EXPECT_TRUE(accessor_->Done()); + EXPECT_EQ(example.features().feature().at("int_field").int64_list().value(0), + 1234); + + partition.set_start_index(0); + partition.set_end_index(1); + accessor_->SetPartition(partition); + TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example)); + EXPECT_EQ(0, row_id); + EXPECT_FALSE(accessor_->Done()); + EXPECT_EQ(example.features().feature().at("int_field").int64_list().value(0), + 1111); + TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example)); + EXPECT_EQ(1, row_id); + EXPECT_TRUE(accessor_->Done()); + EXPECT_EQ(example.features().feature().at("int_field").int64_list().value(0), + 2222); +} + +TEST_F(BigQueryTableAccessorTest, EmptyPartitionTest) { + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/\n" + "Auth Token: fake_token\n", + kSampleSchema)); + + BigQueryTablePartition partition; + partition.set_start_index(3); + partition.set_end_index(2); + TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1, + {}, partition)); + EXPECT_TRUE(accessor_->Done()); + + int64 row_id; + Example example; + EXPECT_TRUE(errors::IsOutOfRange(accessor_->ReadRow(&row_id, &example))); +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test_data.h b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test_data.h new file mode 100644 index 0000000000..b2b11f4f57 --- /dev/null +++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test_data.h @@ -0,0 +1,404 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_TABLE_ACCESSOR_TEST_DATA_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_TABLE_ACCESSOR_TEST_DATA_H_ + +#include <string> + +namespace tensorflow { +namespace { + +const string kSampleSchema = R"({ + "kind": "bigquery#table", + "etag": "\"4zcX32ezvFoFzxHoG04qJqKZk6c/MTQ1Nzk3NTgwNzE4Mw\"", + "id": "test-project:test-dataset.test-table", + "schema": { + "fields": [ + { + "name": "int_field", + "type": "INTEGER", + "mode": "REQUIRED" + },{ + "name": "str_field", + "type": "STRING", + "mode": "NULLABLE" + },{ + "name": "rec_field", + "type": "RECORD", + "fields": [ + { + "name": "float_field", + "type": "FLOAT", + "mode": "NULLABLE" + }] + },{ + "name": "bool_field", + "type": "BOOLEAN", + "mode": "NULLABLE" + },{ + "name": "bytes_field", + "type": "BYTES", + "mode": "NULLABLE" + },{ + "name": "timestamp_field", + "type": "TIMESTAMP", + "mode": "NULLABLE" + },{ + "name": "date_field", + "type": "DATE", + "mode": "NULLABLE" + },{ + "name": "time_field", + "type": "TIME", + "mode": "NULLABLE" + },{ + "name": "datetime_field", + "type": "DATETIME", + "mode": "NULLABLE" + }] + }, + "numRows": "4" +})"; + +const string kSampleSchemaTwoRecords = R"({ + "kind": "bigquery#table", + "etag": "\"4zcX32ezvFoFzxHoG04qJqKZk6c/MTQ1Nzk3NTgwNzE4Mw\"", + "id": "test-project:test-dataset.test-table", + "schema": { + "fields": [ + { + "name": "rec_field1", + "type": "RECORD", + "fields": [ + { + "name": "int_field", + "type": "INTEGER", + "mode": "NULLABLE" + }, { + "name": "float_field", + "type": "FLOAT", + "mode": "NULLABLE" + }] + },{ + "name": "rec_field2", + "type": "RECORD", + "fields": [ + { + "name": "bool_field", + "type": "BOOLEAN", + "mode": "NULLABLE" + },{ + "name": "bytes_field", + "type": "BYTES", + "mode": "NULLABLE" + }] + }] + }, + "numRows": "4" +})"; + +const string kTestRow = R"({ + "kind": "bigquery#table", + "etag": "\"4zcX32ezvFoFzxHoG04qJqKZk6c/MTQ1Nzk3NTgwNzE4Mw\"", + "id": "test-project:test-dataset.test-table", + "rows": [ + { + "f": [ + { + "v": "1234" + },{ + "v": "" + },{ + "v": { + "f": [ + { + "v": "1.23456" + }] + } + },{ + "v": "true" + },{ + "v": "01010100101" + },{ + "v": "timestamp" + },{ + "v": "date" + },{ + "v": "time" + },{ + "v": "datetime" + }]}]})"; + +const string kBrokenTestRow = R"({ + "kind": "bigquery#table", + "etag": "\"4zcX32ezvFoFzxHoG04qJqKZk6c/MTQ1Nzk3NTgwNzE4Mw\"", + "id": "test-project:test-dataset.test-table", + "rows": [ + { + "f": [ + { + "v": "1-234" // This does not parse as integer. + },{ + "v": "" + },{ + },{ + "v": "true" + },{ + "v": "01010100101" + },{ + "v": "timestamp" + },{ + "v": "date" + },{ + "v": "time" + },{ + "v": "datetime" + }]}]})"; + +const string kTestRowWithNulls = R"({ + "kind": "bigquery#table", + "etag": "\"4zcX32ezvFoFzxHoG04qJqKZk6c/MTQ1Nzk3NTgwNzE4Mw\"", + "id": "test-project:test-dataset.test-table", + "rows": [ + { + "f": [ + { + "v": "1234" + },{ + "v": "string" + },{ + "v": null + },{ + "v": "true" + },{ + "v": "01010100101" + },{ + "v": "" + },{ + "v": null + },{ + "v": null + },{ + "v": "datetime" + }]}]})"; + +// Example proto corresponding to kTestRow. +const string kTestExampleProto = R"(features { + feature { + key: "bool_field" + value { + int64_list { + value: 1 + } + } + } + feature { + key: "bytes_field" + value { + bytes_list { + value: "01010100101" + } + } + } + feature { + key: "date_field" + value { + bytes_list { + value: "date" + } + } + } + feature { + key: "datetime_field" + value { + bytes_list { + value: "datetime" + } + } + } + feature { + key: "int_field" + value { + int64_list { + value: 1234 + } + } + } + feature { + key: "rec_field.float_field" + value { + float_list { + value: 1.23456 + } + } + } + feature { + key: "str_field" + value { + bytes_list { + value: "" + } + } + } + feature { + key: "time_field" + value { + bytes_list { + value: "time" + } + } + } + feature { + key: "timestamp_field" + value { + bytes_list { + value: "timestamp" + } + } + } +} +)"; + +// Example proto corresponding to kTestRowWithNulls. +const string kTestExampleProtoWithNulls = R"(features { + feature { + key: "bool_field" + value { + int64_list { + value: 1 + } + } + } + feature { + key: "bytes_field" + value { + bytes_list { + value: "01010100101" + } + } + } + feature { + key: "datetime_field" + value { + bytes_list { + value: "datetime" + } + } + } + feature { + key: "int_field" + value { + int64_list { + value: 1234 + } + } + } + feature { + key: "timestamp_field" + value { + bytes_list { + value: "" + } + } + } + feature { + key: "str_field" + value { + bytes_list { + value: "string" + } + } + } +} +)"; + +// Example proto corresponding to part of kTestRow. +const string kTestPartialExampleProto = R"(features { + feature { + key: "bool_field" + value { + int64_list { + value: 1 + } + } + } + feature { + key: "rec_field.float_field" + value { + float_list { + value: 1.23456 + } + } + } +} +)"; + +const string kTestExampleProtoWithTwoRecords = R"(features { + feature { + key: "rec_field1.float_field" + value { + float_list { + value: 1.23456 + } + } + } + feature { + key: "rec_field2.bool_field" + value { + int64_list { + value: 1 + } + } + } +} +)"; + +const string kTestTwoRows = R"({ + "kind": "bigquery#table", + "etag": "\"4zcX32ezvFoFzxHoG04qJqKZk6c/MTQ1Nzk3NTgwNzE4Mw\"", + "pageToken": "next_page", + "id": "test-project:test-dataset.test-table", + "rows": [ + {"f": [{"v": "1111"},{},{},{},{},{},{},{},{}]}, + {"f": [{"v": "2222"},{},{},{},{},{},{},{},{}]} + ]})"; + +const string kTestRowWithTwoRecords = R"({ + "kind": "bigquery#table", + "etag": "\"4zcX32ezvFoFzxHoG04qJqKZk6c/MTQ1Nzk3NTgwNzE4Mw\"", + "id": "test-project:test-dataset.test-table", + "rows": [ + { + "f": [ + {"v": {"f": [{}, {"v": "1.23456"}]}}, + {"v": {"f": [{"v": "true"}, {}]} + }]}]})"; + +const string kTestEmptyRow = R"({ + "kind": "bigquery#table", + "etag": "\"4zcX32ezvFoFzxHoG04qJqKZk6c/MTQ1Nzk3NTgwNzE4Mw\"", + "id": "test-project:test-dataset.test-table", + "rows": [ + { + "f": [ + {"v": {"f": [{}, {}]}}, + {"v": {"f": [{"v": null}, {}]} + }]}]})"; + +} // namespace +} // namepsace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_TABLE_ACCESSOR_TEST_DATA_H_ diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_partition.proto b/tensorflow/contrib/cloud/kernels/bigquery_table_partition.proto new file mode 100644 index 0000000000..2d9d1380db --- /dev/null +++ b/tensorflow/contrib/cloud/kernels/bigquery_table_partition.proto @@ -0,0 +1,12 @@ +syntax = "proto3"; + +package tensorflow; + +// This proto specifies a table partition in BigQuery. +message BigQueryTablePartition { + // [start_index, end_index] specify the boundaries of a partition. + // If end_index is -1, every row starting from start_index is part of the + // partition. + int64 start_index = 1; + int64 end_index = 2; +}; diff --git a/tensorflow/contrib/cloud/ops/bigquery_reader_ops.cc b/tensorflow/contrib/cloud/ops/bigquery_reader_ops.cc new file mode 100644 index 0000000000..fbba04a31a --- /dev/null +++ b/tensorflow/contrib/cloud/ops/bigquery_reader_ops.cc @@ -0,0 +1,88 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/* This file registers Bigquery reader ops. */ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +namespace tensorflow { + +using shape_inference::InferenceContext; + +REGISTER_OP("BigQueryReader") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Attr("project_id: string") + .Attr("dataset_id: string") + .Attr("table_id: string") + .Attr("columns: list(string)") + .Attr("timestamp_millis: int") + .Attr("test_end_point: string = ''") + .Output("reader_handle: Ref(string)") + .SetIsStateful() + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->Vector(2)); + return Status::OK(); + }) + .Doc(R"doc( +A Reader that outputs rows from a BigQuery table as tensorflow Examples. + +container: If non-empty, this reader is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this reader is named in the given bucket + with this shared_name. Otherwise, the node name is used instead. +project_id: GCP project ID. +dataset_id: BigQuery Dataset ID. +table_id: Table to read. +columns: List of columns to read. Leave empty to read all columns. +timestamp_millis: Table snapshot timestamp in millis since epoch. Relative +(negative or zero) snapshot times are not allowed. For more details, see +'Table Decorators' in BigQuery docs. +test_end_point: Do not use. For testing purposes only. +reader_handle: The handle to reference the Reader. +)doc"); + +REGISTER_OP("GenerateBigQueryReaderPartitions") + .Attr("project_id: string") + .Attr("dataset_id: string") + .Attr("table_id: string") + .Attr("columns: list(string)") + .Attr("timestamp_millis: int") + .Attr("num_partitions: int") + .Attr("test_end_point: string = ''") + .Output("partitions: string") + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->Vector(InferenceContext::kUnknownDim)); + return Status::OK(); + }) + .Doc(R"doc( +Generates serialized partition messages suitable for batch reads. + +This op should not be used directly by clients. Instead, the +bigquery_reader_ops.py file defines a clean interface to the reader. + +project_id: GCP project ID. +dataset_id: BigQuery Dataset ID. +table_id: Table to read. +columns: List of columns to read. Leave empty to read all columns. +timestamp_millis: Table snapshot timestamp in millis since epoch. Relative +(negative or zero) snapshot times are not allowed. For more details, see +'Table Decorators' in BigQuery docs. +num_partitions: Number of partitions to split the table into. +test_end_point: Do not use. For testing purposes only. +partitions: Serialized table partitions. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops.py b/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops.py new file mode 100644 index 0000000000..136707da18 --- /dev/null +++ b/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops.py @@ -0,0 +1,150 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""BigQuery reading support for TensorFlow.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.cloud.python.ops import gen_bigquery_reader_ops +from tensorflow.python.framework import ops +from tensorflow.python.ops import io_ops + + +class BigQueryReader(io_ops.ReaderBase): + """A Reader that outputs keys and tf.Example values from a BigQuery table. + + Example use: + ```python + # Assume a BigQuery has the following schema, + # name STRING, + # age INT, + # state STRING + + # Create the parse_examples list of features. + features = dict( + name=tf.FixedLenFeature([1], tf.string), + age=tf.FixedLenFeature([1], tf.int32), + state=tf.FixedLenFeature([1], dtype=tf.string, default_value="UNK")) + + # Create a Reader. + reader = bigquery_reader_ops.BigQueryReader(project_id=PROJECT, + dataset_id=DATASET, + table_id=TABLE, + timestamp_millis=TIME, + num_partitions=NUM_PARTITIONS, + features=features) + + # Populate a queue with the BigQuery Table partitions. + queue = tf.training.string_input_producer(reader.partitions()) + + # Read and parse examples. + row_id, examples_serialized = reader.read(queue) + examples = tf.parse_example(examples_serialized, features=features) + + # Process the Tensors examples["name"], examples["age"], etc... + ``` + + Note that to create a reader a snapshot timestamp is necessary. This + will enable the reader to look at a consistent snapshot of the table. + For more information, see 'Table Decorators' in BigQuery docs. + + See ReaderBase for supported methods. + """ + + def __init__(self, + project_id, + dataset_id, + table_id, + timestamp_millis, + num_partitions, + features=None, + columns=None, + test_end_point=None, + name=None): + """Creates a BigQueryReader. + + Args: + project_id: GCP project ID. + dataset_id: BigQuery dataset ID. + table_id: BigQuery table ID. + timestamp_millis: timestamp to snapshot the table in milliseconds since + the epoch. Relative (negative or zero) snapshot times are not allowed. + For more details, see 'Table Decorators' in BigQuery docs. + num_partitions: Number of non-overlapping partitions to read from. + features: parse_example compatible dict from keys to `VarLenFeature` and + `FixedLenFeature` objects. Keys are read as columns from the db. + columns: list of columns to read, can be set iff features is None. + test_end_point: Used only for testing purposes (optional). + name: a name for the operation (optional). + + Raises: + TypeError: - If features is neither None nor a dict or + - If columns is is neither None nor a list or + - If both features and columns are None or set. + """ + if (features is None) == (columns is None): + raise TypeError("exactly one of features and columns must be set.") + + if features is not None: + if not isinstance(features, dict): + raise TypeError("features must be a dict.") + self._columns = list(features.keys()) + elif columns is not None: + if not isinstance(columns, list): + raise TypeError("columns must be a list.") + self._columns = columns + + self._project_id = project_id + self._dataset_id = dataset_id + self._table_id = table_id + self._timestamp_millis = timestamp_millis + self._num_partitions = num_partitions + self._test_end_point = test_end_point + + reader = gen_bigquery_reader_ops.big_query_reader( + name=name, + project_id=self._project_id, + dataset_id=self._dataset_id, + table_id=self._table_id, + timestamp_millis=self._timestamp_millis, + columns=self._columns, + test_end_point=self._test_end_point) + super(BigQueryReader, self).__init__(reader) + + def partitions(self, name=None): + """Returns serialized BigQueryTablePartition messages. + + These messages represent a non-overlapping division of a table for a + bulk read. + + Args: + name: a name for the operation (optional). + + Returns: + `1-D` string `Tensor` of serialized `BigQueryTablePartition` messages. + """ + return gen_bigquery_reader_ops.generate_big_query_reader_partitions( + name=name, + project_id=self._project_id, + dataset_id=self._dataset_id, + table_id=self._table_id, + timestamp_millis=self._timestamp_millis, + num_partitions=self._num_partitions, + test_end_point=self._test_end_point, + columns=self._columns) + + +ops.NotDifferentiable("BigQueryReader") diff --git a/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops_test.py b/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops_test.py new file mode 100644 index 0000000000..b7d044ed25 --- /dev/null +++ b/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops_test.py @@ -0,0 +1,287 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for BigQueryReader Op.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import os +import re +import threading + +from six.moves import SimpleHTTPServer +from six.moves import socketserver + +from tensorflow.contrib.cloud.python.ops import bigquery_reader_ops as cloud +from tensorflow.core.example import example_pb2 +from tensorflow.core.framework import types_pb2 +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import parsing_ops +from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import compat + +_PROJECT = "test-project" +_DATASET = "test-dataset" +_TABLE = "test-table" +# List representation of the test rows in the 'test-table' in BigQuery. +# The schema for each row is: [int64, string, float]. +# The values for rows are generated such that some columns have null values. The +# general formula here is: +# - The int64 column is present in every row. +# - The string column is only avaiable in even rows. +# - The float column is only available in every third row. +_ROWS = [[0, "s_0", 0.1], [1, None, None], [2, "s_2", None], [3, None, 3.1], + [4, "s_4", None], [5, None, None], [6, "s_6", 6.1], [7, None, None], + [8, "s_8", None], [9, None, 9.1]] +# Schema for 'test-table'. +# The schema currently has three columns: int64, string, and float +_SCHEMA = { + "kind": "bigquery#table", + "id": "test-project:test-dataset.test-table", + "schema": { + "fields": [{ + "name": "int64_col", + "type": "INTEGER", + "mode": "NULLABLE" + }, { + "name": "string_col", + "type": "STRING", + "mode": "NULLABLE" + }, { + "name": "float_col", + "type": "FLOAT", + "mode": "NULLABLE" + }] + } +} + + +def _ConvertRowToExampleProto(row): + """Converts the input row to an Example proto. + + Args: + row: Input Row instance. + + Returns: + An Example proto initialized with row values. + """ + + example = example_pb2.Example() + example.features.feature["int64_col"].int64_list.value.append(row[0]) + if row[1] is not None: + example.features.feature["string_col"].bytes_list.value.append( + compat.as_bytes(row[1])) + if row[2] is not None: + example.features.feature["float_col"].float_list.value.append(row[2]) + return example + + +class FakeBigQueryServer(threading.Thread): + """Fake http server to return schema and data for sample table.""" + + def __init__(self, address, port): + """Creates a FakeBigQueryServer. + + Args: + address: Server address + port: Server port. Pass 0 to automatically pick an empty port. + """ + threading.Thread.__init__(self) + self.handler = BigQueryRequestHandler + self.httpd = socketserver.TCPServer((address, port), self.handler) + + def run(self): + self.httpd.serve_forever() + + def shutdown(self): + self.httpd.shutdown() + self.httpd.socket.close() + + +class BigQueryRequestHandler(SimpleHTTPServer.SimpleHTTPRequestHandler): + """Responds to BigQuery HTTP requests. + + Attributes: + num_rows: num_rows in the underlying table served by this class. + """ + + num_rows = 0 + + def do_GET(self): + if "data?maxResults=" not in self.path: + # This is a schema request. + _SCHEMA["numRows"] = self.num_rows + response = json.dumps(_SCHEMA) + else: + # This is a data request. + # + # Extract max results and start index. + max_results = int(re.findall(r"maxResults=(\d+)", self.path)[0]) + start_index = int(re.findall(r"startIndex=(\d+)", self.path)[0]) + + # Send the rows as JSON. + rows = [] + for row in _ROWS[start_index:start_index + max_results]: + row_json = { + "f": [{ + "v": str(row[0]) + }, { + "v": str(row[1]) if row[1] is not None else None + }, { + "v": str(row[2]) if row[2] is not None else None + }] + } + rows.append(row_json) + response = json.dumps({ + "kind": "bigquery#table", + "id": "test-project:test-dataset.test-table", + "rows": rows + }) + self.send_response(200) + self.end_headers() + self.wfile.write(compat.as_bytes(response)) + + +def _SetUpQueue(reader): + """Sets up a queue for a reader.""" + queue = data_flow_ops.FIFOQueue(8, [types_pb2.DT_STRING], shapes=()) + key, value = reader.read(queue) + queue.enqueue_many(reader.partitions()).run() + queue.close().run() + return key, value + + +class BigQueryReaderOpsTest(test.TestCase): + + def setUp(self): + super(BigQueryReaderOpsTest, self).setUp() + self.server = FakeBigQueryServer("127.0.0.1", 0) + self.server.start() + logging.info("server address is %s:%s", self.server.httpd.server_address[0], + self.server.httpd.server_address[1]) + + # An override to bypass the GCP auth token retrieval logic + # in google_auth_provider.cc. + os.environ["GOOGLE_AUTH_TOKEN_FOR_TESTING"] = "not-used" + + def tearDown(self): + self.server.shutdown() + super(BigQueryReaderOpsTest, self).tearDown() + + def _ReadAndCheckRowsUsingFeatures(self, num_rows): + self.server.handler.num_rows = num_rows + + with self.test_session() as sess: + feature_configs = { + "int64_col": + parsing_ops.FixedLenFeature( + [1], dtype=dtypes.int64), + "string_col": + parsing_ops.FixedLenFeature( + [1], dtype=dtypes.string, default_value="s_default"), + } + reader = cloud.BigQueryReader( + project_id=_PROJECT, + dataset_id=_DATASET, + table_id=_TABLE, + num_partitions=4, + features=feature_configs, + timestamp_millis=1, + test_end_point=("%s:%s" % (self.server.httpd.server_address[0], + self.server.httpd.server_address[1]))) + + key, value = _SetUpQueue(reader) + + seen_rows = [] + features = parsing_ops.parse_example( + array_ops.reshape(value, [1]), feature_configs) + for _ in range(num_rows): + int_value, str_value = sess.run( + [features["int64_col"], features["string_col"]]) + + # Parse values returned from the session. + self.assertEqual(int_value.shape, (1, 1)) + self.assertEqual(str_value.shape, (1, 1)) + int64_col = int_value[0][0] + string_col = str_value[0][0] + seen_rows.append(int64_col) + + # Compare. + expected_row = _ROWS[int64_col] + self.assertEqual(int64_col, expected_row[0]) + self.assertEqual( + compat.as_str(string_col), ("s_%d" % int64_col) if expected_row[1] + else "s_default") + + self.assertItemsEqual(seen_rows, range(num_rows)) + + with self.assertRaisesOpError("is closed and has insufficient elements " + "\\(requested 1, current size 0\\)"): + sess.run([key, value]) + + def testReadingSingleRowUsingFeatures(self): + self._ReadAndCheckRowsUsingFeatures(1) + + def testReadingMultipleRowsUsingFeatures(self): + self._ReadAndCheckRowsUsingFeatures(10) + + def testReadingMultipleRowsUsingColumns(self): + num_rows = 10 + self.server.handler.num_rows = num_rows + + with self.test_session() as sess: + reader = cloud.BigQueryReader( + project_id=_PROJECT, + dataset_id=_DATASET, + table_id=_TABLE, + num_partitions=4, + columns=["int64_col", "float_col", "string_col"], + timestamp_millis=1, + test_end_point=("%s:%s" % (self.server.httpd.server_address[0], + self.server.httpd.server_address[1]))) + key, value = _SetUpQueue(reader) + seen_rows = [] + for row_index in range(num_rows): + returned_row_id, example_proto = sess.run([key, value]) + example = example_pb2.Example() + example.ParseFromString(example_proto) + self.assertIn("int64_col", example.features.feature) + feature = example.features.feature["int64_col"] + self.assertEqual(len(feature.int64_list.value), 1) + int64_col = feature.int64_list.value[0] + seen_rows.append(int64_col) + + # Create our expected Example. + expected_example = example_pb2.Example() + expected_example = _ConvertRowToExampleProto(_ROWS[int64_col]) + + # Compare. + self.assertProtoEquals(example, expected_example) + self.assertEqual(row_index, int(returned_row_id)) + + self.assertItemsEqual(seen_rows, range(num_rows)) + + with self.assertRaisesOpError("is closed and has insufficient elements " + "\\(requested 1, current size 0\\)"): + sess.run([key, value]) + + +if __name__ == "__main__": + test.main() |