From 1b881b7c77bd1e664382785447a170de2b85f688 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 13 Mar 2017 16:58:18 -0800 Subject: First version of BigQuery Reader. Change: 150016997 --- tensorflow/BUILD | 3 +- tensorflow/contrib/BUILD | 1 + tensorflow/contrib/__init__.py | 1 + tensorflow/contrib/cloud/BUILD | 73 +++ tensorflow/contrib/cloud/__init__.py | 28 ++ tensorflow/contrib/cloud/kernels/BUILD | 94 ++++ .../contrib/cloud/kernels/bigquery_reader_ops.cc | 192 ++++++++ .../cloud/kernels/bigquery_table_accessor.cc | 410 ++++++++++++++++ .../cloud/kernels/bigquery_table_accessor.h | 208 +++++++++ .../cloud/kernels/bigquery_table_accessor_test.cc | 513 +++++++++++++++++++++ .../kernels/bigquery_table_accessor_test_data.h | 404 ++++++++++++++++ .../cloud/kernels/bigquery_table_partition.proto | 12 + .../contrib/cloud/ops/bigquery_reader_ops.cc | 88 ++++ .../cloud/python/ops/bigquery_reader_ops.py | 150 ++++++ .../cloud/python/ops/bigquery_reader_ops_test.py | 287 ++++++++++++ tensorflow/contrib/cmake/tf_core_kernels.cmake | 4 +- tensorflow/contrib/cmake/tf_core_ops.cmake | 1 + tensorflow/contrib/cmake/tf_python.cmake | 8 + tensorflow/contrib/cmake/tf_tests.cmake | 1 + tensorflow/core/BUILD | 10 - tensorflow/core/kernels/BUILD | 2 - tensorflow/core/kernels/cloud/BUILD | 98 ---- .../core/kernels/cloud/bigquery_reader_ops.cc | 193 -------- .../core/kernels/cloud/bigquery_table_accessor.cc | 410 ---------------- .../core/kernels/cloud/bigquery_table_accessor.h | 207 --------- .../kernels/cloud/bigquery_table_accessor_test.cc | 431 ----------------- .../cloud/bigquery_table_accessor_test_data.h | 325 ------------- .../kernels/cloud/bigquery_table_partition.proto | 12 - tensorflow/core/ops/cloud_ops.cc | 88 ---- tensorflow/core/platform/cloud/http_request.cc | 1 - tensorflow/core/platform/default/build_config.bzl | 24 +- tensorflow/python/BUILD | 38 -- tensorflow/python/ops/cloud/__init__.py | 0 tensorflow/python/ops/cloud/bigquery_reader_ops.py | 157 ------- .../python/ops/cloud/bigquery_reader_ops_test.py | 286 ------------ tensorflow/python/ops/cloud/cloud.py | 31 -- 36 files changed, 2491 insertions(+), 2300 deletions(-) create mode 100644 tensorflow/contrib/cloud/BUILD create mode 100644 tensorflow/contrib/cloud/__init__.py create mode 100644 tensorflow/contrib/cloud/kernels/BUILD create mode 100644 tensorflow/contrib/cloud/kernels/bigquery_reader_ops.cc create mode 100644 tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc create mode 100644 tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h create mode 100644 tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test.cc create mode 100644 tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test_data.h create mode 100644 tensorflow/contrib/cloud/kernels/bigquery_table_partition.proto create mode 100644 tensorflow/contrib/cloud/ops/bigquery_reader_ops.cc create mode 100644 tensorflow/contrib/cloud/python/ops/bigquery_reader_ops.py create mode 100644 tensorflow/contrib/cloud/python/ops/bigquery_reader_ops_test.py delete mode 100644 tensorflow/core/kernels/cloud/BUILD delete mode 100644 tensorflow/core/kernels/cloud/bigquery_reader_ops.cc delete mode 100644 tensorflow/core/kernels/cloud/bigquery_table_accessor.cc delete mode 100644 tensorflow/core/kernels/cloud/bigquery_table_accessor.h delete mode 100644 tensorflow/core/kernels/cloud/bigquery_table_accessor_test.cc delete mode 100644 tensorflow/core/kernels/cloud/bigquery_table_accessor_test_data.h delete mode 100644 tensorflow/core/kernels/cloud/bigquery_table_partition.proto delete mode 100644 tensorflow/core/ops/cloud_ops.cc delete mode 100644 tensorflow/python/ops/cloud/__init__.py delete mode 100644 tensorflow/python/ops/cloud/bigquery_reader_ops.py delete mode 100644 tensorflow/python/ops/cloud/bigquery_reader_ops_test.py delete mode 100644 tensorflow/python/ops/cloud/cloud.py diff --git a/tensorflow/BUILD b/tensorflow/BUILD index ec2861a675..db6d42e1bc 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -160,6 +160,8 @@ filegroup( "//tensorflow/contrib:all_files", "//tensorflow/contrib/android:all_files", "//tensorflow/contrib/bayesflow:all_files", + "//tensorflow/contrib/cloud:all_files", + "//tensorflow/contrib/cloud/kernels:all_files", "//tensorflow/contrib/compiler:all_files", "//tensorflow/contrib/copy_graph:all_files", "//tensorflow/contrib/crf:all_files", @@ -220,7 +222,6 @@ filegroup( "//tensorflow/core/grappler/inputs:all_files", "//tensorflow/core/grappler/optimizers:all_files", "//tensorflow/core/kernels:all_files", - "//tensorflow/core/kernels/cloud:all_files", "//tensorflow/core/kernels/hexagon:all_files", "//tensorflow/core/ops/compat:all_files", "//tensorflow/core/platform/cloud:all_files", diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index ef36702a52..29d60ae241 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -16,6 +16,7 @@ py_library( visibility = ["//visibility:public"], deps = [ "//tensorflow/contrib/bayesflow:bayesflow_py", + "//tensorflow/contrib/cloud:cloud_py", "//tensorflow/contrib/compiler:compiler_py", "//tensorflow/contrib/copy_graph:copy_graph_py", "//tensorflow/contrib/crf:crf_py", diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index af4e130870..7c0d1da8a6 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -20,6 +20,7 @@ from __future__ import print_function # Add projects here, they will show up under tf.contrib. from tensorflow.contrib import bayesflow +from tensorflow.contrib import cloud from tensorflow.contrib import compiler from tensorflow.contrib import copy_graph from tensorflow.contrib import crf diff --git a/tensorflow/contrib/cloud/BUILD b/tensorflow/contrib/cloud/BUILD new file mode 100644 index 0000000000..840997223f --- /dev/null +++ b/tensorflow/contrib/cloud/BUILD @@ -0,0 +1,73 @@ +# Description: +# BigQueryReader implementation + +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # Apache 2.0 + +load( + "//tensorflow:tensorflow.bzl", + "tf_gen_op_libs", + "tf_gen_op_wrapper_py", + "tf_py_test", +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +tf_gen_op_libs( + op_lib_names = ["bigquery_reader_ops"], + deps = [ + "//tensorflow/core:lib", + ], +) + +tf_gen_op_wrapper_py( + name = "gen_bigquery_reader_ops", + out = "python/ops/gen_bigquery_reader_ops.py", + require_shape_functions = True, + deps = [":bigquery_reader_ops_op_lib"], +) + +py_library( + name = "cloud_py", + srcs = [ + "__init__.py", + "python/ops/bigquery_reader_ops.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":gen_bigquery_reader_ops", + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform", + ], +) + +tf_py_test( + name = "bigquery_reader_ops_test", + size = "small", + srcs = ["python/ops/bigquery_reader_ops_test.py"], + additional_deps = [ + ":bigquery_reader_ops_op_lib", + ":cloud_py", + "//tensorflow/contrib/cloud/kernels:bigquery_reader_ops", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:io_ops", + "//tensorflow/python:parsing_ops", + ], + tags = ["manual"], +) diff --git a/tensorflow/contrib/cloud/__init__.py b/tensorflow/contrib/cloud/__init__.py new file mode 100644 index 0000000000..8870264b95 --- /dev/null +++ b/tensorflow/contrib/cloud/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Module for cloud ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=line-too-long,wildcard-import +from tensorflow.contrib.cloud.python.ops.bigquery_reader_ops import * +# pylint: enable=line-too-long,wildcard-import + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = ['BigQueryReader'] +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/cloud/kernels/BUILD b/tensorflow/contrib/cloud/kernels/BUILD new file mode 100644 index 0000000000..2500f10c74 --- /dev/null +++ b/tensorflow/contrib/cloud/kernels/BUILD @@ -0,0 +1,94 @@ +# Description: +# BigQueryReader implementation + +package( + default_visibility = ["//visibility:private"], +) + +licenses(["notice"]) # Apache 2.0 + +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", + "tf_copts", + "tf_kernel_library", +) + +# For platform specific build config +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_proto_library", +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +tf_kernel_library( + name = "bigquery_reader_ops", + srcs = [ + "bigquery_reader_ops.cc", + ], + visibility = ["//tensorflow:__subpackages__"], + deps = [ + ":bigquery_table_accessor", + ":bigquery_table_partition_proto_cc", + "//tensorflow/contrib/cloud:bigquery_reader_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:reader_base", + ], +) + +cc_library( + name = "bigquery_table_accessor", + srcs = [ + "bigquery_table_accessor.cc", + ], + hdrs = [ + "bigquery_table_accessor.h", + ], + copts = tf_copts(), + linkstatic = 1, + deps = [ + ":bigquery_table_partition_proto_cc", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform/cloud:google_auth_provider", + "//tensorflow/core/platform/cloud:http_request", + ], + alwayslink = 1, +) + +tf_cc_test( + name = "bigquery_table_accessor_test", + size = "small", + srcs = [ + "bigquery_table_accessor_test.cc", + "bigquery_table_accessor_test_data.h", + ], + deps = [ + ":bigquery_table_accessor", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/platform/cloud:http_request_fake", + ], +) + +tf_proto_library( + name = "bigquery_table_partition_proto", + srcs = [ + "bigquery_table_partition.proto", + ], + cc_api_version = 2, +) diff --git a/tensorflow/contrib/cloud/kernels/bigquery_reader_ops.cc b/tensorflow/contrib/cloud/kernels/bigquery_reader_ops.cc new file mode 100644 index 0000000000..02a759eefd --- /dev/null +++ b/tensorflow/contrib/cloud/kernels/bigquery_reader_ops.cc @@ -0,0 +1,192 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h" +#include "tensorflow/contrib/cloud/kernels/bigquery_table_partition.pb.h" +#include "tensorflow/core/framework/reader_base.h" +#include "tensorflow/core/framework/reader_op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/math/math_util.h" +#include "tensorflow/core/lib/strings/numbers.h" + +namespace tensorflow { +namespace { + +constexpr int64 kDefaultRowBufferSize = 1000; // Number of rows to buffer. + +// This is a helper function for reading table attributes from context. +Status GetTableAttrs(OpKernelConstruction* context, string* project_id, + string* dataset_id, string* table_id, + int64* timestamp_millis, std::vector* columns, + string* test_end_point) { + TF_RETURN_IF_ERROR(context->GetAttr("project_id", project_id)); + TF_RETURN_IF_ERROR(context->GetAttr("dataset_id", dataset_id)); + TF_RETURN_IF_ERROR(context->GetAttr("table_id", table_id)); + TF_RETURN_IF_ERROR(context->GetAttr("timestamp_millis", timestamp_millis)); + TF_RETURN_IF_ERROR(context->GetAttr("columns", columns)); + TF_RETURN_IF_ERROR(context->GetAttr("test_end_point", test_end_point)); + return Status::OK(); +} + +} // namespace + +// Note that overriden methods with names ending in "Locked" are called by +// ReaderBase while a mutex is held. +// See comments for ReaderBase. +class BigQueryReader : public ReaderBase { + public: + explicit BigQueryReader(BigQueryTableAccessor* bigquery_table_accessor, + const string& node_name) + : ReaderBase(strings::StrCat("BigQueryReader '", node_name, "'")), + bigquery_table_accessor_(CHECK_NOTNULL(bigquery_table_accessor)) {} + + Status OnWorkStartedLocked() override { + BigQueryTablePartition partition; + if (!partition.ParseFromString(current_work())) { + return errors::InvalidArgument( + "Could not parse work as as valid partition."); + } + TF_RETURN_IF_ERROR(bigquery_table_accessor_->SetPartition(partition)); + return Status::OK(); + } + + Status ReadLocked(string* key, string* value, bool* produced, + bool* at_end) override { + *at_end = false; + *produced = false; + if (bigquery_table_accessor_->Done()) { + *at_end = true; + return Status::OK(); + } + + Example example; + int64 row_id; + TF_RETURN_IF_ERROR(bigquery_table_accessor_->ReadRow(&row_id, &example)); + + *key = std::to_string(row_id); + *value = example.SerializeAsString(); + *produced = true; + return Status::OK(); + } + + private: + // Not owned. + BigQueryTableAccessor* bigquery_table_accessor_; +}; + +class BigQueryReaderOp : public ReaderOpKernel { + public: + explicit BigQueryReaderOp(OpKernelConstruction* context) + : ReaderOpKernel(context) { + string table_id; + string project_id; + string dataset_id; + int64 timestamp_millis; + std::vector columns; + string test_end_point; + + OP_REQUIRES_OK(context, + GetTableAttrs(context, &project_id, &dataset_id, &table_id, + ×tamp_millis, &columns, &test_end_point)); + OP_REQUIRES_OK(context, + BigQueryTableAccessor::New( + project_id, dataset_id, table_id, timestamp_millis, + kDefaultRowBufferSize, test_end_point, columns, + BigQueryTablePartition(), &bigquery_table_accessor_)); + + SetReaderFactory([this]() { + return new BigQueryReader(bigquery_table_accessor_.get(), name()); + }); + } + + private: + std::unique_ptr bigquery_table_accessor_; +}; + +REGISTER_KERNEL_BUILDER(Name("BigQueryReader").Device(DEVICE_CPU), + BigQueryReaderOp); + +class GenerateBigQueryReaderPartitionsOp : public OpKernel { + public: + explicit GenerateBigQueryReaderPartitionsOp(OpKernelConstruction* context) + : OpKernel(context) { + string project_id; + string dataset_id; + string table_id; + int64 timestamp_millis; + std::vector columns; + string test_end_point; + + OP_REQUIRES_OK(context, + GetTableAttrs(context, &project_id, &dataset_id, &table_id, + ×tamp_millis, &columns, &test_end_point)); + OP_REQUIRES_OK(context, + BigQueryTableAccessor::New( + project_id, dataset_id, table_id, timestamp_millis, + kDefaultRowBufferSize, test_end_point, columns, + BigQueryTablePartition(), &bigquery_table_accessor_)); + OP_REQUIRES_OK(context, InitializeNumberOfPartitions(context)); + OP_REQUIRES_OK(context, InitializeTotalNumberOfRows()); + } + + void Compute(OpKernelContext* context) override { + const int64 partition_size = tensorflow::MathUtil::CeilOfRatio( + total_num_rows_, num_partitions_); + Tensor* output_tensor = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, TensorShape({num_partitions_}), + &output_tensor)); + + auto output = output_tensor->template flat(); + for (int64 i = 0; i < num_partitions_; ++i) { + BigQueryTablePartition partition; + partition.set_start_index(i * partition_size); + partition.set_end_index( + std::min(total_num_rows_, (i + 1) * partition_size) - 1); + output(i) = partition.SerializeAsString(); + } + } + + private: + Status InitializeTotalNumberOfRows() { + total_num_rows_ = bigquery_table_accessor_->total_num_rows(); + if (total_num_rows_ <= 0) { + return errors::FailedPrecondition("Invalid total number of rows."); + } + return Status::OK(); + } + + Status InitializeNumberOfPartitions(OpKernelConstruction* context) { + TF_RETURN_IF_ERROR(context->GetAttr("num_partitions", &num_partitions_)); + if (num_partitions_ <= 0) { + return errors::FailedPrecondition("Invalid number of partitions."); + } + return Status::OK(); + } + + int64 num_partitions_; + int64 total_num_rows_; + std::unique_ptr bigquery_table_accessor_; +}; + +REGISTER_KERNEL_BUILDER( + Name("GenerateBigQueryReaderPartitions").Device(DEVICE_CPU), + GenerateBigQueryReaderPartitionsOp); + +} // namespace tensorflow diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc new file mode 100644 index 0000000000..5e95db55b6 --- /dev/null +++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc @@ -0,0 +1,410 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h" + +#include "tensorflow/core/example/feature.pb.h" +#include "tensorflow/core/lib/strings/numbers.h" + +namespace tensorflow { + +namespace { + +constexpr size_t kBufferSize = 1024 * 1024; // In bytes. +const string kBigQueryEndPoint = "https://www.googleapis.com/bigquery/v2"; + +bool IsPartitionEmpty(const BigQueryTablePartition& partition) { + if (partition.end_index() != -1 && + partition.end_index() < partition.start_index()) { + return true; + } + return false; +} + +Status ParseJson(StringPiece json, Json::Value* result) { + Json::Reader reader; + if (!reader.parse(json.ToString(), *result)) { + return errors::Internal("Couldn't parse JSON response from BigQuery."); + } + return Status::OK(); +} + +string ColumnTypeToString(BigQueryTableAccessor::ColumnType enum_type) { + switch (enum_type) { + case BigQueryTableAccessor::ColumnType::kRecord: + return "RECORD"; + case BigQueryTableAccessor::ColumnType::kString: + return "STRING"; + case BigQueryTableAccessor::ColumnType::kBytes: + return "BYTES"; + case BigQueryTableAccessor::ColumnType::kInteger: + return "INTEGER"; + case BigQueryTableAccessor::ColumnType::kFloat: + return "FLOAT"; + case BigQueryTableAccessor::ColumnType::kBoolean: + return "BOOLEAN"; + case BigQueryTableAccessor::ColumnType::kTimestamp: + return "TIMESTAMP"; + case BigQueryTableAccessor::ColumnType::kDate: + return "DATE"; + case BigQueryTableAccessor::ColumnType::kTime: + return "TIME"; + case BigQueryTableAccessor::ColumnType::kDatetime: + return "DATETIME"; + case BigQueryTableAccessor::ColumnType::kNone: + return "NONE"; + } +} + +Status ParseColumnType(const string& type, + BigQueryTableAccessor::ColumnType* enum_type) { + if (type == "RECORD") { + *enum_type = BigQueryTableAccessor::ColumnType::kRecord; + } else if (type == "STRING") { + *enum_type = BigQueryTableAccessor::ColumnType::kString; + } else if (type == "BYTES") { + *enum_type = BigQueryTableAccessor::ColumnType::kBytes; + } else if (type == "INTEGER") { + *enum_type = BigQueryTableAccessor::ColumnType::kInteger; + } else if (type == "FLOAT") { + *enum_type = BigQueryTableAccessor::ColumnType::kFloat; + } else if (type == "BOOLEAN") { + *enum_type = BigQueryTableAccessor::ColumnType::kBoolean; + } else if (type == "TIMESTAMP") { + *enum_type = BigQueryTableAccessor::ColumnType::kTimestamp; + } else if (type == "DATE") { + *enum_type = BigQueryTableAccessor::ColumnType::kDate; + } else if (type == "TIME") { + *enum_type = BigQueryTableAccessor::ColumnType::kTime; + } else if (type == "DATETIME") { + *enum_type = BigQueryTableAccessor::ColumnType::kDatetime; + } else { + return errors::Internal( + strings::StrCat("Could not parse column type ", type)); + } + return Status::OK(); +} + +} // namespace + +Status BigQueryTableAccessor::New( + const string& project_id, const string& dataset_id, const string& table_id, + int64 timestamp_millis, int64 row_buffer_size, const string& end_point, + const std::vector& columns, const BigQueryTablePartition& partition, + std::unique_ptr* accessor) { + return New(project_id, dataset_id, table_id, timestamp_millis, + row_buffer_size, end_point, columns, partition, nullptr, nullptr, + accessor); +} + +Status BigQueryTableAccessor::New( + const string& project_id, const string& dataset_id, const string& table_id, + int64 timestamp_millis, int64 row_buffer_size, const string& end_point, + const std::vector& columns, const BigQueryTablePartition& partition, + std::unique_ptr auth_provider, + std::unique_ptr http_request_factory, + std::unique_ptr* accessor) { + if (timestamp_millis <= 0) { + return errors::InvalidArgument( + "Cannot use zero or negative timestamp to query a table."); + } + const string& big_query_end_point = + end_point.empty() ? kBigQueryEndPoint : end_point; + if (auth_provider == nullptr && http_request_factory == nullptr) { + accessor->reset(new BigQueryTableAccessor( + project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, + big_query_end_point, columns, partition)); + } else { + accessor->reset(new BigQueryTableAccessor( + project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, + big_query_end_point, columns, partition, std::move(auth_provider), + std::move(http_request_factory))); + } + return (*accessor)->ReadSchema(); +} + +BigQueryTableAccessor::BigQueryTableAccessor( + const string& project_id, const string& dataset_id, const string& table_id, + int64 timestamp_millis, int64 row_buffer_size, const string& end_point, + const std::vector& columns, const BigQueryTablePartition& partition) + : BigQueryTableAccessor( + project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, + end_point, columns, partition, + std::unique_ptr(new GoogleAuthProvider()), + std::unique_ptr(new HttpRequest::Factory())) { + row_buffer_.resize(row_buffer_size); +} + +BigQueryTableAccessor::BigQueryTableAccessor( + const string& project_id, const string& dataset_id, const string& table_id, + int64 timestamp_millis, int64 row_buffer_size, const string& end_point, + const std::vector& columns, const BigQueryTablePartition& partition, + std::unique_ptr auth_provider, + std::unique_ptr http_request_factory) + : project_id_(project_id), + dataset_id_(dataset_id), + table_id_(table_id), + timestamp_millis_(timestamp_millis), + columns_(columns.begin(), columns.end()), + bigquery_end_point_(end_point), + partition_(partition), + auth_provider_(std::move(auth_provider)), + http_request_factory_(std::move(http_request_factory)) { + row_buffer_.resize(row_buffer_size); + Reset(); +} + +Status BigQueryTableAccessor::SetPartition( + const BigQueryTablePartition& partition) { + if (partition.start_index() < 0) { + return errors::InvalidArgument("Start index cannot be negative."); + } + partition_ = partition; + Reset(); + return Status::OK(); +} + +void BigQueryTableAccessor::Reset() { + first_buffered_row_index_ = partition_.start_index(); + next_row_in_buffer_ = -1; + next_page_token_ = ""; +} + +Status BigQueryTableAccessor::ReadRow(int64* row_id, Example* example) { + if (Done()) { + return errors::OutOfRange("Reached end of table ", FullTableName()); + } + + // If the next row is already fetched and cached, return the row from the + // buffer. Otherwise, fill up the row buffer from BigQuery and return a row. + if (next_row_in_buffer_ != -1 && + next_row_in_buffer_ < ComputeMaxResultsArg()) { + *row_id = first_buffered_row_index_ + next_row_in_buffer_; + *example = row_buffer_[next_row_in_buffer_]; + next_row_in_buffer_++; + } else { + string auth_token; + TF_RETURN_IF_ERROR( + AuthProvider::GetToken(auth_provider_.get(), &auth_token)); + + std::unique_ptr request(http_request_factory_->Create()); + std::vector output_buffer; + output_buffer.reserve(kBufferSize); + TF_RETURN_IF_ERROR(request->Init()); + + // The first time that we access BigQuery there is no page token. After that + // we use the page token (which returns rows faster). + if (!next_page_token_.empty()) { + TF_RETURN_IF_ERROR(request->SetUri(strings::StrCat( + BigQueryUriPrefix(), "data?maxResults=", ComputeMaxResultsArg(), + "&pageToken=", request->EscapeString(next_page_token_)))); + first_buffered_row_index_ += row_buffer_.size(); + } else { + TF_RETURN_IF_ERROR(request->SetUri(strings::StrCat( + BigQueryUriPrefix(), "data?maxResults=", ComputeMaxResultsArg(), + "&startIndex=", first_buffered_row_index_))); + } + TF_RETURN_IF_ERROR(request->AddAuthBearerHeader(auth_token)); + TF_RETURN_IF_ERROR(request->SetResultBuffer(&output_buffer)); + TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when reading rows from ", + FullTableName()); + + // Parse the returned row. + StringPiece response_piece = + StringPiece(&output_buffer[0], output_buffer.size()); + Json::Value root; + TF_RETURN_IF_ERROR(ParseJson(response_piece, &root)); + for (unsigned int i = 0; i < root["rows"].size(); ++i) { + row_buffer_[i].Clear(); + TF_RETURN_IF_ERROR( + ParseColumnValues(root["rows"][i], schema_root_, &row_buffer_[i])); + } + + next_page_token_ = root["pageToken"].asString(); + *row_id = first_buffered_row_index_; + *example = row_buffer_[0]; + next_row_in_buffer_ = 1; + } + return Status::OK(); +} + +int64 BigQueryTableAccessor::ComputeMaxResultsArg() { + if (partition_.end_index() == -1) { + return row_buffer_.size(); + } + if (IsPartitionEmpty(partition_)) { + return 0; + } + return std::min(static_cast(row_buffer_.size()), + static_cast(partition_.end_index() - + partition_.start_index() + 1)); +} + +Status BigQueryTableAccessor::ParseColumnValues( + const Json::Value& value, const SchemaNode& root_schema_node, + Example* example) { + if (value.empty()) { + return Status::OK(); + } + if (value["f"].isNull()) { + return Status::OK(); + } + int value_index = 0; + for (const auto& schema_node : root_schema_node.schema_nodes) { + if (value["f"][value_index].isNull()) { + value_index++; + continue; + } + + if (schema_node.type == ColumnType::kRecord) { + TF_RETURN_IF_ERROR(ParseColumnValues(value["f"][value_index]["v"], + schema_node, example)); + } else { + // Append the column value only if user has requested the column. + if (columns_.empty() || + columns_.find(schema_node.name) != columns_.end()) { + TF_RETURN_IF_ERROR(AppendValueToExample(schema_node.name, + value["f"][value_index]["v"], + schema_node.type, example)); + } + } + value_index++; + } + return Status::OK(); +} + +Status BigQueryTableAccessor::ReadSchema() { + string auth_token; + TF_RETURN_IF_ERROR(AuthProvider::GetToken(auth_provider_.get(), &auth_token)); + + // Send a request to read the schema. + std::unique_ptr request(http_request_factory_->Create()); + std::vector output_buffer; + output_buffer.reserve(kBufferSize); + TF_RETURN_IF_ERROR(request->Init()); + TF_RETURN_IF_ERROR(request->SetUri(BigQueryUriPrefix())); + TF_RETURN_IF_ERROR(request->AddAuthBearerHeader(auth_token)); + TF_RETURN_IF_ERROR(request->SetResultBuffer(&output_buffer)); + TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when reading schema for ", + FullTableName()); + + // Parse the schema. + StringPiece response_piece = + StringPiece(&output_buffer[0], output_buffer.size()); + + Json::Value root; + TF_RETURN_IF_ERROR(ParseJson(response_piece, &root)); + const auto& columns = root["schema"]["fields"]; + string column_name_prefix = ""; + schema_root_ = {"", ColumnType::kNone}; + TF_RETURN_IF_ERROR( + ExtractColumnType(columns, column_name_prefix, &schema_root_)); + if (root["numRows"].isNull()) { + return errors::Internal("Number of rows cannot be extracted for table ", + FullTableName()); + } + strings::safe_strto64(root["numRows"].asString().c_str(), &total_num_rows_); + return Status::OK(); +} + +Status BigQueryTableAccessor::ExtractColumnType( + const Json::Value& columns, const string& column_name_prefix, + SchemaNode* root) { + for (auto columns_it = columns.begin(); columns_it != columns.end(); + ++columns_it) { + if ((*columns_it)["mode"].asString() == "REPEATED") { + return errors::Unimplemented(strings::StrCat( + "Tables with repeated columns are not supported: ", FullTableName())); + } + ColumnType type; + const string current_column_name = strings::StrCat( + column_name_prefix, (*columns_it)["name"].asString().c_str()); + TF_RETURN_IF_ERROR( + ParseColumnType((*columns_it)["type"].asString().c_str(), &type)); + root->schema_nodes.emplace_back(current_column_name, type); + if (type == ColumnType::kRecord) { + const auto new_prefix = strings::StrCat(current_column_name, "."); + TF_RETURN_IF_ERROR(ExtractColumnType((*columns_it)["fields"], new_prefix, + &root->schema_nodes.back())); + } + } + return Status::OK(); +} + +Status BigQueryTableAccessor::AppendValueToExample( + const string& column_name, const Json::Value& column_value, + const BigQueryTableAccessor::ColumnType type, Example* example) { + if (column_value.isNull()) { + return Status::OK(); + } + auto& feature = + (*example->mutable_features()->mutable_feature())[column_name]; + + switch (type) { + case BigQueryTableAccessor::ColumnType::kNone: + case BigQueryTableAccessor::ColumnType::kRecord: + return errors::Unimplemented("Cannot append type to an example."); + case BigQueryTableAccessor::ColumnType::kTimestamp: + case BigQueryTableAccessor::ColumnType::kDate: + case BigQueryTableAccessor::ColumnType::kTime: + case BigQueryTableAccessor::ColumnType::kDatetime: + case BigQueryTableAccessor::ColumnType::kString: + case BigQueryTableAccessor::ColumnType::kBytes: + feature.mutable_bytes_list()->add_value(column_value.asString()); + break; + case BigQueryTableAccessor::ColumnType::kBoolean: + feature.mutable_int64_list()->add_value( + column_value.asString() == "false" ? 0 : 1); + break; + case BigQueryTableAccessor::ColumnType::kInteger: + int64 column_value_int64; + if (!strings::safe_strto64(column_value.asString().c_str(), + &column_value_int64)) { + return errors::Internal("Cannot convert value to integer ", + column_value.asString().c_str()); + } + feature.mutable_int64_list()->add_value(column_value_int64); + break; + case BigQueryTableAccessor::ColumnType::kFloat: + // BigQuery float is actually a double. + double column_value_double; + if (!strings::safe_strtod(column_value.asString().c_str(), + &column_value_double)) { + return errors::Internal("Cannot convert value to double: ", + column_value.asString().c_str()); + } + feature.mutable_float_list()->add_value( + static_cast(column_value_double)); + break; + } + return Status::OK(); +} + +string BigQueryTableAccessor::BigQueryTableAccessor::BigQueryUriPrefix() { + HttpRequest request; + return strings::StrCat(bigquery_end_point_, "/projects/", + request.EscapeString(project_id_), "/datasets/", + request.EscapeString(dataset_id_), "/tables/", + request.EscapeString(table_id_), "/"); +} + +bool BigQueryTableAccessor::Done() { + return (total_num_rows_ <= first_buffered_row_index_ + next_row_in_buffer_) || + IsPartitionEmpty(partition_) || + (partition_.end_index() != -1 && + partition_.end_index() < + first_buffered_row_index_ + next_row_in_buffer_); +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h new file mode 100644 index 0000000000..1cd0482186 --- /dev/null +++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h @@ -0,0 +1,208 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_PARTITION_ACCESSOR_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_PARTITION_ACCESSOR_H_ + +#include +#include +#include + +#include "tensorflow/contrib/cloud/kernels/bigquery_table_partition.pb.h" +#include "tensorflow/core/example/example.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/cloud/google_auth_provider.h" +#include "tensorflow/core/platform/cloud/http_request.h" + +namespace tensorflow { + +/// This class facilitates accessing BigQuery tables. +/// +/// Notes: +/// - Nested fields are not supported. +/// - BigQuery 'Record's are automatically flattened, +/// - BigQuery float type is a double but is converted to a C++ float in this +/// class. +/// - It is possible for a table snapshot to go out-of-scope in the BigQuery +/// service while accessing the table if a very old timestamp is used. For +/// exact details, see 'Table Decorators' in BigQuery docs. +class BigQueryTableAccessor { + public: + // Column types supported by BigQuery. + enum class ColumnType { + kString = 0, + kBytes, + kInteger, + kFloat, + kBoolean, + kTimestamp, + kDate, + kTime, + kDatetime, + kRecord, + kNone + }; + + /// \brief Creates a new BigQueryTableAccessor object. + // + // We do not allow relative (negative or zero) snapshot times here since we + // want to have a consistent snapshot of the table for the lifetime of this + // object. + // Use end_point if you want to connect to a different end point than the + // official BigQuery end point. Otherwise send an empty string. + static Status New(const string& project_id, const string& dataset_id, + const string& table_id, int64 timestamp_millis, + int64 row_buffer_size, const string& end_point, + const std::vector& columns, + const BigQueryTablePartition& partition, + std::unique_ptr* accessor); + + /// \brief Starts reading a new partition. + Status SetPartition(const BigQueryTablePartition& partition); + + /// \brief Returns true if there are more rows available in the current + /// partition. + bool Done(); + + /// \brief Returns a single row as example proto. + /// + /// This function will return an error if the table snapshot goes out of scope + /// in the BigQuery service. + Status ReadRow(int64* row_id, Example* example); + + /// \brief Returns total number of rows in the table. + int64 total_num_rows() { return total_num_rows_; } + + virtual ~BigQueryTableAccessor() {} + + private: + friend class BigQueryTableAccessorTest; + + // This struct encapsulates schema nodes for a BigQuery table. + struct SchemaNode { + SchemaNode() {} + SchemaNode(const string& name, ColumnType type) : name(name), type(type) {} + + string name; + ColumnType type; + std::vector schema_nodes; + }; + + /// If nullptr is passed for http_request_factory and auth_provider the + /// default production ones are used. This can be used by tests to override + /// these two variables. + static Status New(const string& project_id, const string& dataset_id, + const string& table_id, int64 timestamp_millis, + int64 row_buffer_size, const string& end_point, + const std::vector& columns, + const BigQueryTablePartition& partition, + std::unique_ptr auth_provider, + std::unique_ptr http_request_factory, + std::unique_ptr* accessor); + + /// \brief Constructs an object for a given table and partition. + BigQueryTableAccessor(const string& project_id, const string& dataset_id, + const string& table_id, int64 timestamp_millis, + int64 row_buffer_size, const string& end_point, + const std::vector& columns, + const BigQueryTablePartition& partition); + + /// Used for unit testing. + BigQueryTableAccessor( + const string& project_id, const string& dataset_id, + const string& table_id, int64 timestamp_millis, int64 row_buffer_size, + const string& end_point, const std::vector& columns, + const BigQueryTablePartition& partition, + std::unique_ptr auth_provider, + std::unique_ptr http_request_factory); + + /// \brief Parses column values for a given row. + Status ParseColumnValues(const Json::Value& value, + const SchemaNode& root_schema_node, + Example* example); + + /// \brief Reads the table schema and stores it. + Status ReadSchema(); + + /// \brief Extracts column type from a column in schema. + Status ExtractColumnType(const Json::Value& columns, + const string& column_name_prefix, SchemaNode* root); + + /// \brief Appends a single BigQuery column Value to 'example' for a given + /// column. + Status AppendValueToExample(const string& column_name, + const Json::Value& column_value, + const BigQueryTableAccessor::ColumnType type, + Example* example); + + /// \brief Resets internal counters for reading a partition. + void Reset(); + + /// \brief Helper function that returns BigQuery http endpoint prefix. + string BigQueryUriPrefix(); + + /// \brief Computes the maxResults arg to send to BigQuery. + int64 ComputeMaxResultsArg(); + + /// \brief Returns full name of the underlying table name. + string FullTableName() { + return strings::StrCat(project_id_, ":", dataset_id_, ".", table_id_, "@", + timestamp_millis_); + } + + const string project_id_; + const string dataset_id_; + const string table_id_; + + // Snapshot timestamp. + const int64 timestamp_millis_; + + // Columns that should be read. Empty means all columns. + const std::set columns_; + + // HTTP address of BigQuery end point to use. + const string bigquery_end_point_; + + // Describes the portion of the table that we are currently accessing. + BigQueryTablePartition partition_; + + // Total number of rows in the underlying table. + int64 total_num_rows_ = 0; + + // Offset of the first row in the underlying row_buffer_. + int64 first_buffered_row_index_ = 0; + + // Offset of the next row in the row_buffer_. -1 indicates that this index + // is invalid. + int next_row_in_buffer_ = -1; + + // This buffer holds next rows to improve performance. Its size will be + // based on how much buffering was requested. + std::vector row_buffer_; + + // If next_page is set, it will used to read next batch of data. + string next_page_token_; + + // A tree representing the schema for the underlying table. + SchemaNode schema_root_; + + std::unique_ptr auth_provider_; + std::unique_ptr http_request_factory_; + + TF_DISALLOW_COPY_AND_ASSIGN(BigQueryTableAccessor); +}; + +} // namespace tensorflow +#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_PARTITION_ACCESSOR_H_ diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test.cc b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test.cc new file mode 100644 index 0000000000..9fb339864d --- /dev/null +++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test.cc @@ -0,0 +1,513 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h" +#include "tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test_data.h" +#include "tensorflow/core/example/feature.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/platform/cloud/http_request_fake.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +constexpr char kTestProject[] = "test-project"; +constexpr char kTestDataset[] = "test-dataset"; +constexpr char kTestTable[] = "test-table"; + +static bool HasSubstr(const string& base, const string& substr) { + bool ok = StringPiece(base).contains(substr); + EXPECT_TRUE(ok) << base << ", expected substring " << substr; + return ok; +} + +class FakeAuthProvider : public AuthProvider { + public: + Status GetToken(string* token) override { + *token = "fake_token"; + return Status::OK(); + } +}; + +static string DeterministicSerialization(const tensorflow::Example& example) { + const int size = example.ByteSize(); + string result(size, '\0'); + ::tensorflow::protobuf::io::ArrayOutputStream array_stream( + gtl::string_as_array(&result), size); + ::tensorflow::protobuf::io::CodedOutputStream output_stream(&array_stream); + + output_stream.SetSerializationDeterministic(true); + example.SerializeWithCachedSizes(&output_stream); + EXPECT_FALSE(output_stream.HadError()); + EXPECT_EQ(size, output_stream.ByteCount()); + return result; +} + +} // namespace + +class BigQueryTableAccessorTest : public ::testing::Test { + protected: + BigQueryTableAccessor::SchemaNode GetSchema() { + return accessor_->schema_root_; + } + + Status CreateTableAccessor(const string& project_id, const string& dataset_id, + const string& table_id, int64 timestamp_millis, + int64 row_buffer_size, + const std::vector& columns, + const BigQueryTablePartition& partition) { + return BigQueryTableAccessor::New( + project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, "", + columns, partition, std::unique_ptr(new FakeAuthProvider), + std::unique_ptr( + new FakeHttpRequestFactory(&requests_)), + &accessor_); + } + + std::vector requests_; + std::unique_ptr accessor_; +}; + +TEST_F(BigQueryTableAccessorTest, NegativeTimestamp) { + const auto status = + CreateTableAccessor(kTestProject, kTestDataset, kTestTable, -1, 3, {}, + BigQueryTablePartition()); + EXPECT_TRUE(errors::IsInvalidArgument(status)); +} + +TEST_F(BigQueryTableAccessorTest, ZeroTimestamp) { + const auto status = + CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 0, 3, {}, + BigQueryTablePartition()); + EXPECT_TRUE(errors::IsInvalidArgument(status)); +} + +TEST_F(BigQueryTableAccessorTest, RepeatedFieldNoAllowedTest) { + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/\n" + "Auth Token: fake_token\n", + R"({ + "kind": "bigquery#table", + "etag": "\"4zcX32ezvFoFzxHoG04qJqKZk6c/MTQ1Nzk3NTgwNzE4Mw\"", + "id": "test-project:test-dataset.test-table", + "schema": { + "fields": [ + { + "name": "int_field", + "type": "INTEGER", + "mode": "REPEATED" + }] + }, + "numRows": "10" + })")); + const auto status = + CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 3, {}, + BigQueryTablePartition()); + EXPECT_TRUE(errors::IsUnimplemented(status)); + EXPECT_TRUE(HasSubstr(status.error_message(), + "Tables with repeated columns are not supported")); +} + +TEST_F(BigQueryTableAccessorTest, ValidSchemaTest) { + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/\n" + "Auth Token: fake_token\n", + kSampleSchema)); + TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 3, + {}, BigQueryTablePartition())); + // Validate total number of rows. + EXPECT_EQ(4, accessor_->total_num_rows()); + + // Validate the schema. + const auto schema_root = GetSchema(); + EXPECT_EQ(schema_root.name, ""); + EXPECT_EQ(schema_root.type, BigQueryTableAccessor::ColumnType::kNone); + EXPECT_EQ(9, schema_root.schema_nodes.size()); + + EXPECT_EQ(schema_root.schema_nodes[0].name, "int_field"); + EXPECT_EQ(schema_root.schema_nodes[0].type, + BigQueryTableAccessor::ColumnType::kInteger); + + EXPECT_EQ(schema_root.schema_nodes[1].name, "str_field"); + EXPECT_EQ(schema_root.schema_nodes[1].type, + BigQueryTableAccessor::ColumnType::kString); + + EXPECT_EQ(1, schema_root.schema_nodes[2].schema_nodes.size()); + EXPECT_EQ(schema_root.schema_nodes[2].name, "rec_field"); + EXPECT_EQ(schema_root.schema_nodes[2].type, + BigQueryTableAccessor::ColumnType::kRecord); + + EXPECT_EQ(schema_root.schema_nodes[2].schema_nodes[0].name, + "rec_field.float_field"); + EXPECT_EQ(schema_root.schema_nodes[2].schema_nodes[0].type, + BigQueryTableAccessor::ColumnType::kFloat); + + EXPECT_EQ(schema_root.schema_nodes[3].name, "bool_field"); + EXPECT_EQ(schema_root.schema_nodes[3].type, + BigQueryTableAccessor::ColumnType::kBoolean); + + EXPECT_EQ(schema_root.schema_nodes[4].name, "bytes_field"); + EXPECT_EQ(schema_root.schema_nodes[4].type, + BigQueryTableAccessor::ColumnType::kBytes); + + EXPECT_EQ(schema_root.schema_nodes[5].name, "timestamp_field"); + EXPECT_EQ(schema_root.schema_nodes[5].type, + BigQueryTableAccessor::ColumnType::kTimestamp); + + EXPECT_EQ(schema_root.schema_nodes[6].name, "date_field"); + EXPECT_EQ(schema_root.schema_nodes[6].type, + BigQueryTableAccessor::ColumnType::kDate); + + EXPECT_EQ(schema_root.schema_nodes[7].name, "time_field"); + EXPECT_EQ(schema_root.schema_nodes[7].type, + BigQueryTableAccessor::ColumnType::kTime); + + EXPECT_EQ(schema_root.schema_nodes[8].name, "datetime_field"); + EXPECT_EQ(schema_root.schema_nodes[8].type, + BigQueryTableAccessor::ColumnType::kDatetime); +} + +TEST_F(BigQueryTableAccessorTest, ReadOneRowTest) { + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/\n" + "Auth Token: fake_token\n", + kSampleSchema)); + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/data?maxResults=1&startIndex=2\n" + "Auth Token: fake_token\n", + kTestRow)); + BigQueryTablePartition partition; + partition.set_start_index(2); + partition.set_end_index(2); + TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1, + {}, partition)); + int64 row_id; + Example example; + TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example)); + + // Validate returned result. + Example expected_example; + ASSERT_TRUE(protobuf::TextFormat::ParseFromString(kTestExampleProto, + &expected_example)); + EXPECT_EQ(DeterministicSerialization(expected_example), + DeterministicSerialization(example)); + EXPECT_EQ(row_id, 2); + EXPECT_TRUE(accessor_->Done()); +} + +TEST_F(BigQueryTableAccessorTest, ReadOneRowPartialTest) { + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/\n" + "Auth Token: fake_token\n", + kSampleSchema)); + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/data?maxResults=1&startIndex=2\n" + "Auth Token: fake_token\n", + kTestRow)); + BigQueryTablePartition partition; + partition.set_start_index(2); + partition.set_end_index(2); + TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1, + {"bool_field", "rec_field.float_field"}, + partition)); + int64 row_id; + Example example; + TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example)); + + // Validate returned result. + EXPECT_EQ(row_id, 2); + EXPECT_TRUE(accessor_->Done()); + Example expected_example; + ASSERT_TRUE(protobuf::TextFormat::ParseFromString(kTestPartialExampleProto, + &expected_example)); + EXPECT_EQ(DeterministicSerialization(expected_example), + DeterministicSerialization(example)); +} + +TEST_F(BigQueryTableAccessorTest, ReadOneRowWithNullsTest) { + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/\n" + "Auth Token: fake_token\n", + kSampleSchema)); + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/data?maxResults=1&startIndex=2\n" + "Auth Token: fake_token\n", + kTestRowWithNulls)); + BigQueryTablePartition partition; + partition.set_start_index(2); + partition.set_end_index(2); + TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1, + {}, partition)); + int64 row_id; + Example example; + TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example)); + + // Validate returned result. + Example expected_example; + ASSERT_TRUE(protobuf::TextFormat::ParseFromString(kTestExampleProtoWithNulls, + &expected_example)); + EXPECT_EQ(DeterministicSerialization(expected_example), + DeterministicSerialization(example)); + EXPECT_EQ(row_id, 2); + EXPECT_TRUE(accessor_->Done()); +} + +TEST_F(BigQueryTableAccessorTest, ReadOneRowTwoRecords) { + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/\n" + "Auth Token: fake_token\n", + kSampleSchemaTwoRecords)); + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/data?maxResults=1&startIndex=2\n" + "Auth Token: fake_token\n", + kTestRowWithTwoRecords)); + BigQueryTablePartition partition; + partition.set_start_index(2); + partition.set_end_index(2); + TF_EXPECT_OK(CreateTableAccessor( + kTestProject, kTestDataset, kTestTable, 1, 1, + {"rec_field2.bool_field", "rec_field1.float_field"}, partition)); + + int64 row_id; + Example example; + TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example)); + + // Validate returned result. + Example expected_example; + ASSERT_TRUE(protobuf::TextFormat::ParseFromString( + kTestExampleProtoWithTwoRecords, &expected_example)); + EXPECT_EQ(DeterministicSerialization(expected_example), + DeterministicSerialization(example)); + EXPECT_EQ(row_id, 2); + EXPECT_TRUE(accessor_->Done()); +} + +TEST_F(BigQueryTableAccessorTest, NonExistentColumns) { + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/\n" + "Auth Token: fake_token\n", + kSampleSchemaTwoRecords)); + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/data?maxResults=1&startIndex=2\n" + "Auth Token: fake_token\n", + kTestRowWithTwoRecords)); + BigQueryTablePartition partition; + partition.set_start_index(2); + partition.set_end_index(2); + TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1, + {"bool_field", "float_field"}, partition)); + int64 row_id; + Example example; + TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example)); + + // Validate returned result. + EXPECT_EQ(row_id, 2); + EXPECT_TRUE(accessor_->Done()); +} + +TEST_F(BigQueryTableAccessorTest, EmptyRow) { + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/\n" + "Auth Token: fake_token\n", + kSampleSchemaTwoRecords)); + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/data?maxResults=1&startIndex=2\n" + "Auth Token: fake_token\n", + kTestEmptyRow)); + BigQueryTablePartition partition; + partition.set_start_index(2); + partition.set_end_index(2); + TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1, + {}, partition)); + int64 row_id; + Example example; + TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example)); + + // Validate returned result. + EXPECT_EQ(row_id, 2); + EXPECT_TRUE(accessor_->Done()); +} + +TEST_F(BigQueryTableAccessorTest, BrokenRowTest) { + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/\n" + "Auth Token: fake_token\n", + kSampleSchema)); + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/data?maxResults=1&startIndex=2\n" + "Auth Token: fake_token\n", + kBrokenTestRow)); + BigQueryTablePartition partition; + partition.set_start_index(2); + partition.set_end_index(2); + TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1, + {}, partition)); + int64 row_id; + Example example; + const auto status = accessor_->ReadRow(&row_id, &example); + EXPECT_TRUE(errors::IsInternal(status)); + EXPECT_TRUE( + HasSubstr(status.error_message(), "Cannot convert value to integer")); +} + +TEST_F(BigQueryTableAccessorTest, MultiplePagesTest) { + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/\n" + "Auth Token: fake_token\n", + kSampleSchema)); + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/data?maxResults=2&startIndex=1\n" + "Auth Token: fake_token\n", + kTestTwoRows)); + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/" + "data?maxResults=2&pageToken=next_page\n" + "Auth Token: fake_token\n", + kTestRowWithNulls)); + + BigQueryTablePartition partition; + partition.set_start_index(1); + partition.set_end_index(-1); + TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 2, + {}, partition)); + + int64 row_id; + Example example; + TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example)); + EXPECT_EQ(1, row_id); + EXPECT_FALSE(accessor_->Done()); + EXPECT_EQ( + (example.features().feature()).at("int_field").int64_list().value(0), + 1111); + + TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example)); + EXPECT_EQ(2, row_id); + EXPECT_FALSE(accessor_->Done()); + EXPECT_EQ(example.features().feature().at("int_field").int64_list().value(0), + 2222); + + TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example)); + EXPECT_EQ(3, row_id); + EXPECT_TRUE(accessor_->Done()); + + Example expected_example; + ASSERT_TRUE(protobuf::TextFormat::ParseFromString(kTestExampleProtoWithNulls, + &expected_example)); + EXPECT_EQ(DeterministicSerialization(expected_example), + DeterministicSerialization(example)); + EXPECT_TRUE(errors::IsOutOfRange(accessor_->ReadRow(&row_id, &example))); +} + +TEST_F(BigQueryTableAccessorTest, SwitchingPartitionsTest) { + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/\n" + "Auth Token: fake_token\n", + kSampleSchema)); + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/data?maxResults=1&startIndex=0\n" + "Auth Token: fake_token\n", + kTestTwoRows)); + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/" + "data?maxResults=2&startIndex=3\n" + "Auth Token: fake_token\n", + kTestRowWithNulls)); + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/data?maxResults=2&startIndex=0\n" + "Auth Token: fake_token\n", + kTestTwoRows)); + + BigQueryTablePartition partition; + partition.set_start_index(0); + partition.set_end_index(0); + TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 2, + {}, partition)); + + int64 row_id; + Example example; + TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example)); + EXPECT_EQ(0, row_id); + EXPECT_TRUE(accessor_->Done()); + EXPECT_EQ(example.features().feature().at("int_field").int64_list().value(0), + 1111); + + partition.set_start_index(3); + partition.set_end_index(-1); + accessor_->SetPartition(partition); + TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example)); + EXPECT_EQ(3, row_id); + EXPECT_TRUE(accessor_->Done()); + EXPECT_EQ(example.features().feature().at("int_field").int64_list().value(0), + 1234); + + partition.set_start_index(0); + partition.set_end_index(1); + accessor_->SetPartition(partition); + TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example)); + EXPECT_EQ(0, row_id); + EXPECT_FALSE(accessor_->Done()); + EXPECT_EQ(example.features().feature().at("int_field").int64_list().value(0), + 1111); + TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example)); + EXPECT_EQ(1, row_id); + EXPECT_TRUE(accessor_->Done()); + EXPECT_EQ(example.features().feature().at("int_field").int64_list().value(0), + 2222); +} + +TEST_F(BigQueryTableAccessorTest, EmptyPartitionTest) { + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/\n" + "Auth Token: fake_token\n", + kSampleSchema)); + + BigQueryTablePartition partition; + partition.set_start_index(3); + partition.set_end_index(2); + TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1, + {}, partition)); + EXPECT_TRUE(accessor_->Done()); + + int64 row_id; + Example example; + EXPECT_TRUE(errors::IsOutOfRange(accessor_->ReadRow(&row_id, &example))); +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test_data.h b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test_data.h new file mode 100644 index 0000000000..b2b11f4f57 --- /dev/null +++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test_data.h @@ -0,0 +1,404 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_TABLE_ACCESSOR_TEST_DATA_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_TABLE_ACCESSOR_TEST_DATA_H_ + +#include + +namespace tensorflow { +namespace { + +const string kSampleSchema = R"({ + "kind": "bigquery#table", + "etag": "\"4zcX32ezvFoFzxHoG04qJqKZk6c/MTQ1Nzk3NTgwNzE4Mw\"", + "id": "test-project:test-dataset.test-table", + "schema": { + "fields": [ + { + "name": "int_field", + "type": "INTEGER", + "mode": "REQUIRED" + },{ + "name": "str_field", + "type": "STRING", + "mode": "NULLABLE" + },{ + "name": "rec_field", + "type": "RECORD", + "fields": [ + { + "name": "float_field", + "type": "FLOAT", + "mode": "NULLABLE" + }] + },{ + "name": "bool_field", + "type": "BOOLEAN", + "mode": "NULLABLE" + },{ + "name": "bytes_field", + "type": "BYTES", + "mode": "NULLABLE" + },{ + "name": "timestamp_field", + "type": "TIMESTAMP", + "mode": "NULLABLE" + },{ + "name": "date_field", + "type": "DATE", + "mode": "NULLABLE" + },{ + "name": "time_field", + "type": "TIME", + "mode": "NULLABLE" + },{ + "name": "datetime_field", + "type": "DATETIME", + "mode": "NULLABLE" + }] + }, + "numRows": "4" +})"; + +const string kSampleSchemaTwoRecords = R"({ + "kind": "bigquery#table", + "etag": "\"4zcX32ezvFoFzxHoG04qJqKZk6c/MTQ1Nzk3NTgwNzE4Mw\"", + "id": "test-project:test-dataset.test-table", + "schema": { + "fields": [ + { + "name": "rec_field1", + "type": "RECORD", + "fields": [ + { + "name": "int_field", + "type": "INTEGER", + "mode": "NULLABLE" + }, { + "name": "float_field", + "type": "FLOAT", + "mode": "NULLABLE" + }] + },{ + "name": "rec_field2", + "type": "RECORD", + "fields": [ + { + "name": "bool_field", + "type": "BOOLEAN", + "mode": "NULLABLE" + },{ + "name": "bytes_field", + "type": "BYTES", + "mode": "NULLABLE" + }] + }] + }, + "numRows": "4" +})"; + +const string kTestRow = R"({ + "kind": "bigquery#table", + "etag": "\"4zcX32ezvFoFzxHoG04qJqKZk6c/MTQ1Nzk3NTgwNzE4Mw\"", + "id": "test-project:test-dataset.test-table", + "rows": [ + { + "f": [ + { + "v": "1234" + },{ + "v": "" + },{ + "v": { + "f": [ + { + "v": "1.23456" + }] + } + },{ + "v": "true" + },{ + "v": "01010100101" + },{ + "v": "timestamp" + },{ + "v": "date" + },{ + "v": "time" + },{ + "v": "datetime" + }]}]})"; + +const string kBrokenTestRow = R"({ + "kind": "bigquery#table", + "etag": "\"4zcX32ezvFoFzxHoG04qJqKZk6c/MTQ1Nzk3NTgwNzE4Mw\"", + "id": "test-project:test-dataset.test-table", + "rows": [ + { + "f": [ + { + "v": "1-234" // This does not parse as integer. + },{ + "v": "" + },{ + },{ + "v": "true" + },{ + "v": "01010100101" + },{ + "v": "timestamp" + },{ + "v": "date" + },{ + "v": "time" + },{ + "v": "datetime" + }]}]})"; + +const string kTestRowWithNulls = R"({ + "kind": "bigquery#table", + "etag": "\"4zcX32ezvFoFzxHoG04qJqKZk6c/MTQ1Nzk3NTgwNzE4Mw\"", + "id": "test-project:test-dataset.test-table", + "rows": [ + { + "f": [ + { + "v": "1234" + },{ + "v": "string" + },{ + "v": null + },{ + "v": "true" + },{ + "v": "01010100101" + },{ + "v": "" + },{ + "v": null + },{ + "v": null + },{ + "v": "datetime" + }]}]})"; + +// Example proto corresponding to kTestRow. +const string kTestExampleProto = R"(features { + feature { + key: "bool_field" + value { + int64_list { + value: 1 + } + } + } + feature { + key: "bytes_field" + value { + bytes_list { + value: "01010100101" + } + } + } + feature { + key: "date_field" + value { + bytes_list { + value: "date" + } + } + } + feature { + key: "datetime_field" + value { + bytes_list { + value: "datetime" + } + } + } + feature { + key: "int_field" + value { + int64_list { + value: 1234 + } + } + } + feature { + key: "rec_field.float_field" + value { + float_list { + value: 1.23456 + } + } + } + feature { + key: "str_field" + value { + bytes_list { + value: "" + } + } + } + feature { + key: "time_field" + value { + bytes_list { + value: "time" + } + } + } + feature { + key: "timestamp_field" + value { + bytes_list { + value: "timestamp" + } + } + } +} +)"; + +// Example proto corresponding to kTestRowWithNulls. +const string kTestExampleProtoWithNulls = R"(features { + feature { + key: "bool_field" + value { + int64_list { + value: 1 + } + } + } + feature { + key: "bytes_field" + value { + bytes_list { + value: "01010100101" + } + } + } + feature { + key: "datetime_field" + value { + bytes_list { + value: "datetime" + } + } + } + feature { + key: "int_field" + value { + int64_list { + value: 1234 + } + } + } + feature { + key: "timestamp_field" + value { + bytes_list { + value: "" + } + } + } + feature { + key: "str_field" + value { + bytes_list { + value: "string" + } + } + } +} +)"; + +// Example proto corresponding to part of kTestRow. +const string kTestPartialExampleProto = R"(features { + feature { + key: "bool_field" + value { + int64_list { + value: 1 + } + } + } + feature { + key: "rec_field.float_field" + value { + float_list { + value: 1.23456 + } + } + } +} +)"; + +const string kTestExampleProtoWithTwoRecords = R"(features { + feature { + key: "rec_field1.float_field" + value { + float_list { + value: 1.23456 + } + } + } + feature { + key: "rec_field2.bool_field" + value { + int64_list { + value: 1 + } + } + } +} +)"; + +const string kTestTwoRows = R"({ + "kind": "bigquery#table", + "etag": "\"4zcX32ezvFoFzxHoG04qJqKZk6c/MTQ1Nzk3NTgwNzE4Mw\"", + "pageToken": "next_page", + "id": "test-project:test-dataset.test-table", + "rows": [ + {"f": [{"v": "1111"},{},{},{},{},{},{},{},{}]}, + {"f": [{"v": "2222"},{},{},{},{},{},{},{},{}]} + ]})"; + +const string kTestRowWithTwoRecords = R"({ + "kind": "bigquery#table", + "etag": "\"4zcX32ezvFoFzxHoG04qJqKZk6c/MTQ1Nzk3NTgwNzE4Mw\"", + "id": "test-project:test-dataset.test-table", + "rows": [ + { + "f": [ + {"v": {"f": [{}, {"v": "1.23456"}]}}, + {"v": {"f": [{"v": "true"}, {}]} + }]}]})"; + +const string kTestEmptyRow = R"({ + "kind": "bigquery#table", + "etag": "\"4zcX32ezvFoFzxHoG04qJqKZk6c/MTQ1Nzk3NTgwNzE4Mw\"", + "id": "test-project:test-dataset.test-table", + "rows": [ + { + "f": [ + {"v": {"f": [{}, {}]}}, + {"v": {"f": [{"v": null}, {}]} + }]}]})"; + +} // namespace +} // namepsace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_TABLE_ACCESSOR_TEST_DATA_H_ diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_partition.proto b/tensorflow/contrib/cloud/kernels/bigquery_table_partition.proto new file mode 100644 index 0000000000..2d9d1380db --- /dev/null +++ b/tensorflow/contrib/cloud/kernels/bigquery_table_partition.proto @@ -0,0 +1,12 @@ +syntax = "proto3"; + +package tensorflow; + +// This proto specifies a table partition in BigQuery. +message BigQueryTablePartition { + // [start_index, end_index] specify the boundaries of a partition. + // If end_index is -1, every row starting from start_index is part of the + // partition. + int64 start_index = 1; + int64 end_index = 2; +}; diff --git a/tensorflow/contrib/cloud/ops/bigquery_reader_ops.cc b/tensorflow/contrib/cloud/ops/bigquery_reader_ops.cc new file mode 100644 index 0000000000..fbba04a31a --- /dev/null +++ b/tensorflow/contrib/cloud/ops/bigquery_reader_ops.cc @@ -0,0 +1,88 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/* This file registers Bigquery reader ops. */ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +namespace tensorflow { + +using shape_inference::InferenceContext; + +REGISTER_OP("BigQueryReader") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Attr("project_id: string") + .Attr("dataset_id: string") + .Attr("table_id: string") + .Attr("columns: list(string)") + .Attr("timestamp_millis: int") + .Attr("test_end_point: string = ''") + .Output("reader_handle: Ref(string)") + .SetIsStateful() + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->Vector(2)); + return Status::OK(); + }) + .Doc(R"doc( +A Reader that outputs rows from a BigQuery table as tensorflow Examples. + +container: If non-empty, this reader is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this reader is named in the given bucket + with this shared_name. Otherwise, the node name is used instead. +project_id: GCP project ID. +dataset_id: BigQuery Dataset ID. +table_id: Table to read. +columns: List of columns to read. Leave empty to read all columns. +timestamp_millis: Table snapshot timestamp in millis since epoch. Relative +(negative or zero) snapshot times are not allowed. For more details, see +'Table Decorators' in BigQuery docs. +test_end_point: Do not use. For testing purposes only. +reader_handle: The handle to reference the Reader. +)doc"); + +REGISTER_OP("GenerateBigQueryReaderPartitions") + .Attr("project_id: string") + .Attr("dataset_id: string") + .Attr("table_id: string") + .Attr("columns: list(string)") + .Attr("timestamp_millis: int") + .Attr("num_partitions: int") + .Attr("test_end_point: string = ''") + .Output("partitions: string") + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->Vector(InferenceContext::kUnknownDim)); + return Status::OK(); + }) + .Doc(R"doc( +Generates serialized partition messages suitable for batch reads. + +This op should not be used directly by clients. Instead, the +bigquery_reader_ops.py file defines a clean interface to the reader. + +project_id: GCP project ID. +dataset_id: BigQuery Dataset ID. +table_id: Table to read. +columns: List of columns to read. Leave empty to read all columns. +timestamp_millis: Table snapshot timestamp in millis since epoch. Relative +(negative or zero) snapshot times are not allowed. For more details, see +'Table Decorators' in BigQuery docs. +num_partitions: Number of partitions to split the table into. +test_end_point: Do not use. For testing purposes only. +partitions: Serialized table partitions. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops.py b/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops.py new file mode 100644 index 0000000000..136707da18 --- /dev/null +++ b/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops.py @@ -0,0 +1,150 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""BigQuery reading support for TensorFlow.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.cloud.python.ops import gen_bigquery_reader_ops +from tensorflow.python.framework import ops +from tensorflow.python.ops import io_ops + + +class BigQueryReader(io_ops.ReaderBase): + """A Reader that outputs keys and tf.Example values from a BigQuery table. + + Example use: + ```python + # Assume a BigQuery has the following schema, + # name STRING, + # age INT, + # state STRING + + # Create the parse_examples list of features. + features = dict( + name=tf.FixedLenFeature([1], tf.string), + age=tf.FixedLenFeature([1], tf.int32), + state=tf.FixedLenFeature([1], dtype=tf.string, default_value="UNK")) + + # Create a Reader. + reader = bigquery_reader_ops.BigQueryReader(project_id=PROJECT, + dataset_id=DATASET, + table_id=TABLE, + timestamp_millis=TIME, + num_partitions=NUM_PARTITIONS, + features=features) + + # Populate a queue with the BigQuery Table partitions. + queue = tf.training.string_input_producer(reader.partitions()) + + # Read and parse examples. + row_id, examples_serialized = reader.read(queue) + examples = tf.parse_example(examples_serialized, features=features) + + # Process the Tensors examples["name"], examples["age"], etc... + ``` + + Note that to create a reader a snapshot timestamp is necessary. This + will enable the reader to look at a consistent snapshot of the table. + For more information, see 'Table Decorators' in BigQuery docs. + + See ReaderBase for supported methods. + """ + + def __init__(self, + project_id, + dataset_id, + table_id, + timestamp_millis, + num_partitions, + features=None, + columns=None, + test_end_point=None, + name=None): + """Creates a BigQueryReader. + + Args: + project_id: GCP project ID. + dataset_id: BigQuery dataset ID. + table_id: BigQuery table ID. + timestamp_millis: timestamp to snapshot the table in milliseconds since + the epoch. Relative (negative or zero) snapshot times are not allowed. + For more details, see 'Table Decorators' in BigQuery docs. + num_partitions: Number of non-overlapping partitions to read from. + features: parse_example compatible dict from keys to `VarLenFeature` and + `FixedLenFeature` objects. Keys are read as columns from the db. + columns: list of columns to read, can be set iff features is None. + test_end_point: Used only for testing purposes (optional). + name: a name for the operation (optional). + + Raises: + TypeError: - If features is neither None nor a dict or + - If columns is is neither None nor a list or + - If both features and columns are None or set. + """ + if (features is None) == (columns is None): + raise TypeError("exactly one of features and columns must be set.") + + if features is not None: + if not isinstance(features, dict): + raise TypeError("features must be a dict.") + self._columns = list(features.keys()) + elif columns is not None: + if not isinstance(columns, list): + raise TypeError("columns must be a list.") + self._columns = columns + + self._project_id = project_id + self._dataset_id = dataset_id + self._table_id = table_id + self._timestamp_millis = timestamp_millis + self._num_partitions = num_partitions + self._test_end_point = test_end_point + + reader = gen_bigquery_reader_ops.big_query_reader( + name=name, + project_id=self._project_id, + dataset_id=self._dataset_id, + table_id=self._table_id, + timestamp_millis=self._timestamp_millis, + columns=self._columns, + test_end_point=self._test_end_point) + super(BigQueryReader, self).__init__(reader) + + def partitions(self, name=None): + """Returns serialized BigQueryTablePartition messages. + + These messages represent a non-overlapping division of a table for a + bulk read. + + Args: + name: a name for the operation (optional). + + Returns: + `1-D` string `Tensor` of serialized `BigQueryTablePartition` messages. + """ + return gen_bigquery_reader_ops.generate_big_query_reader_partitions( + name=name, + project_id=self._project_id, + dataset_id=self._dataset_id, + table_id=self._table_id, + timestamp_millis=self._timestamp_millis, + num_partitions=self._num_partitions, + test_end_point=self._test_end_point, + columns=self._columns) + + +ops.NotDifferentiable("BigQueryReader") diff --git a/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops_test.py b/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops_test.py new file mode 100644 index 0000000000..b7d044ed25 --- /dev/null +++ b/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops_test.py @@ -0,0 +1,287 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for BigQueryReader Op.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import os +import re +import threading + +from six.moves import SimpleHTTPServer +from six.moves import socketserver + +from tensorflow.contrib.cloud.python.ops import bigquery_reader_ops as cloud +from tensorflow.core.example import example_pb2 +from tensorflow.core.framework import types_pb2 +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import parsing_ops +from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import compat + +_PROJECT = "test-project" +_DATASET = "test-dataset" +_TABLE = "test-table" +# List representation of the test rows in the 'test-table' in BigQuery. +# The schema for each row is: [int64, string, float]. +# The values for rows are generated such that some columns have null values. The +# general formula here is: +# - The int64 column is present in every row. +# - The string column is only avaiable in even rows. +# - The float column is only available in every third row. +_ROWS = [[0, "s_0", 0.1], [1, None, None], [2, "s_2", None], [3, None, 3.1], + [4, "s_4", None], [5, None, None], [6, "s_6", 6.1], [7, None, None], + [8, "s_8", None], [9, None, 9.1]] +# Schema for 'test-table'. +# The schema currently has three columns: int64, string, and float +_SCHEMA = { + "kind": "bigquery#table", + "id": "test-project:test-dataset.test-table", + "schema": { + "fields": [{ + "name": "int64_col", + "type": "INTEGER", + "mode": "NULLABLE" + }, { + "name": "string_col", + "type": "STRING", + "mode": "NULLABLE" + }, { + "name": "float_col", + "type": "FLOAT", + "mode": "NULLABLE" + }] + } +} + + +def _ConvertRowToExampleProto(row): + """Converts the input row to an Example proto. + + Args: + row: Input Row instance. + + Returns: + An Example proto initialized with row values. + """ + + example = example_pb2.Example() + example.features.feature["int64_col"].int64_list.value.append(row[0]) + if row[1] is not None: + example.features.feature["string_col"].bytes_list.value.append( + compat.as_bytes(row[1])) + if row[2] is not None: + example.features.feature["float_col"].float_list.value.append(row[2]) + return example + + +class FakeBigQueryServer(threading.Thread): + """Fake http server to return schema and data for sample table.""" + + def __init__(self, address, port): + """Creates a FakeBigQueryServer. + + Args: + address: Server address + port: Server port. Pass 0 to automatically pick an empty port. + """ + threading.Thread.__init__(self) + self.handler = BigQueryRequestHandler + self.httpd = socketserver.TCPServer((address, port), self.handler) + + def run(self): + self.httpd.serve_forever() + + def shutdown(self): + self.httpd.shutdown() + self.httpd.socket.close() + + +class BigQueryRequestHandler(SimpleHTTPServer.SimpleHTTPRequestHandler): + """Responds to BigQuery HTTP requests. + + Attributes: + num_rows: num_rows in the underlying table served by this class. + """ + + num_rows = 0 + + def do_GET(self): + if "data?maxResults=" not in self.path: + # This is a schema request. + _SCHEMA["numRows"] = self.num_rows + response = json.dumps(_SCHEMA) + else: + # This is a data request. + # + # Extract max results and start index. + max_results = int(re.findall(r"maxResults=(\d+)", self.path)[0]) + start_index = int(re.findall(r"startIndex=(\d+)", self.path)[0]) + + # Send the rows as JSON. + rows = [] + for row in _ROWS[start_index:start_index + max_results]: + row_json = { + "f": [{ + "v": str(row[0]) + }, { + "v": str(row[1]) if row[1] is not None else None + }, { + "v": str(row[2]) if row[2] is not None else None + }] + } + rows.append(row_json) + response = json.dumps({ + "kind": "bigquery#table", + "id": "test-project:test-dataset.test-table", + "rows": rows + }) + self.send_response(200) + self.end_headers() + self.wfile.write(compat.as_bytes(response)) + + +def _SetUpQueue(reader): + """Sets up a queue for a reader.""" + queue = data_flow_ops.FIFOQueue(8, [types_pb2.DT_STRING], shapes=()) + key, value = reader.read(queue) + queue.enqueue_many(reader.partitions()).run() + queue.close().run() + return key, value + + +class BigQueryReaderOpsTest(test.TestCase): + + def setUp(self): + super(BigQueryReaderOpsTest, self).setUp() + self.server = FakeBigQueryServer("127.0.0.1", 0) + self.server.start() + logging.info("server address is %s:%s", self.server.httpd.server_address[0], + self.server.httpd.server_address[1]) + + # An override to bypass the GCP auth token retrieval logic + # in google_auth_provider.cc. + os.environ["GOOGLE_AUTH_TOKEN_FOR_TESTING"] = "not-used" + + def tearDown(self): + self.server.shutdown() + super(BigQueryReaderOpsTest, self).tearDown() + + def _ReadAndCheckRowsUsingFeatures(self, num_rows): + self.server.handler.num_rows = num_rows + + with self.test_session() as sess: + feature_configs = { + "int64_col": + parsing_ops.FixedLenFeature( + [1], dtype=dtypes.int64), + "string_col": + parsing_ops.FixedLenFeature( + [1], dtype=dtypes.string, default_value="s_default"), + } + reader = cloud.BigQueryReader( + project_id=_PROJECT, + dataset_id=_DATASET, + table_id=_TABLE, + num_partitions=4, + features=feature_configs, + timestamp_millis=1, + test_end_point=("%s:%s" % (self.server.httpd.server_address[0], + self.server.httpd.server_address[1]))) + + key, value = _SetUpQueue(reader) + + seen_rows = [] + features = parsing_ops.parse_example( + array_ops.reshape(value, [1]), feature_configs) + for _ in range(num_rows): + int_value, str_value = sess.run( + [features["int64_col"], features["string_col"]]) + + # Parse values returned from the session. + self.assertEqual(int_value.shape, (1, 1)) + self.assertEqual(str_value.shape, (1, 1)) + int64_col = int_value[0][0] + string_col = str_value[0][0] + seen_rows.append(int64_col) + + # Compare. + expected_row = _ROWS[int64_col] + self.assertEqual(int64_col, expected_row[0]) + self.assertEqual( + compat.as_str(string_col), ("s_%d" % int64_col) if expected_row[1] + else "s_default") + + self.assertItemsEqual(seen_rows, range(num_rows)) + + with self.assertRaisesOpError("is closed and has insufficient elements " + "\\(requested 1, current size 0\\)"): + sess.run([key, value]) + + def testReadingSingleRowUsingFeatures(self): + self._ReadAndCheckRowsUsingFeatures(1) + + def testReadingMultipleRowsUsingFeatures(self): + self._ReadAndCheckRowsUsingFeatures(10) + + def testReadingMultipleRowsUsingColumns(self): + num_rows = 10 + self.server.handler.num_rows = num_rows + + with self.test_session() as sess: + reader = cloud.BigQueryReader( + project_id=_PROJECT, + dataset_id=_DATASET, + table_id=_TABLE, + num_partitions=4, + columns=["int64_col", "float_col", "string_col"], + timestamp_millis=1, + test_end_point=("%s:%s" % (self.server.httpd.server_address[0], + self.server.httpd.server_address[1]))) + key, value = _SetUpQueue(reader) + seen_rows = [] + for row_index in range(num_rows): + returned_row_id, example_proto = sess.run([key, value]) + example = example_pb2.Example() + example.ParseFromString(example_proto) + self.assertIn("int64_col", example.features.feature) + feature = example.features.feature["int64_col"] + self.assertEqual(len(feature.int64_list.value), 1) + int64_col = feature.int64_list.value[0] + seen_rows.append(int64_col) + + # Create our expected Example. + expected_example = example_pb2.Example() + expected_example = _ConvertRowToExampleProto(_ROWS[int64_col]) + + # Compare. + self.assertProtoEquals(example, expected_example) + self.assertEqual(row_index, int(returned_row_id)) + + self.assertItemsEqual(seen_rows, range(num_rows)) + + with self.assertRaisesOpError("is closed and has insufficient elements " + "\\(requested 1, current size 0\\)"): + sess.run([key, value]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake index 556a0b8919..dd28817b54 100644 --- a/tensorflow/contrib/cmake/tf_core_kernels.cmake +++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake @@ -69,8 +69,8 @@ endif(tensorflow_BUILD_CONTRIB_KERNELS) if(NOT tensorflow_ENABLE_SSL_SUPPORT) # Cloud libraries require boringssl. file(GLOB tf_core_kernels_cloud_srcs - "${tensorflow_source_dir}/tensorflow/core/kernels/cloud/*.h" - "${tensorflow_source_dir}/tensorflow/core/kernels/cloud/*.cc" + "${tensorflow_source_dir}/tensorflow/contrib/cloud/kernels/*.h" + "${tensorflow_source_dir}/tensorflow/contrib/cloud/kernels/*.cc" ) list(REMOVE_ITEM tf_core_kernels_srcs ${tf_core_kernels_cloud_srcs}) endif() diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake index 8e237a78a7..73686d0dd3 100644 --- a/tensorflow/contrib/cmake/tf_core_ops.cmake +++ b/tensorflow/contrib/cmake/tf_core_ops.cmake @@ -51,6 +51,7 @@ GENERATE_CONTRIB_OP_LIBRARY(factorization_factorization "${tensorflow_source_dir GENERATE_CONTRIB_OP_LIBRARY(framework_variable "${tensorflow_source_dir}/tensorflow/contrib/framework/ops/variable_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(memory_stats "${tensorflow_source_dir}/tensorflow/contrib/memory_stats/ops/memory_stats_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(tensor_forest "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/ops/tensor_forest_ops.cc") +GENERATE_CONTRIB_OP_LIBRARY(bigquery_reader "${tensorflow_source_dir}/tensorflow/contrib/cloud/ops/bigquery_reader_ops.cc") ######################################################## # tf_user_ops library diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 38047cae78..e58b672347 100644 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -227,6 +227,11 @@ add_python_module("tensorflow/contrib/bayesflow/examples/reinforce_simple") add_python_module("tensorflow/contrib/bayesflow/python") add_python_module("tensorflow/contrib/bayesflow/python/kernel_tests") add_python_module("tensorflow/contrib/bayesflow/python/ops") +add_python_module("tensorflow/contrib/cloud") +add_python_module("tensorflow/contrib/cloud/kernels") +add_python_module("tensorflow/contrib/cloud/ops") +add_python_module("tensorflow/contrib/cloud/python") +add_python_module("tensorflow/contrib/cloud/python/ops") add_python_module("tensorflow/contrib/compiler") add_python_module("tensorflow/contrib/copy_graph") add_python_module("tensorflow/contrib/copy_graph/python") @@ -542,6 +547,9 @@ GENERATE_PYTHON_OP_LIB("contrib_memory_stats_ops" GENERATE_PYTHON_OP_LIB("contrib_tensor_forest_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/tensor_forest/python/ops/gen_tensor_forest_ops.py) +GENERATE_PYTHON_OP_LIB("contrib_bigquery_reader_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/cloud/python/ops/gen_bigquery_reader_ops.py) + add_custom_target(tf_python_ops SOURCES ${tf_python_ops_generated_files} ${PYTHON_PROTO_GENFILES}) add_dependencies(tf_python_ops tf_python_op_gen_main) diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake index d1c6a74937..fc3363189d 100644 --- a/tensorflow/contrib/cmake/tf_tests.cmake +++ b/tensorflow/contrib/cmake/tf_tests.cmake @@ -181,6 +181,7 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/sample_inputs_op_test.py" # Results in wrong order. "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/scatter_add_ndim_op_test.py" # Bad placement. "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/topn_test.py" # Results inaccurate + "${tensorflow_source_dir}/tensorflow/python/ops/cloud/bigquery_reader_ops_test.py" # No libcurl support # Failing on some CI. "${tensorflow_source_dir}/tensorflow/python/debug/cli/analyzer_cli_test.py" "${tensorflow_source_dir}/tensorflow/python/debug/lib/session_debug_file_test.py" diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index a4d0ec63a3..0fd71f6f0a 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -537,16 +537,6 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "cloud_ops_op_lib", - srcs = ["ops/cloud_ops.cc"], - copts = tf_copts(), - linkstatic = 1, - visibility = ["//visibility:public"], - deps = [":framework"], - alwayslink = 1, -) - cc_library( name = "ops", visibility = ["//visibility:public"], diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index a4eef8fe24..e2c1ade13d 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -3827,8 +3827,6 @@ filegroup( # Excluded due to experimental status: "debug_ops.*", "scatter_nd_op*", - # Lib CURL is not supported on Android. - "bigquery*", ], ), visibility = ["//visibility:public"], diff --git a/tensorflow/core/kernels/cloud/BUILD b/tensorflow/core/kernels/cloud/BUILD deleted file mode 100644 index 52313ef9ed..0000000000 --- a/tensorflow/core/kernels/cloud/BUILD +++ /dev/null @@ -1,98 +0,0 @@ -# Description: -# BigQueryReader implementation - -package( - default_visibility = ["//visibility:private"], -) - -licenses(["notice"]) # Apache 2.0 - -load( - "//tensorflow:tensorflow.bzl", - "tf_kernel_library", - "tf_cc_test", -) - -# For platform specific build config -load( - "//tensorflow/core:platform/default/build_config.bzl", - "tf_proto_library", -) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - -tf_kernel_library( - name = "bigquery_reader_ops", - srcs = [ - "bigquery_reader_ops.cc", - ], - visibility = ["//visibility:public"], - deps = [ - ":bigquery_table_accessor", - ":bigquery_table_partition_proto_cc", - "//tensorflow/core:cloud_ops_op_lib", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:reader_base", - ], -) - -cc_library( - name = "bigquery_table_accessor", - srcs = [ - "bigquery_table_accessor.cc", - ], - hdrs = [ - "bigquery_table_accessor.h", - ], - visibility = ["//visibility:public"], - deps = [ - ":bigquery_table_partition_proto_cc", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:reader_base", - "//tensorflow/core/platform/cloud:google_auth_provider", - "//tensorflow/core/platform/cloud:http_request", - ], - alwayslink = 1, -) - -tf_proto_library( - name = "bigquery_table_partition_proto", - srcs = [ - "bigquery_table_partition.proto", - ], - cc_api_version = 2, -) - -tf_cc_test( - name = "bigquery_table_accessor_test", - size = "small", - srcs = [ - "bigquery_table_accessor_test.cc", - "bigquery_table_accessor_test_data.h", - ], - deps = [ - ":bigquery_table_accessor", - "//tensorflow/core:lib_internal", - "//tensorflow/core:lib_proto_parsing", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/platform/cloud:http_request_fake", - ], -) diff --git a/tensorflow/core/kernels/cloud/bigquery_reader_ops.cc b/tensorflow/core/kernels/cloud/bigquery_reader_ops.cc deleted file mode 100644 index bfaa09cfd6..0000000000 --- a/tensorflow/core/kernels/cloud/bigquery_reader_ops.cc +++ /dev/null @@ -1,193 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -#include "tensorflow/core/example/example.pb.h" -#include "tensorflow/core/framework/reader_base.h" -#include "tensorflow/core/framework/reader_op_kernel.h" -#include "tensorflow/core/kernels/cloud/bigquery_table_accessor.h" -#include "tensorflow/core/kernels/cloud/bigquery_table_partition.pb.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/math/math_util.h" -#include "tensorflow/core/lib/strings/numbers.h" - -namespace tensorflow { -namespace { - -constexpr int64 kDefaultRowBufferSize = 1000; // Number of rows to buffer. - -// This is a helper function for reading table attributes from context. -Status GetTableAttrs(OpKernelConstruction* context, string* project_id, - string* dataset_id, string* table_id, - int64* timestamp_millis, std::vector* columns, - string* test_end_point) { - TF_RETURN_IF_ERROR(context->GetAttr("project_id", project_id)); - TF_RETURN_IF_ERROR(context->GetAttr("dataset_id", dataset_id)); - TF_RETURN_IF_ERROR(context->GetAttr("table_id", table_id)); - TF_RETURN_IF_ERROR(context->GetAttr("timestamp_millis", timestamp_millis)); - TF_RETURN_IF_ERROR(context->GetAttr("columns", columns)); - TF_RETURN_IF_ERROR(context->GetAttr("test_end_point", test_end_point)); - return Status::OK(); -} - -} // namespace - -// Note that overriden methods with names ending in "Locked" are called by -// ReaderBase while a mutex is held. -// See comments for ReaderBase. -class BigQueryReader : public ReaderBase { - public: - explicit BigQueryReader(BigQueryTableAccessor* bigquery_table_accessor, - const string& node_name) - : ReaderBase(strings::StrCat("BigQueryReader '", node_name, "'")), - bigquery_table_accessor_(CHECK_NOTNULL(bigquery_table_accessor)) {} - - Status OnWorkStartedLocked() override { - BigQueryTablePartition partition; - if (!partition.ParseFromString(current_work())) { - return errors::InvalidArgument( - "Could not parse work as as valid partition."); - } - TF_RETURN_IF_ERROR(bigquery_table_accessor_->SetPartition(partition)); - return Status::OK(); - } - - Status ReadLocked(string* key, string* value, bool* produced, - bool* at_end) override { - *at_end = false; - *produced = false; - if (bigquery_table_accessor_->Done()) { - *at_end = true; - return Status::OK(); - } - - Example example; - int64 row_id; - TF_RETURN_IF_ERROR(bigquery_table_accessor_->ReadRow(&row_id, &example)); - - *key = std::to_string(row_id); - *value = example.SerializeAsString(); - *produced = true; - return Status::OK(); - } - - private: - // Not owned. - BigQueryTableAccessor* bigquery_table_accessor_; -}; - -class BigQueryReaderOp : public ReaderOpKernel { - public: - explicit BigQueryReaderOp(OpKernelConstruction* context) - : ReaderOpKernel(context) { - string table_id; - string project_id; - string dataset_id; - int64 timestamp_millis; - std::vector columns; - string test_end_point; - - OP_REQUIRES_OK(context, - GetTableAttrs(context, &project_id, &dataset_id, &table_id, - ×tamp_millis, &columns, &test_end_point)); - OP_REQUIRES_OK(context, - BigQueryTableAccessor::New( - project_id, dataset_id, table_id, timestamp_millis, - kDefaultRowBufferSize, test_end_point, columns, - BigQueryTablePartition(), &bigquery_table_accessor_)); - - SetReaderFactory([this]() { - return new BigQueryReader(bigquery_table_accessor_.get(), name()); - }); - } - - private: - std::unique_ptr bigquery_table_accessor_; -}; - -REGISTER_KERNEL_BUILDER(Name("BigQueryReader").Device(DEVICE_CPU), - BigQueryReaderOp); - -class GenerateBigQueryReaderPartitionsOp : public OpKernel { - public: - explicit GenerateBigQueryReaderPartitionsOp(OpKernelConstruction* context) - : OpKernel(context) { - string project_id; - string dataset_id; - string table_id; - int64 timestamp_millis; - std::vector columns; - string test_end_point; - - OP_REQUIRES_OK(context, - GetTableAttrs(context, &project_id, &dataset_id, &table_id, - ×tamp_millis, &columns, &test_end_point)); - OP_REQUIRES_OK(context, - BigQueryTableAccessor::New( - project_id, dataset_id, table_id, timestamp_millis, - kDefaultRowBufferSize, test_end_point, columns, - BigQueryTablePartition(), &bigquery_table_accessor_)); - OP_REQUIRES_OK(context, InitializeNumberOfPartitions(context)); - OP_REQUIRES_OK(context, InitializeTotalNumberOfRows()); - } - - void Compute(OpKernelContext* context) override { - const int64 partition_size = tensorflow::MathUtil::CeilOfRatio( - total_num_rows_, num_partitions_); - Tensor* output_tensor = nullptr; - OP_REQUIRES_OK(context, - context->allocate_output(0, TensorShape({num_partitions_}), - &output_tensor)); - - auto output = output_tensor->template flat(); - for (int64 i = 0; i < num_partitions_; ++i) { - BigQueryTablePartition partition; - partition.set_start_index(i * partition_size); - partition.set_end_index( - std::min(total_num_rows_, (i + 1) * partition_size) - 1); - output(i) = partition.SerializeAsString(); - } - } - - private: - Status InitializeTotalNumberOfRows() { - total_num_rows_ = bigquery_table_accessor_->total_num_rows(); - if (total_num_rows_ <= 0) { - return errors::FailedPrecondition("Invalid total number of rows."); - } - return Status::OK(); - } - - Status InitializeNumberOfPartitions(OpKernelConstruction* context) { - TF_RETURN_IF_ERROR(context->GetAttr("num_partitions", &num_partitions_)); - if (num_partitions_ <= 0) { - return errors::FailedPrecondition("Invalid number of partitions."); - } - return Status::OK(); - } - - int64 num_partitions_; - int64 total_num_rows_; - std::unique_ptr bigquery_table_accessor_; -}; - -REGISTER_KERNEL_BUILDER( - Name("GenerateBigQueryReaderPartitions").Device(DEVICE_CPU), - GenerateBigQueryReaderPartitionsOp); - -} // namespace tensorflow diff --git a/tensorflow/core/kernels/cloud/bigquery_table_accessor.cc b/tensorflow/core/kernels/cloud/bigquery_table_accessor.cc deleted file mode 100644 index 3e9adfa372..0000000000 --- a/tensorflow/core/kernels/cloud/bigquery_table_accessor.cc +++ /dev/null @@ -1,410 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/kernels/cloud/bigquery_table_accessor.h" - -#include "tensorflow/core/example/feature.pb.h" -#include "tensorflow/core/lib/strings/numbers.h" - -namespace tensorflow { - -namespace { - -constexpr size_t kBufferSize = 1024 * 1024; // In bytes. -const string kBigQueryEndPoint = "https://www.googleapis.com/bigquery/v2"; - -bool IsPartitionEmpty(const BigQueryTablePartition& partition) { - if (partition.end_index() != -1 && - partition.end_index() < partition.start_index()) { - return true; - } - return false; -} - -Status ParseJson(StringPiece json, Json::Value* result) { - Json::Reader reader; - if (!reader.parse(json.ToString(), *result)) { - return errors::Internal("Couldn't parse JSON response from BigQuery."); - } - return Status::OK(); -} - -string ColumnTypeToString(BigQueryTableAccessor::ColumnType enum_type) { - switch (enum_type) { - case BigQueryTableAccessor::ColumnType::kRecord: - return "RECORD"; - case BigQueryTableAccessor::ColumnType::kString: - return "STRING"; - case BigQueryTableAccessor::ColumnType::kBytes: - return "BYTES"; - case BigQueryTableAccessor::ColumnType::kInteger: - return "INTEGER"; - case BigQueryTableAccessor::ColumnType::kFloat: - return "FLOAT"; - case BigQueryTableAccessor::ColumnType::kBoolean: - return "BOOLEAN"; - case BigQueryTableAccessor::ColumnType::kTimestamp: - return "TIMESTAMP"; - case BigQueryTableAccessor::ColumnType::kDate: - return "DATE"; - case BigQueryTableAccessor::ColumnType::kTime: - return "TIME"; - case BigQueryTableAccessor::ColumnType::kDatetime: - return "DATETIME"; - case BigQueryTableAccessor::ColumnType::kNone: - return "NONE"; - } -} - -Status ParseColumnType(const string& type, - BigQueryTableAccessor::ColumnType* enum_type) { - if (type == "RECORD") { - *enum_type = BigQueryTableAccessor::ColumnType::kRecord; - } else if (type == "STRING") { - *enum_type = BigQueryTableAccessor::ColumnType::kString; - } else if (type == "BYTES") { - *enum_type = BigQueryTableAccessor::ColumnType::kBytes; - } else if (type == "INTEGER") { - *enum_type = BigQueryTableAccessor::ColumnType::kInteger; - } else if (type == "FLOAT") { - *enum_type = BigQueryTableAccessor::ColumnType::kFloat; - } else if (type == "BOOLEAN") { - *enum_type = BigQueryTableAccessor::ColumnType::kBoolean; - } else if (type == "TIMESTAMP") { - *enum_type = BigQueryTableAccessor::ColumnType::kTimestamp; - } else if (type == "DATE") { - *enum_type = BigQueryTableAccessor::ColumnType::kDate; - } else if (type == "TIME") { - *enum_type = BigQueryTableAccessor::ColumnType::kTime; - } else if (type == "DATETIME") { - *enum_type = BigQueryTableAccessor::ColumnType::kDatetime; - } else { - return errors::Internal( - strings::StrCat("Could not parse column type ", type)); - } - return Status::OK(); -} - -} // namespace - -Status BigQueryTableAccessor::New( - const string& project_id, const string& dataset_id, const string& table_id, - int64 timestamp_millis, int64 row_buffer_size, const string& end_point, - const std::vector& columns, const BigQueryTablePartition& partition, - std::unique_ptr* accessor) { - return New(project_id, dataset_id, table_id, timestamp_millis, - row_buffer_size, end_point, columns, partition, nullptr, nullptr, - accessor); -} - -Status BigQueryTableAccessor::New( - const string& project_id, const string& dataset_id, const string& table_id, - int64 timestamp_millis, int64 row_buffer_size, const string& end_point, - const std::vector& columns, const BigQueryTablePartition& partition, - std::unique_ptr auth_provider, - std::unique_ptr http_request_factory, - std::unique_ptr* accessor) { - if (timestamp_millis <= 0) { - return errors::InvalidArgument( - "Cannot use zero or negative timestamp to query a table."); - } - const string& big_query_end_point = - end_point.empty() ? kBigQueryEndPoint : end_point; - if (auth_provider == nullptr && http_request_factory == nullptr) { - accessor->reset(new BigQueryTableAccessor( - project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, - big_query_end_point, columns, partition)); - } else { - accessor->reset(new BigQueryTableAccessor( - project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, - big_query_end_point, columns, partition, std::move(auth_provider), - std::move(http_request_factory))); - } - return (*accessor)->ReadSchema(); -} - -BigQueryTableAccessor::BigQueryTableAccessor( - const string& project_id, const string& dataset_id, const string& table_id, - int64 timestamp_millis, int64 row_buffer_size, const string& end_point, - const std::vector& columns, const BigQueryTablePartition& partition) - : BigQueryTableAccessor( - project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, - end_point, columns, partition, - std::unique_ptr(new GoogleAuthProvider()), - std::unique_ptr(new HttpRequest::Factory())) { - row_buffer_.resize(row_buffer_size); -} - -BigQueryTableAccessor::BigQueryTableAccessor( - const string& project_id, const string& dataset_id, const string& table_id, - int64 timestamp_millis, int64 row_buffer_size, const string& end_point, - const std::vector& columns, const BigQueryTablePartition& partition, - std::unique_ptr auth_provider, - std::unique_ptr http_request_factory) - : project_id_(project_id), - dataset_id_(dataset_id), - table_id_(table_id), - timestamp_millis_(timestamp_millis), - columns_(columns.begin(), columns.end()), - bigquery_end_point_(end_point), - partition_(partition), - auth_provider_(std::move(auth_provider)), - http_request_factory_(std::move(http_request_factory)) { - row_buffer_.resize(row_buffer_size); - Reset(); -} - -Status BigQueryTableAccessor::SetPartition( - const BigQueryTablePartition& partition) { - if (partition.start_index() < 0) { - return errors::InvalidArgument("Start index cannot be negative."); - } - partition_ = partition; - Reset(); - return Status::OK(); -} - -void BigQueryTableAccessor::Reset() { - first_buffered_row_index_ = partition_.start_index(); - next_row_in_buffer_ = -1; - next_page_token_ = ""; -} - -Status BigQueryTableAccessor::ReadRow(int64* row_id, Example* example) { - if (Done()) { - return errors::OutOfRange("Reached end of table ", FullTableName()); - } - - // If the next row is already fetched and cached, return the row from the - // buffer. Otherwise, fill up the row buffer from BigQuery and return a row. - if (next_row_in_buffer_ != -1 && - next_row_in_buffer_ < ComputeMaxResultsArg()) { - *row_id = first_buffered_row_index_ + next_row_in_buffer_; - *example = row_buffer_[next_row_in_buffer_]; - next_row_in_buffer_++; - } else { - string auth_token; - TF_RETURN_IF_ERROR( - AuthProvider::GetToken(auth_provider_.get(), &auth_token)); - - std::unique_ptr request(http_request_factory_->Create()); - std::vector output_buffer; - output_buffer.reserve(kBufferSize); - TF_RETURN_IF_ERROR(request->Init()); - - // The first time that we access BigQuery there is no page token. After that - // we use the page token (which returns rows faster). - if (!next_page_token_.empty()) { - TF_RETURN_IF_ERROR(request->SetUri(strings::StrCat( - BigQueryUriPrefix(), "data?maxResults=", ComputeMaxResultsArg(), - "&pageToken=", request->EscapeString(next_page_token_)))); - first_buffered_row_index_ += row_buffer_.size(); - } else { - TF_RETURN_IF_ERROR(request->SetUri(strings::StrCat( - BigQueryUriPrefix(), "data?maxResults=", ComputeMaxResultsArg(), - "&startIndex=", first_buffered_row_index_))); - } - TF_RETURN_IF_ERROR(request->AddAuthBearerHeader(auth_token)); - TF_RETURN_IF_ERROR(request->SetResultBuffer(&output_buffer)); - TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when reading rows from ", - FullTableName()); - - // Parse the returned row. - StringPiece response_piece = - StringPiece(&output_buffer[0], output_buffer.size()); - Json::Value root; - TF_RETURN_IF_ERROR(ParseJson(response_piece, &root)); - for (unsigned int i = 0; i < root["rows"].size(); ++i) { - row_buffer_[i].Clear(); - TF_RETURN_IF_ERROR( - ParseColumnValues(root["rows"][i], schema_root_, &row_buffer_[i])); - } - - next_page_token_ = root["pageToken"].asString(); - *row_id = first_buffered_row_index_; - *example = row_buffer_[0]; - next_row_in_buffer_ = 1; - } - return Status::OK(); -} - -int64 BigQueryTableAccessor::ComputeMaxResultsArg() { - if (partition_.end_index() == -1) { - return row_buffer_.size(); - } - if (IsPartitionEmpty(partition_)) { - return 0; - } - return std::min(static_cast(row_buffer_.size()), - static_cast(partition_.end_index() - - partition_.start_index() + 1)); -} - -Status BigQueryTableAccessor::ParseColumnValues( - const Json::Value& value, const SchemaNode& root_schema_node, - Example* example) { - if (value.empty()) { - return Status::OK(); - } - if (value["f"].isNull()) { - return Status::OK(); - } - int value_index = 0; - for (const auto& schema_node : root_schema_node.schema_nodes) { - if (value["f"][value_index].isNull()) { - value_index++; - continue; - } - - if (schema_node.type == ColumnType::kRecord) { - TF_RETURN_IF_ERROR(ParseColumnValues(value["f"][value_index]["v"], - schema_node, example)); - } else { - // Append the column value only if user has requested the column. - if (columns_.empty() || - columns_.find(schema_node.name) != columns_.end()) { - TF_RETURN_IF_ERROR(AppendValueToExample(schema_node.name, - value["f"][value_index]["v"], - schema_node.type, example)); - } - } - value_index++; - } - return Status::OK(); -} - -Status BigQueryTableAccessor::ReadSchema() { - string auth_token; - TF_RETURN_IF_ERROR(AuthProvider::GetToken(auth_provider_.get(), &auth_token)); - - // Send a request to read the schema. - std::unique_ptr request(http_request_factory_->Create()); - std::vector output_buffer; - output_buffer.reserve(kBufferSize); - TF_RETURN_IF_ERROR(request->Init()); - TF_RETURN_IF_ERROR(request->SetUri(BigQueryUriPrefix())); - TF_RETURN_IF_ERROR(request->AddAuthBearerHeader(auth_token)); - TF_RETURN_IF_ERROR(request->SetResultBuffer(&output_buffer)); - TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when reading schema for ", - FullTableName()); - - // Parse the schema. - StringPiece response_piece = - StringPiece(&output_buffer[0], output_buffer.size()); - - Json::Value root; - TF_RETURN_IF_ERROR(ParseJson(response_piece, &root)); - const auto& columns = root["schema"]["fields"]; - string column_name_prefix = ""; - schema_root_ = {"", ColumnType::kNone}; - TF_RETURN_IF_ERROR( - ExtractColumnType(columns, column_name_prefix, &schema_root_)); - if (root["numRows"].isNull()) { - return errors::Internal("Number of rows cannot be extracted for table ", - FullTableName()); - } - strings::safe_strto64(root["numRows"].asString().c_str(), &total_num_rows_); - return Status::OK(); -} - -Status BigQueryTableAccessor::ExtractColumnType( - const Json::Value& columns, const string& column_name_prefix, - SchemaNode* root) { - for (auto columns_it = columns.begin(); columns_it != columns.end(); - ++columns_it) { - if ((*columns_it)["mode"].asString() == "REPEATED") { - return errors::Unimplemented(strings::StrCat( - "Tables with repeated columns are not supported: ", FullTableName())); - } - ColumnType type; - const string current_column_name = strings::StrCat( - column_name_prefix, (*columns_it)["name"].asString().c_str()); - TF_RETURN_IF_ERROR( - ParseColumnType((*columns_it)["type"].asString().c_str(), &type)); - root->schema_nodes.emplace_back(current_column_name, type); - if (type == ColumnType::kRecord) { - const auto new_prefix = strings::StrCat(current_column_name, "."); - TF_RETURN_IF_ERROR(ExtractColumnType((*columns_it)["fields"], new_prefix, - &root->schema_nodes.back())); - } - } - return Status::OK(); -} - -Status BigQueryTableAccessor::AppendValueToExample( - const string& column_name, const Json::Value& column_value, - const BigQueryTableAccessor::ColumnType type, Example* example) { - if (column_value.isNull()) { - return Status::OK(); - } - auto& feature = - (*example->mutable_features()->mutable_feature())[column_name]; - - switch (type) { - case BigQueryTableAccessor::ColumnType::kNone: - case BigQueryTableAccessor::ColumnType::kRecord: - return errors::Unimplemented("Cannot append type to an example."); - case BigQueryTableAccessor::ColumnType::kTimestamp: - case BigQueryTableAccessor::ColumnType::kDate: - case BigQueryTableAccessor::ColumnType::kTime: - case BigQueryTableAccessor::ColumnType::kDatetime: - case BigQueryTableAccessor::ColumnType::kString: - case BigQueryTableAccessor::ColumnType::kBytes: - feature.mutable_bytes_list()->add_value(column_value.asString()); - break; - case BigQueryTableAccessor::ColumnType::kBoolean: - feature.mutable_int64_list()->add_value( - column_value.asString() == "false" ? 0 : 1); - break; - case BigQueryTableAccessor::ColumnType::kInteger: - int64 column_value_int64; - if (!strings::safe_strto64(column_value.asString().c_str(), - &column_value_int64)) { - return errors::Internal("Cannot convert value to integer ", - column_value.asString().c_str()); - } - feature.mutable_int64_list()->add_value(column_value_int64); - break; - case BigQueryTableAccessor::ColumnType::kFloat: - // BigQuery float is actually a double. - double column_value_double; - if (!strings::safe_strtod(column_value.asString().c_str(), - &column_value_double)) { - return errors::Internal("Cannot convert value to double: ", - column_value.asString().c_str()); - } - feature.mutable_float_list()->add_value( - static_cast(column_value_double)); - break; - } - return Status::OK(); -} - -string BigQueryTableAccessor::BigQueryTableAccessor::BigQueryUriPrefix() { - HttpRequest request; - return strings::StrCat(bigquery_end_point_, "/projects/", - request.EscapeString(project_id_), "/datasets/", - request.EscapeString(dataset_id_), "/tables/", - request.EscapeString(table_id_), "/"); -} - -bool BigQueryTableAccessor::Done() { - return (total_num_rows_ <= first_buffered_row_index_ + next_row_in_buffer_) || - IsPartitionEmpty(partition_) || - (partition_.end_index() != -1 && - partition_.end_index() < - first_buffered_row_index_ + next_row_in_buffer_); -} - -} // namespace tensorflow diff --git a/tensorflow/core/kernels/cloud/bigquery_table_accessor.h b/tensorflow/core/kernels/cloud/bigquery_table_accessor.h deleted file mode 100644 index 33d1905b8a..0000000000 --- a/tensorflow/core/kernels/cloud/bigquery_table_accessor.h +++ /dev/null @@ -1,207 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_PARTITION_ACCESSOR_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_PARTITION_ACCESSOR_H_ - -#include -#include -#include -#include "tensorflow/core/example/example.pb.h" -#include "tensorflow/core/kernels/cloud/bigquery_table_partition.pb.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/platform/cloud/google_auth_provider.h" -#include "tensorflow/core/platform/cloud/http_request.h" - -namespace tensorflow { - -/// This class facilitates accessing BigQuery tables. -/// -/// Notes: -/// - Nested fields are not supported. -/// - BigQuery 'Record's are automatically flattened, -/// - BigQuery float type is a double but is converted to a C++ float in this -/// class. -/// - It is possible for a table snapshot to go out-of-scope in the BigQuery -/// service while accessing the table if a very old timestamp is used. For -/// exact details, see 'Table Decorators' in BigQuery docs. -class BigQueryTableAccessor { - public: - // Column types supported by BigQuery. - enum class ColumnType { - kString = 0, - kBytes, - kInteger, - kFloat, - kBoolean, - kTimestamp, - kDate, - kTime, - kDatetime, - kRecord, - kNone - }; - - /// \brief Creates a new BigQueryTableAccessor object. - // - // We do not allow relative (negative or zero) snapshot times here since we - // want to have a consistent snapshot of the table for the lifetime of this - // object. - // Use end_point if you want to connect to a different end point than the - // official BigQuery end point. Otherwise send an empty string. - static Status New(const string& project_id, const string& dataset_id, - const string& table_id, int64 timestamp_millis, - int64 row_buffer_size, const string& end_point, - const std::vector& columns, - const BigQueryTablePartition& partition, - std::unique_ptr* accessor); - - /// \brief Starts reading a new partition. - Status SetPartition(const BigQueryTablePartition& partition); - - /// \brief Returns true if there are more rows available in the current - /// partition. - bool Done(); - - /// \brief Returns a single row as example proto. - /// - /// This function will return an error if the table snapshot goes out of scope - /// in the BigQuery service. - Status ReadRow(int64* row_id, Example* example); - - /// \brief Returns total number of rows in the table. - int64 total_num_rows() { return total_num_rows_; } - - virtual ~BigQueryTableAccessor() {} - - private: - friend class BigQueryTableAccessorTest; - - // This struct encapsulates schema nodes for a BigQuery table. - struct SchemaNode { - SchemaNode() {} - SchemaNode(const string& name, ColumnType type) : name(name), type(type) {} - - string name; - ColumnType type; - std::vector schema_nodes; - }; - - /// If nullptr is passed for http_request_factory and auth_provider the - /// default production ones are used. This can be used by tests to override - /// these two variables. - static Status New(const string& project_id, const string& dataset_id, - const string& table_id, int64 timestamp_millis, - int64 row_buffer_size, const string& end_point, - const std::vector& columns, - const BigQueryTablePartition& partition, - std::unique_ptr auth_provider, - std::unique_ptr http_request_factory, - std::unique_ptr* accessor); - - /// \brief Constructs an object for a given table and partition. - BigQueryTableAccessor(const string& project_id, const string& dataset_id, - const string& table_id, int64 timestamp_millis, - int64 row_buffer_size, const string& end_point, - const std::vector& columns, - const BigQueryTablePartition& partition); - - /// Used for unit testing. - BigQueryTableAccessor( - const string& project_id, const string& dataset_id, - const string& table_id, int64 timestamp_millis, int64 row_buffer_size, - const string& end_point, const std::vector& columns, - const BigQueryTablePartition& partition, - std::unique_ptr auth_provider, - std::unique_ptr http_request_factory); - - /// \brief Parses column values for a given row. - Status ParseColumnValues(const Json::Value& value, - const SchemaNode& root_schema_node, - Example* example); - - /// \brief Reads the table schema and stores it. - Status ReadSchema(); - - /// \brief Extracts column type from a column in schema. - Status ExtractColumnType(const Json::Value& columns, - const string& column_name_prefix, SchemaNode* root); - - /// \brief Appends a single BigQuery column Value to 'example' for a given - /// column. - Status AppendValueToExample(const string& column_name, - const Json::Value& column_value, - const BigQueryTableAccessor::ColumnType type, - Example* example); - - /// \brief Resets internal counters for reading a partition. - void Reset(); - - /// \brief Helper function that returns BigQuery http endpoint prefix. - string BigQueryUriPrefix(); - - /// \brief Computes the maxResults arg to send to BigQuery. - int64 ComputeMaxResultsArg(); - - /// \brief Returns full name of the underlying table name. - string FullTableName() { - return strings::StrCat(project_id_, ":", dataset_id_, ".", table_id_, "@", - timestamp_millis_); - } - - const string project_id_; - const string dataset_id_; - const string table_id_; - - // Snapshot timestamp. - const int64 timestamp_millis_; - - // Columns that should be read. Empty means all columns. - const std::set columns_; - - // HTTP address of BigQuery end point to use. - const string bigquery_end_point_; - - // Describes the portion of the table that we are currently accessing. - BigQueryTablePartition partition_; - - // Total number of rows in the underlying table. - int64 total_num_rows_ = 0; - - // Offset of the first row in the underlying row_buffer_. - int64 first_buffered_row_index_ = 0; - - // Offset of the next row in the row_buffer_. -1 indicates that this index - // is invalid. - int next_row_in_buffer_ = -1; - - // This buffer holds next rows to improve performance. Its size will be - // based on how much buffering was requested. - std::vector row_buffer_; - - // If next_page is set, it will used to read next batch of data. - string next_page_token_; - - // A tree representing the schema for the underlying table. - SchemaNode schema_root_; - - std::unique_ptr auth_provider_; - std::unique_ptr http_request_factory_; - - TF_DISALLOW_COPY_AND_ASSIGN(BigQueryTableAccessor); -}; - -} // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_PARTITION_ACCESSOR_H_ diff --git a/tensorflow/core/kernels/cloud/bigquery_table_accessor_test.cc b/tensorflow/core/kernels/cloud/bigquery_table_accessor_test.cc deleted file mode 100644 index 7591f9cfd5..0000000000 --- a/tensorflow/core/kernels/cloud/bigquery_table_accessor_test.cc +++ /dev/null @@ -1,431 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/kernels/cloud/bigquery_table_accessor.h" -#include "tensorflow/core/example/feature.pb.h" -#include "tensorflow/core/kernels/cloud/bigquery_table_accessor_test_data.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/gtl/stl_util.h" -#include "tensorflow/core/platform/cloud/http_request_fake.h" -#include "tensorflow/core/platform/protobuf.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { -namespace { - -constexpr char kTestProject[] = "test-project"; -constexpr char kTestDataset[] = "test-dataset"; -constexpr char kTestTable[] = "test-table"; - -static bool HasSubstr(const string& base, const string& substr) { - bool ok = StringPiece(base).contains(substr); - EXPECT_TRUE(ok) << base << ", expected substring " << substr; - return ok; -} - -class FakeAuthProvider : public AuthProvider { - public: - Status GetToken(string* token) override { - *token = "fake_token"; - return Status::OK(); - } -}; - -static string DeterministicSerialization(const tensorflow::Example& example) { - const int size = example.ByteSize(); - string result(size, '\0'); - ::tensorflow::protobuf::io::ArrayOutputStream array_stream( - gtl::string_as_array(&result), size); - ::tensorflow::protobuf::io::CodedOutputStream output_stream(&array_stream); - - output_stream.SetSerializationDeterministic(true); - example.SerializeWithCachedSizes(&output_stream); - EXPECT_FALSE(output_stream.HadError()); - EXPECT_EQ(size, output_stream.ByteCount()); - return result; -} - -} // namespace - -class BigQueryTableAccessorTest : public ::testing::Test { - protected: - BigQueryTableAccessor::SchemaNode GetSchema() { - return accessor_->schema_root_; - } - - Status CreateTableAccessor(const string& project_id, const string& dataset_id, - const string& table_id, int64 timestamp_millis, - int64 row_buffer_size, - const std::vector& columns, - const BigQueryTablePartition& partition) { - return BigQueryTableAccessor::New( - project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, "", - columns, partition, std::unique_ptr(new FakeAuthProvider), - std::unique_ptr( - new FakeHttpRequestFactory(&requests_)), - &accessor_); - } - - std::vector requests_; - std::unique_ptr accessor_; -}; - -TEST_F(BigQueryTableAccessorTest, NegativeTimestamp) { - const auto status = - CreateTableAccessor(kTestProject, kTestDataset, kTestTable, -1, 3, {}, - BigQueryTablePartition()); - EXPECT_TRUE(errors::IsInvalidArgument(status)); -} - -TEST_F(BigQueryTableAccessorTest, ZeroTimestamp) { - const auto status = - CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 0, 3, {}, - BigQueryTablePartition()); - EXPECT_TRUE(errors::IsInvalidArgument(status)); -} - -TEST_F(BigQueryTableAccessorTest, RepeatedFieldNoAllowedTest) { - requests_.emplace_back(new FakeHttpRequest( - "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" - "datasets/test-dataset/tables/test-table/\n" - "Auth Token: fake_token\n", - R"({ - "kind": "bigquery#table", - "etag": "\"4zcX32ezvFoFzxHoG04qJqKZk6c/MTQ1Nzk3NTgwNzE4Mw\"", - "id": "test-project:test-dataset.test-table", - "schema": { - "fields": [ - { - "name": "int_field", - "type": "INTEGER", - "mode": "REPEATED" - }] - }, - "numRows": "10" - })")); - const auto status = - CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 3, {}, - BigQueryTablePartition()); - EXPECT_TRUE(errors::IsUnimplemented(status)); - EXPECT_TRUE(HasSubstr(status.error_message(), - "Tables with repeated columns are not supported")); -} - -TEST_F(BigQueryTableAccessorTest, ValidSchemaTest) { - requests_.emplace_back(new FakeHttpRequest( - "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" - "datasets/test-dataset/tables/test-table/\n" - "Auth Token: fake_token\n", - kSampleSchema)); - TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 3, - {}, BigQueryTablePartition())); - // Validate total number of rows. - EXPECT_EQ(4, accessor_->total_num_rows()); - - // Validate the schema. - const auto schema_root = GetSchema(); - EXPECT_EQ(schema_root.name, ""); - EXPECT_EQ(schema_root.type, BigQueryTableAccessor::ColumnType::kNone); - EXPECT_EQ(9, schema_root.schema_nodes.size()); - - EXPECT_EQ(schema_root.schema_nodes[0].name, "int_field"); - EXPECT_EQ(schema_root.schema_nodes[0].type, - BigQueryTableAccessor::ColumnType::kInteger); - - EXPECT_EQ(schema_root.schema_nodes[1].name, "str_field"); - EXPECT_EQ(schema_root.schema_nodes[1].type, - BigQueryTableAccessor::ColumnType::kString); - - EXPECT_EQ(1, schema_root.schema_nodes[2].schema_nodes.size()); - EXPECT_EQ(schema_root.schema_nodes[2].name, "rec_field"); - EXPECT_EQ(schema_root.schema_nodes[2].type, - BigQueryTableAccessor::ColumnType::kRecord); - - EXPECT_EQ(schema_root.schema_nodes[2].schema_nodes[0].name, - "rec_field.float_field"); - EXPECT_EQ(schema_root.schema_nodes[2].schema_nodes[0].type, - BigQueryTableAccessor::ColumnType::kFloat); - - EXPECT_EQ(schema_root.schema_nodes[3].name, "bool_field"); - EXPECT_EQ(schema_root.schema_nodes[3].type, - BigQueryTableAccessor::ColumnType::kBoolean); - - EXPECT_EQ(schema_root.schema_nodes[4].name, "bytes_field"); - EXPECT_EQ(schema_root.schema_nodes[4].type, - BigQueryTableAccessor::ColumnType::kBytes); - - EXPECT_EQ(schema_root.schema_nodes[5].name, "timestamp_field"); - EXPECT_EQ(schema_root.schema_nodes[5].type, - BigQueryTableAccessor::ColumnType::kTimestamp); - - EXPECT_EQ(schema_root.schema_nodes[6].name, "date_field"); - EXPECT_EQ(schema_root.schema_nodes[6].type, - BigQueryTableAccessor::ColumnType::kDate); - - EXPECT_EQ(schema_root.schema_nodes[7].name, "time_field"); - EXPECT_EQ(schema_root.schema_nodes[7].type, - BigQueryTableAccessor::ColumnType::kTime); - - EXPECT_EQ(schema_root.schema_nodes[8].name, "datetime_field"); - EXPECT_EQ(schema_root.schema_nodes[8].type, - BigQueryTableAccessor::ColumnType::kDatetime); -} - -TEST_F(BigQueryTableAccessorTest, ReadOneRowTest) { - requests_.emplace_back(new FakeHttpRequest( - "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" - "datasets/test-dataset/tables/test-table/\n" - "Auth Token: fake_token\n", - kSampleSchema)); - requests_.emplace_back(new FakeHttpRequest( - "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" - "datasets/test-dataset/tables/test-table/data?maxResults=1&startIndex=2\n" - "Auth Token: fake_token\n", - kTestRow)); - BigQueryTablePartition partition; - partition.set_start_index(2); - partition.set_end_index(2); - TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1, - {}, partition)); - int64 row_id; - Example example; - TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example)); - - // Validate returned result. - Example expected_example; - ASSERT_TRUE(protobuf::TextFormat::ParseFromString(kTestExampleProto, - &expected_example)); - EXPECT_EQ(DeterministicSerialization(expected_example), - DeterministicSerialization(example)); - EXPECT_EQ(row_id, 2); - EXPECT_TRUE(accessor_->Done()); -} - -TEST_F(BigQueryTableAccessorTest, ReadOneRowPartialTest) { - requests_.emplace_back(new FakeHttpRequest( - "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" - "datasets/test-dataset/tables/test-table/\n" - "Auth Token: fake_token\n", - kSampleSchema)); - requests_.emplace_back(new FakeHttpRequest( - "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" - "datasets/test-dataset/tables/test-table/data?maxResults=1&startIndex=2\n" - "Auth Token: fake_token\n", - kTestRow)); - BigQueryTablePartition partition; - partition.set_start_index(2); - partition.set_end_index(2); - TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1, - {"bool_field", "rec_field.float_field"}, - partition)); - int64 row_id; - Example example; - TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example)); - - // Validate returned result. - EXPECT_EQ(row_id, 2); - EXPECT_TRUE(accessor_->Done()); - Example expected_example; - ASSERT_TRUE(protobuf::TextFormat::ParseFromString(kTestPartialExampleProto, - &expected_example)); - EXPECT_EQ(DeterministicSerialization(expected_example), - DeterministicSerialization(example)); -} - -TEST_F(BigQueryTableAccessorTest, ReadOneRowWithNullsTest) { - requests_.emplace_back(new FakeHttpRequest( - "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" - "datasets/test-dataset/tables/test-table/\n" - "Auth Token: fake_token\n", - kSampleSchema)); - requests_.emplace_back(new FakeHttpRequest( - "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" - "datasets/test-dataset/tables/test-table/data?maxResults=1&startIndex=2\n" - "Auth Token: fake_token\n", - kTestRowWithNulls)); - BigQueryTablePartition partition; - partition.set_start_index(2); - partition.set_end_index(2); - TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1, - {}, partition)); - int64 row_id; - Example example; - TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example)); - - // Validate returned result. - Example expected_example; - ASSERT_TRUE(protobuf::TextFormat::ParseFromString(kTestExampleProtoWithNulls, - &expected_example)); - EXPECT_EQ(DeterministicSerialization(expected_example), - DeterministicSerialization(example)); - EXPECT_EQ(row_id, 2); - EXPECT_TRUE(accessor_->Done()); -} - -TEST_F(BigQueryTableAccessorTest, BrokenRowTest) { - requests_.emplace_back(new FakeHttpRequest( - "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" - "datasets/test-dataset/tables/test-table/\n" - "Auth Token: fake_token\n", - kSampleSchema)); - requests_.emplace_back(new FakeHttpRequest( - "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" - "datasets/test-dataset/tables/test-table/data?maxResults=1&startIndex=2\n" - "Auth Token: fake_token\n", - kBrokenTestRow)); - BigQueryTablePartition partition; - partition.set_start_index(2); - partition.set_end_index(2); - TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1, - {}, partition)); - int64 row_id; - Example example; - const auto status = accessor_->ReadRow(&row_id, &example); - EXPECT_TRUE(errors::IsInternal(status)); - EXPECT_TRUE( - HasSubstr(status.error_message(), "Cannot convert value to integer")); -} - -TEST_F(BigQueryTableAccessorTest, MultiplePagesTest) { - requests_.emplace_back(new FakeHttpRequest( - "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" - "datasets/test-dataset/tables/test-table/\n" - "Auth Token: fake_token\n", - kSampleSchema)); - requests_.emplace_back(new FakeHttpRequest( - "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" - "datasets/test-dataset/tables/test-table/data?maxResults=2&startIndex=1\n" - "Auth Token: fake_token\n", - kTestTwoRows)); - requests_.emplace_back(new FakeHttpRequest( - "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" - "datasets/test-dataset/tables/test-table/" - "data?maxResults=2&pageToken=next_page\n" - "Auth Token: fake_token\n", - kTestRowWithNulls)); - - BigQueryTablePartition partition; - partition.set_start_index(1); - partition.set_end_index(-1); - TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 2, - {}, partition)); - - int64 row_id; - Example example; - TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example)); - EXPECT_EQ(1, row_id); - EXPECT_FALSE(accessor_->Done()); - EXPECT_EQ( - (example.features().feature()).at("int_field").int64_list().value(0), - 1111); - - TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example)); - EXPECT_EQ(2, row_id); - EXPECT_FALSE(accessor_->Done()); - EXPECT_EQ(example.features().feature().at("int_field").int64_list().value(0), - 2222); - - TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example)); - EXPECT_EQ(3, row_id); - EXPECT_TRUE(accessor_->Done()); - Example expected_example; - ASSERT_TRUE(protobuf::TextFormat::ParseFromString(kTestExampleProtoWithNulls, - &expected_example)); - EXPECT_EQ(DeterministicSerialization(expected_example), - DeterministicSerialization(example)); - EXPECT_TRUE(errors::IsOutOfRange(accessor_->ReadRow(&row_id, &example))); -} - -TEST_F(BigQueryTableAccessorTest, SwitchingPartitionsTest) { - requests_.emplace_back(new FakeHttpRequest( - "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" - "datasets/test-dataset/tables/test-table/\n" - "Auth Token: fake_token\n", - kSampleSchema)); - requests_.emplace_back(new FakeHttpRequest( - "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" - "datasets/test-dataset/tables/test-table/data?maxResults=1&startIndex=0\n" - "Auth Token: fake_token\n", - kTestTwoRows)); - requests_.emplace_back(new FakeHttpRequest( - "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" - "datasets/test-dataset/tables/test-table/" - "data?maxResults=2&startIndex=3\n" - "Auth Token: fake_token\n", - kTestRowWithNulls)); - requests_.emplace_back(new FakeHttpRequest( - "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" - "datasets/test-dataset/tables/test-table/data?maxResults=2&startIndex=0\n" - "Auth Token: fake_token\n", - kTestTwoRows)); - - BigQueryTablePartition partition; - partition.set_start_index(0); - partition.set_end_index(0); - TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 2, - {}, partition)); - - int64 row_id; - Example example; - TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example)); - EXPECT_EQ(0, row_id); - EXPECT_TRUE(accessor_->Done()); - EXPECT_EQ(example.features().feature().at("int_field").int64_list().value(0), - 1111); - - partition.set_start_index(3); - partition.set_end_index(-1); - TF_EXPECT_OK(accessor_->SetPartition(partition)); - TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example)); - EXPECT_EQ(3, row_id); - EXPECT_TRUE(accessor_->Done()); - EXPECT_EQ(example.features().feature().at("int_field").int64_list().value(0), - 1234); - - partition.set_start_index(0); - partition.set_end_index(1); - TF_EXPECT_OK(accessor_->SetPartition(partition)); - TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example)); - EXPECT_EQ(0, row_id); - EXPECT_FALSE(accessor_->Done()); - EXPECT_EQ(example.features().feature().at("int_field").int64_list().value(0), - 1111); - TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example)); - EXPECT_EQ(1, row_id); - EXPECT_TRUE(accessor_->Done()); - EXPECT_EQ(example.features().feature().at("int_field").int64_list().value(0), - 2222); -} - -TEST_F(BigQueryTableAccessorTest, EmptyPartitionTest) { - requests_.emplace_back(new FakeHttpRequest( - "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" - "datasets/test-dataset/tables/test-table/\n" - "Auth Token: fake_token\n", - kSampleSchema)); - - BigQueryTablePartition partition; - partition.set_start_index(3); - partition.set_end_index(2); - TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1, - {}, partition)); - EXPECT_TRUE(accessor_->Done()); - - int64 row_id; - Example example; - EXPECT_TRUE(errors::IsOutOfRange(accessor_->ReadRow(&row_id, &example))); -} - -} // namespace tensorflow diff --git a/tensorflow/core/kernels/cloud/bigquery_table_accessor_test_data.h b/tensorflow/core/kernels/cloud/bigquery_table_accessor_test_data.h deleted file mode 100644 index e339ff25ff..0000000000 --- a/tensorflow/core/kernels/cloud/bigquery_table_accessor_test_data.h +++ /dev/null @@ -1,325 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_TABLE_ACCESSOR_TEST_DATA_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_TABLE_ACCESSOR_TEST_DATA_H_ - -#include - -namespace tensorflow { -namespace { - -const string kSampleSchema = R"({ - "kind": "bigquery#table", - "etag": "\"4zcX32ezvFoFzxHoG04qJqKZk6c/MTQ1Nzk3NTgwNzE4Mw\"", - "id": "test-project:test-dataset.test-table", - "schema": { - "fields": [ - { - "name": "int_field", - "type": "INTEGER", - "mode": "REQUIRED" - },{ - "name": "str_field", - "type": "STRING", - "mode": "NULLABLE" - },{ - "name": "rec_field", - "type": "RECORD", - "fields": [ - { - "name": "float_field", - "type": "FLOAT", - "mode": "NULLABLE" - }] - },{ - "name": "bool_field", - "type": "BOOLEAN", - "mode": "NULLABLE" - },{ - "name": "bytes_field", - "type": "BYTES", - "mode": "NULLABLE" - },{ - "name": "timestamp_field", - "type": "TIMESTAMP", - "mode": "NULLABLE" - },{ - "name": "date_field", - "type": "DATE", - "mode": "NULLABLE" - },{ - "name": "time_field", - "type": "TIME", - "mode": "NULLABLE" - },{ - "name": "datetime_field", - "type": "DATETIME", - "mode": "NULLABLE" - }] - }, - "numRows": "4" -})"; - -const string kTestRow = R"({ - "kind": "bigquery#table", - "etag": "\"4zcX32ezvFoFzxHoG04qJqKZk6c/MTQ1Nzk3NTgwNzE4Mw\"", - "id": "test-project:test-dataset.test-table", - "rows": [ - { - "f": [ - { - "v": "1234" - },{ - "v": "" - },{ - "v": { - "f": [ - { - "v": "1.23456" - }] - } - },{ - "v": "true" - },{ - "v": "01010100101" - },{ - "v": "timestamp" - },{ - "v": "date" - },{ - "v": "time" - },{ - "v": "datetime" - }]}]})"; - -const string kBrokenTestRow = R"({ - "kind": "bigquery#table", - "etag": "\"4zcX32ezvFoFzxHoG04qJqKZk6c/MTQ1Nzk3NTgwNzE4Mw\"", - "id": "test-project:test-dataset.test-table", - "rows": [ - { - "f": [ - { - "v": "1-234" // This does not parse as integer. - },{ - "v": "" - },{ - },{ - "v": "true" - },{ - "v": "01010100101" - },{ - "v": "timestamp" - },{ - "v": "date" - },{ - "v": "time" - },{ - "v": "datetime" - }]}]})"; - -const string kTestRowWithNulls = R"({ - "kind": "bigquery#table", - "etag": "\"4zcX32ezvFoFzxHoG04qJqKZk6c/MTQ1Nzk3NTgwNzE4Mw\"", - "id": "test-project:test-dataset.test-table", - "rows": [ - { - "f": [ - { - "v": "1234" - },{ - "v": "string" - },{ - "v": null - },{ - "v": "true" - },{ - "v": "01010100101" - },{ - "v": "" - },{ - "v": null - },{ - "v": null - },{ - "v": "datetime" - }]}]})"; - -// Example proto corresponding to kTestRow. -const string kTestExampleProto = R"(features { - feature { - key: "bool_field" - value { - int64_list { - value: 1 - } - } - } - feature { - key: "bytes_field" - value { - bytes_list { - value: "01010100101" - } - } - } - feature { - key: "date_field" - value { - bytes_list { - value: "date" - } - } - } - feature { - key: "datetime_field" - value { - bytes_list { - value: "datetime" - } - } - } - feature { - key: "int_field" - value { - int64_list { - value: 1234 - } - } - } - feature { - key: "rec_field.float_field" - value { - float_list { - value: 1.23456 - } - } - } - feature { - key: "str_field" - value { - bytes_list { - value: "" - } - } - } - feature { - key: "time_field" - value { - bytes_list { - value: "time" - } - } - } - feature { - key: "timestamp_field" - value { - bytes_list { - value: "timestamp" - } - } - } -} -)"; - -// Example proto corresponding to kTestRowWithNulls. -const string kTestExampleProtoWithNulls = R"(features { - feature { - key: "bool_field" - value { - int64_list { - value: 1 - } - } - } - feature { - key: "bytes_field" - value { - bytes_list { - value: "01010100101" - } - } - } - feature { - key: "datetime_field" - value { - bytes_list { - value: "datetime" - } - } - } - feature { - key: "int_field" - value { - int64_list { - value: 1234 - } - } - } - feature { - key: "timestamp_field" - value { - bytes_list { - value: "" - } - } - } - feature { - key: "str_field" - value { - bytes_list { - value: "string" - } - } - } -} -)"; - -// Example proto corresponding to part of kTestRow. -const string kTestPartialExampleProto = R"(features { - feature { - key: "bool_field" - value { - int64_list { - value: 1 - } - } - } - feature { - key: "rec_field.float_field" - value { - float_list { - value: 1.23456 - } - } - } -} -)"; - -const string kTestTwoRows = R"({ - "kind": "bigquery#table", - "etag": "\"4zcX32ezvFoFzxHoG04qJqKZk6c/MTQ1Nzk3NTgwNzE4Mw\"", - "pageToken": "next_page", - "id": "test-project:test-dataset.test-table", - "rows": [ - {"f": [{"v": "1111"},{},{},{},{},{},{},{},{}]}, - {"f": [{"v": "2222"},{},{},{},{},{},{},{},{}]} - ]})"; - -} // namespace -} // namepsace tensorflow - -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_TABLE_ACCESSOR_TEST_DATA_H_ diff --git a/tensorflow/core/kernels/cloud/bigquery_table_partition.proto b/tensorflow/core/kernels/cloud/bigquery_table_partition.proto deleted file mode 100644 index 2d9d1380db..0000000000 --- a/tensorflow/core/kernels/cloud/bigquery_table_partition.proto +++ /dev/null @@ -1,12 +0,0 @@ -syntax = "proto3"; - -package tensorflow; - -// This proto specifies a table partition in BigQuery. -message BigQueryTablePartition { - // [start_index, end_index] specify the boundaries of a partition. - // If end_index is -1, every row starting from start_index is part of the - // partition. - int64 start_index = 1; - int64 end_index = 2; -}; diff --git a/tensorflow/core/ops/cloud_ops.cc b/tensorflow/core/ops/cloud_ops.cc deleted file mode 100644 index 89f31a46ab..0000000000 --- a/tensorflow/core/ops/cloud_ops.cc +++ /dev/null @@ -1,88 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -/* This file registers all cloud ops. */ - -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/shape_inference.h" -namespace tensorflow { - -using shape_inference::InferenceContext; - -REGISTER_OP("BigQueryReader") - .Attr("container: string = ''") - .Attr("shared_name: string = ''") - .Attr("project_id: string") - .Attr("dataset_id: string") - .Attr("table_id: string") - .Attr("columns: list(string)") - .Attr("timestamp_millis: int") - .Attr("test_end_point: string = ''") - .Output("reader_handle: Ref(string)") - .SetIsStateful() - .SetShapeFn([](InferenceContext* c) { - c->set_output(0, c->Vector(2)); - return Status::OK(); - }) - .Doc(R"doc( -A Reader that outputs rows from a BigQuery table as tensorflow Examples. - -container: If non-empty, this reader is placed in the given container. - Otherwise, a default container is used. -shared_name: If non-empty, this reader is named in the given bucket - with this shared_name. Otherwise, the node name is used instead. -project_id: GCP project ID. -dataset_id: BigQuery Dataset ID. -table_id: Table to read. -columns: List of columns to read. Leave empty to read all columns. -timestamp_millis: Table snapshot timestamp in millis since epoch. Relative -(negative or zero) snapshot times are not allowed. For more details, see -'Table Decorators' in BigQuery docs. -test_end_point: Do not use. For testing purposes only. -reader_handle: The handle to reference the Reader. -)doc"); - -REGISTER_OP("GenerateBigQueryReaderPartitions") - .Attr("project_id: string") - .Attr("dataset_id: string") - .Attr("table_id: string") - .Attr("columns: list(string)") - .Attr("timestamp_millis: int") - .Attr("num_partitions: int") - .Attr("test_end_point: string = ''") - .Output("partitions: string") - .SetShapeFn([](InferenceContext* c) { - c->set_output(0, c->Vector(InferenceContext::kUnknownDim)); - return Status::OK(); - }) - .Doc(R"doc( -Generates serialized partition messages suitable for batch reads. - -This op should not be used directly by clients. Instead, the -bigquery_reader_ops.py file defines a clean interface to the reader. - -project_id: GCP project ID. -dataset_id: BigQuery Dataset ID. -table_id: Table to read. -columns: List of columns to read. Leave empty to read all columns. -timestamp_millis: Table snapshot timestamp in millis since epoch. Relative -(negative or zero) snapshot times are not allowed. For more details, see -'Table Decorators' in BigQuery docs. -num_partitions: Number of partitions to split the table into. -test_end_point: Do not use. For testing purposes only. -partitions: Serialized table partitions. -)doc"); - -} // namespace tensorflow diff --git a/tensorflow/core/platform/cloud/http_request.cc b/tensorflow/core/platform/cloud/http_request.cc index 8a8d1e448a..9267d3ea83 100644 --- a/tensorflow/core/platform/cloud/http_request.cc +++ b/tensorflow/core/platform/cloud/http_request.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/strings/scanner.h" #include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/public/version.h" diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl index a2c133b43a..338d9309e8 100644 --- a/tensorflow/core/platform/default/build_config.bzl +++ b/tensorflow/core/platform/default/build_config.bzl @@ -227,19 +227,27 @@ def tf_additional_core_deps(): # TODO(jart, jhseu): Delete when GCP is default on. def tf_additional_cloud_op_deps(): deps = [] - # TODO(hormati): Remove the comments below to enable BigQuery op. The op is - # not linked for now because it is under perf testing. - #if WITH_GCP_SUPPORT: - # deps = if_not_mobile(["//tensorflow/core/kernels/cloud:bigquery_reader_ops"]) + if WITH_GCP_SUPPORT: + deps = select({ + "//tensorflow:windows": [], + "//tensorflow:android": [], + "//tensorflow:ios": [], + "//conditions:default": + ["//tensorflow/contrib/cloud:bigquery_reader_ops_op_lib"], + }) return deps # TODO(jart, jhseu): Delete when GCP is default on. def tf_additional_cloud_kernel_deps(): deps = [] - # TODO(hormati): Remove the comments below to enable BigQuery op. The op is - # not linked for now because it is under perf testing. - #if WITH_GCP_SUPPORT: - # deps = if_not_mobile(["//tensorflow/core:cloud_ops_op_lib"]) + if WITH_GCP_SUPPORT: + deps = select({ + "//tensorflow:windows": [], + "//tensorflow:android": [], + "//tensorflow:ios": [], + "//conditions:default": + ["//tensorflow/contrib/cloud/kernels:bigquery_reader_ops"], + }) return deps def tf_lib_proto_parsing_deps(): diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 40ec2569de..7f416dc609 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -45,7 +45,6 @@ py_library( ":check_ops", ":client", ":client_testlib", - ":cloud_ops", ":confusion_matrix", ":control_flow_ops", ":errors", @@ -124,38 +123,6 @@ py_library( deps = [":platform_benchmark"], ) -py_library( - name = "cloud_ops", - srcs = [ - "ops/cloud/__init__.py", - "ops/cloud/bigquery_reader_ops.py", - "ops/cloud/cloud.py", - ], - srcs_version = "PY2AND3", - deps = [ - ":cloud_ops_gen", - ":framework_for_generated_wrappers", - ], -) - -tf_py_test( - name = "bigquery_reader_ops_test", - size = "small", - srcs = ["ops/cloud/bigquery_reader_ops_test.py"], - additional_deps = [ - ":array_ops", - ":client_testlib", - ":cloud_ops", - ":data_flow_ops", - ":io_ops", - ":parsing_ops", - ":util", - "//tensorflow/core/kernels/cloud:bigquery_reader_ops", - "//tensorflow/core:cloud_ops_op_lib", - ], - tags = ["manual"], -) - tf_py_test( name = "resource_loader_test", size = "small", @@ -1018,11 +985,6 @@ tf_gen_op_wrapper_private_py( visibility = ["//learning/brain/python/ops:__pkg__"], ) -tf_gen_op_wrapper_private_py( - name = "cloud_ops_gen", - require_shape_functions = True, -) - tf_gen_op_wrapper_private_py( name = "control_flow_ops_gen", require_shape_functions = True, diff --git a/tensorflow/python/ops/cloud/__init__.py b/tensorflow/python/ops/cloud/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tensorflow/python/ops/cloud/bigquery_reader_ops.py b/tensorflow/python/ops/cloud/bigquery_reader_ops.py deleted file mode 100644 index 7786aea025..0000000000 --- a/tensorflow/python/ops/cloud/bigquery_reader_ops.py +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""BigQuery reading support for TensorFlow.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.framework import ops -from tensorflow.python.ops import gen_cloud_ops -from tensorflow.python.ops import io_ops - - -class BigQueryReader(io_ops.ReaderBase): - """A Reader that outputs keys and tf.Example values from a BigQuery table. - - Note(1): This op is currently not linked into the binary. It will be linked - by default after more perf testing. - - Note(2): This op currently returns example proto as its output. This is not - final and we are experimenting with adding support for returning csv. Support - for example proto may be deprecated after that. - - Example use: - ```python - # Assume a BigQuery has the following schema, - # name STRING, - # age INT, - # state STRING - - # Create the parse_examples list of features. - features = dict( - name=tf.FixedLenFeature([1], tf.string), - age=tf.FixedLenFeature([1], tf.int32), - state=tf.FixedLenFeature([1], dtype=tf.string, default_value="UNK")) - - # Create a Reader. - reader = bigquery_reader_ops.BigQueryReader(project_id=PROJECT, - dataset_id=DATASET, - table_id=TABLE, - timestamp_millis=TIME, - num_partitions=NUM_PARTITIONS, - features=features) - - # Populate a queue with the BigQuery Table partitions. - queue = tf.training.string_input_producer(reader.partitions()) - - # Read and parse examples. - row_id, examples_serialized = reader.read(queue) - examples = tf.parse_example(examples_serialized, features=features) - - # Process the Tensors examples["name"], examples["age"], etc... - ``` - - Note that to create a reader a snapshot timestamp is necessary. This - will enable the reader to look at a consistent snapshot of the table. - For more information, see 'Table Decorators' in BigQuery docs. - - See ReaderBase for supported methods. - """ - - def __init__(self, - project_id, - dataset_id, - table_id, - timestamp_millis, - num_partitions, - features=None, - columns=None, - test_end_point=None, - name=None): - """Creates a BigQueryReader. - - Args: - project_id: GCP project ID. - dataset_id: BigQuery dataset ID. - table_id: BigQuery table ID. - timestamp_millis: timestamp to snapshot the table in milliseconds since - the epoch. Relative (negative or zero) snapshot times are not allowed. - For more details, see 'Table Decorators' in BigQuery docs. - num_partitions: Number of non-overlapping partitions to read from. - features: parse_example compatible dict from keys to `VarLenFeature` and - `FixedLenFeature` objects. Keys are read as columns from the db. - columns: list of columns to read, can be set iff features is None. - test_end_point: Used only for testing purposes (optional). - name: a name for the operation (optional). - - Raises: - TypeError: - If features is neither None nor a dict or - - If columns is is neither None nor a list or - - If both features and columns are None or set. - """ - if (features is None) == (columns is None): - raise TypeError("exactly one of features and columns must be set.") - - if features is not None: - if not isinstance(features, dict): - raise TypeError("features must be a dict.") - self._columns = list(features.keys()) - elif columns is not None: - if not isinstance(columns, list): - raise TypeError("columns must be a list.") - self._columns = columns - - self._project_id = project_id - self._dataset_id = dataset_id - self._table_id = table_id - self._timestamp_millis = timestamp_millis - self._num_partitions = num_partitions - self._test_end_point = test_end_point - - reader = gen_cloud_ops.big_query_reader( - name=name, - project_id=self._project_id, - dataset_id=self._dataset_id, - table_id=self._table_id, - timestamp_millis=self._timestamp_millis, - columns=self._columns, - test_end_point=self._test_end_point) - super(BigQueryReader, self).__init__(reader) - - def partitions(self, name=None): - """Returns serialized BigQueryTablePartition messages. - - These messages represent a non-overlapping division of a table for a - bulk read. - - Args: - name: a name for the operation (optional). - - Returns: - `1-D` string `Tensor` of serialized `BigQueryTablePartition` messages. - """ - return gen_cloud_ops.generate_big_query_reader_partitions( - name=name, - project_id=self._project_id, - dataset_id=self._dataset_id, - table_id=self._table_id, - timestamp_millis=self._timestamp_millis, - num_partitions=self._num_partitions, - test_end_point=self._test_end_point, - columns=self._columns) - - -ops.NotDifferentiable("BigQueryReader") diff --git a/tensorflow/python/ops/cloud/bigquery_reader_ops_test.py b/tensorflow/python/ops/cloud/bigquery_reader_ops_test.py deleted file mode 100644 index 141b0af901..0000000000 --- a/tensorflow/python/ops/cloud/bigquery_reader_ops_test.py +++ /dev/null @@ -1,286 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for BigQueryReader Op.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import json -import os -import re -import threading - -from six.moves import SimpleHTTPServer -from six.moves import socketserver - -from tensorflow.core.example import example_pb2 -from tensorflow.core.framework import types_pb2 -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import data_flow_ops -from tensorflow.python.ops import parsing_ops -from tensorflow.python.ops.cloud import cloud -from tensorflow.python.platform import test -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.util import compat - -_PROJECT = "test-project" -_DATASET = "test-dataset" -_TABLE = "test-table" -# List representation of the test rows in the 'test-table' in BigQuery. -# The schema for each row is: [int64, string, float]. -# The values for rows are generated such that some columns have null values. The -# general formula here is: -# - The int64 column is present in every row. -# - The string column is only avaiable in even rows. -# - The float column is only available in every third row. -_ROWS = [[0, "s_0", 0.1], [1, None, None], [2, "s_2", None], [3, None, 3.1], - [4, "s_4", None], [5, None, None], [6, "s_6", 6.1], [7, None, None], - [8, "s_8", None], [9, None, 9.1]] -# Schema for 'test-table'. -# The schema currently has three columns: int64, string, and float -_SCHEMA = { - "kind": "bigquery#table", - "id": "test-project:test-dataset.test-table", - "schema": { - "fields": [{ - "name": "int64_col", - "type": "INTEGER", - "mode": "NULLABLE" - }, { - "name": "string_col", - "type": "STRING", - "mode": "NULLABLE" - }, { - "name": "float_col", - "type": "FLOAT", - "mode": "NULLABLE" - }] - } -} - - -def _ConvertRowToExampleProto(row): - """Converts the input row to an Example proto. - - Args: - row: Input Row instance. - - Returns: - An Example proto initialized with row values. - """ - - example = example_pb2.Example() - example.features.feature["int64_col"].int64_list.value.append(row[0]) - if row[1] is not None: - example.features.feature["string_col"].bytes_list.value.append( - compat.as_bytes(row[1])) - if row[2] is not None: - example.features.feature["float_col"].float_list.value.append(row[2]) - return example - - -class FakeBigQueryServer(threading.Thread): - """Fake http server to return schema and data for sample table.""" - - def __init__(self, address, port): - """Creates a FakeBigQueryServer. - - Args: - address: Server address - port: Server port. Pass 0 to automatically pick an empty port. - """ - threading.Thread.__init__(self) - self.handler = BigQueryRequestHandler - self.httpd = socketserver.TCPServer((address, port), self.handler) - - def run(self): - self.httpd.serve_forever() - - def shutdown(self): - self.httpd.shutdown() - self.httpd.socket.close() - - -class BigQueryRequestHandler(SimpleHTTPServer.SimpleHTTPRequestHandler): - """Responds to BigQuery HTTP requests. - - Attributes: - num_rows: num_rows in the underlying table served by this class. - """ - - num_rows = 0 - - def do_GET(self): - if "data?maxResults=" not in self.path: - # This is a schema request. - _SCHEMA["numRows"] = self.num_rows - response = json.dumps(_SCHEMA) - else: - # This is a data request. - # - # Extract max results and start index. - max_results = int(re.findall(r"maxResults=(\d+)", self.path)[0]) - start_index = int(re.findall(r"startIndex=(\d+)", self.path)[0]) - - # Send the rows as JSON. - rows = [] - for row in _ROWS[start_index:start_index + max_results]: - row_json = { - "f": [{ - "v": str(row[0]) - }, { - "v": str(row[1]) if row[1] is not None else None - }, { - "v": str(row[2]) if row[2] is not None else None - }] - } - rows.append(row_json) - response = json.dumps({ - "kind": "bigquery#table", - "id": "test-project:test-dataset.test-table", - "rows": rows - }) - self.send_response(200) - self.end_headers() - self.wfile.write(compat.as_bytes(response)) - - -def _SetUpQueue(reader): - """Sets up a queue for a reader.""" - queue = data_flow_ops.FIFOQueue(8, [types_pb2.DT_STRING], shapes=()) - key, value = reader.read(queue) - queue.enqueue_many(reader.partitions()).run() - queue.close().run() - return key, value - - -class BigQueryReaderOpsTest(test.TestCase): - - def setUp(self): - super(BigQueryReaderOpsTest, self).setUp() - self.server = FakeBigQueryServer("127.0.0.1", 0) - self.server.start() - logging.info("server address is %s:%s", self.server.httpd.server_address[0], - self.server.httpd.server_address[1]) - # An override to bypass the GCP auth token retrieval logic - # in google_auth_provider.cc. - os.environ["GOOGLE_AUTH_TOKEN_FOR_TESTING"] = "not-used" - - def tearDown(self): - self.server.shutdown() - super(BigQueryReaderOpsTest, self).tearDown() - - def _ReadAndCheckRowsUsingFeatures(self, num_rows): - self.server.handler.num_rows = num_rows - - with self.test_session() as sess: - feature_configs = { - "int64_col": - parsing_ops.FixedLenFeature( - [1], dtype=dtypes.int64), - "string_col": - parsing_ops.FixedLenFeature( - [1], dtype=dtypes.string, default_value="s_default"), - } - reader = cloud.BigQueryReader( - project_id=_PROJECT, - dataset_id=_DATASET, - table_id=_TABLE, - num_partitions=4, - features=feature_configs, - timestamp_millis=1, - test_end_point=("%s:%s" % (self.server.httpd.server_address[0], - self.server.httpd.server_address[1]))) - - key, value = _SetUpQueue(reader) - - seen_rows = [] - features = parsing_ops.parse_example( - array_ops.reshape(value, [1]), feature_configs) - for _ in range(num_rows): - int_value, str_value = sess.run( - [features["int64_col"], features["string_col"]]) - - # Parse values returned from the session. - self.assertEqual(int_value.shape, (1, 1)) - self.assertEqual(str_value.shape, (1, 1)) - int64_col = int_value[0][0] - string_col = str_value[0][0] - seen_rows.append(int64_col) - - # Compare. - expected_row = _ROWS[int64_col] - self.assertEqual(int64_col, expected_row[0]) - self.assertEqual( - compat.as_str(string_col), ("s_%d" % int64_col) if expected_row[1] - else "s_default") - - self.assertItemsEqual(seen_rows, range(num_rows)) - - with self.assertRaisesOpError("is closed and has insufficient elements " - "\\(requested 1, current size 0\\)"): - sess.run([key, value]) - - def testReadingSingleRowUsingFeatures(self): - self._ReadAndCheckRowsUsingFeatures(1) - - def testReadingMultipleRowsUsingFeatures(self): - self._ReadAndCheckRowsUsingFeatures(10) - - def testReadingMultipleRowsUsingColumns(self): - num_rows = 10 - self.server.handler.num_rows = num_rows - - with self.test_session() as sess: - reader = cloud.BigQueryReader( - project_id=_PROJECT, - dataset_id=_DATASET, - table_id=_TABLE, - num_partitions=4, - columns=["int64_col", "float_col", "string_col"], - timestamp_millis=1, - test_end_point=("%s:%s" % (self.server.httpd.server_address[0], - self.server.httpd.server_address[1]))) - key, value = _SetUpQueue(reader) - seen_rows = [] - for row_index in range(num_rows): - returned_row_id, example_proto = sess.run([key, value]) - example = example_pb2.Example() - example.ParseFromString(example_proto) - self.assertIn("int64_col", example.features.feature) - feature = example.features.feature["int64_col"] - self.assertEqual(len(feature.int64_list.value), 1) - int64_col = feature.int64_list.value[0] - seen_rows.append(int64_col) - - # Create our expected Example. - expected_example = example_pb2.Example() - expected_example = _ConvertRowToExampleProto(_ROWS[int64_col]) - - # Compare. - self.assertProtoEquals(example, expected_example) - self.assertEqual(row_index, int(returned_row_id)) - - self.assertItemsEqual(seen_rows, range(num_rows)) - - with self.assertRaisesOpError("is closed and has insufficient elements " - "\\(requested 1, current size 0\\)"): - sess.run([key, value]) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/python/ops/cloud/cloud.py b/tensorflow/python/ops/cloud/cloud.py deleted file mode 100644 index eb917a987e..0000000000 --- a/tensorflow/python/ops/cloud/cloud.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Import cloud ops.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import sys - -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.python.ops.cloud.bigquery_reader_ops import * -# pylint: enable=wildcard-import - -from tensorflow.python.util.all_util import remove_undocumented - -_allowed_symbols = ['BigQueryReader'] -remove_undocumented(__name__, _allowed_symbols, [sys.modules[__name__]]) -- cgit v1.2.3