aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/debug
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2017-09-20 10:30:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-20 10:33:38 -0700
commit639661f1a7ddb8c82898d0b4247bd1892f03c7ae (patch)
tree28c9f873c8514ba0f80ca7d6d87da49461cea94b /tensorflow/core/debug
parent2a4ddfb229a6b890624792fff630cc71a33ce71d (diff)
tfdbg: fix a few bugs in grpc_debug_server.py
1. Remove an extraneous EventReply yielding from core metadata case. This bug prevented request_watch() and request_unwatch() calls from taking effect from on_core_metadata() of a stream handler. Covering this fix with a new option in start_server_on_separate_thread in grpc_debug_test_server.py and a new test: SessionDebugGrpcGatingTest.testToggleWatchesOnCoreMetadata in session_debug_grpc_test.py. 2. Fix a typo related to op state: READ_WRITE --> READ_ONLY. 3. Fix a race condition in DebugGrpcChannel::ReadEventReply by moving the mutex_lock to the right place. This prevented multiple breakpoints from working, but did not affect multiple watchpoints (which is what's used by NanChucks). Covering the fix by watching two nodes (instead of one) at a time in SessionDebugGrpcGatingTest.testToggleBreakpointsWorks in session_debug_grpc_test.py Also: * Refactor the code for getting or creating a DebugGrpcChannel from a grpc:// URL into a helper method: GetOrCreateDebugGrpcChannel(). PiperOrigin-RevId: 169413046
Diffstat (limited to 'tensorflow/core/debug')
-rw-r--r--tensorflow/core/debug/debug_io_utils.cc62
-rw-r--r--tensorflow/core/debug/debug_io_utils.h15
2 files changed, 46 insertions, 31 deletions
diff --git a/tensorflow/core/debug/debug_io_utils.cc b/tensorflow/core/debug/debug_io_utils.cc
index 86f66f909e..85d04daa65 100644
--- a/tensorflow/core/debug/debug_io_utils.cc
+++ b/tensorflow/core/debug/debug_io_utils.cc
@@ -699,17 +699,15 @@ Status DebugGrpcChannel::Connect(const int64 timeout_micros) {
bool DebugGrpcChannel::WriteEvent(const Event& event) {
mutex_lock l(mu_);
-
return reader_writer_->Write(event);
}
bool DebugGrpcChannel::ReadEventReply(EventReply* event_reply) {
+ mutex_lock l(mu_);
return reader_writer_->Read(event_reply);
}
void DebugGrpcChannel::ReceiveAndProcessEventReplies(const size_t max_replies) {
- mutex_lock l(mu_);
-
EventReply event_reply;
size_t num_replies = 0;
while ((max_replies == 0 || ++num_replies <= max_replies) &&
@@ -747,11 +745,11 @@ const size_t DebugGrpcIO::kGrpcMessageSizeLimitBytes = 4000 * 1024;
const size_t DebugGrpcIO::kGrpcMaxVarintLengthSize = 6;
-std::unordered_map<string, std::shared_ptr<DebugGrpcChannel>>*
+std::unordered_map<string, std::unique_ptr<DebugGrpcChannel>>*
DebugGrpcIO::GetStreamChannels() {
- static std::unordered_map<string, std::shared_ptr<DebugGrpcChannel>>*
+ static std::unordered_map<string, std::unique_ptr<DebugGrpcChannel>>*
stream_channels =
- new std::unordered_map<string, std::shared_ptr<DebugGrpcChannel>>();
+ new std::unordered_map<string, std::unique_ptr<DebugGrpcChannel>>();
return stream_channels;
}
@@ -771,9 +769,10 @@ Status DebugGrpcIO::SendTensorThroughGrpcStream(
SendEventProtoThroughGrpcStream(event, grpc_stream_url));
}
if (IsWriteGateOpen(grpc_stream_url, debug_node_key.debug_node_name)) {
- EventReply event_reply;
- TF_RETURN_IF_ERROR(ReceiveEventReplyProtoThroughGrpcStream(
- &event_reply, grpc_stream_url));
+ DebugGrpcChannel* debug_grpc_channel = nullptr;
+ TF_RETURN_IF_ERROR(
+ GetOrCreateDebugGrpcChannel(grpc_stream_url, &debug_grpc_channel));
+ debug_grpc_channel->ReceiveAndProcessEventReplies(1);
// TODO(cais): Support new tensor value carried in the EventReply for
// overriding the value of the tensor being published.
}
@@ -783,13 +782,9 @@ Status DebugGrpcIO::SendTensorThroughGrpcStream(
Status DebugGrpcIO::ReceiveEventReplyProtoThroughGrpcStream(
EventReply* event_reply, const string& grpc_stream_url) {
- std::shared_ptr<DebugGrpcChannel> debug_grpc_channel;
- {
- mutex_lock l(streams_mu);
- std::unordered_map<string, std::shared_ptr<DebugGrpcChannel>>*
- stream_channels = GetStreamChannels();
- debug_grpc_channel = (*stream_channels)[grpc_stream_url];
- }
+ DebugGrpcChannel* debug_grpc_channel = nullptr;
+ TF_RETURN_IF_ERROR(
+ GetOrCreateDebugGrpcChannel(grpc_stream_url, &debug_grpc_channel));
if (debug_grpc_channel->ReadEventReply(event_reply)) {
return Status::OK();
} else {
@@ -798,29 +793,36 @@ Status DebugGrpcIO::ReceiveEventReplyProtoThroughGrpcStream(
}
}
-Status DebugGrpcIO::SendEventProtoThroughGrpcStream(
- const Event& event_proto, const string& grpc_stream_url,
- const bool receive_reply) {
+Status DebugGrpcIO::GetOrCreateDebugGrpcChannel(
+ const string& grpc_stream_url, DebugGrpcChannel** debug_grpc_channel) {
const string addr_with_path =
grpc_stream_url.find(DebugIO::kGrpcURLScheme) == 0
? grpc_stream_url.substr(strlen(DebugIO::kGrpcURLScheme))
: grpc_stream_url;
const string server_stream_addr =
addr_with_path.substr(0, addr_with_path.find('/'));
- std::shared_ptr<DebugGrpcChannel> debug_grpc_channel;
{
mutex_lock l(streams_mu);
- std::unordered_map<string, std::shared_ptr<DebugGrpcChannel>>*
+ std::unordered_map<string, std::unique_ptr<DebugGrpcChannel>>*
stream_channels = GetStreamChannels();
if (stream_channels->find(grpc_stream_url) == stream_channels->end()) {
- debug_grpc_channel.reset(new DebugGrpcChannel(server_stream_addr));
- TF_RETURN_IF_ERROR(
- debug_grpc_channel->Connect(channel_connection_timeout_micros));
- (*stream_channels)[grpc_stream_url] = debug_grpc_channel;
- } else {
- debug_grpc_channel = (*stream_channels)[grpc_stream_url];
+ std::unique_ptr<DebugGrpcChannel> channel(
+ new DebugGrpcChannel(server_stream_addr));
+ TF_RETURN_IF_ERROR(channel->Connect(channel_connection_timeout_micros));
+ stream_channels->insert(
+ std::make_pair(grpc_stream_url, std::move(channel)));
}
+ *debug_grpc_channel = (*stream_channels)[grpc_stream_url].get();
}
+ return Status::OK();
+}
+
+Status DebugGrpcIO::SendEventProtoThroughGrpcStream(
+ const Event& event_proto, const string& grpc_stream_url,
+ const bool receive_reply) {
+ DebugGrpcChannel* debug_grpc_channel;
+ TF_RETURN_IF_ERROR(
+ GetOrCreateDebugGrpcChannel(grpc_stream_url, &debug_grpc_channel));
bool write_ok = debug_grpc_channel->WriteEvent(event_proto);
if (!write_ok) {
@@ -857,13 +859,13 @@ bool DebugGrpcIO::IsWriteGateOpen(const string& grpc_debug_url,
Status DebugGrpcIO::CloseGrpcStream(const string& grpc_stream_url) {
mutex_lock l(streams_mu);
- std::unordered_map<string, std::shared_ptr<DebugGrpcChannel>>*
+ std::unordered_map<string, std::unique_ptr<DebugGrpcChannel>>*
stream_channels = GetStreamChannels();
if (stream_channels->find(grpc_stream_url) != stream_channels->end()) {
// Stream of the specified address exists. Close it and remove it from
// record.
- Status s;
- s = (*stream_channels)[grpc_stream_url]->ReceiveServerRepliesAndClose();
+ Status s =
+ (*stream_channels)[grpc_stream_url]->ReceiveServerRepliesAndClose();
(*stream_channels).erase(grpc_stream_url);
return s;
} else {
diff --git a/tensorflow/core/debug/debug_io_utils.h b/tensorflow/core/debug/debug_io_utils.h
index 023d7a7ee0..c974a47051 100644
--- a/tensorflow/core/debug/debug_io_utils.h
+++ b/tensorflow/core/debug/debug_io_utils.h
@@ -361,9 +361,22 @@ class DebugGrpcIO {
// Returns a global map from grpc debug URLs to the corresponding
// DebugGrpcChannels.
- static std::unordered_map<string, std::shared_ptr<DebugGrpcChannel>>*
+ static std::unordered_map<string, std::unique_ptr<DebugGrpcChannel>>*
GetStreamChannels();
+ // Get a DebugGrpcChannel object at a given URL, creating one if necessary.
+ //
+ // Args:
+ // grpc_stream_url: grpc:// URL of the stream, e.g., "grpc://localhost:6064"
+ // debug_grpc_channel: A pointer to the DebugGrpcChannel object, passed as a
+ // a pointer to the pointer. The DebugGrpcChannel object is owned
+ // statically elsewhere, not by the caller of this function.
+ //
+ // Returns:
+ // Status of this operation.
+ static Status GetOrCreateDebugGrpcChannel(
+ const string& grpc_stream_url, DebugGrpcChannel** debug_grpc_channel);
+
// Returns a map from debug URL to a map from debug op name to enabled state.
static std::unordered_map<string, DebugNodeName2State>*
GetEnabledDebugOpStates();