diff options
author | Shanqing Cai <cais@google.com> | 2017-07-19 07:23:53 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-07-19 07:28:28 -0700 |
commit | 41803db36d4f4a3239bd81e5d460eb0e6e2eea88 (patch) | |
tree | 4c677032f4a6b607ec1bfa968f2a36efc01ba04d /tensorflow/core/debug | |
parent | ac7530e54ddaa17c17e070e5a141002c43b86275 (diff) |
tfdbg: open-source C++ and Python libraries of gRPC debugger mode
with the exception of Windows
the session_debug_grpc_test is temporarily disabled on mac pending pip install of futures and grpcio on all slaves.
PiperOrigin-RevId: 162482541
Diffstat (limited to 'tensorflow/core/debug')
-rw-r--r-- | tensorflow/core/debug/BUILD | 42 | ||||
-rw-r--r-- | tensorflow/core/debug/debug_grpc_io_utils_test.cc | 432 | ||||
-rw-r--r-- | tensorflow/core/debug/debug_io_utils.cc | 45 | ||||
-rw-r--r-- | tensorflow/core/debug/debug_io_utils.h | 4 |
4 files changed, 489 insertions, 34 deletions
diff --git a/tensorflow/core/debug/BUILD b/tensorflow/core/debug/BUILD index dd26177d25..971576698e 100644 --- a/tensorflow/core/debug/BUILD +++ b/tensorflow/core/debug/BUILD @@ -15,7 +15,6 @@ package( licenses(["notice"]) # Apache 2.0 -# Google-internal rules omitted. load( "//tensorflow:tensorflow.bzl", "check_deps", @@ -28,6 +27,7 @@ load( load( "//tensorflow/core:platform/default/build_config.bzl", "tf_kernel_tests_linkstatic", + "tf_proto_library", "tf_proto_library_cc", ) load( @@ -42,13 +42,25 @@ check_deps( deps = ["//tensorflow/core:tensorflow"], ) -tf_proto_library_cc( +tf_proto_library( name = "debug_service_proto", - srcs = ["debug_service.proto"], + srcs = [ + "debug_service.proto", + ], has_services = 1, cc_api_version = 2, cc_grpc_version = 1, - protodeps = ["//tensorflow/core:protos_all"], + protodeps = [ + ":debugger_event_metadata_proto", + "//tensorflow/core:protos_all", + ], + visibility = ["//tensorflow:__subpackages__"], +) + +tf_proto_library( + name = "debugger_event_metadata_proto", + srcs = ["debugger_event_metadata.proto"], + cc_api_version = 2, ) cc_library( @@ -196,6 +208,7 @@ tf_cc_test( deps = [ ":debug_grpc_testlib", ":debug_io_utils", + ":debug_service_proto_cc", ":debugger_event_metadata_proto_cc", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", @@ -254,10 +267,23 @@ tf_cc_test( ], ) -tf_proto_library_cc( - name = "debugger_event_metadata_proto", - srcs = ["debugger_event_metadata.proto"], - cc_api_version = 2, +tf_cc_test( + name = "debug_grpc_io_utils_test", + size = "small", + srcs = ["debug_grpc_io_utils_test.cc"], + linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":debug_graph_utils", + ":debug_grpc_testlib", + ":debug_io_utils", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], ) # TODO(cais): Add the following back in when tfdbg is supported on Android. diff --git a/tensorflow/core/debug/debug_grpc_io_utils_test.cc b/tensorflow/core/debug/debug_grpc_io_utils_test.cc new file mode 100644 index 0000000000..6510424182 --- /dev/null +++ b/tensorflow/core/debug/debug_grpc_io_utils_test.cc @@ -0,0 +1,432 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/debug/debug_graph_utils.h" +#include "tensorflow/core/debug/debug_grpc_testlib.h" +#include "tensorflow/core/debug/debug_io_utils.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/tracing.h" + +namespace tensorflow { + +class GrpcDebugTest : public ::testing::Test { + protected: + struct ServerData { + int port; + string url; + std::unique_ptr<test::TestEventListenerImpl> server; + std::unique_ptr<thread::ThreadPool> thread_pool; + }; + + void SetUp() override { + ClearEnabledWatchKeys(); + SetUpInProcessServer(&server_data_, 0); + } + + void TearDown() override { TearDownInProcessServer(&server_data_); } + + void SetUpInProcessServer(ServerData* server_data, + int64 server_start_delay_micros) { + server_data->port = testing::PickUnusedPortOrDie(); + server_data->url = strings::StrCat("grpc://localhost:", server_data->port); + server_data->server.reset(new test::TestEventListenerImpl()); + + server_data->thread_pool.reset( + new thread::ThreadPool(Env::Default(), "test_server", 1)); + server_data->thread_pool->Schedule( + [server_data, server_start_delay_micros]() { + Env::Default()->SleepForMicroseconds(server_start_delay_micros); + server_data->server->RunServer(server_data->port); + }); + } + + void TearDownInProcessServer(ServerData* server_data) { + server_data->server->StopServer(); + server_data->thread_pool.reset(); + } + + void ClearEnabledWatchKeys() { DebugGrpcIO::ClearEnabledWatchKeys(); } + + void CreateEmptyEnabledSet(const string& grpc_debug_url) { + DebugGrpcIO::CreateEmptyEnabledSet(grpc_debug_url); + } + + const int64 GetChannelConnectionTimeoutMicros() { + return DebugGrpcIO::channel_connection_timeout_micros; + } + + void SetChannelConnectionTimeoutMicros(const int64 timeout) { + DebugGrpcIO::channel_connection_timeout_micros = timeout; + } + + ServerData server_data_; +}; + +TEST_F(GrpcDebugTest, ConnectionTimeoutWorks) { + // Use a short timeout so the test won't take too long. + const int64 kOriginalTimeoutMicros = GetChannelConnectionTimeoutMicros(); + const int64 kShortTimeoutMicros = 500 * 1000; + SetChannelConnectionTimeoutMicros(kShortTimeoutMicros); + ASSERT_EQ(kShortTimeoutMicros, GetChannelConnectionTimeoutMicros()); + + const string& kInvalidGrpcUrl = + strings::StrCat("grpc://localhost:", testing::PickUnusedPortOrDie()); + Tensor tensor(DT_FLOAT, TensorShape({1, 1})); + tensor.flat<float>()(0) = 42.0; + Status publish_status = DebugIO::PublishDebugTensor( + DebugNodeKey("/job:localhost/replica:0/task:0/cpu:0", "foo_tensor", 0, + "DebugIdentity"), + tensor, Env::Default()->NowMicros(), {kInvalidGrpcUrl}); + SetChannelConnectionTimeoutMicros(kOriginalTimeoutMicros); + TF_ASSERT_OK(DebugIO::CloseDebugURL(kInvalidGrpcUrl)); + + ASSERT_FALSE(publish_status.ok()); + const string expected_error_msg = strings::StrCat( + "Failed to connect to gRPC channel at ", kInvalidGrpcUrl.substr(7), + " within a timeout of ", kShortTimeoutMicros / 1e6, " s"); + ASSERT_NE(string::npos, + publish_status.error_message().find(expected_error_msg)); +} + +TEST_F(GrpcDebugTest, ConnectionToDelayedStartingServerWorks) { + ServerData server_data; + // Server start will be delayed for 1 second. + SetUpInProcessServer(&server_data, 1 * 1000 * 1000); + + Tensor tensor(DT_FLOAT, TensorShape({1, 1})); + tensor.flat<float>()(0) = 42.0; + const DebugNodeKey kDebugNodeKey("/job:localhost/replica:0/task:0/cpu:0", + "foo_tensor", 0, "DebugIdentity"); + Status publish_status = DebugIO::PublishDebugTensor( + kDebugNodeKey, tensor, Env::Default()->NowMicros(), {server_data.url}); + ASSERT_TRUE(publish_status.ok()); + TF_ASSERT_OK(DebugIO::CloseDebugURL(server_data.url)); + + ASSERT_EQ(1, server_data.server->node_names.size()); + ASSERT_EQ(1, server_data.server->output_slots.size()); + ASSERT_EQ(1, server_data.server->debug_ops.size()); + EXPECT_EQ(kDebugNodeKey.device_name, server_data.server->device_names[0]); + EXPECT_EQ(kDebugNodeKey.node_name, server_data.server->node_names[0]); + EXPECT_EQ(kDebugNodeKey.output_slot, server_data.server->output_slots[0]); + EXPECT_EQ(kDebugNodeKey.debug_op, server_data.server->debug_ops[0]); + TearDownInProcessServer(&server_data); +} + +TEST_F(GrpcDebugTest, SendSingleDebugTensorViaGrpcTest) { + Tensor tensor(DT_FLOAT, TensorShape({1, 1})); + tensor.flat<float>()(0) = 42.0; + const DebugNodeKey kDebugNodeKey("/job:localhost/replica:0/task:0/cpu:0", + "foo_tensor", 0, "DebugIdentity"); + TF_ASSERT_OK(DebugIO::PublishDebugTensor( + kDebugNodeKey, tensor, Env::Default()->NowMicros(), {server_data_.url})); + TF_ASSERT_OK(DebugIO::CloseDebugURL(server_data_.url)); + + // Verify that the expected debug tensor sending happened. + ASSERT_EQ(1, server_data_.server->node_names.size()); + ASSERT_EQ(1, server_data_.server->output_slots.size()); + ASSERT_EQ(1, server_data_.server->debug_ops.size()); + EXPECT_EQ(kDebugNodeKey.device_name, server_data_.server->device_names[0]); + EXPECT_EQ(kDebugNodeKey.node_name, server_data_.server->node_names[0]); + EXPECT_EQ(kDebugNodeKey.output_slot, server_data_.server->output_slots[0]); + EXPECT_EQ(kDebugNodeKey.debug_op, server_data_.server->debug_ops[0]); +} + +TEST_F(GrpcDebugTest, SendDebugTensorWithLargeStringAtIndex0ViaGrpcTest) { + Tensor tensor(DT_STRING, TensorShape({1, 1})); + tensor.flat<string>()(0) = string(5000 * 1024, 'A'); + const DebugNodeKey kDebugNodeKey("/job:localhost/replica:0/task:0/cpu:0", + "foo_tensor", 0, "DebugIdentity"); + const Status status = DebugIO::PublishDebugTensor( + kDebugNodeKey, tensor, Env::Default()->NowMicros(), {server_data_.url}); + ASSERT_FALSE(status.ok()); + ASSERT_NE(status.error_message().find("string value at index 0 from debug " + "node foo_tensor:0:DebugIdentity does " + "not fit gRPC message size limit"), + string::npos); + TF_ASSERT_OK(DebugIO::CloseDebugURL(server_data_.url)); +} + +TEST_F(GrpcDebugTest, SendDebugTensorWithLargeStringAtIndex1ViaGrpcTest) { + Tensor tensor(DT_STRING, TensorShape({1, 2})); + tensor.flat<string>()(0) = "A"; + tensor.flat<string>()(1) = string(5000 * 1024, 'A'); + const DebugNodeKey kDebugNodeKey("/job:localhost/replica:0/task:0/cpu:0", + "foo_tensor", 0, "DebugIdentity"); + const Status status = DebugIO::PublishDebugTensor( + kDebugNodeKey, tensor, Env::Default()->NowMicros(), {server_data_.url}); + ASSERT_FALSE(status.ok()); + ASSERT_NE(status.error_message().find("string value at index 1 from debug " + "node foo_tensor:0:DebugIdentity does " + "not fit gRPC message size limit"), + string::npos); + TF_ASSERT_OK(DebugIO::CloseDebugURL(server_data_.url)); +} + +TEST_F(GrpcDebugTest, SendMultipleDebugTensorsSynchronizedViaGrpcTest) { + const int32 kSends = 4; + + // Prepare the tensors to sent. + std::vector<Tensor> tensors; + for (int i = 0; i < kSends; ++i) { + Tensor tensor(DT_INT32, TensorShape({1, 1})); + tensor.flat<int>()(0) = i * i; + tensors.push_back(tensor); + } + + thread::ThreadPool* tp = + new thread::ThreadPool(Env::Default(), "grpc_debug_test", kSends); + + mutex mu; + Notification all_done; + int tensor_count GUARDED_BY(mu) = 0; + std::vector<Status> statuses GUARDED_BY(mu); + + const std::vector<string> urls({server_data_.url}); + + // Set up the concurrent tasks of sending Tensors via an Event stream to the + // server. + auto fn = [this, &mu, &tensor_count, &tensors, &statuses, &all_done, + &urls]() { + int this_count; + { + mutex_lock l(mu); + this_count = tensor_count++; + } + + // Different concurrent tasks will send different tensors. + const uint64 wall_time = Env::Default()->NowMicros(); + Status publish_status = DebugIO::PublishDebugTensor( + DebugNodeKey("/job:localhost/replica:0/task:0/cpu:0", + strings::StrCat("synchronized_node_", this_count), 0, + "DebugIdentity"), + tensors[this_count], wall_time, urls); + + { + mutex_lock l(mu); + statuses.push_back(publish_status); + if (this_count == kSends - 1 && !all_done.HasBeenNotified()) { + all_done.Notify(); + } + } + }; + + // Schedule the concurrent tasks. + for (int i = 0; i < kSends; ++i) { + tp->Schedule(fn); + } + + // Wait for all client tasks to finish. + all_done.WaitForNotification(); + delete tp; + + // Close the debug gRPC stream. + Status close_status = DebugIO::CloseDebugURL(server_data_.url); + ASSERT_TRUE(close_status.ok()); + + // Check all statuses from the PublishDebugTensor calls(). + for (const Status& status : statuses) { + TF_ASSERT_OK(status); + } + + // One prep tensor plus kSends concurrent tensors are expected. + ASSERT_EQ(kSends, server_data_.server->node_names.size()); + for (size_t i = 0; i < server_data_.server->node_names.size(); ++i) { + std::vector<string> items = + str_util::Split(server_data_.server->node_names[i], '_'); + int tensor_index; + strings::safe_strto32(items[2], &tensor_index); + + ASSERT_EQ(TensorShape({1, 1}), + server_data_.server->debug_tensors[i].shape()); + ASSERT_EQ(tensor_index * tensor_index, + server_data_.server->debug_tensors[i].flat<int>()(0)); + } +} + +TEST_F(GrpcDebugTest, SendeDebugTensorsThroughMultipleRoundsUsingGrpcGating) { + // Prepare the tensor to send. + const DebugNodeKey kDebugNodeKey("/job:localhost/replica:0/task:0/cpu:0", + "test_namescope/test_node", 0, + "DebugIdentity"); + Tensor tensor(DT_INT32, TensorShape({1, 1})); + tensor.flat<int>()(0) = 42; + + const std::vector<string> urls({server_data_.url}); + for (int i = 0; i < 3; ++i) { + server_data_.server->ClearReceivedDebugData(); + const uint64 wall_time = Env::Default()->NowMicros(); + + // On the 1st send (i == 0), gating is disabled, so data should be sent. + // On the 2nd send (i == 1), gating is enabled, and the server has enabled + // the watch key in the previous send, so data should be sent. + // On the 3rd send (i == 2), gating is enabled, but the server has disabled + // the watch key in the previous send, so data should not be sent. + const bool enable_gated_grpc = (i != 0); + TF_ASSERT_OK(DebugIO::PublishDebugTensor(kDebugNodeKey, tensor, wall_time, + urls, enable_gated_grpc)); + + server_data_.server->RequestDebugOpStateChangeAtNextStream(i == 0, + kDebugNodeKey); + + // Close the debug gRPC stream. + Status close_status = DebugIO::CloseDebugURL(server_data_.url); + ASSERT_TRUE(close_status.ok()); + + // Check dumped files according to the expected gating results. + if (i < 2) { + ASSERT_EQ(1, server_data_.server->node_names.size()); + ASSERT_EQ(1, server_data_.server->output_slots.size()); + ASSERT_EQ(1, server_data_.server->debug_ops.size()); + EXPECT_EQ(kDebugNodeKey.device_name, + server_data_.server->device_names[0]); + EXPECT_EQ(kDebugNodeKey.node_name, server_data_.server->node_names[0]); + EXPECT_EQ(kDebugNodeKey.output_slot, + server_data_.server->output_slots[0]); + EXPECT_EQ(kDebugNodeKey.debug_op, server_data_.server->debug_ops[0]); + } else { + ASSERT_EQ(0, server_data_.server->node_names.size()); + } + } +} + +TEST_F(GrpcDebugTest, TestGateDebugNodeOnEmptyEnabledSet) { + CreateEmptyEnabledSet("grpc://localhost:3333"); + + ASSERT_FALSE(DebugIO::IsDebugNodeGateOpen("foo:0:DebugIdentity", + {"grpc://localhost:3333"})); + + // file:// debug URLs are not subject to grpc gating. + ASSERT_TRUE(DebugIO::IsDebugNodeGateOpen( + "foo:0:DebugIdentity", {"grpc://localhost:3333", "file:///tmp/tfdbg_1"})); +} + +TEST_F(GrpcDebugTest, TestGateDebugNodeOnNonEmptyEnabledSet) { + const string kGrpcUrl1 = "grpc://localhost:3333"; + const string kGrpcUrl2 = "grpc://localhost:3334"; + + DebugGrpcIO::EnableWatchKey(kGrpcUrl1, "foo:0:DebugIdentity"); + DebugGrpcIO::EnableWatchKey(kGrpcUrl1, "bar:0:DebugIdentity"); + + ASSERT_FALSE( + DebugIO::IsDebugNodeGateOpen("foo:1:DebugIdentity", {kGrpcUrl1})); + ASSERT_FALSE( + DebugIO::IsDebugNodeGateOpen("foo:1:DebugNumericSummary", {kGrpcUrl1})); + ASSERT_FALSE( + DebugIO::IsDebugNodeGateOpen("qux:0:DebugIdentity", {kGrpcUrl1})); + ASSERT_TRUE(DebugIO::IsDebugNodeGateOpen("foo:0:DebugIdentity", {kGrpcUrl1})); + ASSERT_TRUE(DebugIO::IsDebugNodeGateOpen("bar:0:DebugIdentity", {kGrpcUrl1})); + + // Wrong grpc:// debug URLs. + ASSERT_FALSE( + DebugIO::IsDebugNodeGateOpen("foo:0:DebugIdentity", {kGrpcUrl2})); + ASSERT_FALSE( + DebugIO::IsDebugNodeGateOpen("bar:0:DebugIdentity", {kGrpcUrl2})); + + // file:// debug URLs are not subject to grpc gating. + ASSERT_TRUE(DebugIO::IsDebugNodeGateOpen("qux:0:DebugIdentity", + {"file:///tmp/tfdbg_1", kGrpcUrl1})); +} + +TEST_F(GrpcDebugTest, TestGateDebugNodeOnMultipleEmptyEnabledSets) { + const string kGrpcUrl1 = "grpc://localhost:3333"; + const string kGrpcUrl2 = "grpc://localhost:3334"; + const string kGrpcUrl3 = "grpc://localhost:3335"; + + DebugGrpcIO::EnableWatchKey(kGrpcUrl1, "foo:0:DebugIdentity"); + DebugGrpcIO::EnableWatchKey(kGrpcUrl2, "bar:0:DebugIdentity"); + CreateEmptyEnabledSet(kGrpcUrl3); + + ASSERT_TRUE(DebugIO::IsDebugNodeGateOpen("foo:0:DebugIdentity", {kGrpcUrl1})); + ASSERT_TRUE(DebugIO::IsDebugNodeGateOpen("bar:0:DebugIdentity", {kGrpcUrl2})); + ASSERT_FALSE( + DebugIO::IsDebugNodeGateOpen("foo:0:DebugIdentity", {kGrpcUrl2})); + ASSERT_FALSE( + DebugIO::IsDebugNodeGateOpen("bar:0:DebugIdentity", {kGrpcUrl1})); + ASSERT_FALSE( + DebugIO::IsDebugNodeGateOpen("foo:0:DebugIdentity", {kGrpcUrl3})); + ASSERT_FALSE( + DebugIO::IsDebugNodeGateOpen("bar:0:DebugIdentity", {kGrpcUrl3})); + ASSERT_TRUE(DebugIO::IsDebugNodeGateOpen("foo:0:DebugIdentity", + {kGrpcUrl1, kGrpcUrl2})); + ASSERT_TRUE(DebugIO::IsDebugNodeGateOpen("bar:0:DebugIdentity", + {kGrpcUrl1, kGrpcUrl2})); + ASSERT_TRUE(DebugIO::IsDebugNodeGateOpen("foo:0:DebugIdentity", + {kGrpcUrl1, kGrpcUrl3})); + ASSERT_FALSE(DebugIO::IsDebugNodeGateOpen("bar:0:DebugIdentity", + {kGrpcUrl1, kGrpcUrl3})); +} + +TEST_F(GrpcDebugTest, TestGateDebugNodeOnNonEmptyEnabledSetAndEmptyURLs) { + DebugGrpcIO::EnableWatchKey("grpc://localhost:3333", "foo:0:DebugIdentity"); + + std::vector<string> debug_urls_1; + ASSERT_FALSE( + DebugIO::IsDebugNodeGateOpen("foo:1:DebugIdentity", debug_urls_1)); +} + +TEST_F(GrpcDebugTest, TestGateCopyNodeOnEmptyEnabledSet) { + const string kGrpcUrl1 = "grpc://localhost:3333"; + const string kWatch1 = "foo:0:DebugIdentity"; + CreateEmptyEnabledSet(kGrpcUrl1); + + ASSERT_FALSE(DebugIO::IsCopyNodeGateOpen( + {DebugWatchAndURLSpec(kWatch1, kGrpcUrl1, true)})); + ASSERT_TRUE(DebugIO::IsCopyNodeGateOpen( + {DebugWatchAndURLSpec(kWatch1, kGrpcUrl1, false)})); + + // file:// debug URLs are not subject to grpc gating. + ASSERT_TRUE(DebugIO::IsCopyNodeGateOpen( + {DebugWatchAndURLSpec("foo:0:DebugIdentity", kGrpcUrl1, true), + DebugWatchAndURLSpec("foo:0:DebugIdentity", "file:///tmp/tfdbg_1", + false)})); +} + +TEST_F(GrpcDebugTest, TestGateCopyNodeOnNonEmptyEnabledSet) { + const string kGrpcUrl1 = "grpc://localhost:3333"; + const string kGrpcUrl2 = "grpc://localhost:3334"; + const string kWatch1 = "foo:0:DebugIdentity"; + const string kWatch2 = "foo:1:DebugIdentity"; + CreateEmptyEnabledSet(kGrpcUrl1); + CreateEmptyEnabledSet(kGrpcUrl2); + DebugGrpcIO::EnableWatchKey(kGrpcUrl1, kWatch1); + + ASSERT_TRUE(DebugIO::IsCopyNodeGateOpen( + {DebugWatchAndURLSpec(kWatch1, kGrpcUrl1, true)})); + + ASSERT_FALSE(DebugIO::IsCopyNodeGateOpen( + {DebugWatchAndURLSpec(kWatch1, kGrpcUrl2, true)})); + ASSERT_TRUE(DebugIO::IsCopyNodeGateOpen( + {DebugWatchAndURLSpec(kWatch1, kGrpcUrl2, false)})); + + ASSERT_FALSE(DebugIO::IsCopyNodeGateOpen( + {DebugWatchAndURLSpec(kWatch2, kGrpcUrl1, true)})); + ASSERT_TRUE(DebugIO::IsCopyNodeGateOpen( + {DebugWatchAndURLSpec(kWatch2, kGrpcUrl1, false)})); + + ASSERT_TRUE(DebugIO::IsCopyNodeGateOpen( + {DebugWatchAndURLSpec(kWatch1, kGrpcUrl1, true), + DebugWatchAndURLSpec(kWatch1, kGrpcUrl2, true)})); + ASSERT_TRUE(DebugIO::IsCopyNodeGateOpen( + {DebugWatchAndURLSpec(kWatch1, kGrpcUrl1, true), + DebugWatchAndURLSpec(kWatch2, kGrpcUrl2, true)})); +} + +} // namespace tensorflow diff --git a/tensorflow/core/debug/debug_io_utils.cc b/tensorflow/core/debug/debug_io_utils.cc index 4e185b37a8..f4208a0bbc 100644 --- a/tensorflow/core/debug/debug_io_utils.cc +++ b/tensorflow/core/debug/debug_io_utils.cc @@ -17,14 +17,12 @@ limitations under the License. #include <vector> -#if defined(PLATFORM_GOOGLE) +#ifndef PLATFORM_WINDOWS #include "grpc++/create_channel.h" -#endif - -#if defined(PLATFORM_WINDOWS) +#else // winsock2.h is used in grpc, so Ws2_32.lib is needed #pragma comment(lib,"Ws2_32.lib") -#endif +#endif // #ifndef PLATFORM_WINDOWS #include "tensorflow/core/debug/debugger_event_metadata.pb.h" #include "tensorflow/core/framework/graph.pb.h" @@ -37,10 +35,9 @@ limitations under the License. #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/util/event.pb.h" -#define GRPC_OSS_UNIMPLEMENTED_ERROR \ - return errors::Unimplemented( \ - kGrpcURLScheme, \ - " debug URL scheme is not implemented in open source yet.") +#define GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR \ + return errors::Unimplemented( \ + kGrpcURLScheme, " debug URL scheme is not implemented on Windows yet.") namespace tensorflow { @@ -234,7 +231,7 @@ string AppendTimestampToFilePath(const string& in, const uint64 timestamp) { return out; } -#if defined(PLATFORM_GOOGLE) +#ifndef PLATFORM_WINDOWS // Publishes encoded GraphDef through a gRPC debugger stream, in chunks, // conforming to the gRPC message size limit. Status PublishEncodedGraphDefInChunks(const string& encoded_graph_def, @@ -268,7 +265,7 @@ Status PublishEncodedGraphDefInChunks(const string& encoded_graph_def, } return Status::OK(); } -#endif +#endif // #ifndef PLATFORM_WINDOWS } // namespace @@ -393,7 +390,7 @@ Status DebugIO::PublishDebugMetadata( Status status; for (const string& url : debug_urls) { if (str_util::Lowercase(url).find(kGrpcURLScheme) == 0) { -#if defined(PLATFORM_GOOGLE) +#ifndef PLATFORM_WINDOWS Event grpc_event; // Determine the path (if any) in the grpc:// URL, and add it as a field @@ -411,7 +408,7 @@ Status DebugIO::PublishDebugMetadata( status.Update( DebugGrpcIO::SendEventProtoThroughGrpcStream(grpc_event, url)); #else - GRPC_OSS_UNIMPLEMENTED_ERROR; + GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR; #endif } else if (str_util::Lowercase(url).find(kFileURLScheme) == 0) { const string dump_root_dir = url.substr(strlen(kFileURLScheme)); @@ -450,7 +447,7 @@ Status DebugIO::PublishDebugTensor(const DebugNodeKey& debug_node_key, fail_statuses.push_back(s); } } else if (str_util::Lowercase(url).find(kGrpcURLScheme) == 0) { -#if defined(PLATFORM_GOOGLE) +#ifndef PLATFORM_WINDOWS Status s = DebugGrpcIO::SendTensorThroughGrpcStream( debug_node_key, tensor, wall_time_us, url, gated_grpc); @@ -459,7 +456,7 @@ Status DebugIO::PublishDebugTensor(const DebugNodeKey& debug_node_key, fail_statuses.push_back(s); } #else - GRPC_OSS_UNIMPLEMENTED_ERROR; + GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR; #endif } else { return Status(error::UNAVAILABLE, @@ -519,11 +516,11 @@ Status DebugIO::PublishGraph(const Graph& graph, const string& device_name, status.Update( DebugFileIO::DumpEventProtoToFile(event, dump_root_dir, file_name)); } else if (debug_url.find(kGrpcURLScheme) == 0) { -#if defined(PLATFORM_GOOGLE) +#ifndef PLATFORM_WINDOWS status.Update(PublishEncodedGraphDefInChunks(buf, device_name, now_micros, debug_url)); #else - GRPC_OSS_UNIMPLEMENTED_ERROR; + GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR; #endif } } @@ -534,7 +531,7 @@ Status DebugIO::PublishGraph(const Graph& graph, const string& device_name, // static bool DebugIO::IsCopyNodeGateOpen( const std::vector<DebugWatchAndURLSpec>& specs) { -#if defined(PLATFORM_GOOGLE) +#ifndef PLATFORM_WINDOWS for (const DebugWatchAndURLSpec& spec : specs) { if (!spec.gated_grpc || spec.url.compare(0, strlen(DebugIO::kGrpcURLScheme), DebugIO::kGrpcURLScheme)) { @@ -554,7 +551,7 @@ bool DebugIO::IsCopyNodeGateOpen( // static bool DebugIO::IsDebugNodeGateOpen(const string& watch_key, const std::vector<string>& debug_urls) { -#if defined(PLATFORM_GOOGLE) +#ifndef PLATFORM_WINDOWS for (const string& debug_url : debug_urls) { if (debug_url.compare(0, strlen(DebugIO::kGrpcURLScheme), DebugIO::kGrpcURLScheme)) { @@ -574,7 +571,7 @@ bool DebugIO::IsDebugNodeGateOpen(const string& watch_key, // static bool DebugIO::IsDebugURLGateOpen(const string& watch_key, const string& debug_url) { -#if defined(PLATFORM_GOOGLE) +#ifndef PLATFORM_WINDOWS if (debug_url.find(kGrpcURLScheme) != 0) { return true; } else { @@ -588,10 +585,10 @@ bool DebugIO::IsDebugURLGateOpen(const string& watch_key, // static Status DebugIO::CloseDebugURL(const string& debug_url) { if (debug_url.find(DebugIO::kGrpcURLScheme) == 0) { -#if defined(PLATFORM_GOOGLE) +#ifndef PLATFORM_WINDOWS return DebugGrpcIO::CloseGrpcStream(debug_url); #else - GRPC_OSS_UNIMPLEMENTED_ERROR; + GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR; #endif } else { // No-op for non-gRPC URLs. @@ -703,7 +700,7 @@ Status DebugFileIO::RecursiveCreateDir(Env* env, const string& dir) { } } -#if defined(PLATFORM_GOOGLE) +#ifndef PLATFORM_WINDOWS DebugGrpcChannel::DebugGrpcChannel(const string& server_stream_addr) : server_stream_addr_(server_stream_addr), url_(strings::StrCat(DebugIO::kGrpcURLScheme, server_stream_addr)) {} @@ -926,6 +923,6 @@ void DebugGrpcIO::CreateEmptyEnabledSet(const string& grpc_debug_url) { } } -#endif // #if defined(PLATFORM_GOOGLE) +#endif // #ifndef PLATFORM_WINDOWS } // namespace tensorflow diff --git a/tensorflow/core/debug/debug_io_utils.h b/tensorflow/core/debug/debug_io_utils.h index 98effed425..caf9f5341d 100644 --- a/tensorflow/core/debug/debug_io_utils.h +++ b/tensorflow/core/debug/debug_io_utils.h @@ -221,7 +221,7 @@ class DebugFileIO { // TODO(cais): Support grpc:// debug URLs in open source once Python grpc // genrule becomes available. See b/23796275. -#if defined(PLATFORM_GOOGLE) +#ifndef PLATFORM_WINDOWS #include "tensorflow/core/debug/debug_service.grpc.pb.h" namespace tensorflow { @@ -345,6 +345,6 @@ class DebugGrpcIO { }; } // namespace tensorflow -#endif // #if defined(PLATFORM_GOOGLE) +#endif // #ifndef(PLATFORM_WINDOWS) #endif // TENSORFLOW_DEBUG_IO_UTILS_H_ |