diff options
author | 2018-09-25 16:08:24 -0700 | |
---|---|---|
committer | 2018-09-25 16:17:30 -0700 | |
commit | 4b780e46dcb29a1fb7a3ab81d95b3f8376101989 (patch) | |
tree | 2ba85fb065ac920bff0241545cb265df5001a83a /tensorflow/python/client | |
parent | 22776289fbe30ca7f4b1a80d7e23f5bddca391c2 (diff) |
Remove unneeded locks in session logging.
PiperOrigin-RevId: 214521486
Diffstat (limited to 'tensorflow/python/client')
-rw-r--r-- | tensorflow/python/client/session_ref.cc | 40 |
1 files changed, 25 insertions, 15 deletions
diff --git a/tensorflow/python/client/session_ref.cc b/tensorflow/python/client/session_ref.cc index b2300df0b6..4d361612b7 100644 --- a/tensorflow/python/client/session_ref.cc +++ b/tensorflow/python/client/session_ref.cc @@ -93,23 +93,35 @@ class SessionLogger { public: SessionLogger() { std::string log_name = getenv("TF_REPLAY_LOG_FILE"); + LOG(INFO) << "Constructing new session logger for " << log_name; TF_CHECK_OK( Env::Default()->RecursivelyCreateDir(string(io::Dirname(log_name)))); Env::Default()->DeleteFile(log_name).IgnoreError(); - TF_CHECK_OK(Env::Default()->NewWritableFile(log_name, &log_file_)); + TF_CHECK_OK(Env::Default()->NewWritableFile(log_name, &log_file_)); log_writer_ = absl::make_unique<io::RecordWriter>(log_file_.get()); } - Status RecordCreateSession(Session* session) { - LOG(INFO) << "Capturing devices for session."; + ~SessionLogger() { + log_writer_->Close().IgnoreError(); + log_writer_.release(); + log_file_->Close().IgnoreError(); + } + + Status RecordNewSession(Session* session) { + LOG(INFO) << "New session discovered. Capturing devices..."; ReplayOp op; NewReplaySession* req = op.mutable_new_replay_session(); std::vector<DeviceAttributes> devices; - TF_CHECK_OK(session->ListDevices(&devices)); - for (const DeviceAttributes& dev : devices) { - *req->mutable_devices()->add_local_device() = dev; + Status status = session->ListDevices(&devices); + if (status.ok()) { + LOG(INFO) << "Found: " << devices.size() << " devices."; + for (const DeviceAttributes& dev : devices) { + *req->mutable_devices()->add_local_device() = dev; + } + } else { + LOG(WARNING) << "Failed to list devices on session. Continuing."; } req->set_session_handle(SessionToHandle(session)); @@ -226,7 +238,6 @@ class SessionLogger { // N.B. RunOptions is not stored (it has no entry in CloseRequest) Status RecordClose(Session* session, const RunOptions& run_options) { - mutex_lock l(log_mutex_); ReplayOp op; CloseSessionRequest* req = op.mutable_close_session(); req->set_session_handle(SessionToHandle(session)); @@ -241,7 +252,6 @@ class SessionLogger { Status RecordListDevices(Session* session, std::vector<DeviceAttributes>* response) { - mutex_lock l(log_mutex_); ReplayOp op; ListDevicesRequest* req = op.mutable_list_devices(); ListDevicesResponse* resp = op.mutable_list_devices_response(); @@ -258,7 +268,6 @@ class SessionLogger { const std::vector<string>& output_names, const std::vector<string>& target_nodes, string* handle) { - mutex_lock l(log_mutex_); ReplayOp op; PartialRunSetupRequest* req = op.mutable_partial_run_setup(); req->set_session_handle(SessionToHandle(session)); @@ -362,18 +371,19 @@ class SessionLogger { private: Status Flush(const ReplayOp& op) { + mutex_lock l(log_mutex_); + string buf; op.SerializeToString(&buf); TF_RETURN_IF_ERROR(log_writer_->WriteRecord(buf)); - // Flushing the RecordWriter _does not_ flush the underlying file. - TF_RETURN_IF_ERROR(log_writer_->Flush()); - return log_file_->Flush(); + // TODO(b/116624106): Not all file-systems respect calls to `Sync()` + return log_file_->Sync(); } - mutex log_mutex_; - std::unique_ptr<io::RecordWriter> log_writer_; std::unique_ptr<WritableFile> log_file_; + std::unique_ptr<io::RecordWriter> log_writer_; + mutex log_mutex_; }; static SessionLogger* global_session_logger() { @@ -384,7 +394,7 @@ static SessionLogger* global_session_logger() { SessionRef::SessionRef(Session* session) : session_(session) { if (getenv("TF_REPLAY_LOG_FILE") != nullptr) { logger_ = global_session_logger(); - logger_->RecordCreateSession(this->session_.get()).IgnoreError(); + logger_->RecordNewSession(this->session_.get()).IgnoreError(); } else { logger_ = nullptr; } |