aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-12-21 11:11:54 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-21 11:27:19 -0800
commit1583f18633925dfe414f4e122eb7dafa9c07d56e (patch)
tree2b37a186876e1102e6425f4dd9426ae087c73a62
parent9a1e2d5d3d2c6420c410378c385b0c4665cedb9b (diff)
Adding BigQuery Reader OP.
This changelist adds the necessary C++ and Python code to register a new op to partition and read BigQuery tables. Change: 142681535
-rw-r--r--tensorflow/core/BUILD12
-rw-r--r--tensorflow/core/kernels/cloud/BUILD19
-rw-r--r--tensorflow/core/kernels/cloud/bigquery_reader_ops.cc193
-rw-r--r--tensorflow/core/kernels/cloud/bigquery_table_accessor.cc79
-rw-r--r--tensorflow/core/kernels/cloud/bigquery_table_accessor.h45
-rw-r--r--tensorflow/core/kernels/cloud/bigquery_table_accessor_test.cc38
-rw-r--r--tensorflow/core/ops/cloud_ops.cc88
-rw-r--r--tensorflow/python/BUILD30
-rw-r--r--tensorflow/python/__init__.py2
-rw-r--r--tensorflow/python/ops/cloud/__init__.py22
-rw-r--r--tensorflow/python/ops/cloud/bigquery_reader_ops.py150
-rw-r--r--tensorflow/python/ops/cloud/bigquery_reader_ops_test.py274
12 files changed, 900 insertions, 52 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 95b111f23b..fd45006c95 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -451,11 +451,22 @@ cc_library(
)
cc_library(
+ name = "cloud_ops_op_lib",
+ srcs = ["ops/cloud_ops.cc"],
+ copts = tf_copts(),
+ linkstatic = 1,
+ visibility = ["//visibility:public"],
+ deps = [":framework"],
+ alwayslink = 1,
+)
+
+cc_library(
name = "ops",
visibility = ["//visibility:public"],
deps = [
":array_ops_op_lib",
":candidate_sampling_ops_op_lib",
+ ":cloud_ops_op_lib",
":control_flow_ops_op_lib",
":ctc_ops_op_lib",
":data_flow_ops_op_lib",
@@ -602,6 +613,7 @@ cc_library(
"//tensorflow/core/kernels:string",
"//tensorflow/core/kernels:training_ops",
"//tensorflow/core/kernels:word2vec_kernels",
+ "//tensorflow/core/kernels/cloud:bigquery_reader_ops",
] + if_not_windows([
"//tensorflow/core/kernels:fact_op",
"//tensorflow/core/kernels:array_not_windows",
diff --git a/tensorflow/core/kernels/cloud/BUILD b/tensorflow/core/kernels/cloud/BUILD
index dfb4772b97..710cb5aa14 100644
--- a/tensorflow/core/kernels/cloud/BUILD
+++ b/tensorflow/core/kernels/cloud/BUILD
@@ -9,6 +9,7 @@ licenses(["notice"]) # Apache 2.0
load(
"//tensorflow:tensorflow.bzl",
+ "tf_kernel_library",
"tf_cc_test",
)
@@ -30,6 +31,24 @@ filegroup(
visibility = ["//tensorflow:__subpackages__"],
)
+tf_kernel_library(
+ name = "bigquery_reader_ops",
+ srcs = [
+ "bigquery_reader_ops.cc",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":bigquery_table_accessor",
+ ":bigquery_table_partition_proto_cc",
+ "//tensorflow/core:cloud_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/kernels:reader_base",
+ ],
+)
+
cc_library(
name = "bigquery_table_accessor",
srcs = [
diff --git a/tensorflow/core/kernels/cloud/bigquery_reader_ops.cc b/tensorflow/core/kernels/cloud/bigquery_reader_ops.cc
new file mode 100644
index 0000000000..a3b026e2a1
--- /dev/null
+++ b/tensorflow/core/kernels/cloud/bigquery_reader_ops.cc
@@ -0,0 +1,193 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <map>
+#include <memory>
+#include <set>
+
+#include "tensorflow/core/example/example.pb.h"
+#include "tensorflow/core/framework/reader_op_kernel.h"
+#include "tensorflow/core/kernels/cloud/bigquery_table_accessor.h"
+#include "tensorflow/core/kernels/cloud/bigquery_table_partition.pb.h"
+#include "tensorflow/core/kernels/reader_base.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/math/math_util.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+
+namespace tensorflow {
+namespace {
+
+constexpr int64 kDefaultRowBufferSize = 1000; // Number of rows to buffer.
+
+// This is a helper function for reading table attributes from context.
+Status GetTableAttrs(OpKernelConstruction* context, string* project_id,
+ string* dataset_id, string* table_id,
+ int64* timestamp_millis, std::vector<string>* columns,
+ string* test_end_point) {
+ TF_RETURN_IF_ERROR(context->GetAttr("project_id", project_id));
+ TF_RETURN_IF_ERROR(context->GetAttr("dataset_id", dataset_id));
+ TF_RETURN_IF_ERROR(context->GetAttr("table_id", table_id));
+ TF_RETURN_IF_ERROR(context->GetAttr("timestamp_millis", timestamp_millis));
+ TF_RETURN_IF_ERROR(context->GetAttr("columns", columns));
+ TF_RETURN_IF_ERROR(context->GetAttr("test_end_point", test_end_point));
+ return Status::OK();
+}
+
+} // namespace
+
+// Note that overriden methods with names ending in "Locked" are called by
+// ReaderBase while a mutex is held.
+// See comments for ReaderBase.
+class BigQueryReader : public ReaderBase {
+ public:
+ explicit BigQueryReader(BigQueryTableAccessor* bigquery_table_accessor,
+ const string& node_name)
+ : ReaderBase(strings::StrCat("BigQueryReader '", node_name, "'")),
+ bigquery_table_accessor_(CHECK_NOTNULL(bigquery_table_accessor)) {}
+
+ Status OnWorkStartedLocked() override {
+ BigQueryTablePartition partition;
+ if (!partition.ParseFromString(current_work())) {
+ return errors::InvalidArgument(
+ "Could not parse work as as valid partition.");
+ }
+ TF_RETURN_IF_ERROR(bigquery_table_accessor_->SetPartition(partition));
+ return Status::OK();
+ }
+
+ Status ReadLocked(string* key, string* value, bool* produced,
+ bool* at_end) override {
+ *at_end = false;
+ *produced = false;
+ if (bigquery_table_accessor_->Done()) {
+ *at_end = true;
+ return Status::OK();
+ }
+
+ Example example;
+ int64 row_id;
+ TF_RETURN_IF_ERROR(bigquery_table_accessor_->ReadRow(&row_id, &example));
+
+ *key = std::to_string(row_id);
+ *value = example.SerializeAsString();
+ *produced = true;
+ return Status::OK();
+ }
+
+ private:
+ // Not owned.
+ BigQueryTableAccessor* bigquery_table_accessor_;
+};
+
+class BigQueryReaderOp : public ReaderOpKernel {
+ public:
+ explicit BigQueryReaderOp(OpKernelConstruction* context)
+ : ReaderOpKernel(context) {
+ string table_id;
+ string project_id;
+ string dataset_id;
+ int64 timestamp_millis;
+ std::vector<string> columns;
+ string test_end_point;
+
+ OP_REQUIRES_OK(context,
+ GetTableAttrs(context, &project_id, &dataset_id, &table_id,
+ &timestamp_millis, &columns, &test_end_point));
+ OP_REQUIRES_OK(context,
+ BigQueryTableAccessor::New(
+ project_id, dataset_id, table_id, timestamp_millis,
+ kDefaultRowBufferSize, test_end_point, columns,
+ BigQueryTablePartition(), &bigquery_table_accessor_));
+
+ SetReaderFactory([this]() {
+ return new BigQueryReader(bigquery_table_accessor_.get(), name());
+ });
+ }
+
+ private:
+ std::unique_ptr<BigQueryTableAccessor> bigquery_table_accessor_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("BigQueryReader").Device(DEVICE_CPU),
+ BigQueryReaderOp);
+
+class GenerateBigQueryReaderPartitionsOp : public OpKernel {
+ public:
+ explicit GenerateBigQueryReaderPartitionsOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ string project_id;
+ string dataset_id;
+ string table_id;
+ int64 timestamp_millis;
+ std::vector<string> columns;
+ string test_end_point;
+
+ OP_REQUIRES_OK(context,
+ GetTableAttrs(context, &project_id, &dataset_id, &table_id,
+ &timestamp_millis, &columns, &test_end_point));
+ OP_REQUIRES_OK(context,
+ BigQueryTableAccessor::New(
+ project_id, dataset_id, table_id, timestamp_millis,
+ kDefaultRowBufferSize, test_end_point, columns,
+ BigQueryTablePartition(), &bigquery_table_accessor_));
+ OP_REQUIRES_OK(context, InitializeNumberOfPartitions(context));
+ OP_REQUIRES_OK(context, InitializeTotalNumberOfRows());
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const int64 partition_size = tensorflow::MathUtil::CeilOfRatio<int64>(
+ total_num_rows_, num_partitions_);
+ Tensor* output_tensor = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, TensorShape({num_partitions_}),
+ &output_tensor));
+
+ auto output = output_tensor->template flat<string>();
+ for (int64 i = 0; i < num_partitions_; ++i) {
+ BigQueryTablePartition partition;
+ partition.set_start_index(i * partition_size);
+ partition.set_end_index(
+ std::min(total_num_rows_, (i + 1) * partition_size) - 1);
+ output(i) = partition.SerializeAsString();
+ }
+ }
+
+ private:
+ Status InitializeTotalNumberOfRows() {
+ total_num_rows_ = bigquery_table_accessor_->total_num_rows();
+ if (total_num_rows_ <= 0) {
+ return errors::FailedPrecondition("Invalid total number of rows.");
+ }
+ return Status::OK();
+ }
+
+ Status InitializeNumberOfPartitions(OpKernelConstruction* context) {
+ TF_RETURN_IF_ERROR(context->GetAttr("num_partitions", &num_partitions_));
+ if (num_partitions_ <= 0) {
+ return errors::FailedPrecondition("Invalid number of partitions.");
+ }
+ return Status::OK();
+ }
+
+ int64 num_partitions_;
+ int64 total_num_rows_;
+ std::unique_ptr<BigQueryTableAccessor> bigquery_table_accessor_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("GenerateBigQueryReaderPartitions").Device(DEVICE_CPU),
+ GenerateBigQueryReaderPartitionsOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cloud/bigquery_table_accessor.cc b/tensorflow/core/kernels/cloud/bigquery_table_accessor.cc
index 293d47d975..3e9adfa372 100644
--- a/tensorflow/core/kernels/cloud/bigquery_table_accessor.cc
+++ b/tensorflow/core/kernels/cloud/bigquery_table_accessor.cc
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-
#include "tensorflow/core/kernels/cloud/bigquery_table_accessor.h"
#include "tensorflow/core/example/feature.pb.h"
@@ -23,6 +22,15 @@ namespace tensorflow {
namespace {
constexpr size_t kBufferSize = 1024 * 1024; // In bytes.
+const string kBigQueryEndPoint = "https://www.googleapis.com/bigquery/v2";
+
+bool IsPartitionEmpty(const BigQueryTablePartition& partition) {
+ if (partition.end_index() != -1 &&
+ partition.end_index() < partition.start_index()) {
+ return true;
+ }
+ return false;
+}
Status ParseJson(StringPiece json, Json::Value* result) {
Json::Reader reader;
@@ -92,17 +100,18 @@ Status ParseColumnType(const string& type,
Status BigQueryTableAccessor::New(
const string& project_id, const string& dataset_id, const string& table_id,
- int64 timestamp_millis, int64 row_buffer_size,
- const std::set<string>& columns, const BigQueryTablePartition& partition,
+ int64 timestamp_millis, int64 row_buffer_size, const string& end_point,
+ const std::vector<string>& columns, const BigQueryTablePartition& partition,
std::unique_ptr<BigQueryTableAccessor>* accessor) {
return New(project_id, dataset_id, table_id, timestamp_millis,
- row_buffer_size, columns, partition, nullptr, nullptr, accessor);
+ row_buffer_size, end_point, columns, partition, nullptr, nullptr,
+ accessor);
}
Status BigQueryTableAccessor::New(
const string& project_id, const string& dataset_id, const string& table_id,
- int64 timestamp_millis, int64 row_buffer_size,
- const std::set<string>& columns, const BigQueryTablePartition& partition,
+ int64 timestamp_millis, int64 row_buffer_size, const string& end_point,
+ const std::vector<string>& columns, const BigQueryTablePartition& partition,
std::unique_ptr<AuthProvider> auth_provider,
std::unique_ptr<HttpRequest::Factory> http_request_factory,
std::unique_ptr<BigQueryTableAccessor>* accessor) {
@@ -110,14 +119,16 @@ Status BigQueryTableAccessor::New(
return errors::InvalidArgument(
"Cannot use zero or negative timestamp to query a table.");
}
+ const string& big_query_end_point =
+ end_point.empty() ? kBigQueryEndPoint : end_point;
if (auth_provider == nullptr && http_request_factory == nullptr) {
- accessor->reset(new BigQueryTableAccessor(project_id, dataset_id, table_id,
- timestamp_millis, row_buffer_size,
- columns, partition));
+ accessor->reset(new BigQueryTableAccessor(
+ project_id, dataset_id, table_id, timestamp_millis, row_buffer_size,
+ big_query_end_point, columns, partition));
} else {
accessor->reset(new BigQueryTableAccessor(
project_id, dataset_id, table_id, timestamp_millis, row_buffer_size,
- columns, partition, std::move(auth_provider),
+ big_query_end_point, columns, partition, std::move(auth_provider),
std::move(http_request_factory)));
}
return (*accessor)->ReadSchema();
@@ -125,11 +136,11 @@ Status BigQueryTableAccessor::New(
BigQueryTableAccessor::BigQueryTableAccessor(
const string& project_id, const string& dataset_id, const string& table_id,
- int64 timestamp_millis, int64 row_buffer_size,
- const std::set<string>& columns, const BigQueryTablePartition& partition)
+ int64 timestamp_millis, int64 row_buffer_size, const string& end_point,
+ const std::vector<string>& columns, const BigQueryTablePartition& partition)
: BigQueryTableAccessor(
project_id, dataset_id, table_id, timestamp_millis, row_buffer_size,
- columns, partition,
+ end_point, columns, partition,
std::unique_ptr<AuthProvider>(new GoogleAuthProvider()),
std::unique_ptr<HttpRequest::Factory>(new HttpRequest::Factory())) {
row_buffer_.resize(row_buffer_size);
@@ -137,15 +148,16 @@ BigQueryTableAccessor::BigQueryTableAccessor(
BigQueryTableAccessor::BigQueryTableAccessor(
const string& project_id, const string& dataset_id, const string& table_id,
- int64 timestamp_millis, int64 row_buffer_size,
- const std::set<string>& columns, const BigQueryTablePartition& partition,
+ int64 timestamp_millis, int64 row_buffer_size, const string& end_point,
+ const std::vector<string>& columns, const BigQueryTablePartition& partition,
std::unique_ptr<AuthProvider> auth_provider,
std::unique_ptr<HttpRequest::Factory> http_request_factory)
: project_id_(project_id),
dataset_id_(dataset_id),
table_id_(table_id),
timestamp_millis_(timestamp_millis),
- columns_(columns),
+ columns_(columns.begin(), columns.end()),
+ bigquery_end_point_(end_point),
partition_(partition),
auth_provider_(std::move(auth_provider)),
http_request_factory_(std::move(http_request_factory)) {
@@ -153,10 +165,14 @@ BigQueryTableAccessor::BigQueryTableAccessor(
Reset();
}
-void BigQueryTableAccessor::SetPartition(
+Status BigQueryTableAccessor::SetPartition(
const BigQueryTablePartition& partition) {
+ if (partition.start_index() < 0) {
+ return errors::InvalidArgument("Start index cannot be negative.");
+ }
partition_ = partition;
Reset();
+ return Status::OK();
}
void BigQueryTableAccessor::Reset() {
@@ -172,7 +188,8 @@ Status BigQueryTableAccessor::ReadRow(int64* row_id, Example* example) {
// If the next row is already fetched and cached, return the row from the
// buffer. Otherwise, fill up the row buffer from BigQuery and return a row.
- if (next_row_in_buffer_ != -1 && next_row_in_buffer_ < row_buffer_.size()) {
+ if (next_row_in_buffer_ != -1 &&
+ next_row_in_buffer_ < ComputeMaxResultsArg()) {
*row_id = first_buffered_row_index_ + next_row_in_buffer_;
*example = row_buffer_[next_row_in_buffer_];
next_row_in_buffer_++;
@@ -190,12 +207,12 @@ Status BigQueryTableAccessor::ReadRow(int64* row_id, Example* example) {
// we use the page token (which returns rows faster).
if (!next_page_token_.empty()) {
TF_RETURN_IF_ERROR(request->SetUri(strings::StrCat(
- BigQueryUriPrefix(), "data?maxResults=", row_buffer_.size(),
+ BigQueryUriPrefix(), "data?maxResults=", ComputeMaxResultsArg(),
"&pageToken=", request->EscapeString(next_page_token_))));
first_buffered_row_index_ += row_buffer_.size();
} else {
TF_RETURN_IF_ERROR(request->SetUri(strings::StrCat(
- BigQueryUriPrefix(), "data?maxResults=", row_buffer_.size(),
+ BigQueryUriPrefix(), "data?maxResults=", ComputeMaxResultsArg(),
"&startIndex=", first_buffered_row_index_)));
}
TF_RETURN_IF_ERROR(request->AddAuthBearerHeader(auth_token));
@@ -222,6 +239,18 @@ Status BigQueryTableAccessor::ReadRow(int64* row_id, Example* example) {
return Status::OK();
}
+int64 BigQueryTableAccessor::ComputeMaxResultsArg() {
+ if (partition_.end_index() == -1) {
+ return row_buffer_.size();
+ }
+ if (IsPartitionEmpty(partition_)) {
+ return 0;
+ }
+ return std::min(static_cast<int64>(row_buffer_.size()),
+ static_cast<int64>(partition_.end_index() -
+ partition_.start_index() + 1));
+}
+
Status BigQueryTableAccessor::ParseColumnValues(
const Json::Value& value, const SchemaNode& root_schema_node,
Example* example) {
@@ -364,21 +393,17 @@ Status BigQueryTableAccessor::AppendValueToExample(
string BigQueryTableAccessor::BigQueryTableAccessor::BigQueryUriPrefix() {
HttpRequest request;
- return strings::StrCat("https://www.googleapis.com/bigquery/v2/projects/",
+ return strings::StrCat(bigquery_end_point_, "/projects/",
request.EscapeString(project_id_), "/datasets/",
request.EscapeString(dataset_id_), "/tables/",
request.EscapeString(table_id_), "/");
}
-string BigQueryTableAccessor::FullTableName() {
- return strings::StrCat(project_id_, ":", dataset_id_, ".", table_id_, "@",
- timestamp_millis_);
-}
-
bool BigQueryTableAccessor::Done() {
return (total_num_rows_ <= first_buffered_row_index_ + next_row_in_buffer_) ||
+ IsPartitionEmpty(partition_) ||
(partition_.end_index() != -1 &&
- partition_.end_index() <=
+ partition_.end_index() <
first_buffered_row_index_ + next_row_in_buffer_);
}
diff --git a/tensorflow/core/kernels/cloud/bigquery_table_accessor.h b/tensorflow/core/kernels/cloud/bigquery_table_accessor.h
index fafda9cdd6..33d1905b8a 100644
--- a/tensorflow/core/kernels/cloud/bigquery_table_accessor.h
+++ b/tensorflow/core/kernels/cloud/bigquery_table_accessor.h
@@ -55,16 +55,23 @@ class BigQueryTableAccessor {
};
/// \brief Creates a new BigQueryTableAccessor object.
+ //
+ // We do not allow relative (negative or zero) snapshot times here since we
+ // want to have a consistent snapshot of the table for the lifetime of this
+ // object.
+ // Use end_point if you want to connect to a different end point than the
+ // official BigQuery end point. Otherwise send an empty string.
static Status New(const string& project_id, const string& dataset_id,
const string& table_id, int64 timestamp_millis,
- int64 row_buffer_size, const std::set<string>& columns,
+ int64 row_buffer_size, const string& end_point,
+ const std::vector<string>& columns,
const BigQueryTablePartition& partition,
std::unique_ptr<BigQueryTableAccessor>* accessor);
/// \brief Starts reading a new partition.
- void SetPartition(const BigQueryTablePartition& partition);
+ Status SetPartition(const BigQueryTablePartition& partition);
- /// \brief Returns false if there are more rows available in the current
+ /// \brief Returns true if there are more rows available in the current
/// partition.
bool Done();
@@ -74,9 +81,11 @@ class BigQueryTableAccessor {
/// in the BigQuery service.
Status ReadRow(int64* row_id, Example* example);
- /// \brief Returns total number of rows.
+ /// \brief Returns total number of rows in the table.
int64 total_num_rows() { return total_num_rows_; }
+ virtual ~BigQueryTableAccessor() {}
+
private:
friend class BigQueryTableAccessorTest;
@@ -95,7 +104,8 @@ class BigQueryTableAccessor {
/// these two variables.
static Status New(const string& project_id, const string& dataset_id,
const string& table_id, int64 timestamp_millis,
- int64 row_buffer_size, const std::set<string>& columns,
+ int64 row_buffer_size, const string& end_point,
+ const std::vector<string>& columns,
const BigQueryTablePartition& partition,
std::unique_ptr<AuthProvider> auth_provider,
std::unique_ptr<HttpRequest::Factory> http_request_factory,
@@ -104,14 +114,16 @@ class BigQueryTableAccessor {
/// \brief Constructs an object for a given table and partition.
BigQueryTableAccessor(const string& project_id, const string& dataset_id,
const string& table_id, int64 timestamp_millis,
- int64 row_buffer_size, const std::set<string>& columns,
+ int64 row_buffer_size, const string& end_point,
+ const std::vector<string>& columns,
const BigQueryTablePartition& partition);
/// Used for unit testing.
BigQueryTableAccessor(
const string& project_id, const string& dataset_id,
const string& table_id, int64 timestamp_millis, int64 row_buffer_size,
- const std::set<string>& columns, const BigQueryTablePartition& partition,
+ const string& end_point, const std::vector<string>& columns,
+ const BigQueryTablePartition& partition,
std::unique_ptr<AuthProvider> auth_provider,
std::unique_ptr<HttpRequest::Factory> http_request_factory);
@@ -132,7 +144,7 @@ class BigQueryTableAccessor {
Status AppendValueToExample(const string& column_name,
const Json::Value& column_value,
const BigQueryTableAccessor::ColumnType type,
- Example* ex);
+ Example* example);
/// \brief Resets internal counters for reading a partition.
void Reset();
@@ -140,25 +152,28 @@ class BigQueryTableAccessor {
/// \brief Helper function that returns BigQuery http endpoint prefix.
string BigQueryUriPrefix();
+ /// \brief Computes the maxResults arg to send to BigQuery.
+ int64 ComputeMaxResultsArg();
+
/// \brief Returns full name of the underlying table name.
- string FullTableName();
+ string FullTableName() {
+ return strings::StrCat(project_id_, ":", dataset_id_, ".", table_id_, "@",
+ timestamp_millis_);
+ }
const string project_id_;
const string dataset_id_;
const string table_id_;
// Snapshot timestamp.
- //
- // Indicates a snapshot of the table in milliseconds since the epoch.
- //
- // We do not allow relative (negative or zero) times here since we want to
- // have a consistent snapshot of the table for the lifetime of this object.
- // For more details, see 'Table Decorators' in BigQuery documentation.
const int64 timestamp_millis_;
// Columns that should be read. Empty means all columns.
const std::set<string> columns_;
+ // HTTP address of BigQuery end point to use.
+ const string bigquery_end_point_;
+
// Describes the portion of the table that we are currently accessing.
BigQueryTablePartition partition_;
diff --git a/tensorflow/core/kernels/cloud/bigquery_table_accessor_test.cc b/tensorflow/core/kernels/cloud/bigquery_table_accessor_test.cc
index 306cc5a4e1..57a4b89251 100644
--- a/tensorflow/core/kernels/cloud/bigquery_table_accessor_test.cc
+++ b/tensorflow/core/kernels/cloud/bigquery_table_accessor_test.cc
@@ -23,7 +23,6 @@ limitations under the License.
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
-
namespace {
constexpr char kTestProject[] = "test-project";
@@ -69,10 +68,10 @@ class BigQueryTableAccessorTest : public ::testing::Test {
Status CreateTableAccessor(const string& project_id, const string& dataset_id,
const string& table_id, int64 timestamp_millis,
int64 row_buffer_size,
- const std::set<string>& columns,
+ const std::vector<string>& columns,
const BigQueryTablePartition& partition) {
return BigQueryTableAccessor::New(
- project_id, dataset_id, table_id, timestamp_millis, row_buffer_size,
+ project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, "",
columns, partition, std::unique_ptr<AuthProvider>(new FakeAuthProvider),
std::unique_ptr<HttpRequest::Factory>(
new FakeHttpRequestFactory(&requests_)),
@@ -197,7 +196,7 @@ TEST_F(BigQueryTableAccessorTest, ReadOneRowTest) {
kTestRow));
BigQueryTablePartition partition;
partition.set_start_index(2);
- partition.set_end_index(3);
+ partition.set_end_index(2);
TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1,
{}, partition));
int64 row_id;
@@ -227,7 +226,7 @@ TEST_F(BigQueryTableAccessorTest, ReadOneRowPartialTest) {
kTestRow));
BigQueryTablePartition partition;
partition.set_start_index(2);
- partition.set_end_index(3);
+ partition.set_end_index(2);
TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1,
{"bool_field", "rec_field.float_field"},
partition));
@@ -258,7 +257,7 @@ TEST_F(BigQueryTableAccessorTest, ReadOneRowWithNullsTest) {
kTestRowWithNulls));
BigQueryTablePartition partition;
partition.set_start_index(2);
- partition.set_end_index(3);
+ partition.set_end_index(2);
TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1,
{}, partition));
int64 row_id;
@@ -288,7 +287,7 @@ TEST_F(BigQueryTableAccessorTest, BrokenRowTest) {
kBrokenTestRow));
BigQueryTablePartition partition;
partition.set_start_index(2);
- partition.set_end_index(3);
+ partition.set_end_index(2);
TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1,
{}, partition));
int64 row_id;
@@ -357,7 +356,7 @@ TEST_F(BigQueryTableAccessorTest, SwitchingPartitionsTest) {
kSampleSchema));
requests_.emplace_back(new FakeHttpRequest(
"Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/"
- "datasets/test-dataset/tables/test-table/data?maxResults=2&startIndex=0\n"
+ "datasets/test-dataset/tables/test-table/data?maxResults=1&startIndex=0\n"
"Auth Token: fake_token\n",
kTestTwoRows));
requests_.emplace_back(new FakeHttpRequest(
@@ -374,7 +373,7 @@ TEST_F(BigQueryTableAccessorTest, SwitchingPartitionsTest) {
BigQueryTablePartition partition;
partition.set_start_index(0);
- partition.set_end_index(1);
+ partition.set_end_index(0);
TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 2,
{}, partition));
@@ -396,7 +395,7 @@ TEST_F(BigQueryTableAccessorTest, SwitchingPartitionsTest) {
1234);
partition.set_start_index(0);
- partition.set_end_index(2);
+ partition.set_end_index(1);
accessor_->SetPartition(partition);
TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example));
EXPECT_EQ(0, row_id);
@@ -410,4 +409,23 @@ TEST_F(BigQueryTableAccessorTest, SwitchingPartitionsTest) {
2222);
}
+TEST_F(BigQueryTableAccessorTest, EmptyPartitionTest) {
+ requests_.emplace_back(new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/"
+ "datasets/test-dataset/tables/test-table/\n"
+ "Auth Token: fake_token\n",
+ kSampleSchema));
+
+ BigQueryTablePartition partition;
+ partition.set_start_index(3);
+ partition.set_end_index(2);
+ TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1,
+ {}, partition));
+ EXPECT_TRUE(accessor_->Done());
+
+ int64 row_id;
+ Example example;
+ EXPECT_TRUE(errors::IsOutOfRange(accessor_->ReadRow(&row_id, &example)));
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/ops/cloud_ops.cc b/tensorflow/core/ops/cloud_ops.cc
new file mode 100644
index 0000000000..89f31a46ab
--- /dev/null
+++ b/tensorflow/core/ops/cloud_ops.cc
@@ -0,0 +1,88 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+/* This file registers all cloud ops. */
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+namespace tensorflow {
+
+using shape_inference::InferenceContext;
+
+REGISTER_OP("BigQueryReader")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .Attr("project_id: string")
+ .Attr("dataset_id: string")
+ .Attr("table_id: string")
+ .Attr("columns: list(string)")
+ .Attr("timestamp_millis: int")
+ .Attr("test_end_point: string = ''")
+ .Output("reader_handle: Ref(string)")
+ .SetIsStateful()
+ .SetShapeFn([](InferenceContext* c) {
+ c->set_output(0, c->Vector(2));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+A Reader that outputs rows from a BigQuery table as tensorflow Examples.
+
+container: If non-empty, this reader is placed in the given container.
+ Otherwise, a default container is used.
+shared_name: If non-empty, this reader is named in the given bucket
+ with this shared_name. Otherwise, the node name is used instead.
+project_id: GCP project ID.
+dataset_id: BigQuery Dataset ID.
+table_id: Table to read.
+columns: List of columns to read. Leave empty to read all columns.
+timestamp_millis: Table snapshot timestamp in millis since epoch. Relative
+(negative or zero) snapshot times are not allowed. For more details, see
+'Table Decorators' in BigQuery docs.
+test_end_point: Do not use. For testing purposes only.
+reader_handle: The handle to reference the Reader.
+)doc");
+
+REGISTER_OP("GenerateBigQueryReaderPartitions")
+ .Attr("project_id: string")
+ .Attr("dataset_id: string")
+ .Attr("table_id: string")
+ .Attr("columns: list(string)")
+ .Attr("timestamp_millis: int")
+ .Attr("num_partitions: int")
+ .Attr("test_end_point: string = ''")
+ .Output("partitions: string")
+ .SetShapeFn([](InferenceContext* c) {
+ c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Generates serialized partition messages suitable for batch reads.
+
+This op should not be used directly by clients. Instead, the
+bigquery_reader_ops.py file defines a clean interface to the reader.
+
+project_id: GCP project ID.
+dataset_id: BigQuery Dataset ID.
+table_id: Table to read.
+columns: List of columns to read. Leave empty to read all columns.
+timestamp_millis: Table snapshot timestamp in millis since epoch. Relative
+(negative or zero) snapshot times are not allowed. For more details, see
+'Table Decorators' in BigQuery docs.
+num_partitions: Number of partitions to split the table into.
+test_end_point: Do not use. For testing purposes only.
+partitions: Serialized table partitions.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index c4162b6e22..d5f6ecc78c 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -35,6 +35,7 @@ py_library(
":check_ops",
":client",
":client_testlib",
+ ":cloud_ops",
":confusion_matrix",
":control_flow_ops",
":errors",
@@ -790,6 +791,11 @@ tf_gen_op_wrapper_private_py(
)
tf_gen_op_wrapper_private_py(
+ name = "cloud_ops_gen",
+ require_shape_functions = True,
+)
+
+tf_gen_op_wrapper_private_py(
name = "control_flow_ops_gen",
require_shape_functions = True,
deps = [
@@ -1432,6 +1438,19 @@ py_library(
)
py_library(
+ name = "cloud_ops",
+ srcs = [
+ "ops/cloud/__init__.py",
+ "ops/cloud/bigquery_reader_ops.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":cloud_ops_gen",
+ ":framework",
+ ],
+)
+
+py_library(
name = "script_ops",
srcs = ["ops/script_ops.py"],
srcs_version = "PY2AND3",
@@ -2028,6 +2047,17 @@ cuda_py_test(
],
)
+tf_py_test(
+ name = "bigquery_reader_ops_test",
+ size = "small",
+ srcs = ["ops/cloud/bigquery_reader_ops_test.py"],
+ additional_deps = [
+ ":cloud_ops",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:util",
+ ],
+)
+
py_library(
name = "training",
srcs = glob(
diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py
index 70ea38ffb2..6626a80149 100644
--- a/tensorflow/python/__init__.py
+++ b/tensorflow/python/__init__.py
@@ -83,6 +83,7 @@ from tensorflow.python.ops.standard_ops import *
# Bring in subpackages.
from tensorflow.python.layers import layers
+from tensorflow.python.ops import cloud
from tensorflow.python.ops import metrics
from tensorflow.python.ops import nn
from tensorflow.python.ops import sdca_ops as sdca
@@ -213,6 +214,7 @@ _allowed_symbols.extend([
_allowed_symbols.extend([
'app',
'compat',
+ 'cloud',
'errors',
'flags',
'gfile',
diff --git a/tensorflow/python/ops/cloud/__init__.py b/tensorflow/python/ops/cloud/__init__.py
new file mode 100644
index 0000000000..536d911a48
--- /dev/null
+++ b/tensorflow/python/ops/cloud/__init__.py
@@ -0,0 +1,22 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Import cloud ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=wildcard-import
+from tensorflow.python.ops.cloud.bigquery_reader_ops import *
diff --git a/tensorflow/python/ops/cloud/bigquery_reader_ops.py b/tensorflow/python/ops/cloud/bigquery_reader_ops.py
new file mode 100644
index 0000000000..dbdc3fd7a2
--- /dev/null
+++ b/tensorflow/python/ops/cloud/bigquery_reader_ops.py
@@ -0,0 +1,150 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""BigQuery reading support for TensorFlow."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_cloud_ops
+from tensorflow.python.ops import io_ops
+
+
+class BigQueryReader(io_ops.ReaderBase):
+ """A Reader that outputs keys and tf.Example values from a BigQuery table.
+
+ Example use:
+ ```python
+ # Assume a BigQuery has the following schema,
+ # name STRING,
+ # age INT,
+ # state STRING
+
+ # Create the parse_examples list of features.
+ features = dict(
+ name=tf.FixedLenFeature([1], tf.string),
+ age=tf.FixedLenFeature([1], tf.int32),
+ state=tf.FixedLenFeature([1], dtype=tf.string, default_value="UNK"))
+
+ # Create a Reader.
+ reader = bigquery_reader_ops.BigQueryReader(project_id=PROJECT,
+ dataset_id=DATASET,
+ table_id=TABLE,
+ timestamp_millis=TIME,
+ num_partitions=NUM_PARTITIONS,
+ features=features)
+
+ # Populate a queue with the BigQuery Table partitions.
+ queue = tf.training.string_input_producer(reader.partitions())
+
+ # Read and parse examples.
+ row_id, examples_serialized = reader.read(queue)
+ examples = tf.parse_example(examples_serialized, features=features)
+
+ # Process the Tensors examples["name"], examples["age"], etc...
+ ```
+
+ Note that to create a reader a snapshot timestamp is necessary. This
+ will enable the reader to look at a consistent snapshot of the table.
+ For more information, see 'Table Decorators' in BigQuery docs.
+
+ See ReaderBase for supported methods.
+ """
+
+ def __init__(self,
+ project_id,
+ dataset_id,
+ table_id,
+ timestamp_millis,
+ num_partitions,
+ features=None,
+ columns=None,
+ test_end_point=None,
+ name=None):
+ """Creates a BigQueryReader.
+
+ Args:
+ project_id: GCP project ID.
+ dataset_id: BigQuery dataset ID.
+ table_id: BigQuery table ID.
+ timestamp_millis: timestamp to snapshot the table in milliseconds since
+ the epoch. Relative (negative or zero) snapshot times are not allowed.
+ For more details, see 'Table Decorators' in BigQuery docs.
+ num_partitions: Number of non-overlapping partitions to read from.
+ features: parse_example compatible dict from keys to `VarLenFeature` and
+ `FixedLenFeature` objects. Keys are read as columns from the db.
+ columns: list of columns to read, can be set iff features is None.
+ test_end_point: Used only for testing purposes (optional).
+ name: a name for the operation (optional).
+
+ Raises:
+ TypeError: - If features is neither None nor a dict or
+ - If columns is is neither None nor a list or
+ - If both features and columns are None or set.
+ """
+ if (features is None) == (columns is None):
+ raise TypeError("exactly one of features and columns must be set.")
+
+ if features is not None:
+ if not isinstance(features, dict):
+ raise TypeError("features must be a dict.")
+ self._columns = list(features.keys())
+ elif columns is not None:
+ if not isinstance(columns, list):
+ raise TypeError("columns must be a list.")
+ self._columns = columns
+
+ self._project_id = project_id
+ self._dataset_id = dataset_id
+ self._table_id = table_id
+ self._timestamp_millis = timestamp_millis
+ self._num_partitions = num_partitions
+ self._test_end_point = test_end_point
+
+ reader = gen_cloud_ops.big_query_reader(
+ name=name,
+ project_id=self._project_id,
+ dataset_id=self._dataset_id,
+ table_id=self._table_id,
+ timestamp_millis=self._timestamp_millis,
+ columns=self._columns,
+ test_end_point=self._test_end_point)
+ super(BigQueryReader, self).__init__(reader)
+
+ def partitions(self, name=None):
+ """Returns serialized BigQueryTablePartition messages.
+
+ These messages represent a non-overlapping division of a table for a
+ bulk read.
+
+ Args:
+ name: a name for the operation (optional).
+
+ Returns:
+ `1-D` string `Tensor` of serialized `BigQueryTablePartition` messages.
+ """
+ return gen_cloud_ops.generate_big_query_reader_partitions(
+ name=name,
+ project_id=self._project_id,
+ dataset_id=self._dataset_id,
+ table_id=self._table_id,
+ timestamp_millis=self._timestamp_millis,
+ num_partitions=self._num_partitions,
+ test_end_point=self._test_end_point,
+ columns=self._columns)
+
+
+ops.NotDifferentiable("BigQueryReader")
diff --git a/tensorflow/python/ops/cloud/bigquery_reader_ops_test.py b/tensorflow/python/ops/cloud/bigquery_reader_ops_test.py
new file mode 100644
index 0000000000..8e59985ff4
--- /dev/null
+++ b/tensorflow/python/ops/cloud/bigquery_reader_ops_test.py
@@ -0,0 +1,274 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for BigQueryReader Op."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import json
+import re
+import threading
+
+from six.moves import SimpleHTTPServer
+from six.moves import socketserver
+import tensorflow as tf
+
+from tensorflow.core.example import example_pb2
+from tensorflow.core.framework import types_pb2
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import compat
+
+_PROJECT = "test-project"
+_DATASET = "test-dataset"
+_TABLE = "test-table"
+# List representation of the test rows in the 'test-table' in BigQuery.
+# The schema for each row is: [int64, string, float].
+# The values for rows are generated such that some columns have null values. The
+# general formula here is:
+# - The int64 column is present in every row.
+# - The string column is only avaiable in even rows.
+# - The float column is only available in every third row.
+_ROWS = [[0, "s_0", 0.1], [1, None, None], [2, "s_2", None], [3, None, 3.1],
+ [4, "s_4", None], [5, None, None], [6, "s_6", 6.1], [7, None, None],
+ [8, "s_8", None], [9, None, 9.1]]
+# Schema for 'test-table'.
+# The schema currently has three columns: int64, string, and float
+_SCHEMA = {
+ "kind": "bigquery#table",
+ "id": "test-project:test-dataset.test-table",
+ "schema": {
+ "fields": [{
+ "name": "int64_col",
+ "type": "INTEGER",
+ "mode": "NULLABLE"
+ }, {
+ "name": "string_col",
+ "type": "STRING",
+ "mode": "NULLABLE"
+ }, {
+ "name": "float_col",
+ "type": "FLOAT",
+ "mode": "NULLABLE"
+ }]
+ }
+}
+
+
+def _ConvertRowToExampleProto(row):
+ """Converts the input row to an Example proto.
+
+ Args:
+ row: Input Row instance.
+
+ Returns:
+ An Example proto initialized with row values.
+ """
+
+ example = example_pb2.Example()
+ example.features.feature["int64_col"].int64_list.value.append(row[0])
+ if row[1] is not None:
+ example.features.feature["string_col"].bytes_list.value.append(compat.as_bytes(row[1]))
+ if row[2] is not None:
+ example.features.feature["float_col"].float_list.value.append(row[2])
+ return example
+
+
+class FakeBigQueryServer(threading.Thread):
+ """Fake http server to return schema and data for sample table."""
+
+ def __init__(self, address, port):
+ """Creates a FakeBigQueryServer.
+
+ Args:
+ address: Server address
+ port: Server port. Pass 0 to automatically pick an empty port.
+ """
+ threading.Thread.__init__(self)
+ self.handler = BigQueryRequestHandler
+ self.httpd = socketserver.TCPServer((address, port), self.handler)
+
+ def run(self):
+ self.httpd.serve_forever()
+
+ def shutdown(self):
+ self.httpd.shutdown()
+ self.httpd.socket.close()
+
+
+class BigQueryRequestHandler(SimpleHTTPServer.SimpleHTTPRequestHandler):
+ """Responds to BigQuery HTTP requests.
+
+ Attributes:
+ num_rows: num_rows in the underlying table served by this class.
+ """
+
+ num_rows = 0
+
+ def do_GET(self):
+ if "data?maxResults=" not in self.path:
+ # This is a schema request.
+ _SCHEMA["numRows"] = self.num_rows
+ response = json.dumps(_SCHEMA)
+ else:
+ # This is a data request.
+ #
+ # Extract max results and start index.
+ max_results = int(re.findall(r"maxResults=(\d+)", self.path)[0])
+ start_index = int(re.findall(r"startIndex=(\d+)", self.path)[0])
+
+ # Send the rows as JSON.
+ rows = []
+ for row in _ROWS[start_index:start_index + max_results]:
+ row_json = {
+ "f": [{
+ "v": str(row[0])
+ }, {
+ "v": str(row[1]) if row[1] is not None else None
+ }, {
+ "v": str(row[2]) if row[2] is not None else None
+ }]
+ }
+ rows.append(row_json)
+ response = json.dumps({
+ "kind": "bigquery#table",
+ "id": "test-project:test-dataset.test-table",
+ "rows": rows
+ })
+ self.send_response(200)
+ self.end_headers()
+ self.wfile.write(compat.as_bytes(response))
+
+
+def _SetUpQueue(reader):
+ """Sets up a queue for a reader."""
+ queue = tf.FIFOQueue(8, [types_pb2.DT_STRING], shapes=())
+ key, value = reader.read(queue)
+ queue.enqueue_many(reader.partitions()).run()
+ queue.close().run()
+ return key, value
+
+
+class BigQueryReaderOpsTest(tf.test.TestCase):
+
+ def setUp(self):
+ super(BigQueryReaderOpsTest, self).setUp()
+ self.server = FakeBigQueryServer("127.0.0.1", 0)
+ self.server.start()
+ logging.info("server address is %s:%s", self.server.httpd.server_address[0],
+ self.server.httpd.server_address[1])
+
+ def tearDown(self):
+ self.server.shutdown()
+ super(BigQueryReaderOpsTest, self).tearDown()
+
+ def _ReadAndCheckRowsUsingFeatures(self, num_rows):
+ self.server.handler.num_rows = num_rows
+
+ with self.test_session() as sess:
+ feature_configs = {
+ "int64_col":
+ tf.FixedLenFeature(
+ [1], dtype=tf.int64),
+ "string_col":
+ tf.FixedLenFeature(
+ [1], dtype=tf.string, default_value="s_default"),
+ }
+ reader = tf.cloud.BigQueryReader(
+ project_id=_PROJECT,
+ dataset_id=_DATASET,
+ table_id=_TABLE,
+ num_partitions=4,
+ features=feature_configs,
+ timestamp_millis=1,
+ test_end_point=("%s:%s" % (self.server.httpd.server_address[0],
+ self.server.httpd.server_address[1])))
+
+ key, value = _SetUpQueue(reader)
+
+ seen_rows = []
+ features = tf.parse_example(tf.reshape(value, [1]), feature_configs)
+ for _ in range(num_rows):
+ int_value, str_value = sess.run(
+ [features["int64_col"], features["string_col"]])
+
+ # Parse values returned from the session.
+ self.assertEqual(int_value.shape, (1, 1))
+ self.assertEqual(str_value.shape, (1, 1))
+ int64_col = int_value[0][0]
+ string_col = str_value[0][0]
+ seen_rows.append(int64_col)
+
+ # Compare.
+ expected_row = _ROWS[int64_col]
+ self.assertEqual(int64_col, expected_row[0])
+ self.assertEqual(compat.as_str(string_col), ("s_%d" % int64_col) if expected_row[1]
+ else "s_default")
+
+ self.assertItemsEqual(seen_rows, range(num_rows))
+
+ with self.assertRaisesOpError("is closed and has insufficient elements "
+ "\\(requested 1, current size 0\\)"):
+ sess.run([key, value])
+
+ def testReadingSingleRowUsingFeatures(self):
+ self._ReadAndCheckRowsUsingFeatures(1)
+
+ def testReadingMultipleRowsUsingFeatures(self):
+ self._ReadAndCheckRowsUsingFeatures(10)
+
+ def testReadingMultipleRowsUsingColumns(self):
+ num_rows = 10
+ self.server.handler.num_rows = num_rows
+
+ with self.test_session() as sess:
+ reader = tf.cloud.BigQueryReader(
+ project_id=_PROJECT,
+ dataset_id=_DATASET,
+ table_id=_TABLE,
+ num_partitions=4,
+ columns=["int64_col", "float_col", "string_col"],
+ timestamp_millis=1,
+ test_end_point=("%s:%s" % (self.server.httpd.server_address[0],
+ self.server.httpd.server_address[1])))
+ key, value = _SetUpQueue(reader)
+ seen_rows = []
+ for row_index in range(num_rows):
+ returned_row_id, example_proto = sess.run([key, value])
+ example = example_pb2.Example()
+ example.ParseFromString(example_proto)
+ self.assertIn("int64_col", example.features.feature)
+ feature = example.features.feature["int64_col"]
+ self.assertEqual(len(feature.int64_list.value), 1)
+ int64_col = feature.int64_list.value[0]
+ seen_rows.append(int64_col)
+
+ # Create our expected Example.
+ expected_example = example_pb2.Example()
+ expected_example = _ConvertRowToExampleProto(_ROWS[int64_col])
+
+ # Compare.
+ self.assertProtoEquals(example, expected_example)
+ self.assertEqual(row_index, int(returned_row_id))
+
+ self.assertItemsEqual(seen_rows, range(num_rows))
+
+ with self.assertRaisesOpError("is closed and has insufficient elements "
+ "\\(requested 1, current size 0\\)"):
+ sess.run([key, value])
+
+
+if __name__ == "__main__":
+ tf.test.main()