aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/bigtable
diff options
context:
space:
mode:
authorGravatar Brennan Saeta <saeta@google.com>2018-07-07 00:36:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-07 21:38:24 -0700
commit6d5b8b7cae669372df3f756f827c40b08a0d14a9 (patch)
tree9776fc3818bdea2bb9a154594d56e7e5c092317a /tensorflow/contrib/bigtable
parent82d122cd1efc905070ed86ed55951b7437209523 (diff)
[tf.data / Bigtable] Add max_receive_message_size connection parameter.
When storing images in Cloud Bigtable, the resulting gRPC messages are often larger than the default receive message max size value. This change makes the maximum receive message sizes configurable, and sets a more reasonable default for general TensorFlow use. PiperOrigin-RevId: 203569796
Diffstat (limited to 'tensorflow/contrib/bigtable')
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc18
-rw-r--r--tensorflow/contrib/bigtable/ops/bigtable_ops.cc1
-rw-r--r--tensorflow/contrib/bigtable/python/ops/bigtable_api.py20
3 files changed, 34 insertions, 5 deletions
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc
index 24be38758f..9b276ec676 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc
@@ -40,7 +40,16 @@ class BigtableClientOp : public OpKernel {
if (connection_pool_size_ == -1) {
connection_pool_size_ = 100;
}
- OP_REQUIRES(ctx, connection_pool_size_ > 0,
+
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("max_receive_message_size",
+ &max_receive_message_size_));
+ // If left unset by the client code, set it to a default of 100. Note: the
+ // cloud-cpp default of 4 concurrent connections is far too low for high
+ // performance streaming.
+ if (max_receive_message_size_ == -1) {
+ max_receive_message_size_ = 1 << 24; // 16 MBytes
+ }
+ OP_REQUIRES(ctx, max_receive_message_size_ > 0,
errors::InvalidArgument("connection_pool_size must be > 0"));
}
@@ -68,6 +77,12 @@ class BigtableClientOp : public OpKernel {
[this, ctx](
BigtableClientResource** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
auto client_options = google::cloud::bigtable::ClientOptions();
+ client_options.set_connection_pool_size(connection_pool_size_);
+ auto channel_args = client_options.channel_arguments();
+ channel_args.SetMaxReceiveMessageSize(
+ max_receive_message_size_);
+ channel_args.SetUserAgentPrefix("tensorflow");
+ client_options.set_channel_arguments(channel_args);
std::shared_ptr<google::cloud::bigtable::DataClient> client =
google::cloud::bigtable::CreateDefaultDataClient(
project_id_, instance_id_, std::move(client_options));
@@ -87,6 +102,7 @@ class BigtableClientOp : public OpKernel {
string project_id_;
string instance_id_;
int64 connection_pool_size_;
+ int32 max_receive_message_size_;
mutex mu_;
ContainerInfo cinfo_ GUARDED_BY(mu_);
diff --git a/tensorflow/contrib/bigtable/ops/bigtable_ops.cc b/tensorflow/contrib/bigtable/ops/bigtable_ops.cc
index 179963457c..36a392f2a4 100644
--- a/tensorflow/contrib/bigtable/ops/bigtable_ops.cc
+++ b/tensorflow/contrib/bigtable/ops/bigtable_ops.cc
@@ -23,6 +23,7 @@ REGISTER_OP("BigtableClient")
.Attr("project_id: string")
.Attr("instance_id: string")
.Attr("connection_pool_size: int")
+ .Attr("max_receive_message_size: int = -1")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.Output("client: resource")
diff --git a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py
index acf4d34e9d..a7ec3a1142 100644
--- a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py
+++ b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py
@@ -49,7 +49,11 @@ class BigtableClient(object):
`table` method to open a Bigtable Table.
"""
- def __init__(self, project_id, instance_id, connection_pool_size=None):
+ def __init__(self,
+ project_id,
+ instance_id,
+ connection_pool_size=None,
+ max_receive_message_size=None):
"""Creates a BigtableClient that can be used to open connections to tables.
Args:
@@ -57,6 +61,8 @@ class BigtableClient(object):
instance_id: A string representing the Bigtable instance to connect to.
connection_pool_size: (Optional.) A number representing the number of
concurrent connections to the Cloud Bigtable service to make.
+ max_receive_message_size: (Optional.) The maximum bytes received in a
+ single gRPC response.
Raises:
ValueError: if the arguments are invalid (e.g. wrong type, or out of
@@ -74,10 +80,16 @@ class BigtableClient(object):
connection_pool_size = -1
elif connection_pool_size < 1:
raise ValueError("`connection_pool_size` must be positive")
+
+ if max_receive_message_size is None:
+ max_receive_message_size = -1
+ elif max_receive_message_size < 1:
+ raise ValueError("`max_receive_message_size` must be positive")
+
self._connection_pool_size = connection_pool_size
- self._resource = gen_bigtable_ops.bigtable_client(project_id, instance_id,
- connection_pool_size)
+ self._resource = gen_bigtable_ops.bigtable_client(
+ project_id, instance_id, connection_pool_size, max_receive_message_size)
def table(self, name, snapshot=None):
"""Opens a table and returns a `BigTable` object.
@@ -452,7 +464,7 @@ class _BigtableSampleKeysDataset(_BigtableKeyDataset):
def _as_variant_tensor(self):
return gen_bigtable_ops.bigtable_sample_keys_dataset(
- table=self._table._resource) # pylint: disable=protected_access
+ table=self._table._resource) # pylint: disable=protected-access
class _BigtableLookupDataset(dataset_ops.Dataset):