aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/debug
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2017-07-19 07:23:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-19 07:28:28 -0700
commit41803db36d4f4a3239bd81e5d460eb0e6e2eea88 (patch)
tree4c677032f4a6b607ec1bfa968f2a36efc01ba04d /tensorflow/core/debug
parentac7530e54ddaa17c17e070e5a141002c43b86275 (diff)
tfdbg: open-source C++ and Python libraries of gRPC debugger mode
with the exception of Windows the session_debug_grpc_test is temporarily disabled on mac pending pip install of futures and grpcio on all slaves. PiperOrigin-RevId: 162482541
Diffstat (limited to 'tensorflow/core/debug')
-rw-r--r--tensorflow/core/debug/BUILD42
-rw-r--r--tensorflow/core/debug/debug_grpc_io_utils_test.cc432
-rw-r--r--tensorflow/core/debug/debug_io_utils.cc45
-rw-r--r--tensorflow/core/debug/debug_io_utils.h4
4 files changed, 489 insertions, 34 deletions
diff --git a/tensorflow/core/debug/BUILD b/tensorflow/core/debug/BUILD
index dd26177d25..971576698e 100644
--- a/tensorflow/core/debug/BUILD
+++ b/tensorflow/core/debug/BUILD
@@ -15,7 +15,6 @@ package(
licenses(["notice"]) # Apache 2.0
-# Google-internal rules omitted.
load(
"//tensorflow:tensorflow.bzl",
"check_deps",
@@ -28,6 +27,7 @@ load(
load(
"//tensorflow/core:platform/default/build_config.bzl",
"tf_kernel_tests_linkstatic",
+ "tf_proto_library",
"tf_proto_library_cc",
)
load(
@@ -42,13 +42,25 @@ check_deps(
deps = ["//tensorflow/core:tensorflow"],
)
-tf_proto_library_cc(
+tf_proto_library(
name = "debug_service_proto",
- srcs = ["debug_service.proto"],
+ srcs = [
+ "debug_service.proto",
+ ],
has_services = 1,
cc_api_version = 2,
cc_grpc_version = 1,
- protodeps = ["//tensorflow/core:protos_all"],
+ protodeps = [
+ ":debugger_event_metadata_proto",
+ "//tensorflow/core:protos_all",
+ ],
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+tf_proto_library(
+ name = "debugger_event_metadata_proto",
+ srcs = ["debugger_event_metadata.proto"],
+ cc_api_version = 2,
)
cc_library(
@@ -196,6 +208,7 @@ tf_cc_test(
deps = [
":debug_grpc_testlib",
":debug_io_utils",
+ ":debug_service_proto_cc",
":debugger_event_metadata_proto_cc",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
@@ -254,10 +267,23 @@ tf_cc_test(
],
)
-tf_proto_library_cc(
- name = "debugger_event_metadata_proto",
- srcs = ["debugger_event_metadata.proto"],
- cc_api_version = 2,
+tf_cc_test(
+ name = "debug_grpc_io_utils_test",
+ size = "small",
+ srcs = ["debug_grpc_io_utils_test.cc"],
+ linkstatic = tf_kernel_tests_linkstatic(),
+ deps = [
+ ":debug_graph_utils",
+ ":debug_grpc_testlib",
+ ":debug_io_utils",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
)
# TODO(cais): Add the following back in when tfdbg is supported on Android.
diff --git a/tensorflow/core/debug/debug_grpc_io_utils_test.cc b/tensorflow/core/debug/debug_grpc_io_utils_test.cc
new file mode 100644
index 0000000000..6510424182
--- /dev/null
+++ b/tensorflow/core/debug/debug_grpc_io_utils_test.cc
@@ -0,0 +1,432 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/debug/debug_graph_utils.h"
+#include "tensorflow/core/debug/debug_grpc_testlib.h"
+#include "tensorflow/core/debug/debug_io_utils.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/tracing.h"
+
+namespace tensorflow {
+
+class GrpcDebugTest : public ::testing::Test {
+ protected:
+ struct ServerData {
+ int port;
+ string url;
+ std::unique_ptr<test::TestEventListenerImpl> server;
+ std::unique_ptr<thread::ThreadPool> thread_pool;
+ };
+
+ void SetUp() override {
+ ClearEnabledWatchKeys();
+ SetUpInProcessServer(&server_data_, 0);
+ }
+
+ void TearDown() override { TearDownInProcessServer(&server_data_); }
+
+ void SetUpInProcessServer(ServerData* server_data,
+ int64 server_start_delay_micros) {
+ server_data->port = testing::PickUnusedPortOrDie();
+ server_data->url = strings::StrCat("grpc://localhost:", server_data->port);
+ server_data->server.reset(new test::TestEventListenerImpl());
+
+ server_data->thread_pool.reset(
+ new thread::ThreadPool(Env::Default(), "test_server", 1));
+ server_data->thread_pool->Schedule(
+ [server_data, server_start_delay_micros]() {
+ Env::Default()->SleepForMicroseconds(server_start_delay_micros);
+ server_data->server->RunServer(server_data->port);
+ });
+ }
+
+ void TearDownInProcessServer(ServerData* server_data) {
+ server_data->server->StopServer();
+ server_data->thread_pool.reset();
+ }
+
+ void ClearEnabledWatchKeys() { DebugGrpcIO::ClearEnabledWatchKeys(); }
+
+ void CreateEmptyEnabledSet(const string& grpc_debug_url) {
+ DebugGrpcIO::CreateEmptyEnabledSet(grpc_debug_url);
+ }
+
+ const int64 GetChannelConnectionTimeoutMicros() {
+ return DebugGrpcIO::channel_connection_timeout_micros;
+ }
+
+ void SetChannelConnectionTimeoutMicros(const int64 timeout) {
+ DebugGrpcIO::channel_connection_timeout_micros = timeout;
+ }
+
+ ServerData server_data_;
+};
+
+TEST_F(GrpcDebugTest, ConnectionTimeoutWorks) {
+ // Use a short timeout so the test won't take too long.
+ const int64 kOriginalTimeoutMicros = GetChannelConnectionTimeoutMicros();
+ const int64 kShortTimeoutMicros = 500 * 1000;
+ SetChannelConnectionTimeoutMicros(kShortTimeoutMicros);
+ ASSERT_EQ(kShortTimeoutMicros, GetChannelConnectionTimeoutMicros());
+
+ const string& kInvalidGrpcUrl =
+ strings::StrCat("grpc://localhost:", testing::PickUnusedPortOrDie());
+ Tensor tensor(DT_FLOAT, TensorShape({1, 1}));
+ tensor.flat<float>()(0) = 42.0;
+ Status publish_status = DebugIO::PublishDebugTensor(
+ DebugNodeKey("/job:localhost/replica:0/task:0/cpu:0", "foo_tensor", 0,
+ "DebugIdentity"),
+ tensor, Env::Default()->NowMicros(), {kInvalidGrpcUrl});
+ SetChannelConnectionTimeoutMicros(kOriginalTimeoutMicros);
+ TF_ASSERT_OK(DebugIO::CloseDebugURL(kInvalidGrpcUrl));
+
+ ASSERT_FALSE(publish_status.ok());
+ const string expected_error_msg = strings::StrCat(
+ "Failed to connect to gRPC channel at ", kInvalidGrpcUrl.substr(7),
+ " within a timeout of ", kShortTimeoutMicros / 1e6, " s");
+ ASSERT_NE(string::npos,
+ publish_status.error_message().find(expected_error_msg));
+}
+
+TEST_F(GrpcDebugTest, ConnectionToDelayedStartingServerWorks) {
+ ServerData server_data;
+ // Server start will be delayed for 1 second.
+ SetUpInProcessServer(&server_data, 1 * 1000 * 1000);
+
+ Tensor tensor(DT_FLOAT, TensorShape({1, 1}));
+ tensor.flat<float>()(0) = 42.0;
+ const DebugNodeKey kDebugNodeKey("/job:localhost/replica:0/task:0/cpu:0",
+ "foo_tensor", 0, "DebugIdentity");
+ Status publish_status = DebugIO::PublishDebugTensor(
+ kDebugNodeKey, tensor, Env::Default()->NowMicros(), {server_data.url});
+ ASSERT_TRUE(publish_status.ok());
+ TF_ASSERT_OK(DebugIO::CloseDebugURL(server_data.url));
+
+ ASSERT_EQ(1, server_data.server->node_names.size());
+ ASSERT_EQ(1, server_data.server->output_slots.size());
+ ASSERT_EQ(1, server_data.server->debug_ops.size());
+ EXPECT_EQ(kDebugNodeKey.device_name, server_data.server->device_names[0]);
+ EXPECT_EQ(kDebugNodeKey.node_name, server_data.server->node_names[0]);
+ EXPECT_EQ(kDebugNodeKey.output_slot, server_data.server->output_slots[0]);
+ EXPECT_EQ(kDebugNodeKey.debug_op, server_data.server->debug_ops[0]);
+ TearDownInProcessServer(&server_data);
+}
+
+TEST_F(GrpcDebugTest, SendSingleDebugTensorViaGrpcTest) {
+ Tensor tensor(DT_FLOAT, TensorShape({1, 1}));
+ tensor.flat<float>()(0) = 42.0;
+ const DebugNodeKey kDebugNodeKey("/job:localhost/replica:0/task:0/cpu:0",
+ "foo_tensor", 0, "DebugIdentity");
+ TF_ASSERT_OK(DebugIO::PublishDebugTensor(
+ kDebugNodeKey, tensor, Env::Default()->NowMicros(), {server_data_.url}));
+ TF_ASSERT_OK(DebugIO::CloseDebugURL(server_data_.url));
+
+ // Verify that the expected debug tensor sending happened.
+ ASSERT_EQ(1, server_data_.server->node_names.size());
+ ASSERT_EQ(1, server_data_.server->output_slots.size());
+ ASSERT_EQ(1, server_data_.server->debug_ops.size());
+ EXPECT_EQ(kDebugNodeKey.device_name, server_data_.server->device_names[0]);
+ EXPECT_EQ(kDebugNodeKey.node_name, server_data_.server->node_names[0]);
+ EXPECT_EQ(kDebugNodeKey.output_slot, server_data_.server->output_slots[0]);
+ EXPECT_EQ(kDebugNodeKey.debug_op, server_data_.server->debug_ops[0]);
+}
+
+TEST_F(GrpcDebugTest, SendDebugTensorWithLargeStringAtIndex0ViaGrpcTest) {
+ Tensor tensor(DT_STRING, TensorShape({1, 1}));
+ tensor.flat<string>()(0) = string(5000 * 1024, 'A');
+ const DebugNodeKey kDebugNodeKey("/job:localhost/replica:0/task:0/cpu:0",
+ "foo_tensor", 0, "DebugIdentity");
+ const Status status = DebugIO::PublishDebugTensor(
+ kDebugNodeKey, tensor, Env::Default()->NowMicros(), {server_data_.url});
+ ASSERT_FALSE(status.ok());
+ ASSERT_NE(status.error_message().find("string value at index 0 from debug "
+ "node foo_tensor:0:DebugIdentity does "
+ "not fit gRPC message size limit"),
+ string::npos);
+ TF_ASSERT_OK(DebugIO::CloseDebugURL(server_data_.url));
+}
+
+TEST_F(GrpcDebugTest, SendDebugTensorWithLargeStringAtIndex1ViaGrpcTest) {
+ Tensor tensor(DT_STRING, TensorShape({1, 2}));
+ tensor.flat<string>()(0) = "A";
+ tensor.flat<string>()(1) = string(5000 * 1024, 'A');
+ const DebugNodeKey kDebugNodeKey("/job:localhost/replica:0/task:0/cpu:0",
+ "foo_tensor", 0, "DebugIdentity");
+ const Status status = DebugIO::PublishDebugTensor(
+ kDebugNodeKey, tensor, Env::Default()->NowMicros(), {server_data_.url});
+ ASSERT_FALSE(status.ok());
+ ASSERT_NE(status.error_message().find("string value at index 1 from debug "
+ "node foo_tensor:0:DebugIdentity does "
+ "not fit gRPC message size limit"),
+ string::npos);
+ TF_ASSERT_OK(DebugIO::CloseDebugURL(server_data_.url));
+}
+
+TEST_F(GrpcDebugTest, SendMultipleDebugTensorsSynchronizedViaGrpcTest) {
+ const int32 kSends = 4;
+
+ // Prepare the tensors to sent.
+ std::vector<Tensor> tensors;
+ for (int i = 0; i < kSends; ++i) {
+ Tensor tensor(DT_INT32, TensorShape({1, 1}));
+ tensor.flat<int>()(0) = i * i;
+ tensors.push_back(tensor);
+ }
+
+ thread::ThreadPool* tp =
+ new thread::ThreadPool(Env::Default(), "grpc_debug_test", kSends);
+
+ mutex mu;
+ Notification all_done;
+ int tensor_count GUARDED_BY(mu) = 0;
+ std::vector<Status> statuses GUARDED_BY(mu);
+
+ const std::vector<string> urls({server_data_.url});
+
+ // Set up the concurrent tasks of sending Tensors via an Event stream to the
+ // server.
+ auto fn = [this, &mu, &tensor_count, &tensors, &statuses, &all_done,
+ &urls]() {
+ int this_count;
+ {
+ mutex_lock l(mu);
+ this_count = tensor_count++;
+ }
+
+ // Different concurrent tasks will send different tensors.
+ const uint64 wall_time = Env::Default()->NowMicros();
+ Status publish_status = DebugIO::PublishDebugTensor(
+ DebugNodeKey("/job:localhost/replica:0/task:0/cpu:0",
+ strings::StrCat("synchronized_node_", this_count), 0,
+ "DebugIdentity"),
+ tensors[this_count], wall_time, urls);
+
+ {
+ mutex_lock l(mu);
+ statuses.push_back(publish_status);
+ if (this_count == kSends - 1 && !all_done.HasBeenNotified()) {
+ all_done.Notify();
+ }
+ }
+ };
+
+ // Schedule the concurrent tasks.
+ for (int i = 0; i < kSends; ++i) {
+ tp->Schedule(fn);
+ }
+
+ // Wait for all client tasks to finish.
+ all_done.WaitForNotification();
+ delete tp;
+
+ // Close the debug gRPC stream.
+ Status close_status = DebugIO::CloseDebugURL(server_data_.url);
+ ASSERT_TRUE(close_status.ok());
+
+ // Check all statuses from the PublishDebugTensor calls().
+ for (const Status& status : statuses) {
+ TF_ASSERT_OK(status);
+ }
+
+ // One prep tensor plus kSends concurrent tensors are expected.
+ ASSERT_EQ(kSends, server_data_.server->node_names.size());
+ for (size_t i = 0; i < server_data_.server->node_names.size(); ++i) {
+ std::vector<string> items =
+ str_util::Split(server_data_.server->node_names[i], '_');
+ int tensor_index;
+ strings::safe_strto32(items[2], &tensor_index);
+
+ ASSERT_EQ(TensorShape({1, 1}),
+ server_data_.server->debug_tensors[i].shape());
+ ASSERT_EQ(tensor_index * tensor_index,
+ server_data_.server->debug_tensors[i].flat<int>()(0));
+ }
+}
+
+TEST_F(GrpcDebugTest, SendeDebugTensorsThroughMultipleRoundsUsingGrpcGating) {
+ // Prepare the tensor to send.
+ const DebugNodeKey kDebugNodeKey("/job:localhost/replica:0/task:0/cpu:0",
+ "test_namescope/test_node", 0,
+ "DebugIdentity");
+ Tensor tensor(DT_INT32, TensorShape({1, 1}));
+ tensor.flat<int>()(0) = 42;
+
+ const std::vector<string> urls({server_data_.url});
+ for (int i = 0; i < 3; ++i) {
+ server_data_.server->ClearReceivedDebugData();
+ const uint64 wall_time = Env::Default()->NowMicros();
+
+ // On the 1st send (i == 0), gating is disabled, so data should be sent.
+ // On the 2nd send (i == 1), gating is enabled, and the server has enabled
+ // the watch key in the previous send, so data should be sent.
+ // On the 3rd send (i == 2), gating is enabled, but the server has disabled
+ // the watch key in the previous send, so data should not be sent.
+ const bool enable_gated_grpc = (i != 0);
+ TF_ASSERT_OK(DebugIO::PublishDebugTensor(kDebugNodeKey, tensor, wall_time,
+ urls, enable_gated_grpc));
+
+ server_data_.server->RequestDebugOpStateChangeAtNextStream(i == 0,
+ kDebugNodeKey);
+
+ // Close the debug gRPC stream.
+ Status close_status = DebugIO::CloseDebugURL(server_data_.url);
+ ASSERT_TRUE(close_status.ok());
+
+ // Check dumped files according to the expected gating results.
+ if (i < 2) {
+ ASSERT_EQ(1, server_data_.server->node_names.size());
+ ASSERT_EQ(1, server_data_.server->output_slots.size());
+ ASSERT_EQ(1, server_data_.server->debug_ops.size());
+ EXPECT_EQ(kDebugNodeKey.device_name,
+ server_data_.server->device_names[0]);
+ EXPECT_EQ(kDebugNodeKey.node_name, server_data_.server->node_names[0]);
+ EXPECT_EQ(kDebugNodeKey.output_slot,
+ server_data_.server->output_slots[0]);
+ EXPECT_EQ(kDebugNodeKey.debug_op, server_data_.server->debug_ops[0]);
+ } else {
+ ASSERT_EQ(0, server_data_.server->node_names.size());
+ }
+ }
+}
+
+TEST_F(GrpcDebugTest, TestGateDebugNodeOnEmptyEnabledSet) {
+ CreateEmptyEnabledSet("grpc://localhost:3333");
+
+ ASSERT_FALSE(DebugIO::IsDebugNodeGateOpen("foo:0:DebugIdentity",
+ {"grpc://localhost:3333"}));
+
+ // file:// debug URLs are not subject to grpc gating.
+ ASSERT_TRUE(DebugIO::IsDebugNodeGateOpen(
+ "foo:0:DebugIdentity", {"grpc://localhost:3333", "file:///tmp/tfdbg_1"}));
+}
+
+TEST_F(GrpcDebugTest, TestGateDebugNodeOnNonEmptyEnabledSet) {
+ const string kGrpcUrl1 = "grpc://localhost:3333";
+ const string kGrpcUrl2 = "grpc://localhost:3334";
+
+ DebugGrpcIO::EnableWatchKey(kGrpcUrl1, "foo:0:DebugIdentity");
+ DebugGrpcIO::EnableWatchKey(kGrpcUrl1, "bar:0:DebugIdentity");
+
+ ASSERT_FALSE(
+ DebugIO::IsDebugNodeGateOpen("foo:1:DebugIdentity", {kGrpcUrl1}));
+ ASSERT_FALSE(
+ DebugIO::IsDebugNodeGateOpen("foo:1:DebugNumericSummary", {kGrpcUrl1}));
+ ASSERT_FALSE(
+ DebugIO::IsDebugNodeGateOpen("qux:0:DebugIdentity", {kGrpcUrl1}));
+ ASSERT_TRUE(DebugIO::IsDebugNodeGateOpen("foo:0:DebugIdentity", {kGrpcUrl1}));
+ ASSERT_TRUE(DebugIO::IsDebugNodeGateOpen("bar:0:DebugIdentity", {kGrpcUrl1}));
+
+ // Wrong grpc:// debug URLs.
+ ASSERT_FALSE(
+ DebugIO::IsDebugNodeGateOpen("foo:0:DebugIdentity", {kGrpcUrl2}));
+ ASSERT_FALSE(
+ DebugIO::IsDebugNodeGateOpen("bar:0:DebugIdentity", {kGrpcUrl2}));
+
+ // file:// debug URLs are not subject to grpc gating.
+ ASSERT_TRUE(DebugIO::IsDebugNodeGateOpen("qux:0:DebugIdentity",
+ {"file:///tmp/tfdbg_1", kGrpcUrl1}));
+}
+
+TEST_F(GrpcDebugTest, TestGateDebugNodeOnMultipleEmptyEnabledSets) {
+ const string kGrpcUrl1 = "grpc://localhost:3333";
+ const string kGrpcUrl2 = "grpc://localhost:3334";
+ const string kGrpcUrl3 = "grpc://localhost:3335";
+
+ DebugGrpcIO::EnableWatchKey(kGrpcUrl1, "foo:0:DebugIdentity");
+ DebugGrpcIO::EnableWatchKey(kGrpcUrl2, "bar:0:DebugIdentity");
+ CreateEmptyEnabledSet(kGrpcUrl3);
+
+ ASSERT_TRUE(DebugIO::IsDebugNodeGateOpen("foo:0:DebugIdentity", {kGrpcUrl1}));
+ ASSERT_TRUE(DebugIO::IsDebugNodeGateOpen("bar:0:DebugIdentity", {kGrpcUrl2}));
+ ASSERT_FALSE(
+ DebugIO::IsDebugNodeGateOpen("foo:0:DebugIdentity", {kGrpcUrl2}));
+ ASSERT_FALSE(
+ DebugIO::IsDebugNodeGateOpen("bar:0:DebugIdentity", {kGrpcUrl1}));
+ ASSERT_FALSE(
+ DebugIO::IsDebugNodeGateOpen("foo:0:DebugIdentity", {kGrpcUrl3}));
+ ASSERT_FALSE(
+ DebugIO::IsDebugNodeGateOpen("bar:0:DebugIdentity", {kGrpcUrl3}));
+ ASSERT_TRUE(DebugIO::IsDebugNodeGateOpen("foo:0:DebugIdentity",
+ {kGrpcUrl1, kGrpcUrl2}));
+ ASSERT_TRUE(DebugIO::IsDebugNodeGateOpen("bar:0:DebugIdentity",
+ {kGrpcUrl1, kGrpcUrl2}));
+ ASSERT_TRUE(DebugIO::IsDebugNodeGateOpen("foo:0:DebugIdentity",
+ {kGrpcUrl1, kGrpcUrl3}));
+ ASSERT_FALSE(DebugIO::IsDebugNodeGateOpen("bar:0:DebugIdentity",
+ {kGrpcUrl1, kGrpcUrl3}));
+}
+
+TEST_F(GrpcDebugTest, TestGateDebugNodeOnNonEmptyEnabledSetAndEmptyURLs) {
+ DebugGrpcIO::EnableWatchKey("grpc://localhost:3333", "foo:0:DebugIdentity");
+
+ std::vector<string> debug_urls_1;
+ ASSERT_FALSE(
+ DebugIO::IsDebugNodeGateOpen("foo:1:DebugIdentity", debug_urls_1));
+}
+
+TEST_F(GrpcDebugTest, TestGateCopyNodeOnEmptyEnabledSet) {
+ const string kGrpcUrl1 = "grpc://localhost:3333";
+ const string kWatch1 = "foo:0:DebugIdentity";
+ CreateEmptyEnabledSet(kGrpcUrl1);
+
+ ASSERT_FALSE(DebugIO::IsCopyNodeGateOpen(
+ {DebugWatchAndURLSpec(kWatch1, kGrpcUrl1, true)}));
+ ASSERT_TRUE(DebugIO::IsCopyNodeGateOpen(
+ {DebugWatchAndURLSpec(kWatch1, kGrpcUrl1, false)}));
+
+ // file:// debug URLs are not subject to grpc gating.
+ ASSERT_TRUE(DebugIO::IsCopyNodeGateOpen(
+ {DebugWatchAndURLSpec("foo:0:DebugIdentity", kGrpcUrl1, true),
+ DebugWatchAndURLSpec("foo:0:DebugIdentity", "file:///tmp/tfdbg_1",
+ false)}));
+}
+
+TEST_F(GrpcDebugTest, TestGateCopyNodeOnNonEmptyEnabledSet) {
+ const string kGrpcUrl1 = "grpc://localhost:3333";
+ const string kGrpcUrl2 = "grpc://localhost:3334";
+ const string kWatch1 = "foo:0:DebugIdentity";
+ const string kWatch2 = "foo:1:DebugIdentity";
+ CreateEmptyEnabledSet(kGrpcUrl1);
+ CreateEmptyEnabledSet(kGrpcUrl2);
+ DebugGrpcIO::EnableWatchKey(kGrpcUrl1, kWatch1);
+
+ ASSERT_TRUE(DebugIO::IsCopyNodeGateOpen(
+ {DebugWatchAndURLSpec(kWatch1, kGrpcUrl1, true)}));
+
+ ASSERT_FALSE(DebugIO::IsCopyNodeGateOpen(
+ {DebugWatchAndURLSpec(kWatch1, kGrpcUrl2, true)}));
+ ASSERT_TRUE(DebugIO::IsCopyNodeGateOpen(
+ {DebugWatchAndURLSpec(kWatch1, kGrpcUrl2, false)}));
+
+ ASSERT_FALSE(DebugIO::IsCopyNodeGateOpen(
+ {DebugWatchAndURLSpec(kWatch2, kGrpcUrl1, true)}));
+ ASSERT_TRUE(DebugIO::IsCopyNodeGateOpen(
+ {DebugWatchAndURLSpec(kWatch2, kGrpcUrl1, false)}));
+
+ ASSERT_TRUE(DebugIO::IsCopyNodeGateOpen(
+ {DebugWatchAndURLSpec(kWatch1, kGrpcUrl1, true),
+ DebugWatchAndURLSpec(kWatch1, kGrpcUrl2, true)}));
+ ASSERT_TRUE(DebugIO::IsCopyNodeGateOpen(
+ {DebugWatchAndURLSpec(kWatch1, kGrpcUrl1, true),
+ DebugWatchAndURLSpec(kWatch2, kGrpcUrl2, true)}));
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/debug/debug_io_utils.cc b/tensorflow/core/debug/debug_io_utils.cc
index 4e185b37a8..f4208a0bbc 100644
--- a/tensorflow/core/debug/debug_io_utils.cc
+++ b/tensorflow/core/debug/debug_io_utils.cc
@@ -17,14 +17,12 @@ limitations under the License.
#include <vector>
-#if defined(PLATFORM_GOOGLE)
+#ifndef PLATFORM_WINDOWS
#include "grpc++/create_channel.h"
-#endif
-
-#if defined(PLATFORM_WINDOWS)
+#else
// winsock2.h is used in grpc, so Ws2_32.lib is needed
#pragma comment(lib,"Ws2_32.lib")
-#endif
+#endif // #ifndef PLATFORM_WINDOWS
#include "tensorflow/core/debug/debugger_event_metadata.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
@@ -37,10 +35,9 @@ limitations under the License.
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/util/event.pb.h"
-#define GRPC_OSS_UNIMPLEMENTED_ERROR \
- return errors::Unimplemented( \
- kGrpcURLScheme, \
- " debug URL scheme is not implemented in open source yet.")
+#define GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR \
+ return errors::Unimplemented( \
+ kGrpcURLScheme, " debug URL scheme is not implemented on Windows yet.")
namespace tensorflow {
@@ -234,7 +231,7 @@ string AppendTimestampToFilePath(const string& in, const uint64 timestamp) {
return out;
}
-#if defined(PLATFORM_GOOGLE)
+#ifndef PLATFORM_WINDOWS
// Publishes encoded GraphDef through a gRPC debugger stream, in chunks,
// conforming to the gRPC message size limit.
Status PublishEncodedGraphDefInChunks(const string& encoded_graph_def,
@@ -268,7 +265,7 @@ Status PublishEncodedGraphDefInChunks(const string& encoded_graph_def,
}
return Status::OK();
}
-#endif
+#endif // #ifndef PLATFORM_WINDOWS
} // namespace
@@ -393,7 +390,7 @@ Status DebugIO::PublishDebugMetadata(
Status status;
for (const string& url : debug_urls) {
if (str_util::Lowercase(url).find(kGrpcURLScheme) == 0) {
-#if defined(PLATFORM_GOOGLE)
+#ifndef PLATFORM_WINDOWS
Event grpc_event;
// Determine the path (if any) in the grpc:// URL, and add it as a field
@@ -411,7 +408,7 @@ Status DebugIO::PublishDebugMetadata(
status.Update(
DebugGrpcIO::SendEventProtoThroughGrpcStream(grpc_event, url));
#else
- GRPC_OSS_UNIMPLEMENTED_ERROR;
+ GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR;
#endif
} else if (str_util::Lowercase(url).find(kFileURLScheme) == 0) {
const string dump_root_dir = url.substr(strlen(kFileURLScheme));
@@ -450,7 +447,7 @@ Status DebugIO::PublishDebugTensor(const DebugNodeKey& debug_node_key,
fail_statuses.push_back(s);
}
} else if (str_util::Lowercase(url).find(kGrpcURLScheme) == 0) {
-#if defined(PLATFORM_GOOGLE)
+#ifndef PLATFORM_WINDOWS
Status s = DebugGrpcIO::SendTensorThroughGrpcStream(
debug_node_key, tensor, wall_time_us, url, gated_grpc);
@@ -459,7 +456,7 @@ Status DebugIO::PublishDebugTensor(const DebugNodeKey& debug_node_key,
fail_statuses.push_back(s);
}
#else
- GRPC_OSS_UNIMPLEMENTED_ERROR;
+ GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR;
#endif
} else {
return Status(error::UNAVAILABLE,
@@ -519,11 +516,11 @@ Status DebugIO::PublishGraph(const Graph& graph, const string& device_name,
status.Update(
DebugFileIO::DumpEventProtoToFile(event, dump_root_dir, file_name));
} else if (debug_url.find(kGrpcURLScheme) == 0) {
-#if defined(PLATFORM_GOOGLE)
+#ifndef PLATFORM_WINDOWS
status.Update(PublishEncodedGraphDefInChunks(buf, device_name, now_micros,
debug_url));
#else
- GRPC_OSS_UNIMPLEMENTED_ERROR;
+ GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR;
#endif
}
}
@@ -534,7 +531,7 @@ Status DebugIO::PublishGraph(const Graph& graph, const string& device_name,
// static
bool DebugIO::IsCopyNodeGateOpen(
const std::vector<DebugWatchAndURLSpec>& specs) {
-#if defined(PLATFORM_GOOGLE)
+#ifndef PLATFORM_WINDOWS
for (const DebugWatchAndURLSpec& spec : specs) {
if (!spec.gated_grpc || spec.url.compare(0, strlen(DebugIO::kGrpcURLScheme),
DebugIO::kGrpcURLScheme)) {
@@ -554,7 +551,7 @@ bool DebugIO::IsCopyNodeGateOpen(
// static
bool DebugIO::IsDebugNodeGateOpen(const string& watch_key,
const std::vector<string>& debug_urls) {
-#if defined(PLATFORM_GOOGLE)
+#ifndef PLATFORM_WINDOWS
for (const string& debug_url : debug_urls) {
if (debug_url.compare(0, strlen(DebugIO::kGrpcURLScheme),
DebugIO::kGrpcURLScheme)) {
@@ -574,7 +571,7 @@ bool DebugIO::IsDebugNodeGateOpen(const string& watch_key,
// static
bool DebugIO::IsDebugURLGateOpen(const string& watch_key,
const string& debug_url) {
-#if defined(PLATFORM_GOOGLE)
+#ifndef PLATFORM_WINDOWS
if (debug_url.find(kGrpcURLScheme) != 0) {
return true;
} else {
@@ -588,10 +585,10 @@ bool DebugIO::IsDebugURLGateOpen(const string& watch_key,
// static
Status DebugIO::CloseDebugURL(const string& debug_url) {
if (debug_url.find(DebugIO::kGrpcURLScheme) == 0) {
-#if defined(PLATFORM_GOOGLE)
+#ifndef PLATFORM_WINDOWS
return DebugGrpcIO::CloseGrpcStream(debug_url);
#else
- GRPC_OSS_UNIMPLEMENTED_ERROR;
+ GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR;
#endif
} else {
// No-op for non-gRPC URLs.
@@ -703,7 +700,7 @@ Status DebugFileIO::RecursiveCreateDir(Env* env, const string& dir) {
}
}
-#if defined(PLATFORM_GOOGLE)
+#ifndef PLATFORM_WINDOWS
DebugGrpcChannel::DebugGrpcChannel(const string& server_stream_addr)
: server_stream_addr_(server_stream_addr),
url_(strings::StrCat(DebugIO::kGrpcURLScheme, server_stream_addr)) {}
@@ -926,6 +923,6 @@ void DebugGrpcIO::CreateEmptyEnabledSet(const string& grpc_debug_url) {
}
}
-#endif // #if defined(PLATFORM_GOOGLE)
+#endif // #ifndef PLATFORM_WINDOWS
} // namespace tensorflow
diff --git a/tensorflow/core/debug/debug_io_utils.h b/tensorflow/core/debug/debug_io_utils.h
index 98effed425..caf9f5341d 100644
--- a/tensorflow/core/debug/debug_io_utils.h
+++ b/tensorflow/core/debug/debug_io_utils.h
@@ -221,7 +221,7 @@ class DebugFileIO {
// TODO(cais): Support grpc:// debug URLs in open source once Python grpc
// genrule becomes available. See b/23796275.
-#if defined(PLATFORM_GOOGLE)
+#ifndef PLATFORM_WINDOWS
#include "tensorflow/core/debug/debug_service.grpc.pb.h"
namespace tensorflow {
@@ -345,6 +345,6 @@ class DebugGrpcIO {
};
} // namespace tensorflow
-#endif // #if defined(PLATFORM_GOOGLE)
+#endif // #ifndef(PLATFORM_WINDOWS)
#endif // TENSORFLOW_DEBUG_IO_UTILS_H_