diff options
author | 2016-11-22 17:30:19 -0800 | |
---|---|---|
committer | 2016-11-22 17:44:29 -0800 | |
commit | ef2a926ec05dfd337d84279aafa58b22f0f36123 (patch) | |
tree | b8c46584d7a585698ab6ba308d43c5ce20e8c749 /tensorflow/core/debug/debug_io_utils.cc | |
parent | cce0d12e13e87401a47d05145551b5a87d1167b3 (diff) |
tfdbg core: implement gRPC debug URLs
Change: 139976177
Diffstat (limited to 'tensorflow/core/debug/debug_io_utils.cc')
-rw-r--r-- | tensorflow/core/debug/debug_io_utils.cc | 164 |
1 files changed, 156 insertions, 8 deletions
diff --git a/tensorflow/core/debug/debug_io_utils.cc b/tensorflow/core/debug/debug_io_utils.cc index 3738dd21d3..dc7121e6c3 100644 --- a/tensorflow/core/debug/debug_io_utils.cc +++ b/tensorflow/core/debug/debug_io_utils.cc @@ -17,6 +17,8 @@ limitations under the License. #include <vector> +#include "grpc++/create_channel.h" +#include "tensorflow/core/debug/debug_service.grpc.pb.h" #include "tensorflow/core/framework/summary.pb.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -56,6 +58,35 @@ Event WrapTensorAsEvent(const string& tensor_name, const string& debug_op, } // namespace +Status ReadEventFromFile(const string& dump_file_path, Event* event) { + Env* env(Env::Default()); + + string content; + uint64 file_size = 0; + + Status s = env->GetFileSize(dump_file_path, &file_size); + if (!s.ok()) { + return s; + } + + content.resize(file_size); + + std::unique_ptr<RandomAccessFile> file; + s = env->NewRandomAccessFile(dump_file_path, &file); + if (!s.ok()) { + return s; + } + + StringPiece result; + s = file->Read(0, file_size, &result, &(content)[0]); + if (!s.ok()) { + return s; + } + + event->ParseFromString(content); + return Status::OK(); +} + // static const char* const DebugIO::kFileURLScheme = "file://"; // static @@ -85,6 +116,7 @@ Status DebugIO::PublishDebugTensor(const string& tensor_name, } int num_failed_urls = 0; + std::vector<Status> fail_statuses; for (const string& url : debug_urls) { if (str_util::Lowercase(url).find(kFileURLScheme) == 0) { const string dump_root_dir = url.substr(strlen(kFileURLScheme)); @@ -94,12 +126,18 @@ Status DebugIO::PublishDebugTensor(const string& tensor_name, wall_time_us, dump_root_dir, nullptr); if (!s.ok()) { num_failed_urls++; + fail_statuses.push_back(s); } } else if (str_util::Lowercase(url).find(kGrpcURLScheme) == 0) { - // TODO(cais): Implement PublishTensor with grpc urls. - return Status(error::UNIMPLEMENTED, - strings::StrCat("Puslishing to GRPC debug target is not ", - "implemented yet")); + const string grpc_server_stream_addr = url.substr(strlen(kGrpcURLScheme)); + Status s = DebugGrpcIO::SendTensorThroughGrpcStream( + node_name, output_slot, debug_op, tensor, wall_time_us, + grpc_server_stream_addr); + + if (!s.ok()) { + num_failed_urls++; + fail_statuses.push_back(s); + } } else { return Status(error::UNAVAILABLE, strings::StrCat("Invalid debug target URL: ", url)); @@ -109,13 +147,31 @@ Status DebugIO::PublishDebugTensor(const string& tensor_name, if (num_failed_urls == 0) { return Status::OK(); } else { - return Status( - error::INTERNAL, - strings::StrCat("Puslishing to ", num_failed_urls, " of ", - debug_urls.size(), " debug target URLs failed")); + string error_message = strings::StrCat( + "Publishing to ", num_failed_urls, " of ", debug_urls.size(), + " debug target URLs failed, due to the following errors:"); + for (Status& status : fail_statuses) { + error_message = + strings::StrCat(error_message, " ", status.error_message(), ";"); + } + + return Status(error::INTERNAL, error_message); } } +Status DebugIO::CloseDebugURL(const string& debug_url) { + if (debug_url.find(DebugIO::kGrpcURLScheme) == 0) { + return DebugGrpcIO::CloseGrpcStream( + debug_url.substr(strlen(DebugIO::kGrpcURLScheme))); + } else { + // No-op for non-gRPC URLs. + return Status::OK(); + } +} + +// static +static Status CloseDebugURL(const string& debug_url) { return Status::OK(); } + // static Status DebugFileIO::DumpTensorToDir( const string& node_name, const int32 output_slot, const string& debug_op, @@ -208,4 +264,96 @@ Status DebugFileIO::RecursiveCreateDir(Env* env, const string& dir) { } } +DebugGrpcChannel::DebugGrpcChannel(const string& server_stream_addr) + : ctx_(), + channel_(::grpc::CreateCustomChannel(server_stream_addr, + ::grpc::InsecureChannelCredentials(), + ::grpc::ChannelArguments())), + stub_(EventListener::NewStub(channel_)), + reader_writer_(stub_->SendEvents(&ctx_)), + mu_() {} +// TODO(cais): Set GRPC_ARG_MAX_MESSAGE_LENGTH to max if necessary. + +bool DebugGrpcChannel::is_channel_ready() { + return channel_->GetState(false) == GRPC_CHANNEL_READY; +} + +bool DebugGrpcChannel::WriteEvent(const Event& event) { + mutex_lock l(mu_); + + return reader_writer_->Write(event); +} + +Status DebugGrpcChannel::Close() { + mutex_lock l(mu_); + + reader_writer_->WritesDone(); + if (reader_writer_->Finish().ok()) { + std::cout << "Finish() returned ok status" << std::endl; // DEBUG + return Status::OK(); + } else { + std::cout << "Finish() returned non-ok status" << std::endl; // DEBUG + return Status(error::FAILED_PRECONDITION, + "Failed to close debug GRPC stream."); + } +} + +// static +mutex DebugGrpcIO::streams_mu; +std::unordered_map<string, std::shared_ptr<DebugGrpcChannel>> + DebugGrpcIO::stream_channels; + +// static +Status DebugGrpcIO::SendTensorThroughGrpcStream( + const string& node_name, const int32 output_slot, const string& debug_op, + const Tensor& tensor, const uint64 wall_time_us, + const string& server_stream_addr) { + const string tensor_name = strings::StrCat(node_name, ":", output_slot); + + // Prepare tensor Event data to be sent. + Event event = WrapTensorAsEvent(tensor_name, debug_op, tensor, wall_time_us); + + std::shared_ptr<DebugGrpcChannel> debug_grpc_channel; + { + mutex_lock l(streams_mu); + if (stream_channels.find(server_stream_addr) == stream_channels.end()) { + debug_grpc_channel.reset(new DebugGrpcChannel(server_stream_addr)); + + if (!debug_grpc_channel->is_channel_ready()) { + return errors::FailedPrecondition( + strings::StrCat("Channel at the following gRPC address is ", + "not ready: ", server_stream_addr)); + } + + stream_channels[server_stream_addr] = debug_grpc_channel; + } else { + debug_grpc_channel = stream_channels[server_stream_addr]; + } + } + + bool write_ok = debug_grpc_channel->WriteEvent(event); + if (!write_ok) { + return errors::Cancelled(strings::StrCat("Write event to stream URL ", + server_stream_addr, "failed.")); + } + + return Status::OK(); +} + +Status DebugGrpcIO::CloseGrpcStream(const string& server_stream_addr) { + mutex_lock l(streams_mu); + + if (stream_channels.find(server_stream_addr) != stream_channels.end()) { + // Stream of the specified address exists. Close it and remove it from + // record. + Status s; + s = stream_channels[server_stream_addr]->Close(); + stream_channels.erase(server_stream_addr); + return s; + } else { + // Stream of the specified address does not exist. No action. + return Status::OK(); + } +} + } // namespace tensorflow |