aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/debug/debug_io_utils.h
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2016-11-22 17:30:19 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-22 17:44:29 -0800
commitef2a926ec05dfd337d84279aafa58b22f0f36123 (patch)
treeb8c46584d7a585698ab6ba308d43c5ce20e8c749 /tensorflow/core/debug/debug_io_utils.h
parentcce0d12e13e87401a47d05145551b5a87d1167b3 (diff)
tfdbg core: implement gRPC debug URLs
Change: 139976177
Diffstat (limited to 'tensorflow/core/debug/debug_io_utils.h')
-rw-r--r--tensorflow/core/debug/debug_io_utils.h74
1 files changed, 72 insertions, 2 deletions
diff --git a/tensorflow/core/debug/debug_io_utils.h b/tensorflow/core/debug/debug_io_utils.h
index 2366abeda2..860004a02a 100644
--- a/tensorflow/core/debug/debug_io_utils.h
+++ b/tensorflow/core/debug/debug_io_utils.h
@@ -16,6 +16,10 @@ limitations under the License.
#ifndef TENSORFLOW_DEBUG_IO_UTILS_H_
#define TENSORFLOW_DEBUG_IO_UTILS_H_
+#include <unordered_map>
+#include <unordered_set>
+
+#include "tensorflow/core/debug/debug_service.grpc.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
@@ -23,6 +27,8 @@ limitations under the License.
namespace tensorflow {
+Status ReadEventFromFile(const string& dump_file_path, Event* event);
+
class DebugIO {
public:
// Publish a tensor to a debug target URL.
@@ -36,12 +42,14 @@ class DebugIO {
// tensor: The Tensor object being published.
// wall_time_us: Time stamp for the Tensor. Unit: microseconds (us).
// debug_urls: An array of debug target URLs, e.g.,
- // "file:///foo/tfdbg_dump", "grpc://localhot:11011"
+ // "file:///foo/tfdbg_dump", "grpc://localhost:11011"
static Status PublishDebugTensor(const string& tensor_name,
const string& debug_op, const Tensor& tensor,
const uint64 wall_time_us,
const gtl::ArraySlice<string>& debug_urls);
+ static Status CloseDebugURL(const string& debug_url);
+
private:
static const char* const kFileURLScheme;
static const char* const kGrpcURLScheme;
@@ -70,7 +78,7 @@ class DebugFileIO {
// tensor: The Tensor object to be dumped to file.
// wall_time_us: Wall time at which the Tensor is generated during graph
// execution. Unit: microseconds (us).
- // dump_root_dir: Root diretory for dumping the tensor.
+ // dump_root_dir: Root directory for dumping the tensor.
// dump_file_path: The actual dump file path (passed as reference).
static Status DumpTensorToDir(const string& node_name,
const int32 output_slot, const string& debug_op,
@@ -104,6 +112,68 @@ class DebugFileIO {
static Status RecursiveCreateDir(Env* env, const string& dir);
};
+class DebugGrpcChannel {
+ public:
+ // Constructor of DebugGrpcChannel.
+ //
+ // Args:
+ // server_stream_addr: Address (host name and port) of the debug stream
+ // server implementing the EventListener service (see
+ // debug_service.proto). E.g., "127.0.0.1:12345".
+ DebugGrpcChannel(const string& server_stream_addr);
+
+ virtual ~DebugGrpcChannel() {}
+
+ // Query whether the gRPC channel is ready for use.
+ bool is_channel_ready();
+
+ // Write an Event proto to the debug gRPC stream.
+ //
+ // Thread-safety: Safe with respect to other calls to the same method and
+ // call to Close().
+ // Args:
+ // event: The event proto to be written to the stream.
+ //
+ // Returns:
+ // True iff the write is successful.
+ bool WriteEvent(const Event& event);
+
+ // Close the stream and the channel.
+ Status Close();
+
+ private:
+ ::grpc::ClientContext ctx_;
+ std::shared_ptr<::grpc::Channel> channel_;
+ std::unique_ptr<EventListener::Stub> stub_;
+ std::unique_ptr<::grpc::ClientReaderWriterInterface<Event, EventReply>>
+ reader_writer_;
+
+ mutex mu_;
+};
+
+class DebugGrpcIO {
+ public:
+ // Send a tensor through a debug gRPC stream.
+ // Thread-safety: Safe with respect to other calls to the same method and
+ // calls to CloseGrpcStream().
+ static Status SendTensorThroughGrpcStream(const string& node_name,
+ const int32 output_slot,
+ const string& debug_op,
+ const Tensor& tensor,
+ const uint64 wall_time_us,
+ const string& server_stream_addr);
+
+ // Close a gRPC stream to the given address, if it exists.
+ // Thread-safety: Safe with respect to other calls to the same method and
+ // calls to SendTensorThroughGrpcStream().
+ static Status CloseGrpcStream(const string& server_stream_addr);
+
+ private:
+ static mutex streams_mu;
+ static std::unordered_map<string, std::shared_ptr<DebugGrpcChannel>>
+ stream_channels GUARDED_BY(streams_mu);
+};
+
} // namespace tensorflow
#endif // TENSORFLOW_DEBUG_IO_UTILS_H_