aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/debug
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2017-09-13 15:53:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-13 15:56:51 -0700
commitecae82d1343df293fa36e67949e5404111817110 (patch)
treefd7044715524d8c5df048ee92755d959c7117f22 /tensorflow/core/debug
parentf95b1cf115330774c01a5b2cefc8b81b26552190 (diff)
tfdbg: change grpc protocol to make it possible to switch on/off gated_grpc debug ops at the first Session.run()
Previously, the request_watch() and request_unwatch() calls will not take effect the current Session.run() has ended. This CL eliminates that limitation by letting PublishCoreMeatadata() and PublishGraph() read an EventReply from the server before they return. For the Python API, this change is backward compatible. The on_core_meatadata() and on_graph_def() methods can continue to return nothing (None), in which case the base server class will just replace the None with a default-constructor EventReply. PiperOrigin-RevId: 168608272
Diffstat (limited to 'tensorflow/core/debug')
-rw-r--r--tensorflow/core/debug/debug_grpc_testlib.cc2
-rw-r--r--tensorflow/core/debug/debug_io_utils.cc69
-rw-r--r--tensorflow/core/debug/debug_io_utils.h26
3 files changed, 48 insertions, 49 deletions
diff --git a/tensorflow/core/debug/debug_grpc_testlib.cc b/tensorflow/core/debug/debug_grpc_testlib.cc
index aa80ea84e3..a312f789d8 100644
--- a/tensorflow/core/debug/debug_grpc_testlib.cc
+++ b/tensorflow/core/debug/debug_grpc_testlib.cc
@@ -37,8 +37,10 @@ namespace test {
while (stream->Read(&event)) {
if (event.has_log_message()) {
debug_metadata_strings.push_back(event.log_message().message());
+ stream->Write(EventReply());
} else if (!event.graph_def().empty()) {
encoded_graph_defs.push_back(event.graph_def());
+ stream->Write(EventReply());
} else if (event.has_summary()) {
const Summary::Value& val = event.summary().value(0);
diff --git a/tensorflow/core/debug/debug_io_utils.cc b/tensorflow/core/debug/debug_io_utils.cc
index c9f2c24732..546cde4c16 100644
--- a/tensorflow/core/debug/debug_io_utils.cc
+++ b/tensorflow/core/debug/debug_io_utils.cc
@@ -263,10 +263,13 @@ Status PublishEncodedGraphDefInChunks(const string& encoded_graph_def,
event.set_graph_def(strings::StrCat(hash, ",", device_name, ",", wall_time,
"|", i, "|", num_chunks, "|",
encoded_graph_def.substr(pos, len)));
- if (!DebugGrpcIO::SendEventProtoThroughGrpcStream(event, debug_url).ok()) {
+ const Status s = DebugGrpcIO::SendEventProtoThroughGrpcStream(
+ event, debug_url, num_chunks - 1 == i);
+ if (!s.ok()) {
return errors::FailedPrecondition(
"Failed to send chunk ", i, " of ", num_chunks,
- " of encoded GraphDef of size ", encoded_graph_def.size(), " bytes");
+ " of encoded GraphDef of size ", encoded_graph_def.size(), " bytes, ",
+ "due to: ", s.error_message());
}
}
return Status::OK();
@@ -275,22 +278,16 @@ Status PublishEncodedGraphDefInChunks(const string& encoded_graph_def,
} // namespace
-// static
const char* const DebugIO::kDebuggerPluginName = "debugger";
-// static
const char* const DebugIO::kMetadataFilePrefix = "_tfdbg_";
-// static
const char* const DebugIO::kCoreMetadataTag = "core_metadata_";
-// static
const char* const DebugIO::kDeviceTag = "device_";
-// static
const char* const DebugIO::kGraphTag = "graph_";
-// static
const char* const DebugIO::kHashTag = "hash";
DebugNodeKey::DebugNodeKey(const string& device_name, const string& node_name,
@@ -341,7 +338,6 @@ Status ReadEventFromFile(const string& dump_file_path, Event* event) {
return Status::OK();
}
-// static
const string DebugNodeKey::DeviceNameToDevicePath(const string& device_name) {
return strings::StrCat(
DebugIO::kMetadataFilePrefix, DebugIO::kDeviceTag,
@@ -350,13 +346,10 @@ const string DebugNodeKey::DeviceNameToDevicePath(const string& device_name) {
true));
}
-// static
const char* const DebugIO::kFileURLScheme = "file://";
-// static
const char* const DebugIO::kGrpcURLScheme = "grpc://";
// Publishes debug metadata to a set of debug URLs.
-// static
Status DebugIO::PublishDebugMetadata(
const int64 global_step, const int64 session_run_index,
const int64 executor_step_index, const std::vector<string>& input_names,
@@ -421,7 +414,7 @@ Status DebugIO::PublishDebugMetadata(
",\"grpc_path\":\"", path, "\"}"));
status.Update(
- DebugGrpcIO::SendEventProtoThroughGrpcStream(grpc_event, url));
+ DebugGrpcIO::SendEventProtoThroughGrpcStream(grpc_event, url, true));
#else
GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR;
#endif
@@ -443,7 +436,6 @@ Status DebugIO::PublishDebugMetadata(
return status;
}
-// static
Status DebugIO::PublishDebugTensor(const DebugNodeKey& debug_node_key,
const Tensor& tensor,
const uint64 wall_time_us,
@@ -494,7 +486,6 @@ Status DebugIO::PublishDebugTensor(const DebugNodeKey& debug_node_key,
}
}
-// static
Status DebugIO::PublishDebugTensor(const DebugNodeKey& debug_node_key,
const Tensor& tensor,
const uint64 wall_time_us,
@@ -503,7 +494,6 @@ Status DebugIO::PublishDebugTensor(const DebugNodeKey& debug_node_key,
false);
}
-// static
Status DebugIO::PublishGraph(const Graph& graph, const string& device_name,
const std::unordered_set<string>& debug_urls) {
GraphDef graph_def;
@@ -543,7 +533,6 @@ Status DebugIO::PublishGraph(const Graph& graph, const string& device_name,
return status;
}
-// static
bool DebugIO::IsCopyNodeGateOpen(
const std::vector<DebugWatchAndURLSpec>& specs) {
#ifndef PLATFORM_WINDOWS
@@ -563,7 +552,6 @@ bool DebugIO::IsCopyNodeGateOpen(
#endif
}
-// static
bool DebugIO::IsDebugNodeGateOpen(const string& watch_key,
const std::vector<string>& debug_urls) {
#ifndef PLATFORM_WINDOWS
@@ -583,7 +571,6 @@ bool DebugIO::IsDebugNodeGateOpen(const string& watch_key,
#endif
}
-// static
bool DebugIO::IsDebugURLGateOpen(const string& watch_key,
const string& debug_url) {
#ifndef PLATFORM_WINDOWS
@@ -597,7 +584,6 @@ bool DebugIO::IsDebugURLGateOpen(const string& watch_key,
#endif
}
-// static
Status DebugIO::CloseDebugURL(const string& debug_url) {
if (debug_url.find(DebugIO::kGrpcURLScheme) == 0) {
#ifndef PLATFORM_WINDOWS
@@ -611,10 +597,8 @@ Status DebugIO::CloseDebugURL(const string& debug_url) {
}
}
-// static
static Status CloseDebugURL(const string& debug_url) { return Status::OK(); }
-// static
Status DebugFileIO::DumpTensorToDir(const DebugNodeKey& debug_node_key,
const Tensor& tensor,
const uint64 wall_time_us,
@@ -630,7 +614,6 @@ Status DebugFileIO::DumpTensorToDir(const DebugNodeKey& debug_node_key,
return DumpTensorToEventFile(debug_node_key, tensor, wall_time_us, file_path);
}
-// static
string DebugFileIO::GetDumpFilePath(const string& dump_root_dir,
const DebugNodeKey& debug_node_key,
const uint64 wall_time_us) {
@@ -642,7 +625,6 @@ string DebugFileIO::GetDumpFilePath(const string& dump_root_dir,
wall_time_us);
}
-// static
Status DebugFileIO::DumpEventProtoToFile(const Event& event_proto,
const string& dir_name,
const string& file_name) {
@@ -668,7 +650,6 @@ Status DebugFileIO::DumpEventProtoToFile(const Event& event_proto,
return Status::OK();
}
-// static
Status DebugFileIO::DumpTensorToEventFile(const DebugNodeKey& debug_node_key,
const Tensor& tensor,
const uint64 wall_time_us,
@@ -680,7 +661,6 @@ Status DebugFileIO::DumpTensorToEventFile(const DebugNodeKey& debug_node_key,
io::Basename(file_path).ToString());
}
-// static
Status DebugFileIO::RecursiveCreateDir(Env* env, const string& dir) {
if (env->FileExists(dir).ok() && env->IsDirectory(dir).ok()) {
// The path already exists as a directory. Return OK right away.
@@ -747,19 +727,16 @@ bool DebugGrpcChannel::WriteEvent(const Event& event) {
}
bool DebugGrpcChannel::ReadEventReply(EventReply* event_reply) {
- mutex_lock l(mu_);
-
return reader_writer_->Read(event_reply);
}
-Status DebugGrpcChannel::ReceiveServerRepliesAndClose() {
+void DebugGrpcChannel::ReceiveAndProcessEventReplies(const size_t max_replies) {
mutex_lock l(mu_);
- reader_writer_->WritesDone();
-
- // Read all EventReply messages (if any) from the server.
EventReply event_reply;
- while (reader_writer_->Read(&event_reply)) {
+ size_t num_replies = 0;
+ while ((max_replies == 0 || ++num_replies <= max_replies) &&
+ ReadEventReply(&event_reply)) {
for (const EventReply::DebugOpStateChange& debug_op_state_change :
event_reply.debug_op_state_changes()) {
string watch_key = strings::StrCat(debug_op_state_change.node_name(), ":",
@@ -769,6 +746,12 @@ Status DebugGrpcChannel::ReceiveServerRepliesAndClose() {
debug_op_state_change.state());
}
}
+}
+
+Status DebugGrpcChannel::ReceiveServerRepliesAndClose() {
+ reader_writer_->WritesDone();
+ // Read all EventReply messages (if any) from the server.
+ ReceiveAndProcessEventReplies(0);
if (reader_writer_->Finish().ok()) {
return Status::OK();
@@ -778,20 +761,15 @@ Status DebugGrpcChannel::ReceiveServerRepliesAndClose() {
}
}
-// static
mutex DebugGrpcIO::streams_mu;
-// static
int64 DebugGrpcIO::channel_connection_timeout_micros = 900 * 1000 * 1000;
// TODO(cais): Make this configurable?
-// static
const size_t DebugGrpcIO::kGrpcMessageSizeLimitBytes = 4000 * 1024;
-// static
const size_t DebugGrpcIO::kGrpcMaxVarintLengthSize = 6;
-// static
std::unordered_map<string, std::shared_ptr<DebugGrpcChannel>>*
DebugGrpcIO::GetStreamChannels() {
static std::unordered_map<string, std::shared_ptr<DebugGrpcChannel>>*
@@ -800,7 +778,6 @@ DebugGrpcIO::GetStreamChannels() {
return stream_channels;
}
-// static
Status DebugGrpcIO::SendTensorThroughGrpcStream(
const DebugNodeKey& debug_node_key, const Tensor& tensor,
const uint64 wall_time_us, const string& grpc_stream_url,
@@ -827,7 +804,6 @@ Status DebugGrpcIO::SendTensorThroughGrpcStream(
}
}
-// static
Status DebugGrpcIO::ReceiveEventReplyProtoThroughGrpcStream(
EventReply* event_reply, const string& grpc_stream_url) {
std::shared_ptr<DebugGrpcChannel> debug_grpc_channel;
@@ -845,9 +821,9 @@ Status DebugGrpcIO::ReceiveEventReplyProtoThroughGrpcStream(
}
}
-// static
Status DebugGrpcIO::SendEventProtoThroughGrpcStream(
- const Event& event_proto, const string& grpc_stream_url) {
+ const Event& event_proto, const string& grpc_stream_url,
+ const bool receive_reply) {
const string addr_with_path =
grpc_stream_url.find(DebugIO::kGrpcURLScheme) == 0
? grpc_stream_url.substr(strlen(DebugIO::kGrpcURLScheme))
@@ -875,6 +851,10 @@ Status DebugGrpcIO::SendEventProtoThroughGrpcStream(
grpc_stream_url, " failed."));
}
+ if (receive_reply) {
+ debug_grpc_channel->ReceiveAndProcessEventReplies(1);
+ }
+
return Status::OK();
}
@@ -897,7 +877,6 @@ bool DebugGrpcIO::IsWriteGateOpen(const string& grpc_debug_url,
}
}
-// static
Status DebugGrpcIO::CloseGrpcStream(const string& grpc_stream_url) {
mutex_lock l(streams_mu);
@@ -916,7 +895,6 @@ Status DebugGrpcIO::CloseGrpcStream(const string& grpc_stream_url) {
}
}
-// static
std::unordered_map<string, DebugGrpcIO::DebugNodeName2State>*
DebugGrpcIO::GetEnabledDebugOpStates() {
static std::unordered_map<string, DebugNodeName2State>*
@@ -925,7 +903,6 @@ DebugGrpcIO::GetEnabledDebugOpStates() {
return enabled_debug_op_states;
}
-// static
DebugGrpcIO::DebugNodeName2State* DebugGrpcIO::GetEnabledDebugOpStatesAtUrl(
const string& grpc_debug_url) {
static mutex* debug_ops_state_mu = new mutex();
@@ -940,7 +917,6 @@ DebugGrpcIO::DebugNodeName2State* DebugGrpcIO::GetEnabledDebugOpStatesAtUrl(
return &(*states)[grpc_debug_url];
}
-// static
void DebugGrpcIO::SetDebugNodeKeyGrpcState(
const string& grpc_debug_url, const string& watch_key,
const EventReply::DebugOpStateChange::State new_state) {
@@ -957,7 +933,6 @@ void DebugGrpcIO::SetDebugNodeKeyGrpcState(
}
}
-// static
void DebugGrpcIO::ClearEnabledWatchKeys() {
GetEnabledDebugOpStates()->clear();
}
diff --git a/tensorflow/core/debug/debug_io_utils.h b/tensorflow/core/debug/debug_io_utils.h
index 35e735172b..75fc2b07f3 100644
--- a/tensorflow/core/debug/debug_io_utils.h
+++ b/tensorflow/core/debug/debug_io_utils.h
@@ -65,6 +65,7 @@ struct DebugNodeKey {
const string device_path;
};
+// TODO(cais): Put static functions and members in a namespace, not a class.
class DebugIO {
public:
static const char* const kDebuggerPluginName;
@@ -295,6 +296,16 @@ class DebugGrpcChannel {
// True iff the read is successful.
bool ReadEventReply(EventReply* event_reply);
+ // Receive and process EventReply protos from the gRPC debug server.
+ //
+ // The processing includes setting debug watch key states using the
+ // DebugOpStateChange fields of the EventReply.
+ //
+ // Args:
+ // max_replies: Maximum number of replies to receive. Will receive all
+ // remaining replies iff max_replies == 0.
+ void ReceiveAndProcessEventReplies(size_t max_replies);
+
// Receive EventReplies from server (if any) and close the stream and the
// channel.
Status ReceiveServerRepliesAndClose();
@@ -326,8 +337,19 @@ class DebugGrpcIO {
// Sends an Event proto through a debug gRPC stream.
// Thread-safety: Safe with respect to other calls to the same method and
// calls to CloseGrpcStream().
- static Status SendEventProtoThroughGrpcStream(const Event& event_proto,
- const string& grpc_stream_url);
+ //
+ // Args:
+ // event_proto: The Event proto to be sent.
+ // grpc_stream_url: The grpc:// URL of the stream to use, e.g.,
+ // "grpc://localhost:11011", "localhost:22022".
+ // receive_reply: Whether an EventReply proto will be read after event_proto
+ // is sent and before the function returns.
+ //
+ // Returns:
+ // The Status of the operation.
+ static Status SendEventProtoThroughGrpcStream(
+ const Event& event_proto, const string& grpc_stream_url,
+ const bool receive_reply = false);
// Receive an EventReply proto through a debug gRPC stream.
static Status ReceiveEventReplyProtoThroughGrpcStream(