diff options
author | 2016-11-22 17:30:19 -0800 | |
---|---|---|
committer | 2016-11-22 17:44:29 -0800 | |
commit | ef2a926ec05dfd337d84279aafa58b22f0f36123 (patch) | |
tree | b8c46584d7a585698ab6ba308d43c5ce20e8c749 /tensorflow/core/debug/debug_io_utils.h | |
parent | cce0d12e13e87401a47d05145551b5a87d1167b3 (diff) |
tfdbg core: implement gRPC debug URLs
Change: 139976177
Diffstat (limited to 'tensorflow/core/debug/debug_io_utils.h')
-rw-r--r-- | tensorflow/core/debug/debug_io_utils.h | 74 |
1 files changed, 72 insertions, 2 deletions
diff --git a/tensorflow/core/debug/debug_io_utils.h b/tensorflow/core/debug/debug_io_utils.h index 2366abeda2..860004a02a 100644 --- a/tensorflow/core/debug/debug_io_utils.h +++ b/tensorflow/core/debug/debug_io_utils.h @@ -16,6 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_DEBUG_IO_UTILS_H_ #define TENSORFLOW_DEBUG_IO_UTILS_H_ +#include <unordered_map> +#include <unordered_set> + +#include "tensorflow/core/debug/debug_service.grpc.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -23,6 +27,8 @@ limitations under the License. namespace tensorflow { +Status ReadEventFromFile(const string& dump_file_path, Event* event); + class DebugIO { public: // Publish a tensor to a debug target URL. @@ -36,12 +42,14 @@ class DebugIO { // tensor: The Tensor object being published. // wall_time_us: Time stamp for the Tensor. Unit: microseconds (us). // debug_urls: An array of debug target URLs, e.g., - // "file:///foo/tfdbg_dump", "grpc://localhot:11011" + // "file:///foo/tfdbg_dump", "grpc://localhost:11011" static Status PublishDebugTensor(const string& tensor_name, const string& debug_op, const Tensor& tensor, const uint64 wall_time_us, const gtl::ArraySlice<string>& debug_urls); + static Status CloseDebugURL(const string& debug_url); + private: static const char* const kFileURLScheme; static const char* const kGrpcURLScheme; @@ -70,7 +78,7 @@ class DebugFileIO { // tensor: The Tensor object to be dumped to file. // wall_time_us: Wall time at which the Tensor is generated during graph // execution. Unit: microseconds (us). - // dump_root_dir: Root diretory for dumping the tensor. + // dump_root_dir: Root directory for dumping the tensor. // dump_file_path: The actual dump file path (passed as reference). static Status DumpTensorToDir(const string& node_name, const int32 output_slot, const string& debug_op, @@ -104,6 +112,68 @@ class DebugFileIO { static Status RecursiveCreateDir(Env* env, const string& dir); }; +class DebugGrpcChannel { + public: + // Constructor of DebugGrpcChannel. + // + // Args: + // server_stream_addr: Address (host name and port) of the debug stream + // server implementing the EventListener service (see + // debug_service.proto). E.g., "127.0.0.1:12345". + DebugGrpcChannel(const string& server_stream_addr); + + virtual ~DebugGrpcChannel() {} + + // Query whether the gRPC channel is ready for use. + bool is_channel_ready(); + + // Write an Event proto to the debug gRPC stream. + // + // Thread-safety: Safe with respect to other calls to the same method and + // call to Close(). + // Args: + // event: The event proto to be written to the stream. + // + // Returns: + // True iff the write is successful. + bool WriteEvent(const Event& event); + + // Close the stream and the channel. + Status Close(); + + private: + ::grpc::ClientContext ctx_; + std::shared_ptr<::grpc::Channel> channel_; + std::unique_ptr<EventListener::Stub> stub_; + std::unique_ptr<::grpc::ClientReaderWriterInterface<Event, EventReply>> + reader_writer_; + + mutex mu_; +}; + +class DebugGrpcIO { + public: + // Send a tensor through a debug gRPC stream. + // Thread-safety: Safe with respect to other calls to the same method and + // calls to CloseGrpcStream(). + static Status SendTensorThroughGrpcStream(const string& node_name, + const int32 output_slot, + const string& debug_op, + const Tensor& tensor, + const uint64 wall_time_us, + const string& server_stream_addr); + + // Close a gRPC stream to the given address, if it exists. + // Thread-safety: Safe with respect to other calls to the same method and + // calls to SendTensorThroughGrpcStream(). + static Status CloseGrpcStream(const string& server_stream_addr); + + private: + static mutex streams_mu; + static std::unordered_map<string, std::shared_ptr<DebugGrpcChannel>> + stream_channels GUARDED_BY(streams_mu); +}; + } // namespace tensorflow #endif // TENSORFLOW_DEBUG_IO_UTILS_H_ |