aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/debug/debug_io_utils.cc
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2016-11-22 17:30:19 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-22 17:44:29 -0800
commitef2a926ec05dfd337d84279aafa58b22f0f36123 (patch)
treeb8c46584d7a585698ab6ba308d43c5ce20e8c749 /tensorflow/core/debug/debug_io_utils.cc
parentcce0d12e13e87401a47d05145551b5a87d1167b3 (diff)
tfdbg core: implement gRPC debug URLs
Change: 139976177
Diffstat (limited to 'tensorflow/core/debug/debug_io_utils.cc')
-rw-r--r--tensorflow/core/debug/debug_io_utils.cc164
1 files changed, 156 insertions, 8 deletions
diff --git a/tensorflow/core/debug/debug_io_utils.cc b/tensorflow/core/debug/debug_io_utils.cc
index 3738dd21d3..dc7121e6c3 100644
--- a/tensorflow/core/debug/debug_io_utils.cc
+++ b/tensorflow/core/debug/debug_io_utils.cc
@@ -17,6 +17,8 @@ limitations under the License.
#include <vector>
+#include "grpc++/create_channel.h"
+#include "tensorflow/core/debug/debug_service.grpc.pb.h"
#include "tensorflow/core/framework/summary.pb.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
@@ -56,6 +58,35 @@ Event WrapTensorAsEvent(const string& tensor_name, const string& debug_op,
} // namespace
+Status ReadEventFromFile(const string& dump_file_path, Event* event) {
+ Env* env(Env::Default());
+
+ string content;
+ uint64 file_size = 0;
+
+ Status s = env->GetFileSize(dump_file_path, &file_size);
+ if (!s.ok()) {
+ return s;
+ }
+
+ content.resize(file_size);
+
+ std::unique_ptr<RandomAccessFile> file;
+ s = env->NewRandomAccessFile(dump_file_path, &file);
+ if (!s.ok()) {
+ return s;
+ }
+
+ StringPiece result;
+ s = file->Read(0, file_size, &result, &(content)[0]);
+ if (!s.ok()) {
+ return s;
+ }
+
+ event->ParseFromString(content);
+ return Status::OK();
+}
+
// static
const char* const DebugIO::kFileURLScheme = "file://";
// static
@@ -85,6 +116,7 @@ Status DebugIO::PublishDebugTensor(const string& tensor_name,
}
int num_failed_urls = 0;
+ std::vector<Status> fail_statuses;
for (const string& url : debug_urls) {
if (str_util::Lowercase(url).find(kFileURLScheme) == 0) {
const string dump_root_dir = url.substr(strlen(kFileURLScheme));
@@ -94,12 +126,18 @@ Status DebugIO::PublishDebugTensor(const string& tensor_name,
wall_time_us, dump_root_dir, nullptr);
if (!s.ok()) {
num_failed_urls++;
+ fail_statuses.push_back(s);
}
} else if (str_util::Lowercase(url).find(kGrpcURLScheme) == 0) {
- // TODO(cais): Implement PublishTensor with grpc urls.
- return Status(error::UNIMPLEMENTED,
- strings::StrCat("Puslishing to GRPC debug target is not ",
- "implemented yet"));
+ const string grpc_server_stream_addr = url.substr(strlen(kGrpcURLScheme));
+ Status s = DebugGrpcIO::SendTensorThroughGrpcStream(
+ node_name, output_slot, debug_op, tensor, wall_time_us,
+ grpc_server_stream_addr);
+
+ if (!s.ok()) {
+ num_failed_urls++;
+ fail_statuses.push_back(s);
+ }
} else {
return Status(error::UNAVAILABLE,
strings::StrCat("Invalid debug target URL: ", url));
@@ -109,13 +147,31 @@ Status DebugIO::PublishDebugTensor(const string& tensor_name,
if (num_failed_urls == 0) {
return Status::OK();
} else {
- return Status(
- error::INTERNAL,
- strings::StrCat("Puslishing to ", num_failed_urls, " of ",
- debug_urls.size(), " debug target URLs failed"));
+ string error_message = strings::StrCat(
+ "Publishing to ", num_failed_urls, " of ", debug_urls.size(),
+ " debug target URLs failed, due to the following errors:");
+ for (Status& status : fail_statuses) {
+ error_message =
+ strings::StrCat(error_message, " ", status.error_message(), ";");
+ }
+
+ return Status(error::INTERNAL, error_message);
}
}
+Status DebugIO::CloseDebugURL(const string& debug_url) {
+ if (debug_url.find(DebugIO::kGrpcURLScheme) == 0) {
+ return DebugGrpcIO::CloseGrpcStream(
+ debug_url.substr(strlen(DebugIO::kGrpcURLScheme)));
+ } else {
+ // No-op for non-gRPC URLs.
+ return Status::OK();
+ }
+}
+
+// static
+static Status CloseDebugURL(const string& debug_url) { return Status::OK(); }
+
// static
Status DebugFileIO::DumpTensorToDir(
const string& node_name, const int32 output_slot, const string& debug_op,
@@ -208,4 +264,96 @@ Status DebugFileIO::RecursiveCreateDir(Env* env, const string& dir) {
}
}
+DebugGrpcChannel::DebugGrpcChannel(const string& server_stream_addr)
+ : ctx_(),
+ channel_(::grpc::CreateCustomChannel(server_stream_addr,
+ ::grpc::InsecureChannelCredentials(),
+ ::grpc::ChannelArguments())),
+ stub_(EventListener::NewStub(channel_)),
+ reader_writer_(stub_->SendEvents(&ctx_)),
+ mu_() {}
+// TODO(cais): Set GRPC_ARG_MAX_MESSAGE_LENGTH to max if necessary.
+
+bool DebugGrpcChannel::is_channel_ready() {
+ return channel_->GetState(false) == GRPC_CHANNEL_READY;
+}
+
+bool DebugGrpcChannel::WriteEvent(const Event& event) {
+ mutex_lock l(mu_);
+
+ return reader_writer_->Write(event);
+}
+
+Status DebugGrpcChannel::Close() {
+ mutex_lock l(mu_);
+
+ reader_writer_->WritesDone();
+ if (reader_writer_->Finish().ok()) {
+ std::cout << "Finish() returned ok status" << std::endl; // DEBUG
+ return Status::OK();
+ } else {
+ std::cout << "Finish() returned non-ok status" << std::endl; // DEBUG
+ return Status(error::FAILED_PRECONDITION,
+ "Failed to close debug GRPC stream.");
+ }
+}
+
+// static
+mutex DebugGrpcIO::streams_mu;
+std::unordered_map<string, std::shared_ptr<DebugGrpcChannel>>
+ DebugGrpcIO::stream_channels;
+
+// static
+Status DebugGrpcIO::SendTensorThroughGrpcStream(
+ const string& node_name, const int32 output_slot, const string& debug_op,
+ const Tensor& tensor, const uint64 wall_time_us,
+ const string& server_stream_addr) {
+ const string tensor_name = strings::StrCat(node_name, ":", output_slot);
+
+ // Prepare tensor Event data to be sent.
+ Event event = WrapTensorAsEvent(tensor_name, debug_op, tensor, wall_time_us);
+
+ std::shared_ptr<DebugGrpcChannel> debug_grpc_channel;
+ {
+ mutex_lock l(streams_mu);
+ if (stream_channels.find(server_stream_addr) == stream_channels.end()) {
+ debug_grpc_channel.reset(new DebugGrpcChannel(server_stream_addr));
+
+ if (!debug_grpc_channel->is_channel_ready()) {
+ return errors::FailedPrecondition(
+ strings::StrCat("Channel at the following gRPC address is ",
+ "not ready: ", server_stream_addr));
+ }
+
+ stream_channels[server_stream_addr] = debug_grpc_channel;
+ } else {
+ debug_grpc_channel = stream_channels[server_stream_addr];
+ }
+ }
+
+ bool write_ok = debug_grpc_channel->WriteEvent(event);
+ if (!write_ok) {
+ return errors::Cancelled(strings::StrCat("Write event to stream URL ",
+ server_stream_addr, "failed."));
+ }
+
+ return Status::OK();
+}
+
+Status DebugGrpcIO::CloseGrpcStream(const string& server_stream_addr) {
+ mutex_lock l(streams_mu);
+
+ if (stream_channels.find(server_stream_addr) != stream_channels.end()) {
+ // Stream of the specified address exists. Close it and remove it from
+ // record.
+ Status s;
+ s = stream_channels[server_stream_addr]->Close();
+ stream_channels.erase(server_stream_addr);
+ return s;
+ } else {
+ // Stream of the specified address does not exist. No action.
+ return Status::OK();
+ }
+}
+
} // namespace tensorflow