aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/cloud
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-03-13 16:58:18 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-13 18:09:05 -0700
commit1b881b7c77bd1e664382785447a170de2b85f688 (patch)
tree2d3b8c455ef9b93cb55eb03ea91d050db598fa74 /tensorflow/contrib/cloud
parentee6f27b647fd51b11f9795042c4f6941c77d1c86 (diff)
First version of BigQuery Reader.
Change: 150016997
Diffstat (limited to 'tensorflow/contrib/cloud')
-rw-r--r--tensorflow/contrib/cloud/BUILD73
-rw-r--r--tensorflow/contrib/cloud/__init__.py28
-rw-r--r--tensorflow/contrib/cloud/kernels/BUILD94
-rw-r--r--tensorflow/contrib/cloud/kernels/bigquery_reader_ops.cc192
-rw-r--r--tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc410
-rw-r--r--tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h208
-rw-r--r--tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test.cc513
-rw-r--r--tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test_data.h404
-rw-r--r--tensorflow/contrib/cloud/kernels/bigquery_table_partition.proto12
-rw-r--r--tensorflow/contrib/cloud/ops/bigquery_reader_ops.cc88
-rw-r--r--tensorflow/contrib/cloud/python/ops/bigquery_reader_ops.py150
-rw-r--r--tensorflow/contrib/cloud/python/ops/bigquery_reader_ops_test.py287
12 files changed, 2459 insertions, 0 deletions
diff --git a/tensorflow/contrib/cloud/BUILD b/tensorflow/contrib/cloud/BUILD
new file mode 100644
index 0000000000..840997223f
--- /dev/null
+++ b/tensorflow/contrib/cloud/BUILD
@@ -0,0 +1,73 @@
+# Description:
+# BigQueryReader implementation
+
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_gen_op_libs",
+ "tf_gen_op_wrapper_py",
+ "tf_py_test",
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+tf_gen_op_libs(
+ op_lib_names = ["bigquery_reader_ops"],
+ deps = [
+ "//tensorflow/core:lib",
+ ],
+)
+
+tf_gen_op_wrapper_py(
+ name = "gen_bigquery_reader_ops",
+ out = "python/ops/gen_bigquery_reader_ops.py",
+ require_shape_functions = True,
+ deps = [":bigquery_reader_ops_op_lib"],
+)
+
+py_library(
+ name = "cloud_py",
+ srcs = [
+ "__init__.py",
+ "python/ops/bigquery_reader_ops.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":gen_bigquery_reader_ops",
+ "//tensorflow/contrib/util:util_py",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:platform",
+ ],
+)
+
+tf_py_test(
+ name = "bigquery_reader_ops_test",
+ size = "small",
+ srcs = ["python/ops/bigquery_reader_ops_test.py"],
+ additional_deps = [
+ ":bigquery_reader_ops_op_lib",
+ ":cloud_py",
+ "//tensorflow/contrib/cloud/kernels:bigquery_reader_ops",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:data_flow_ops",
+ "//tensorflow/python:io_ops",
+ "//tensorflow/python:parsing_ops",
+ ],
+ tags = ["manual"],
+)
diff --git a/tensorflow/contrib/cloud/__init__.py b/tensorflow/contrib/cloud/__init__.py
new file mode 100644
index 0000000000..8870264b95
--- /dev/null
+++ b/tensorflow/contrib/cloud/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Module for cloud ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=line-too-long,wildcard-import
+from tensorflow.contrib.cloud.python.ops.bigquery_reader_ops import *
+# pylint: enable=line-too-long,wildcard-import
+
+from tensorflow.python.util.all_util import remove_undocumented
+
+_allowed_symbols = ['BigQueryReader']
+remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/cloud/kernels/BUILD b/tensorflow/contrib/cloud/kernels/BUILD
new file mode 100644
index 0000000000..2500f10c74
--- /dev/null
+++ b/tensorflow/contrib/cloud/kernels/BUILD
@@ -0,0 +1,94 @@
+# Description:
+# BigQueryReader implementation
+
+package(
+ default_visibility = ["//visibility:private"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_cc_test",
+ "tf_copts",
+ "tf_kernel_library",
+)
+
+# For platform specific build config
+load(
+ "//tensorflow/core:platform/default/build_config.bzl",
+ "tf_proto_library",
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+tf_kernel_library(
+ name = "bigquery_reader_ops",
+ srcs = [
+ "bigquery_reader_ops.cc",
+ ],
+ visibility = ["//tensorflow:__subpackages__"],
+ deps = [
+ ":bigquery_table_accessor",
+ ":bigquery_table_partition_proto_cc",
+ "//tensorflow/contrib/cloud:bigquery_reader_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:reader_base",
+ ],
+)
+
+cc_library(
+ name = "bigquery_table_accessor",
+ srcs = [
+ "bigquery_table_accessor.cc",
+ ],
+ hdrs = [
+ "bigquery_table_accessor.h",
+ ],
+ copts = tf_copts(),
+ linkstatic = 1,
+ deps = [
+ ":bigquery_table_partition_proto_cc",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/platform/cloud:google_auth_provider",
+ "//tensorflow/core/platform/cloud:http_request",
+ ],
+ alwayslink = 1,
+)
+
+tf_cc_test(
+ name = "bigquery_table_accessor_test",
+ size = "small",
+ srcs = [
+ "bigquery_table_accessor_test.cc",
+ "bigquery_table_accessor_test_data.h",
+ ],
+ deps = [
+ ":bigquery_table_accessor",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core/platform/cloud:http_request_fake",
+ ],
+)
+
+tf_proto_library(
+ name = "bigquery_table_partition_proto",
+ srcs = [
+ "bigquery_table_partition.proto",
+ ],
+ cc_api_version = 2,
+)
diff --git a/tensorflow/contrib/cloud/kernels/bigquery_reader_ops.cc b/tensorflow/contrib/cloud/kernels/bigquery_reader_ops.cc
new file mode 100644
index 0000000000..02a759eefd
--- /dev/null
+++ b/tensorflow/contrib/cloud/kernels/bigquery_reader_ops.cc
@@ -0,0 +1,192 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <map>
+#include <memory>
+#include <set>
+
+#include "tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h"
+#include "tensorflow/contrib/cloud/kernels/bigquery_table_partition.pb.h"
+#include "tensorflow/core/framework/reader_base.h"
+#include "tensorflow/core/framework/reader_op_kernel.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/math/math_util.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+
+namespace tensorflow {
+namespace {
+
+constexpr int64 kDefaultRowBufferSize = 1000; // Number of rows to buffer.
+
+// This is a helper function for reading table attributes from context.
+Status GetTableAttrs(OpKernelConstruction* context, string* project_id,
+ string* dataset_id, string* table_id,
+ int64* timestamp_millis, std::vector<string>* columns,
+ string* test_end_point) {
+ TF_RETURN_IF_ERROR(context->GetAttr("project_id", project_id));
+ TF_RETURN_IF_ERROR(context->GetAttr("dataset_id", dataset_id));
+ TF_RETURN_IF_ERROR(context->GetAttr("table_id", table_id));
+ TF_RETURN_IF_ERROR(context->GetAttr("timestamp_millis", timestamp_millis));
+ TF_RETURN_IF_ERROR(context->GetAttr("columns", columns));
+ TF_RETURN_IF_ERROR(context->GetAttr("test_end_point", test_end_point));
+ return Status::OK();
+}
+
+} // namespace
+
+// Note that overriden methods with names ending in "Locked" are called by
+// ReaderBase while a mutex is held.
+// See comments for ReaderBase.
+class BigQueryReader : public ReaderBase {
+ public:
+ explicit BigQueryReader(BigQueryTableAccessor* bigquery_table_accessor,
+ const string& node_name)
+ : ReaderBase(strings::StrCat("BigQueryReader '", node_name, "'")),
+ bigquery_table_accessor_(CHECK_NOTNULL(bigquery_table_accessor)) {}
+
+ Status OnWorkStartedLocked() override {
+ BigQueryTablePartition partition;
+ if (!partition.ParseFromString(current_work())) {
+ return errors::InvalidArgument(
+ "Could not parse work as as valid partition.");
+ }
+ TF_RETURN_IF_ERROR(bigquery_table_accessor_->SetPartition(partition));
+ return Status::OK();
+ }
+
+ Status ReadLocked(string* key, string* value, bool* produced,
+ bool* at_end) override {
+ *at_end = false;
+ *produced = false;
+ if (bigquery_table_accessor_->Done()) {
+ *at_end = true;
+ return Status::OK();
+ }
+
+ Example example;
+ int64 row_id;
+ TF_RETURN_IF_ERROR(bigquery_table_accessor_->ReadRow(&row_id, &example));
+
+ *key = std::to_string(row_id);
+ *value = example.SerializeAsString();
+ *produced = true;
+ return Status::OK();
+ }
+
+ private:
+ // Not owned.
+ BigQueryTableAccessor* bigquery_table_accessor_;
+};
+
+class BigQueryReaderOp : public ReaderOpKernel {
+ public:
+ explicit BigQueryReaderOp(OpKernelConstruction* context)
+ : ReaderOpKernel(context) {
+ string table_id;
+ string project_id;
+ string dataset_id;
+ int64 timestamp_millis;
+ std::vector<string> columns;
+ string test_end_point;
+
+ OP_REQUIRES_OK(context,
+ GetTableAttrs(context, &project_id, &dataset_id, &table_id,
+ &timestamp_millis, &columns, &test_end_point));
+ OP_REQUIRES_OK(context,
+ BigQueryTableAccessor::New(
+ project_id, dataset_id, table_id, timestamp_millis,
+ kDefaultRowBufferSize, test_end_point, columns,
+ BigQueryTablePartition(), &bigquery_table_accessor_));
+
+ SetReaderFactory([this]() {
+ return new BigQueryReader(bigquery_table_accessor_.get(), name());
+ });
+ }
+
+ private:
+ std::unique_ptr<BigQueryTableAccessor> bigquery_table_accessor_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("BigQueryReader").Device(DEVICE_CPU),
+ BigQueryReaderOp);
+
+class GenerateBigQueryReaderPartitionsOp : public OpKernel {
+ public:
+ explicit GenerateBigQueryReaderPartitionsOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ string project_id;
+ string dataset_id;
+ string table_id;
+ int64 timestamp_millis;
+ std::vector<string> columns;
+ string test_end_point;
+
+ OP_REQUIRES_OK(context,
+ GetTableAttrs(context, &project_id, &dataset_id, &table_id,
+ &timestamp_millis, &columns, &test_end_point));
+ OP_REQUIRES_OK(context,
+ BigQueryTableAccessor::New(
+ project_id, dataset_id, table_id, timestamp_millis,
+ kDefaultRowBufferSize, test_end_point, columns,
+ BigQueryTablePartition(), &bigquery_table_accessor_));
+ OP_REQUIRES_OK(context, InitializeNumberOfPartitions(context));
+ OP_REQUIRES_OK(context, InitializeTotalNumberOfRows());
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const int64 partition_size = tensorflow::MathUtil::CeilOfRatio<int64>(
+ total_num_rows_, num_partitions_);
+ Tensor* output_tensor = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, TensorShape({num_partitions_}),
+ &output_tensor));
+
+ auto output = output_tensor->template flat<string>();
+ for (int64 i = 0; i < num_partitions_; ++i) {
+ BigQueryTablePartition partition;
+ partition.set_start_index(i * partition_size);
+ partition.set_end_index(
+ std::min(total_num_rows_, (i + 1) * partition_size) - 1);
+ output(i) = partition.SerializeAsString();
+ }
+ }
+
+ private:
+ Status InitializeTotalNumberOfRows() {
+ total_num_rows_ = bigquery_table_accessor_->total_num_rows();
+ if (total_num_rows_ <= 0) {
+ return errors::FailedPrecondition("Invalid total number of rows.");
+ }
+ return Status::OK();
+ }
+
+ Status InitializeNumberOfPartitions(OpKernelConstruction* context) {
+ TF_RETURN_IF_ERROR(context->GetAttr("num_partitions", &num_partitions_));
+ if (num_partitions_ <= 0) {
+ return errors::FailedPrecondition("Invalid number of partitions.");
+ }
+ return Status::OK();
+ }
+
+ int64 num_partitions_;
+ int64 total_num_rows_;
+ std::unique_ptr<BigQueryTableAccessor> bigquery_table_accessor_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("GenerateBigQueryReaderPartitions").Device(DEVICE_CPU),
+ GenerateBigQueryReaderPartitionsOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc
new file mode 100644
index 0000000000..5e95db55b6
--- /dev/null
+++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc
@@ -0,0 +1,410 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h"
+
+#include "tensorflow/core/example/feature.pb.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+
+namespace tensorflow {
+
+namespace {
+
+constexpr size_t kBufferSize = 1024 * 1024; // In bytes.
+const string kBigQueryEndPoint = "https://www.googleapis.com/bigquery/v2";
+
+bool IsPartitionEmpty(const BigQueryTablePartition& partition) {
+ if (partition.end_index() != -1 &&
+ partition.end_index() < partition.start_index()) {
+ return true;
+ }
+ return false;
+}
+
+Status ParseJson(StringPiece json, Json::Value* result) {
+ Json::Reader reader;
+ if (!reader.parse(json.ToString(), *result)) {
+ return errors::Internal("Couldn't parse JSON response from BigQuery.");
+ }
+ return Status::OK();
+}
+
+string ColumnTypeToString(BigQueryTableAccessor::ColumnType enum_type) {
+ switch (enum_type) {
+ case BigQueryTableAccessor::ColumnType::kRecord:
+ return "RECORD";
+ case BigQueryTableAccessor::ColumnType::kString:
+ return "STRING";
+ case BigQueryTableAccessor::ColumnType::kBytes:
+ return "BYTES";
+ case BigQueryTableAccessor::ColumnType::kInteger:
+ return "INTEGER";
+ case BigQueryTableAccessor::ColumnType::kFloat:
+ return "FLOAT";
+ case BigQueryTableAccessor::ColumnType::kBoolean:
+ return "BOOLEAN";
+ case BigQueryTableAccessor::ColumnType::kTimestamp:
+ return "TIMESTAMP";
+ case BigQueryTableAccessor::ColumnType::kDate:
+ return "DATE";
+ case BigQueryTableAccessor::ColumnType::kTime:
+ return "TIME";
+ case BigQueryTableAccessor::ColumnType::kDatetime:
+ return "DATETIME";
+ case BigQueryTableAccessor::ColumnType::kNone:
+ return "NONE";
+ }
+}
+
+Status ParseColumnType(const string& type,
+ BigQueryTableAccessor::ColumnType* enum_type) {
+ if (type == "RECORD") {
+ *enum_type = BigQueryTableAccessor::ColumnType::kRecord;
+ } else if (type == "STRING") {
+ *enum_type = BigQueryTableAccessor::ColumnType::kString;
+ } else if (type == "BYTES") {
+ *enum_type = BigQueryTableAccessor::ColumnType::kBytes;
+ } else if (type == "INTEGER") {
+ *enum_type = BigQueryTableAccessor::ColumnType::kInteger;
+ } else if (type == "FLOAT") {
+ *enum_type = BigQueryTableAccessor::ColumnType::kFloat;
+ } else if (type == "BOOLEAN") {
+ *enum_type = BigQueryTableAccessor::ColumnType::kBoolean;
+ } else if (type == "TIMESTAMP") {
+ *enum_type = BigQueryTableAccessor::ColumnType::kTimestamp;
+ } else if (type == "DATE") {
+ *enum_type = BigQueryTableAccessor::ColumnType::kDate;
+ } else if (type == "TIME") {
+ *enum_type = BigQueryTableAccessor::ColumnType::kTime;
+ } else if (type == "DATETIME") {
+ *enum_type = BigQueryTableAccessor::ColumnType::kDatetime;
+ } else {
+ return errors::Internal(
+ strings::StrCat("Could not parse column type ", type));
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+Status BigQueryTableAccessor::New(
+ const string& project_id, const string& dataset_id, const string& table_id,
+ int64 timestamp_millis, int64 row_buffer_size, const string& end_point,
+ const std::vector<string>& columns, const BigQueryTablePartition& partition,
+ std::unique_ptr<BigQueryTableAccessor>* accessor) {
+ return New(project_id, dataset_id, table_id, timestamp_millis,
+ row_buffer_size, end_point, columns, partition, nullptr, nullptr,
+ accessor);
+}
+
+Status BigQueryTableAccessor::New(
+ const string& project_id, const string& dataset_id, const string& table_id,
+ int64 timestamp_millis, int64 row_buffer_size, const string& end_point,
+ const std::vector<string>& columns, const BigQueryTablePartition& partition,
+ std::unique_ptr<AuthProvider> auth_provider,
+ std::unique_ptr<HttpRequest::Factory> http_request_factory,
+ std::unique_ptr<BigQueryTableAccessor>* accessor) {
+ if (timestamp_millis <= 0) {
+ return errors::InvalidArgument(
+ "Cannot use zero or negative timestamp to query a table.");
+ }
+ const string& big_query_end_point =
+ end_point.empty() ? kBigQueryEndPoint : end_point;
+ if (auth_provider == nullptr && http_request_factory == nullptr) {
+ accessor->reset(new BigQueryTableAccessor(
+ project_id, dataset_id, table_id, timestamp_millis, row_buffer_size,
+ big_query_end_point, columns, partition));
+ } else {
+ accessor->reset(new BigQueryTableAccessor(
+ project_id, dataset_id, table_id, timestamp_millis, row_buffer_size,
+ big_query_end_point, columns, partition, std::move(auth_provider),
+ std::move(http_request_factory)));
+ }
+ return (*accessor)->ReadSchema();
+}
+
+BigQueryTableAccessor::BigQueryTableAccessor(
+ const string& project_id, const string& dataset_id, const string& table_id,
+ int64 timestamp_millis, int64 row_buffer_size, const string& end_point,
+ const std::vector<string>& columns, const BigQueryTablePartition& partition)
+ : BigQueryTableAccessor(
+ project_id, dataset_id, table_id, timestamp_millis, row_buffer_size,
+ end_point, columns, partition,
+ std::unique_ptr<AuthProvider>(new GoogleAuthProvider()),
+ std::unique_ptr<HttpRequest::Factory>(new HttpRequest::Factory())) {
+ row_buffer_.resize(row_buffer_size);
+}
+
+BigQueryTableAccessor::BigQueryTableAccessor(
+ const string& project_id, const string& dataset_id, const string& table_id,
+ int64 timestamp_millis, int64 row_buffer_size, const string& end_point,
+ const std::vector<string>& columns, const BigQueryTablePartition& partition,
+ std::unique_ptr<AuthProvider> auth_provider,
+ std::unique_ptr<HttpRequest::Factory> http_request_factory)
+ : project_id_(project_id),
+ dataset_id_(dataset_id),
+ table_id_(table_id),
+ timestamp_millis_(timestamp_millis),
+ columns_(columns.begin(), columns.end()),
+ bigquery_end_point_(end_point),
+ partition_(partition),
+ auth_provider_(std::move(auth_provider)),
+ http_request_factory_(std::move(http_request_factory)) {
+ row_buffer_.resize(row_buffer_size);
+ Reset();
+}
+
+Status BigQueryTableAccessor::SetPartition(
+ const BigQueryTablePartition& partition) {
+ if (partition.start_index() < 0) {
+ return errors::InvalidArgument("Start index cannot be negative.");
+ }
+ partition_ = partition;
+ Reset();
+ return Status::OK();
+}
+
+void BigQueryTableAccessor::Reset() {
+ first_buffered_row_index_ = partition_.start_index();
+ next_row_in_buffer_ = -1;
+ next_page_token_ = "";
+}
+
+Status BigQueryTableAccessor::ReadRow(int64* row_id, Example* example) {
+ if (Done()) {
+ return errors::OutOfRange("Reached end of table ", FullTableName());
+ }
+
+ // If the next row is already fetched and cached, return the row from the
+ // buffer. Otherwise, fill up the row buffer from BigQuery and return a row.
+ if (next_row_in_buffer_ != -1 &&
+ next_row_in_buffer_ < ComputeMaxResultsArg()) {
+ *row_id = first_buffered_row_index_ + next_row_in_buffer_;
+ *example = row_buffer_[next_row_in_buffer_];
+ next_row_in_buffer_++;
+ } else {
+ string auth_token;
+ TF_RETURN_IF_ERROR(
+ AuthProvider::GetToken(auth_provider_.get(), &auth_token));
+
+ std::unique_ptr<HttpRequest> request(http_request_factory_->Create());
+ std::vector<char> output_buffer;
+ output_buffer.reserve(kBufferSize);
+ TF_RETURN_IF_ERROR(request->Init());
+
+ // The first time that we access BigQuery there is no page token. After that
+ // we use the page token (which returns rows faster).
+ if (!next_page_token_.empty()) {
+ TF_RETURN_IF_ERROR(request->SetUri(strings::StrCat(
+ BigQueryUriPrefix(), "data?maxResults=", ComputeMaxResultsArg(),
+ "&pageToken=", request->EscapeString(next_page_token_))));
+ first_buffered_row_index_ += row_buffer_.size();
+ } else {
+ TF_RETURN_IF_ERROR(request->SetUri(strings::StrCat(
+ BigQueryUriPrefix(), "data?maxResults=", ComputeMaxResultsArg(),
+ "&startIndex=", first_buffered_row_index_)));
+ }
+ TF_RETURN_IF_ERROR(request->AddAuthBearerHeader(auth_token));
+ TF_RETURN_IF_ERROR(request->SetResultBuffer(&output_buffer));
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when reading rows from ",
+ FullTableName());
+
+ // Parse the returned row.
+ StringPiece response_piece =
+ StringPiece(&output_buffer[0], output_buffer.size());
+ Json::Value root;
+ TF_RETURN_IF_ERROR(ParseJson(response_piece, &root));
+ for (unsigned int i = 0; i < root["rows"].size(); ++i) {
+ row_buffer_[i].Clear();
+ TF_RETURN_IF_ERROR(
+ ParseColumnValues(root["rows"][i], schema_root_, &row_buffer_[i]));
+ }
+
+ next_page_token_ = root["pageToken"].asString();
+ *row_id = first_buffered_row_index_;
+ *example = row_buffer_[0];
+ next_row_in_buffer_ = 1;
+ }
+ return Status::OK();
+}
+
+int64 BigQueryTableAccessor::ComputeMaxResultsArg() {
+ if (partition_.end_index() == -1) {
+ return row_buffer_.size();
+ }
+ if (IsPartitionEmpty(partition_)) {
+ return 0;
+ }
+ return std::min(static_cast<int64>(row_buffer_.size()),
+ static_cast<int64>(partition_.end_index() -
+ partition_.start_index() + 1));
+}
+
+Status BigQueryTableAccessor::ParseColumnValues(
+ const Json::Value& value, const SchemaNode& root_schema_node,
+ Example* example) {
+ if (value.empty()) {
+ return Status::OK();
+ }
+ if (value["f"].isNull()) {
+ return Status::OK();
+ }
+ int value_index = 0;
+ for (const auto& schema_node : root_schema_node.schema_nodes) {
+ if (value["f"][value_index].isNull()) {
+ value_index++;
+ continue;
+ }
+
+ if (schema_node.type == ColumnType::kRecord) {
+ TF_RETURN_IF_ERROR(ParseColumnValues(value["f"][value_index]["v"],
+ schema_node, example));
+ } else {
+ // Append the column value only if user has requested the column.
+ if (columns_.empty() ||
+ columns_.find(schema_node.name) != columns_.end()) {
+ TF_RETURN_IF_ERROR(AppendValueToExample(schema_node.name,
+ value["f"][value_index]["v"],
+ schema_node.type, example));
+ }
+ }
+ value_index++;
+ }
+ return Status::OK();
+}
+
+Status BigQueryTableAccessor::ReadSchema() {
+ string auth_token;
+ TF_RETURN_IF_ERROR(AuthProvider::GetToken(auth_provider_.get(), &auth_token));
+
+ // Send a request to read the schema.
+ std::unique_ptr<HttpRequest> request(http_request_factory_->Create());
+ std::vector<char> output_buffer;
+ output_buffer.reserve(kBufferSize);
+ TF_RETURN_IF_ERROR(request->Init());
+ TF_RETURN_IF_ERROR(request->SetUri(BigQueryUriPrefix()));
+ TF_RETURN_IF_ERROR(request->AddAuthBearerHeader(auth_token));
+ TF_RETURN_IF_ERROR(request->SetResultBuffer(&output_buffer));
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when reading schema for ",
+ FullTableName());
+
+ // Parse the schema.
+ StringPiece response_piece =
+ StringPiece(&output_buffer[0], output_buffer.size());
+
+ Json::Value root;
+ TF_RETURN_IF_ERROR(ParseJson(response_piece, &root));
+ const auto& columns = root["schema"]["fields"];
+ string column_name_prefix = "";
+ schema_root_ = {"", ColumnType::kNone};
+ TF_RETURN_IF_ERROR(
+ ExtractColumnType(columns, column_name_prefix, &schema_root_));
+ if (root["numRows"].isNull()) {
+ return errors::Internal("Number of rows cannot be extracted for table ",
+ FullTableName());
+ }
+ strings::safe_strto64(root["numRows"].asString().c_str(), &total_num_rows_);
+ return Status::OK();
+}
+
+Status BigQueryTableAccessor::ExtractColumnType(
+ const Json::Value& columns, const string& column_name_prefix,
+ SchemaNode* root) {
+ for (auto columns_it = columns.begin(); columns_it != columns.end();
+ ++columns_it) {
+ if ((*columns_it)["mode"].asString() == "REPEATED") {
+ return errors::Unimplemented(strings::StrCat(
+ "Tables with repeated columns are not supported: ", FullTableName()));
+ }
+ ColumnType type;
+ const string current_column_name = strings::StrCat(
+ column_name_prefix, (*columns_it)["name"].asString().c_str());
+ TF_RETURN_IF_ERROR(
+ ParseColumnType((*columns_it)["type"].asString().c_str(), &type));
+ root->schema_nodes.emplace_back(current_column_name, type);
+ if (type == ColumnType::kRecord) {
+ const auto new_prefix = strings::StrCat(current_column_name, ".");
+ TF_RETURN_IF_ERROR(ExtractColumnType((*columns_it)["fields"], new_prefix,
+ &root->schema_nodes.back()));
+ }
+ }
+ return Status::OK();
+}
+
+Status BigQueryTableAccessor::AppendValueToExample(
+ const string& column_name, const Json::Value& column_value,
+ const BigQueryTableAccessor::ColumnType type, Example* example) {
+ if (column_value.isNull()) {
+ return Status::OK();
+ }
+ auto& feature =
+ (*example->mutable_features()->mutable_feature())[column_name];
+
+ switch (type) {
+ case BigQueryTableAccessor::ColumnType::kNone:
+ case BigQueryTableAccessor::ColumnType::kRecord:
+ return errors::Unimplemented("Cannot append type to an example.");
+ case BigQueryTableAccessor::ColumnType::kTimestamp:
+ case BigQueryTableAccessor::ColumnType::kDate:
+ case BigQueryTableAccessor::ColumnType::kTime:
+ case BigQueryTableAccessor::ColumnType::kDatetime:
+ case BigQueryTableAccessor::ColumnType::kString:
+ case BigQueryTableAccessor::ColumnType::kBytes:
+ feature.mutable_bytes_list()->add_value(column_value.asString());
+ break;
+ case BigQueryTableAccessor::ColumnType::kBoolean:
+ feature.mutable_int64_list()->add_value(
+ column_value.asString() == "false" ? 0 : 1);
+ break;
+ case BigQueryTableAccessor::ColumnType::kInteger:
+ int64 column_value_int64;
+ if (!strings::safe_strto64(column_value.asString().c_str(),
+ &column_value_int64)) {
+ return errors::Internal("Cannot convert value to integer ",
+ column_value.asString().c_str());
+ }
+ feature.mutable_int64_list()->add_value(column_value_int64);
+ break;
+ case BigQueryTableAccessor::ColumnType::kFloat:
+ // BigQuery float is actually a double.
+ double column_value_double;
+ if (!strings::safe_strtod(column_value.asString().c_str(),
+ &column_value_double)) {
+ return errors::Internal("Cannot convert value to double: ",
+ column_value.asString().c_str());
+ }
+ feature.mutable_float_list()->add_value(
+ static_cast<float>(column_value_double));
+ break;
+ }
+ return Status::OK();
+}
+
+string BigQueryTableAccessor::BigQueryTableAccessor::BigQueryUriPrefix() {
+ HttpRequest request;
+ return strings::StrCat(bigquery_end_point_, "/projects/",
+ request.EscapeString(project_id_), "/datasets/",
+ request.EscapeString(dataset_id_), "/tables/",
+ request.EscapeString(table_id_), "/");
+}
+
+bool BigQueryTableAccessor::Done() {
+ return (total_num_rows_ <= first_buffered_row_index_ + next_row_in_buffer_) ||
+ IsPartitionEmpty(partition_) ||
+ (partition_.end_index() != -1 &&
+ partition_.end_index() <
+ first_buffered_row_index_ + next_row_in_buffer_);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h
new file mode 100644
index 0000000000..1cd0482186
--- /dev/null
+++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h
@@ -0,0 +1,208 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_PARTITION_ACCESSOR_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_PARTITION_ACCESSOR_H_
+
+#include <map>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/contrib/cloud/kernels/bigquery_table_partition.pb.h"
+#include "tensorflow/core/example/example.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/cloud/google_auth_provider.h"
+#include "tensorflow/core/platform/cloud/http_request.h"
+
+namespace tensorflow {
+
+/// This class facilitates accessing BigQuery tables.
+///
+/// Notes:
+/// - Nested fields are not supported.
+/// - BigQuery 'Record's are automatically flattened,
+/// - BigQuery float type is a double but is converted to a C++ float in this
+/// class.
+/// - It is possible for a table snapshot to go out-of-scope in the BigQuery
+/// service while accessing the table if a very old timestamp is used. For
+/// exact details, see 'Table Decorators' in BigQuery docs.
+class BigQueryTableAccessor {
+ public:
+ // Column types supported by BigQuery.
+ enum class ColumnType {
+ kString = 0,
+ kBytes,
+ kInteger,
+ kFloat,
+ kBoolean,
+ kTimestamp,
+ kDate,
+ kTime,
+ kDatetime,
+ kRecord,
+ kNone
+ };
+
+ /// \brief Creates a new BigQueryTableAccessor object.
+ //
+ // We do not allow relative (negative or zero) snapshot times here since we
+ // want to have a consistent snapshot of the table for the lifetime of this
+ // object.
+ // Use end_point if you want to connect to a different end point than the
+ // official BigQuery end point. Otherwise send an empty string.
+ static Status New(const string& project_id, const string& dataset_id,
+ const string& table_id, int64 timestamp_millis,
+ int64 row_buffer_size, const string& end_point,
+ const std::vector<string>& columns,
+ const BigQueryTablePartition& partition,
+ std::unique_ptr<BigQueryTableAccessor>* accessor);
+
+ /// \brief Starts reading a new partition.
+ Status SetPartition(const BigQueryTablePartition& partition);
+
+ /// \brief Returns true if there are more rows available in the current
+ /// partition.
+ bool Done();
+
+ /// \brief Returns a single row as example proto.
+ ///
+ /// This function will return an error if the table snapshot goes out of scope
+ /// in the BigQuery service.
+ Status ReadRow(int64* row_id, Example* example);
+
+ /// \brief Returns total number of rows in the table.
+ int64 total_num_rows() { return total_num_rows_; }
+
+ virtual ~BigQueryTableAccessor() {}
+
+ private:
+ friend class BigQueryTableAccessorTest;
+
+ // This struct encapsulates schema nodes for a BigQuery table.
+ struct SchemaNode {
+ SchemaNode() {}
+ SchemaNode(const string& name, ColumnType type) : name(name), type(type) {}
+
+ string name;
+ ColumnType type;
+ std::vector<SchemaNode> schema_nodes;
+ };
+
+ /// If nullptr is passed for http_request_factory and auth_provider the
+ /// default production ones are used. This can be used by tests to override
+ /// these two variables.
+ static Status New(const string& project_id, const string& dataset_id,
+ const string& table_id, int64 timestamp_millis,
+ int64 row_buffer_size, const string& end_point,
+ const std::vector<string>& columns,
+ const BigQueryTablePartition& partition,
+ std::unique_ptr<AuthProvider> auth_provider,
+ std::unique_ptr<HttpRequest::Factory> http_request_factory,
+ std::unique_ptr<BigQueryTableAccessor>* accessor);
+
+ /// \brief Constructs an object for a given table and partition.
+ BigQueryTableAccessor(const string& project_id, const string& dataset_id,
+ const string& table_id, int64 timestamp_millis,
+ int64 row_buffer_size, const string& end_point,
+ const std::vector<string>& columns,
+ const BigQueryTablePartition& partition);
+
+ /// Used for unit testing.
+ BigQueryTableAccessor(
+ const string& project_id, const string& dataset_id,
+ const string& table_id, int64 timestamp_millis, int64 row_buffer_size,
+ const string& end_point, const std::vector<string>& columns,
+ const BigQueryTablePartition& partition,
+ std::unique_ptr<AuthProvider> auth_provider,
+ std::unique_ptr<HttpRequest::Factory> http_request_factory);
+
+ /// \brief Parses column values for a given row.
+ Status ParseColumnValues(const Json::Value& value,
+ const SchemaNode& root_schema_node,
+ Example* example);
+
+ /// \brief Reads the table schema and stores it.
+ Status ReadSchema();
+
+ /// \brief Extracts column type from a column in schema.
+ Status ExtractColumnType(const Json::Value& columns,
+ const string& column_name_prefix, SchemaNode* root);
+
+ /// \brief Appends a single BigQuery column Value to 'example' for a given
+ /// column.
+ Status AppendValueToExample(const string& column_name,
+ const Json::Value& column_value,
+ const BigQueryTableAccessor::ColumnType type,
+ Example* example);
+
+ /// \brief Resets internal counters for reading a partition.
+ void Reset();
+
+ /// \brief Helper function that returns BigQuery http endpoint prefix.
+ string BigQueryUriPrefix();
+
+ /// \brief Computes the maxResults arg to send to BigQuery.
+ int64 ComputeMaxResultsArg();
+
+ /// \brief Returns full name of the underlying table name.
+ string FullTableName() {
+ return strings::StrCat(project_id_, ":", dataset_id_, ".", table_id_, "@",
+ timestamp_millis_);
+ }
+
+ const string project_id_;
+ const string dataset_id_;
+ const string table_id_;
+
+ // Snapshot timestamp.
+ const int64 timestamp_millis_;
+
+ // Columns that should be read. Empty means all columns.
+ const std::set<string> columns_;
+
+ // HTTP address of BigQuery end point to use.
+ const string bigquery_end_point_;
+
+ // Describes the portion of the table that we are currently accessing.
+ BigQueryTablePartition partition_;
+
+ // Total number of rows in the underlying table.
+ int64 total_num_rows_ = 0;
+
+ // Offset of the first row in the underlying row_buffer_.
+ int64 first_buffered_row_index_ = 0;
+
+ // Offset of the next row in the row_buffer_. -1 indicates that this index
+ // is invalid.
+ int next_row_in_buffer_ = -1;
+
+ // This buffer holds next rows to improve performance. Its size will be
+ // based on how much buffering was requested.
+ std::vector<Example> row_buffer_;
+
+ // If next_page is set, it will used to read next batch of data.
+ string next_page_token_;
+
+ // A tree representing the schema for the underlying table.
+ SchemaNode schema_root_;
+
+ std::unique_ptr<AuthProvider> auth_provider_;
+ std::unique_ptr<HttpRequest::Factory> http_request_factory_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(BigQueryTableAccessor);
+};
+
+} // namespace tensorflow
+#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_PARTITION_ACCESSOR_H_
diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test.cc b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test.cc
new file mode 100644
index 0000000000..9fb339864d
--- /dev/null
+++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test.cc
@@ -0,0 +1,513 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h"
+#include "tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test_data.h"
+#include "tensorflow/core/example/feature.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/platform/cloud/http_request_fake.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+constexpr char kTestProject[] = "test-project";
+constexpr char kTestDataset[] = "test-dataset";
+constexpr char kTestTable[] = "test-table";
+
+static bool HasSubstr(const string& base, const string& substr) {
+ bool ok = StringPiece(base).contains(substr);
+ EXPECT_TRUE(ok) << base << ", expected substring " << substr;
+ return ok;
+}
+
+class FakeAuthProvider : public AuthProvider {
+ public:
+ Status GetToken(string* token) override {
+ *token = "fake_token";
+ return Status::OK();
+ }
+};
+
+static string DeterministicSerialization(const tensorflow::Example& example) {
+ const int size = example.ByteSize();
+ string result(size, '\0');
+ ::tensorflow::protobuf::io::ArrayOutputStream array_stream(
+ gtl::string_as_array(&result), size);
+ ::tensorflow::protobuf::io::CodedOutputStream output_stream(&array_stream);
+
+ output_stream.SetSerializationDeterministic(true);
+ example.SerializeWithCachedSizes(&output_stream);
+ EXPECT_FALSE(output_stream.HadError());
+ EXPECT_EQ(size, output_stream.ByteCount());
+ return result;
+}
+
+} // namespace
+
+class BigQueryTableAccessorTest : public ::testing::Test {
+ protected:
+ BigQueryTableAccessor::SchemaNode GetSchema() {
+ return accessor_->schema_root_;
+ }
+
+ Status CreateTableAccessor(const string& project_id, const string& dataset_id,
+ const string& table_id, int64 timestamp_millis,
+ int64 row_buffer_size,
+ const std::vector<string>& columns,
+ const BigQueryTablePartition& partition) {
+ return BigQueryTableAccessor::New(
+ project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, "",
+ columns, partition, std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests_)),
+ &accessor_);
+ }
+
+ std::vector<HttpRequest*> requests_;
+ std::unique_ptr<BigQueryTableAccessor> accessor_;
+};
+
+TEST_F(BigQueryTableAccessorTest, NegativeTimestamp) {
+ const auto status =
+ CreateTableAccessor(kTestProject, kTestDataset, kTestTable, -1, 3, {},
+ BigQueryTablePartition());
+ EXPECT_TRUE(errors::IsInvalidArgument(status));
+}
+
+TEST_F(BigQueryTableAccessorTest, ZeroTimestamp) {
+ const auto status =
+ CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 0, 3, {},
+ BigQueryTablePartition());
+ EXPECT_TRUE(errors::IsInvalidArgument(status));
+}
+
+TEST_F(BigQueryTableAccessorTest, RepeatedFieldNoAllowedTest) {
+ requests_.emplace_back(new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/"
+ "datasets/test-dataset/tables/test-table/\n"
+ "Auth Token: fake_token\n",
+ R"({
+ "kind": "bigquery#table",
+ "etag": "\"4zcX32ezvFoFzxHoG04qJqKZk6c/MTQ1Nzk3NTgwNzE4Mw\"",
+ "id": "test-project:test-dataset.test-table",
+ "schema": {
+ "fields": [
+ {
+ "name": "int_field",
+ "type": "INTEGER",
+ "mode": "REPEATED"
+ }]
+ },
+ "numRows": "10"
+ })"));
+ const auto status =
+ CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 3, {},
+ BigQueryTablePartition());
+ EXPECT_TRUE(errors::IsUnimplemented(status));
+ EXPECT_TRUE(HasSubstr(status.error_message(),
+ "Tables with repeated columns are not supported"));
+}
+
+TEST_F(BigQueryTableAccessorTest, ValidSchemaTest) {
+ requests_.emplace_back(new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/"
+ "datasets/test-dataset/tables/test-table/\n"
+ "Auth Token: fake_token\n",
+ kSampleSchema));
+ TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 3,
+ {}, BigQueryTablePartition()));
+ // Validate total number of rows.
+ EXPECT_EQ(4, accessor_->total_num_rows());
+
+ // Validate the schema.
+ const auto schema_root = GetSchema();
+ EXPECT_EQ(schema_root.name, "");
+ EXPECT_EQ(schema_root.type, BigQueryTableAccessor::ColumnType::kNone);
+ EXPECT_EQ(9, schema_root.schema_nodes.size());
+
+ EXPECT_EQ(schema_root.schema_nodes[0].name, "int_field");
+ EXPECT_EQ(schema_root.schema_nodes[0].type,
+ BigQueryTableAccessor::ColumnType::kInteger);
+
+ EXPECT_EQ(schema_root.schema_nodes[1].name, "str_field");
+ EXPECT_EQ(schema_root.schema_nodes[1].type,
+ BigQueryTableAccessor::ColumnType::kString);
+
+ EXPECT_EQ(1, schema_root.schema_nodes[2].schema_nodes.size());
+ EXPECT_EQ(schema_root.schema_nodes[2].name, "rec_field");
+ EXPECT_EQ(schema_root.schema_nodes[2].type,
+ BigQueryTableAccessor::ColumnType::kRecord);
+
+ EXPECT_EQ(schema_root.schema_nodes[2].schema_nodes[0].name,
+ "rec_field.float_field");
+ EXPECT_EQ(schema_root.schema_nodes[2].schema_nodes[0].type,
+ BigQueryTableAccessor::ColumnType::kFloat);
+
+ EXPECT_EQ(schema_root.schema_nodes[3].name, "bool_field");
+ EXPECT_EQ(schema_root.schema_nodes[3].type,
+ BigQueryTableAccessor::ColumnType::kBoolean);
+
+ EXPECT_EQ(schema_root.schema_nodes[4].name, "bytes_field");
+ EXPECT_EQ(schema_root.schema_nodes[4].type,
+ BigQueryTableAccessor::ColumnType::kBytes);
+
+ EXPECT_EQ(schema_root.schema_nodes[5].name, "timestamp_field");
+ EXPECT_EQ(schema_root.schema_nodes[5].type,
+ BigQueryTableAccessor::ColumnType::kTimestamp);
+
+ EXPECT_EQ(schema_root.schema_nodes[6].name, "date_field");
+ EXPECT_EQ(schema_root.schema_nodes[6].type,
+ BigQueryTableAccessor::ColumnType::kDate);
+
+ EXPECT_EQ(schema_root.schema_nodes[7].name, "time_field");
+ EXPECT_EQ(schema_root.schema_nodes[7].type,
+ BigQueryTableAccessor::ColumnType::kTime);
+
+ EXPECT_EQ(schema_root.schema_nodes[8].name, "datetime_field");
+ EXPECT_EQ(schema_root.schema_nodes[8].type,
+ BigQueryTableAccessor::ColumnType::kDatetime);
+}
+
+TEST_F(BigQueryTableAccessorTest, ReadOneRowTest) {
+ requests_.emplace_back(new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/"
+ "datasets/test-dataset/tables/test-table/\n"
+ "Auth Token: fake_token\n",
+ kSampleSchema));
+ requests_.emplace_back(new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/"
+ "datasets/test-dataset/tables/test-table/data?maxResults=1&startIndex=2\n"
+ "Auth Token: fake_token\n",
+ kTestRow));
+ BigQueryTablePartition partition;
+ partition.set_start_index(2);
+ partition.set_end_index(2);
+ TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1,
+ {}, partition));
+ int64 row_id;
+ Example example;
+ TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example));
+
+ // Validate returned result.
+ Example expected_example;
+ ASSERT_TRUE(protobuf::TextFormat::ParseFromString(kTestExampleProto,
+ &expected_example));
+ EXPECT_EQ(DeterministicSerialization(expected_example),
+ DeterministicSerialization(example));
+ EXPECT_EQ(row_id, 2);
+ EXPECT_TRUE(accessor_->Done());
+}
+
+TEST_F(BigQueryTableAccessorTest, ReadOneRowPartialTest) {
+ requests_.emplace_back(new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/"
+ "datasets/test-dataset/tables/test-table/\n"
+ "Auth Token: fake_token\n",
+ kSampleSchema));
+ requests_.emplace_back(new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/"
+ "datasets/test-dataset/tables/test-table/data?maxResults=1&startIndex=2\n"
+ "Auth Token: fake_token\n",
+ kTestRow));
+ BigQueryTablePartition partition;
+ partition.set_start_index(2);
+ partition.set_end_index(2);
+ TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1,
+ {"bool_field", "rec_field.float_field"},
+ partition));
+ int64 row_id;
+ Example example;
+ TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example));
+
+ // Validate returned result.
+ EXPECT_EQ(row_id, 2);
+ EXPECT_TRUE(accessor_->Done());
+ Example expected_example;
+ ASSERT_TRUE(protobuf::TextFormat::ParseFromString(kTestPartialExampleProto,
+ &expected_example));
+ EXPECT_EQ(DeterministicSerialization(expected_example),
+ DeterministicSerialization(example));
+}
+
+TEST_F(BigQueryTableAccessorTest, ReadOneRowWithNullsTest) {
+ requests_.emplace_back(new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/"
+ "datasets/test-dataset/tables/test-table/\n"
+ "Auth Token: fake_token\n",
+ kSampleSchema));
+ requests_.emplace_back(new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/"
+ "datasets/test-dataset/tables/test-table/data?maxResults=1&startIndex=2\n"
+ "Auth Token: fake_token\n",
+ kTestRowWithNulls));
+ BigQueryTablePartition partition;
+ partition.set_start_index(2);
+ partition.set_end_index(2);
+ TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1,
+ {}, partition));
+ int64 row_id;
+ Example example;
+ TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example));
+
+ // Validate returned result.
+ Example expected_example;
+ ASSERT_TRUE(protobuf::TextFormat::ParseFromString(kTestExampleProtoWithNulls,
+ &expected_example));
+ EXPECT_EQ(DeterministicSerialization(expected_example),
+ DeterministicSerialization(example));
+ EXPECT_EQ(row_id, 2);
+ EXPECT_TRUE(accessor_->Done());
+}
+
+TEST_F(BigQueryTableAccessorTest, ReadOneRowTwoRecords) {
+ requests_.emplace_back(new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/"
+ "datasets/test-dataset/tables/test-table/\n"
+ "Auth Token: fake_token\n",
+ kSampleSchemaTwoRecords));
+ requests_.emplace_back(new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/"
+ "datasets/test-dataset/tables/test-table/data?maxResults=1&startIndex=2\n"
+ "Auth Token: fake_token\n",
+ kTestRowWithTwoRecords));
+ BigQueryTablePartition partition;
+ partition.set_start_index(2);
+ partition.set_end_index(2);
+ TF_EXPECT_OK(CreateTableAccessor(
+ kTestProject, kTestDataset, kTestTable, 1, 1,
+ {"rec_field2.bool_field", "rec_field1.float_field"}, partition));
+
+ int64 row_id;
+ Example example;
+ TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example));
+
+ // Validate returned result.
+ Example expected_example;
+ ASSERT_TRUE(protobuf::TextFormat::ParseFromString(
+ kTestExampleProtoWithTwoRecords, &expected_example));
+ EXPECT_EQ(DeterministicSerialization(expected_example),
+ DeterministicSerialization(example));
+ EXPECT_EQ(row_id, 2);
+ EXPECT_TRUE(accessor_->Done());
+}
+
+TEST_F(BigQueryTableAccessorTest, NonExistentColumns) {
+ requests_.emplace_back(new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/"
+ "datasets/test-dataset/tables/test-table/\n"
+ "Auth Token: fake_token\n",
+ kSampleSchemaTwoRecords));
+ requests_.emplace_back(new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/"
+ "datasets/test-dataset/tables/test-table/data?maxResults=1&startIndex=2\n"
+ "Auth Token: fake_token\n",
+ kTestRowWithTwoRecords));
+ BigQueryTablePartition partition;
+ partition.set_start_index(2);
+ partition.set_end_index(2);
+ TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1,
+ {"bool_field", "float_field"}, partition));
+ int64 row_id;
+ Example example;
+ TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example));
+
+ // Validate returned result.
+ EXPECT_EQ(row_id, 2);
+ EXPECT_TRUE(accessor_->Done());
+}
+
+TEST_F(BigQueryTableAccessorTest, EmptyRow) {
+ requests_.emplace_back(new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/"
+ "datasets/test-dataset/tables/test-table/\n"
+ "Auth Token: fake_token\n",
+ kSampleSchemaTwoRecords));
+ requests_.emplace_back(new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/"
+ "datasets/test-dataset/tables/test-table/data?maxResults=1&startIndex=2\n"
+ "Auth Token: fake_token\n",
+ kTestEmptyRow));
+ BigQueryTablePartition partition;
+ partition.set_start_index(2);
+ partition.set_end_index(2);
+ TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1,
+ {}, partition));
+ int64 row_id;
+ Example example;
+ TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example));
+
+ // Validate returned result.
+ EXPECT_EQ(row_id, 2);
+ EXPECT_TRUE(accessor_->Done());
+}
+
+TEST_F(BigQueryTableAccessorTest, BrokenRowTest) {
+ requests_.emplace_back(new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/"
+ "datasets/test-dataset/tables/test-table/\n"
+ "Auth Token: fake_token\n",
+ kSampleSchema));
+ requests_.emplace_back(new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/"
+ "datasets/test-dataset/tables/test-table/data?maxResults=1&startIndex=2\n"
+ "Auth Token: fake_token\n",
+ kBrokenTestRow));
+ BigQueryTablePartition partition;
+ partition.set_start_index(2);
+ partition.set_end_index(2);
+ TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1,
+ {}, partition));
+ int64 row_id;
+ Example example;
+ const auto status = accessor_->ReadRow(&row_id, &example);
+ EXPECT_TRUE(errors::IsInternal(status));
+ EXPECT_TRUE(
+ HasSubstr(status.error_message(), "Cannot convert value to integer"));
+}
+
+TEST_F(BigQueryTableAccessorTest, MultiplePagesTest) {
+ requests_.emplace_back(new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/"
+ "datasets/test-dataset/tables/test-table/\n"
+ "Auth Token: fake_token\n",
+ kSampleSchema));
+ requests_.emplace_back(new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/"
+ "datasets/test-dataset/tables/test-table/data?maxResults=2&startIndex=1\n"
+ "Auth Token: fake_token\n",
+ kTestTwoRows));
+ requests_.emplace_back(new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/"
+ "datasets/test-dataset/tables/test-table/"
+ "data?maxResults=2&pageToken=next_page\n"
+ "Auth Token: fake_token\n",
+ kTestRowWithNulls));
+
+ BigQueryTablePartition partition;
+ partition.set_start_index(1);
+ partition.set_end_index(-1);
+ TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 2,
+ {}, partition));
+
+ int64 row_id;
+ Example example;
+ TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example));
+ EXPECT_EQ(1, row_id);
+ EXPECT_FALSE(accessor_->Done());
+ EXPECT_EQ(
+ (example.features().feature()).at("int_field").int64_list().value(0),
+ 1111);
+
+ TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example));
+ EXPECT_EQ(2, row_id);
+ EXPECT_FALSE(accessor_->Done());
+ EXPECT_EQ(example.features().feature().at("int_field").int64_list().value(0),
+ 2222);
+
+ TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example));
+ EXPECT_EQ(3, row_id);
+ EXPECT_TRUE(accessor_->Done());
+
+ Example expected_example;
+ ASSERT_TRUE(protobuf::TextFormat::ParseFromString(kTestExampleProtoWithNulls,
+ &expected_example));
+ EXPECT_EQ(DeterministicSerialization(expected_example),
+ DeterministicSerialization(example));
+ EXPECT_TRUE(errors::IsOutOfRange(accessor_->ReadRow(&row_id, &example)));
+}
+
+TEST_F(BigQueryTableAccessorTest, SwitchingPartitionsTest) {
+ requests_.emplace_back(new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/"
+ "datasets/test-dataset/tables/test-table/\n"
+ "Auth Token: fake_token\n",
+ kSampleSchema));
+ requests_.emplace_back(new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/"
+ "datasets/test-dataset/tables/test-table/data?maxResults=1&startIndex=0\n"
+ "Auth Token: fake_token\n",
+ kTestTwoRows));
+ requests_.emplace_back(new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/"
+ "datasets/test-dataset/tables/test-table/"
+ "data?maxResults=2&startIndex=3\n"
+ "Auth Token: fake_token\n",
+ kTestRowWithNulls));
+ requests_.emplace_back(new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/"
+ "datasets/test-dataset/tables/test-table/data?maxResults=2&startIndex=0\n"
+ "Auth Token: fake_token\n",
+ kTestTwoRows));
+
+ BigQueryTablePartition partition;
+ partition.set_start_index(0);
+ partition.set_end_index(0);
+ TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 2,
+ {}, partition));
+
+ int64 row_id;
+ Example example;
+ TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example));
+ EXPECT_EQ(0, row_id);
+ EXPECT_TRUE(accessor_->Done());
+ EXPECT_EQ(example.features().feature().at("int_field").int64_list().value(0),
+ 1111);
+
+ partition.set_start_index(3);
+ partition.set_end_index(-1);
+ accessor_->SetPartition(partition);
+ TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example));
+ EXPECT_EQ(3, row_id);
+ EXPECT_TRUE(accessor_->Done());
+ EXPECT_EQ(example.features().feature().at("int_field").int64_list().value(0),
+ 1234);
+
+ partition.set_start_index(0);
+ partition.set_end_index(1);
+ accessor_->SetPartition(partition);
+ TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example));
+ EXPECT_EQ(0, row_id);
+ EXPECT_FALSE(accessor_->Done());
+ EXPECT_EQ(example.features().feature().at("int_field").int64_list().value(0),
+ 1111);
+ TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example));
+ EXPECT_EQ(1, row_id);
+ EXPECT_TRUE(accessor_->Done());
+ EXPECT_EQ(example.features().feature().at("int_field").int64_list().value(0),
+ 2222);
+}
+
+TEST_F(BigQueryTableAccessorTest, EmptyPartitionTest) {
+ requests_.emplace_back(new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/"
+ "datasets/test-dataset/tables/test-table/\n"
+ "Auth Token: fake_token\n",
+ kSampleSchema));
+
+ BigQueryTablePartition partition;
+ partition.set_start_index(3);
+ partition.set_end_index(2);
+ TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1,
+ {}, partition));
+ EXPECT_TRUE(accessor_->Done());
+
+ int64 row_id;
+ Example example;
+ EXPECT_TRUE(errors::IsOutOfRange(accessor_->ReadRow(&row_id, &example)));
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test_data.h b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test_data.h
new file mode 100644
index 0000000000..b2b11f4f57
--- /dev/null
+++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test_data.h
@@ -0,0 +1,404 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_TABLE_ACCESSOR_TEST_DATA_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_TABLE_ACCESSOR_TEST_DATA_H_
+
+#include <string>
+
+namespace tensorflow {
+namespace {
+
+const string kSampleSchema = R"({
+ "kind": "bigquery#table",
+ "etag": "\"4zcX32ezvFoFzxHoG04qJqKZk6c/MTQ1Nzk3NTgwNzE4Mw\"",
+ "id": "test-project:test-dataset.test-table",
+ "schema": {
+ "fields": [
+ {
+ "name": "int_field",
+ "type": "INTEGER",
+ "mode": "REQUIRED"
+ },{
+ "name": "str_field",
+ "type": "STRING",
+ "mode": "NULLABLE"
+ },{
+ "name": "rec_field",
+ "type": "RECORD",
+ "fields": [
+ {
+ "name": "float_field",
+ "type": "FLOAT",
+ "mode": "NULLABLE"
+ }]
+ },{
+ "name": "bool_field",
+ "type": "BOOLEAN",
+ "mode": "NULLABLE"
+ },{
+ "name": "bytes_field",
+ "type": "BYTES",
+ "mode": "NULLABLE"
+ },{
+ "name": "timestamp_field",
+ "type": "TIMESTAMP",
+ "mode": "NULLABLE"
+ },{
+ "name": "date_field",
+ "type": "DATE",
+ "mode": "NULLABLE"
+ },{
+ "name": "time_field",
+ "type": "TIME",
+ "mode": "NULLABLE"
+ },{
+ "name": "datetime_field",
+ "type": "DATETIME",
+ "mode": "NULLABLE"
+ }]
+ },
+ "numRows": "4"
+})";
+
+const string kSampleSchemaTwoRecords = R"({
+ "kind": "bigquery#table",
+ "etag": "\"4zcX32ezvFoFzxHoG04qJqKZk6c/MTQ1Nzk3NTgwNzE4Mw\"",
+ "id": "test-project:test-dataset.test-table",
+ "schema": {
+ "fields": [
+ {
+ "name": "rec_field1",
+ "type": "RECORD",
+ "fields": [
+ {
+ "name": "int_field",
+ "type": "INTEGER",
+ "mode": "NULLABLE"
+ }, {
+ "name": "float_field",
+ "type": "FLOAT",
+ "mode": "NULLABLE"
+ }]
+ },{
+ "name": "rec_field2",
+ "type": "RECORD",
+ "fields": [
+ {
+ "name": "bool_field",
+ "type": "BOOLEAN",
+ "mode": "NULLABLE"
+ },{
+ "name": "bytes_field",
+ "type": "BYTES",
+ "mode": "NULLABLE"
+ }]
+ }]
+ },
+ "numRows": "4"
+})";
+
+const string kTestRow = R"({
+ "kind": "bigquery#table",
+ "etag": "\"4zcX32ezvFoFzxHoG04qJqKZk6c/MTQ1Nzk3NTgwNzE4Mw\"",
+ "id": "test-project:test-dataset.test-table",
+ "rows": [
+ {
+ "f": [
+ {
+ "v": "1234"
+ },{
+ "v": ""
+ },{
+ "v": {
+ "f": [
+ {
+ "v": "1.23456"
+ }]
+ }
+ },{
+ "v": "true"
+ },{
+ "v": "01010100101"
+ },{
+ "v": "timestamp"
+ },{
+ "v": "date"
+ },{
+ "v": "time"
+ },{
+ "v": "datetime"
+ }]}]})";
+
+const string kBrokenTestRow = R"({
+ "kind": "bigquery#table",
+ "etag": "\"4zcX32ezvFoFzxHoG04qJqKZk6c/MTQ1Nzk3NTgwNzE4Mw\"",
+ "id": "test-project:test-dataset.test-table",
+ "rows": [
+ {
+ "f": [
+ {
+ "v": "1-234" // This does not parse as integer.
+ },{
+ "v": ""
+ },{
+ },{
+ "v": "true"
+ },{
+ "v": "01010100101"
+ },{
+ "v": "timestamp"
+ },{
+ "v": "date"
+ },{
+ "v": "time"
+ },{
+ "v": "datetime"
+ }]}]})";
+
+const string kTestRowWithNulls = R"({
+ "kind": "bigquery#table",
+ "etag": "\"4zcX32ezvFoFzxHoG04qJqKZk6c/MTQ1Nzk3NTgwNzE4Mw\"",
+ "id": "test-project:test-dataset.test-table",
+ "rows": [
+ {
+ "f": [
+ {
+ "v": "1234"
+ },{
+ "v": "string"
+ },{
+ "v": null
+ },{
+ "v": "true"
+ },{
+ "v": "01010100101"
+ },{
+ "v": ""
+ },{
+ "v": null
+ },{
+ "v": null
+ },{
+ "v": "datetime"
+ }]}]})";
+
+// Example proto corresponding to kTestRow.
+const string kTestExampleProto = R"(features {
+ feature {
+ key: "bool_field"
+ value {
+ int64_list {
+ value: 1
+ }
+ }
+ }
+ feature {
+ key: "bytes_field"
+ value {
+ bytes_list {
+ value: "01010100101"
+ }
+ }
+ }
+ feature {
+ key: "date_field"
+ value {
+ bytes_list {
+ value: "date"
+ }
+ }
+ }
+ feature {
+ key: "datetime_field"
+ value {
+ bytes_list {
+ value: "datetime"
+ }
+ }
+ }
+ feature {
+ key: "int_field"
+ value {
+ int64_list {
+ value: 1234
+ }
+ }
+ }
+ feature {
+ key: "rec_field.float_field"
+ value {
+ float_list {
+ value: 1.23456
+ }
+ }
+ }
+ feature {
+ key: "str_field"
+ value {
+ bytes_list {
+ value: ""
+ }
+ }
+ }
+ feature {
+ key: "time_field"
+ value {
+ bytes_list {
+ value: "time"
+ }
+ }
+ }
+ feature {
+ key: "timestamp_field"
+ value {
+ bytes_list {
+ value: "timestamp"
+ }
+ }
+ }
+}
+)";
+
+// Example proto corresponding to kTestRowWithNulls.
+const string kTestExampleProtoWithNulls = R"(features {
+ feature {
+ key: "bool_field"
+ value {
+ int64_list {
+ value: 1
+ }
+ }
+ }
+ feature {
+ key: "bytes_field"
+ value {
+ bytes_list {
+ value: "01010100101"
+ }
+ }
+ }
+ feature {
+ key: "datetime_field"
+ value {
+ bytes_list {
+ value: "datetime"
+ }
+ }
+ }
+ feature {
+ key: "int_field"
+ value {
+ int64_list {
+ value: 1234
+ }
+ }
+ }
+ feature {
+ key: "timestamp_field"
+ value {
+ bytes_list {
+ value: ""
+ }
+ }
+ }
+ feature {
+ key: "str_field"
+ value {
+ bytes_list {
+ value: "string"
+ }
+ }
+ }
+}
+)";
+
+// Example proto corresponding to part of kTestRow.
+const string kTestPartialExampleProto = R"(features {
+ feature {
+ key: "bool_field"
+ value {
+ int64_list {
+ value: 1
+ }
+ }
+ }
+ feature {
+ key: "rec_field.float_field"
+ value {
+ float_list {
+ value: 1.23456
+ }
+ }
+ }
+}
+)";
+
+const string kTestExampleProtoWithTwoRecords = R"(features {
+ feature {
+ key: "rec_field1.float_field"
+ value {
+ float_list {
+ value: 1.23456
+ }
+ }
+ }
+ feature {
+ key: "rec_field2.bool_field"
+ value {
+ int64_list {
+ value: 1
+ }
+ }
+ }
+}
+)";
+
+const string kTestTwoRows = R"({
+ "kind": "bigquery#table",
+ "etag": "\"4zcX32ezvFoFzxHoG04qJqKZk6c/MTQ1Nzk3NTgwNzE4Mw\"",
+ "pageToken": "next_page",
+ "id": "test-project:test-dataset.test-table",
+ "rows": [
+ {"f": [{"v": "1111"},{},{},{},{},{},{},{},{}]},
+ {"f": [{"v": "2222"},{},{},{},{},{},{},{},{}]}
+ ]})";
+
+const string kTestRowWithTwoRecords = R"({
+ "kind": "bigquery#table",
+ "etag": "\"4zcX32ezvFoFzxHoG04qJqKZk6c/MTQ1Nzk3NTgwNzE4Mw\"",
+ "id": "test-project:test-dataset.test-table",
+ "rows": [
+ {
+ "f": [
+ {"v": {"f": [{}, {"v": "1.23456"}]}},
+ {"v": {"f": [{"v": "true"}, {}]}
+ }]}]})";
+
+const string kTestEmptyRow = R"({
+ "kind": "bigquery#table",
+ "etag": "\"4zcX32ezvFoFzxHoG04qJqKZk6c/MTQ1Nzk3NTgwNzE4Mw\"",
+ "id": "test-project:test-dataset.test-table",
+ "rows": [
+ {
+ "f": [
+ {"v": {"f": [{}, {}]}},
+ {"v": {"f": [{"v": null}, {}]}
+ }]}]})";
+
+} // namespace
+} // namepsace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_TABLE_ACCESSOR_TEST_DATA_H_
diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_partition.proto b/tensorflow/contrib/cloud/kernels/bigquery_table_partition.proto
new file mode 100644
index 0000000000..2d9d1380db
--- /dev/null
+++ b/tensorflow/contrib/cloud/kernels/bigquery_table_partition.proto
@@ -0,0 +1,12 @@
+syntax = "proto3";
+
+package tensorflow;
+
+// This proto specifies a table partition in BigQuery.
+message BigQueryTablePartition {
+ // [start_index, end_index] specify the boundaries of a partition.
+ // If end_index is -1, every row starting from start_index is part of the
+ // partition.
+ int64 start_index = 1;
+ int64 end_index = 2;
+};
diff --git a/tensorflow/contrib/cloud/ops/bigquery_reader_ops.cc b/tensorflow/contrib/cloud/ops/bigquery_reader_ops.cc
new file mode 100644
index 0000000000..fbba04a31a
--- /dev/null
+++ b/tensorflow/contrib/cloud/ops/bigquery_reader_ops.cc
@@ -0,0 +1,88 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+/* This file registers Bigquery reader ops. */
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+namespace tensorflow {
+
+using shape_inference::InferenceContext;
+
+REGISTER_OP("BigQueryReader")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .Attr("project_id: string")
+ .Attr("dataset_id: string")
+ .Attr("table_id: string")
+ .Attr("columns: list(string)")
+ .Attr("timestamp_millis: int")
+ .Attr("test_end_point: string = ''")
+ .Output("reader_handle: Ref(string)")
+ .SetIsStateful()
+ .SetShapeFn([](InferenceContext* c) {
+ c->set_output(0, c->Vector(2));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+A Reader that outputs rows from a BigQuery table as tensorflow Examples.
+
+container: If non-empty, this reader is placed in the given container.
+ Otherwise, a default container is used.
+shared_name: If non-empty, this reader is named in the given bucket
+ with this shared_name. Otherwise, the node name is used instead.
+project_id: GCP project ID.
+dataset_id: BigQuery Dataset ID.
+table_id: Table to read.
+columns: List of columns to read. Leave empty to read all columns.
+timestamp_millis: Table snapshot timestamp in millis since epoch. Relative
+(negative or zero) snapshot times are not allowed. For more details, see
+'Table Decorators' in BigQuery docs.
+test_end_point: Do not use. For testing purposes only.
+reader_handle: The handle to reference the Reader.
+)doc");
+
+REGISTER_OP("GenerateBigQueryReaderPartitions")
+ .Attr("project_id: string")
+ .Attr("dataset_id: string")
+ .Attr("table_id: string")
+ .Attr("columns: list(string)")
+ .Attr("timestamp_millis: int")
+ .Attr("num_partitions: int")
+ .Attr("test_end_point: string = ''")
+ .Output("partitions: string")
+ .SetShapeFn([](InferenceContext* c) {
+ c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Generates serialized partition messages suitable for batch reads.
+
+This op should not be used directly by clients. Instead, the
+bigquery_reader_ops.py file defines a clean interface to the reader.
+
+project_id: GCP project ID.
+dataset_id: BigQuery Dataset ID.
+table_id: Table to read.
+columns: List of columns to read. Leave empty to read all columns.
+timestamp_millis: Table snapshot timestamp in millis since epoch. Relative
+(negative or zero) snapshot times are not allowed. For more details, see
+'Table Decorators' in BigQuery docs.
+num_partitions: Number of partitions to split the table into.
+test_end_point: Do not use. For testing purposes only.
+partitions: Serialized table partitions.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops.py b/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops.py
new file mode 100644
index 0000000000..136707da18
--- /dev/null
+++ b/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops.py
@@ -0,0 +1,150 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""BigQuery reading support for TensorFlow."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.cloud.python.ops import gen_bigquery_reader_ops
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import io_ops
+
+
+class BigQueryReader(io_ops.ReaderBase):
+ """A Reader that outputs keys and tf.Example values from a BigQuery table.
+
+ Example use:
+ ```python
+ # Assume a BigQuery has the following schema,
+ # name STRING,
+ # age INT,
+ # state STRING
+
+ # Create the parse_examples list of features.
+ features = dict(
+ name=tf.FixedLenFeature([1], tf.string),
+ age=tf.FixedLenFeature([1], tf.int32),
+ state=tf.FixedLenFeature([1], dtype=tf.string, default_value="UNK"))
+
+ # Create a Reader.
+ reader = bigquery_reader_ops.BigQueryReader(project_id=PROJECT,
+ dataset_id=DATASET,
+ table_id=TABLE,
+ timestamp_millis=TIME,
+ num_partitions=NUM_PARTITIONS,
+ features=features)
+
+ # Populate a queue with the BigQuery Table partitions.
+ queue = tf.training.string_input_producer(reader.partitions())
+
+ # Read and parse examples.
+ row_id, examples_serialized = reader.read(queue)
+ examples = tf.parse_example(examples_serialized, features=features)
+
+ # Process the Tensors examples["name"], examples["age"], etc...
+ ```
+
+ Note that to create a reader a snapshot timestamp is necessary. This
+ will enable the reader to look at a consistent snapshot of the table.
+ For more information, see 'Table Decorators' in BigQuery docs.
+
+ See ReaderBase for supported methods.
+ """
+
+ def __init__(self,
+ project_id,
+ dataset_id,
+ table_id,
+ timestamp_millis,
+ num_partitions,
+ features=None,
+ columns=None,
+ test_end_point=None,
+ name=None):
+ """Creates a BigQueryReader.
+
+ Args:
+ project_id: GCP project ID.
+ dataset_id: BigQuery dataset ID.
+ table_id: BigQuery table ID.
+ timestamp_millis: timestamp to snapshot the table in milliseconds since
+ the epoch. Relative (negative or zero) snapshot times are not allowed.
+ For more details, see 'Table Decorators' in BigQuery docs.
+ num_partitions: Number of non-overlapping partitions to read from.
+ features: parse_example compatible dict from keys to `VarLenFeature` and
+ `FixedLenFeature` objects. Keys are read as columns from the db.
+ columns: list of columns to read, can be set iff features is None.
+ test_end_point: Used only for testing purposes (optional).
+ name: a name for the operation (optional).
+
+ Raises:
+ TypeError: - If features is neither None nor a dict or
+ - If columns is is neither None nor a list or
+ - If both features and columns are None or set.
+ """
+ if (features is None) == (columns is None):
+ raise TypeError("exactly one of features and columns must be set.")
+
+ if features is not None:
+ if not isinstance(features, dict):
+ raise TypeError("features must be a dict.")
+ self._columns = list(features.keys())
+ elif columns is not None:
+ if not isinstance(columns, list):
+ raise TypeError("columns must be a list.")
+ self._columns = columns
+
+ self._project_id = project_id
+ self._dataset_id = dataset_id
+ self._table_id = table_id
+ self._timestamp_millis = timestamp_millis
+ self._num_partitions = num_partitions
+ self._test_end_point = test_end_point
+
+ reader = gen_bigquery_reader_ops.big_query_reader(
+ name=name,
+ project_id=self._project_id,
+ dataset_id=self._dataset_id,
+ table_id=self._table_id,
+ timestamp_millis=self._timestamp_millis,
+ columns=self._columns,
+ test_end_point=self._test_end_point)
+ super(BigQueryReader, self).__init__(reader)
+
+ def partitions(self, name=None):
+ """Returns serialized BigQueryTablePartition messages.
+
+ These messages represent a non-overlapping division of a table for a
+ bulk read.
+
+ Args:
+ name: a name for the operation (optional).
+
+ Returns:
+ `1-D` string `Tensor` of serialized `BigQueryTablePartition` messages.
+ """
+ return gen_bigquery_reader_ops.generate_big_query_reader_partitions(
+ name=name,
+ project_id=self._project_id,
+ dataset_id=self._dataset_id,
+ table_id=self._table_id,
+ timestamp_millis=self._timestamp_millis,
+ num_partitions=self._num_partitions,
+ test_end_point=self._test_end_point,
+ columns=self._columns)
+
+
+ops.NotDifferentiable("BigQueryReader")
diff --git a/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops_test.py b/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops_test.py
new file mode 100644
index 0000000000..b7d044ed25
--- /dev/null
+++ b/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops_test.py
@@ -0,0 +1,287 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for BigQueryReader Op."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import json
+import os
+import re
+import threading
+
+from six.moves import SimpleHTTPServer
+from six.moves import socketserver
+
+from tensorflow.contrib.cloud.python.ops import bigquery_reader_ops as cloud
+from tensorflow.core.example import example_pb2
+from tensorflow.core.framework import types_pb2
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import compat
+
+_PROJECT = "test-project"
+_DATASET = "test-dataset"
+_TABLE = "test-table"
+# List representation of the test rows in the 'test-table' in BigQuery.
+# The schema for each row is: [int64, string, float].
+# The values for rows are generated such that some columns have null values. The
+# general formula here is:
+# - The int64 column is present in every row.
+# - The string column is only avaiable in even rows.
+# - The float column is only available in every third row.
+_ROWS = [[0, "s_0", 0.1], [1, None, None], [2, "s_2", None], [3, None, 3.1],
+ [4, "s_4", None], [5, None, None], [6, "s_6", 6.1], [7, None, None],
+ [8, "s_8", None], [9, None, 9.1]]
+# Schema for 'test-table'.
+# The schema currently has three columns: int64, string, and float
+_SCHEMA = {
+ "kind": "bigquery#table",
+ "id": "test-project:test-dataset.test-table",
+ "schema": {
+ "fields": [{
+ "name": "int64_col",
+ "type": "INTEGER",
+ "mode": "NULLABLE"
+ }, {
+ "name": "string_col",
+ "type": "STRING",
+ "mode": "NULLABLE"
+ }, {
+ "name": "float_col",
+ "type": "FLOAT",
+ "mode": "NULLABLE"
+ }]
+ }
+}
+
+
+def _ConvertRowToExampleProto(row):
+ """Converts the input row to an Example proto.
+
+ Args:
+ row: Input Row instance.
+
+ Returns:
+ An Example proto initialized with row values.
+ """
+
+ example = example_pb2.Example()
+ example.features.feature["int64_col"].int64_list.value.append(row[0])
+ if row[1] is not None:
+ example.features.feature["string_col"].bytes_list.value.append(
+ compat.as_bytes(row[1]))
+ if row[2] is not None:
+ example.features.feature["float_col"].float_list.value.append(row[2])
+ return example
+
+
+class FakeBigQueryServer(threading.Thread):
+ """Fake http server to return schema and data for sample table."""
+
+ def __init__(self, address, port):
+ """Creates a FakeBigQueryServer.
+
+ Args:
+ address: Server address
+ port: Server port. Pass 0 to automatically pick an empty port.
+ """
+ threading.Thread.__init__(self)
+ self.handler = BigQueryRequestHandler
+ self.httpd = socketserver.TCPServer((address, port), self.handler)
+
+ def run(self):
+ self.httpd.serve_forever()
+
+ def shutdown(self):
+ self.httpd.shutdown()
+ self.httpd.socket.close()
+
+
+class BigQueryRequestHandler(SimpleHTTPServer.SimpleHTTPRequestHandler):
+ """Responds to BigQuery HTTP requests.
+
+ Attributes:
+ num_rows: num_rows in the underlying table served by this class.
+ """
+
+ num_rows = 0
+
+ def do_GET(self):
+ if "data?maxResults=" not in self.path:
+ # This is a schema request.
+ _SCHEMA["numRows"] = self.num_rows
+ response = json.dumps(_SCHEMA)
+ else:
+ # This is a data request.
+ #
+ # Extract max results and start index.
+ max_results = int(re.findall(r"maxResults=(\d+)", self.path)[0])
+ start_index = int(re.findall(r"startIndex=(\d+)", self.path)[0])
+
+ # Send the rows as JSON.
+ rows = []
+ for row in _ROWS[start_index:start_index + max_results]:
+ row_json = {
+ "f": [{
+ "v": str(row[0])
+ }, {
+ "v": str(row[1]) if row[1] is not None else None
+ }, {
+ "v": str(row[2]) if row[2] is not None else None
+ }]
+ }
+ rows.append(row_json)
+ response = json.dumps({
+ "kind": "bigquery#table",
+ "id": "test-project:test-dataset.test-table",
+ "rows": rows
+ })
+ self.send_response(200)
+ self.end_headers()
+ self.wfile.write(compat.as_bytes(response))
+
+
+def _SetUpQueue(reader):
+ """Sets up a queue for a reader."""
+ queue = data_flow_ops.FIFOQueue(8, [types_pb2.DT_STRING], shapes=())
+ key, value = reader.read(queue)
+ queue.enqueue_many(reader.partitions()).run()
+ queue.close().run()
+ return key, value
+
+
+class BigQueryReaderOpsTest(test.TestCase):
+
+ def setUp(self):
+ super(BigQueryReaderOpsTest, self).setUp()
+ self.server = FakeBigQueryServer("127.0.0.1", 0)
+ self.server.start()
+ logging.info("server address is %s:%s", self.server.httpd.server_address[0],
+ self.server.httpd.server_address[1])
+
+ # An override to bypass the GCP auth token retrieval logic
+ # in google_auth_provider.cc.
+ os.environ["GOOGLE_AUTH_TOKEN_FOR_TESTING"] = "not-used"
+
+ def tearDown(self):
+ self.server.shutdown()
+ super(BigQueryReaderOpsTest, self).tearDown()
+
+ def _ReadAndCheckRowsUsingFeatures(self, num_rows):
+ self.server.handler.num_rows = num_rows
+
+ with self.test_session() as sess:
+ feature_configs = {
+ "int64_col":
+ parsing_ops.FixedLenFeature(
+ [1], dtype=dtypes.int64),
+ "string_col":
+ parsing_ops.FixedLenFeature(
+ [1], dtype=dtypes.string, default_value="s_default"),
+ }
+ reader = cloud.BigQueryReader(
+ project_id=_PROJECT,
+ dataset_id=_DATASET,
+ table_id=_TABLE,
+ num_partitions=4,
+ features=feature_configs,
+ timestamp_millis=1,
+ test_end_point=("%s:%s" % (self.server.httpd.server_address[0],
+ self.server.httpd.server_address[1])))
+
+ key, value = _SetUpQueue(reader)
+
+ seen_rows = []
+ features = parsing_ops.parse_example(
+ array_ops.reshape(value, [1]), feature_configs)
+ for _ in range(num_rows):
+ int_value, str_value = sess.run(
+ [features["int64_col"], features["string_col"]])
+
+ # Parse values returned from the session.
+ self.assertEqual(int_value.shape, (1, 1))
+ self.assertEqual(str_value.shape, (1, 1))
+ int64_col = int_value[0][0]
+ string_col = str_value[0][0]
+ seen_rows.append(int64_col)
+
+ # Compare.
+ expected_row = _ROWS[int64_col]
+ self.assertEqual(int64_col, expected_row[0])
+ self.assertEqual(
+ compat.as_str(string_col), ("s_%d" % int64_col) if expected_row[1]
+ else "s_default")
+
+ self.assertItemsEqual(seen_rows, range(num_rows))
+
+ with self.assertRaisesOpError("is closed and has insufficient elements "
+ "\\(requested 1, current size 0\\)"):
+ sess.run([key, value])
+
+ def testReadingSingleRowUsingFeatures(self):
+ self._ReadAndCheckRowsUsingFeatures(1)
+
+ def testReadingMultipleRowsUsingFeatures(self):
+ self._ReadAndCheckRowsUsingFeatures(10)
+
+ def testReadingMultipleRowsUsingColumns(self):
+ num_rows = 10
+ self.server.handler.num_rows = num_rows
+
+ with self.test_session() as sess:
+ reader = cloud.BigQueryReader(
+ project_id=_PROJECT,
+ dataset_id=_DATASET,
+ table_id=_TABLE,
+ num_partitions=4,
+ columns=["int64_col", "float_col", "string_col"],
+ timestamp_millis=1,
+ test_end_point=("%s:%s" % (self.server.httpd.server_address[0],
+ self.server.httpd.server_address[1])))
+ key, value = _SetUpQueue(reader)
+ seen_rows = []
+ for row_index in range(num_rows):
+ returned_row_id, example_proto = sess.run([key, value])
+ example = example_pb2.Example()
+ example.ParseFromString(example_proto)
+ self.assertIn("int64_col", example.features.feature)
+ feature = example.features.feature["int64_col"]
+ self.assertEqual(len(feature.int64_list.value), 1)
+ int64_col = feature.int64_list.value[0]
+ seen_rows.append(int64_col)
+
+ # Create our expected Example.
+ expected_example = example_pb2.Example()
+ expected_example = _ConvertRowToExampleProto(_ROWS[int64_col])
+
+ # Compare.
+ self.assertProtoEquals(example, expected_example)
+ self.assertEqual(row_index, int(returned_row_id))
+
+ self.assertItemsEqual(seen_rows, range(num_rows))
+
+ with self.assertRaisesOpError("is closed and has insufficient elements "
+ "\\(requested 1, current size 0\\)"):
+ sess.run([key, value])
+
+
+if __name__ == "__main__":
+ test.main()