diff options
author | Akshay Modi <nareshmodi@google.com> | 2018-07-09 15:34:27 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-09 15:37:56 -0700 |
commit | 7e02e69e138511065df9bfc51542f411db2bd298 (patch) | |
tree | c70333b67f6aeec55e2f291ea1173e92962dbe28 /tensorflow/core/distributed_runtime | |
parent | 3656bb80c0e2a0066f7fb8aafb48f10f821da301 (diff) |
Allow passing in an IPv6 address in server def.
I belive this will be required if (when?) the TPUClusterResolver returns IPv6 addresses.
PiperOrigin-RevId: 203842540
Diffstat (limited to 'tensorflow/core/distributed_runtime')
3 files changed, 18 insertions, 13 deletions
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc b/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc index 0ebc084cb6..b7eb3c9015 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc @@ -42,12 +42,12 @@ string MakeAddress(const string& job, int task) { return strings::StrCat("/job:", job, "/replica:0/task:", task); } +// Allows the host to be a raw IP (either v4 or v6). Status ValidateHostPortPair(const string& host_port) { uint32 port; - std::vector<string> parts = str_util::Split(host_port, ':'); - // Must be host:port, port must be a number, host must not contain a '/'. - if (parts.size() != 2 || !strings::safe_strtou32(parts[1], &port) || - parts[0].find("/") != string::npos) { + auto colon_index = host_port.find_last_of(':'); + if (!strings::safe_strtou32(host_port.substr(colon_index + 1), &port) || + host_port.substr(0, colon_index).find("/") != string::npos) { return errors::InvalidArgument("Could not interpret \"", host_port, "\" as a host-port pair."); } diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc index a17acc85b3..f07a5a0974 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc @@ -150,10 +150,15 @@ TEST(GrpcChannelTest, NewHostPortGrpcChannelValidation) { EXPECT_TRUE(NewHostPortGrpcChannel("127.0.0.1:2222", &mock_ptr).ok()); EXPECT_TRUE(NewHostPortGrpcChannel("example.com:2222", &mock_ptr).ok()); EXPECT_TRUE(NewHostPortGrpcChannel("fqdn.example.com.:2222", &mock_ptr).ok()); + EXPECT_TRUE(NewHostPortGrpcChannel("[2002:a9c:258e::]:2222", &mock_ptr).ok()); + EXPECT_TRUE(NewHostPortGrpcChannel("[::]:2222", &mock_ptr).ok()); EXPECT_FALSE(NewHostPortGrpcChannel("example.com/abc:2222", &mock_ptr).ok()); EXPECT_FALSE(NewHostPortGrpcChannel("127.0.0.1:2222/", &mock_ptr).ok()); EXPECT_FALSE(NewHostPortGrpcChannel("example.com/abc:", &mock_ptr).ok()); + EXPECT_FALSE(NewHostPortGrpcChannel("[::]/:2222", &mock_ptr).ok()); + EXPECT_FALSE(NewHostPortGrpcChannel("[::]:2222/", &mock_ptr).ok()); + EXPECT_FALSE(NewHostPortGrpcChannel("[::]:", &mock_ptr).ok()); } } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc index 2c833d11a9..db14f6473e 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc @@ -152,16 +152,14 @@ Status GrpcServer::Init( " was not defined in job \"", server_def_.job_name(), "\""); } - const std::vector<string> hostname_port = - str_util::Split(iter->second, ':'); - if (hostname_port.size() != 2 || - !strings::safe_strto32(hostname_port[1], &requested_port)) { + auto colon_index = iter->second.find_last_of(':'); + if (!strings::safe_strto32(iter->second.substr(colon_index + 1), + &requested_port)) { return errors::InvalidArgument( "Could not parse port for local server from \"", iter->second, - "\""); - } else { - break; + "\"."); } + break; } } if (requested_port == -1) { @@ -343,11 +341,13 @@ Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options, const string host_port = channel_cache_->TranslateTask(name_prefix); int requested_port; - if (!strings::safe_strto32(str_util::Split(host_port, ':')[1], + auto colon_index = host_port.find_last_of(':'); + if (!strings::safe_strto32(host_port.substr(colon_index + 1), &requested_port)) { return errors::Internal("Could not parse port for local server from \"", - channel_cache_->TranslateTask(name_prefix), "\"."); + host_port, "\"."); } + if (requested_port != bound_port_) { return errors::InvalidArgument("Requested port ", requested_port, " differs from expected port ", bound_port_); |