aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/client
diff options
context:
space:
mode:
authorGravatar Russell Power <power@google.com>2018-09-25 16:08:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-25 16:17:30 -0700
commit4b780e46dcb29a1fb7a3ab81d95b3f8376101989 (patch)
tree2ba85fb065ac920bff0241545cb265df5001a83a /tensorflow/python/client
parent22776289fbe30ca7f4b1a80d7e23f5bddca391c2 (diff)
Remove unneeded locks in session logging.
PiperOrigin-RevId: 214521486
Diffstat (limited to 'tensorflow/python/client')
-rw-r--r--tensorflow/python/client/session_ref.cc40
1 files changed, 25 insertions, 15 deletions
diff --git a/tensorflow/python/client/session_ref.cc b/tensorflow/python/client/session_ref.cc
index b2300df0b6..4d361612b7 100644
--- a/tensorflow/python/client/session_ref.cc
+++ b/tensorflow/python/client/session_ref.cc
@@ -93,23 +93,35 @@ class SessionLogger {
public:
SessionLogger() {
std::string log_name = getenv("TF_REPLAY_LOG_FILE");
+ LOG(INFO) << "Constructing new session logger for " << log_name;
TF_CHECK_OK(
Env::Default()->RecursivelyCreateDir(string(io::Dirname(log_name))));
Env::Default()->DeleteFile(log_name).IgnoreError();
- TF_CHECK_OK(Env::Default()->NewWritableFile(log_name, &log_file_));
+ TF_CHECK_OK(Env::Default()->NewWritableFile(log_name, &log_file_));
log_writer_ = absl::make_unique<io::RecordWriter>(log_file_.get());
}
- Status RecordCreateSession(Session* session) {
- LOG(INFO) << "Capturing devices for session.";
+ ~SessionLogger() {
+ log_writer_->Close().IgnoreError();
+ log_writer_.release();
+ log_file_->Close().IgnoreError();
+ }
+
+ Status RecordNewSession(Session* session) {
+ LOG(INFO) << "New session discovered. Capturing devices...";
ReplayOp op;
NewReplaySession* req = op.mutable_new_replay_session();
std::vector<DeviceAttributes> devices;
- TF_CHECK_OK(session->ListDevices(&devices));
- for (const DeviceAttributes& dev : devices) {
- *req->mutable_devices()->add_local_device() = dev;
+ Status status = session->ListDevices(&devices);
+ if (status.ok()) {
+ LOG(INFO) << "Found: " << devices.size() << " devices.";
+ for (const DeviceAttributes& dev : devices) {
+ *req->mutable_devices()->add_local_device() = dev;
+ }
+ } else {
+ LOG(WARNING) << "Failed to list devices on session. Continuing.";
}
req->set_session_handle(SessionToHandle(session));
@@ -226,7 +238,6 @@ class SessionLogger {
// N.B. RunOptions is not stored (it has no entry in CloseRequest)
Status RecordClose(Session* session, const RunOptions& run_options) {
- mutex_lock l(log_mutex_);
ReplayOp op;
CloseSessionRequest* req = op.mutable_close_session();
req->set_session_handle(SessionToHandle(session));
@@ -241,7 +252,6 @@ class SessionLogger {
Status RecordListDevices(Session* session,
std::vector<DeviceAttributes>* response) {
- mutex_lock l(log_mutex_);
ReplayOp op;
ListDevicesRequest* req = op.mutable_list_devices();
ListDevicesResponse* resp = op.mutable_list_devices_response();
@@ -258,7 +268,6 @@ class SessionLogger {
const std::vector<string>& output_names,
const std::vector<string>& target_nodes,
string* handle) {
- mutex_lock l(log_mutex_);
ReplayOp op;
PartialRunSetupRequest* req = op.mutable_partial_run_setup();
req->set_session_handle(SessionToHandle(session));
@@ -362,18 +371,19 @@ class SessionLogger {
private:
Status Flush(const ReplayOp& op) {
+ mutex_lock l(log_mutex_);
+
string buf;
op.SerializeToString(&buf);
TF_RETURN_IF_ERROR(log_writer_->WriteRecord(buf));
- // Flushing the RecordWriter _does not_ flush the underlying file.
- TF_RETURN_IF_ERROR(log_writer_->Flush());
- return log_file_->Flush();
+ // TODO(b/116624106): Not all file-systems respect calls to `Sync()`
+ return log_file_->Sync();
}
- mutex log_mutex_;
- std::unique_ptr<io::RecordWriter> log_writer_;
std::unique_ptr<WritableFile> log_file_;
+ std::unique_ptr<io::RecordWriter> log_writer_;
+ mutex log_mutex_;
};
static SessionLogger* global_session_logger() {
@@ -384,7 +394,7 @@ static SessionLogger* global_session_logger() {
SessionRef::SessionRef(Session* session) : session_(session) {
if (getenv("TF_REPLAY_LOG_FILE") != nullptr) {
logger_ = global_session_logger();
- logger_->RecordCreateSession(this->session_.get()).IgnoreError();
+ logger_->RecordNewSession(this->session_.get()).IgnoreError();
} else {
logger_ = nullptr;
}