/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/python/client/session_ref.h" #include #include #include #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/io/record_writer.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/protobuf/master.pb.h" #include "tensorflow/core/protobuf/named_tensor.pb.h" #include "tensorflow/core/protobuf/replay_log.pb.h" namespace tensorflow { namespace { // Scope helper to track active calls and manage session lifetime. // SessionRef blocks closing until all active calls complete or are cancelled. struct RunCounter { std::shared_ptr session; uint64* value; mutex* m; condition_variable* cv; explicit RunCounter(std::shared_ptr s, uint64* v, mutex* m, condition_variable* cv) : session(std::move(s)), value(v), m(m), cv(cv) { mutex_lock l(*m); ++*value; } ~RunCounter() { mutex_lock l(*m); if (--*value == 0) { cv->notify_all(); } } }; std::string SessionToHandle(Session* session) { return strings::Printf("%llu", reinterpret_cast(session)); } // The Session interface has many methods of the form: // // X(a, b); // X(RunOptions, a, b); // // Not all sessions support the second case (with an empty RunOptions()). // We use this variable as a sentinel to dispatch to the correct call. RunOptions* kEmptyRunOptions() { static RunOptions* options = new RunOptions(); return options; } } // namespace // Run the given session operation, recording start and end timestamps. // If the operation returns a bad status, return after flushing the current // log request. This should be run _after_ all request information has been // added to the current op. #define RUN_WITH_TIMESTAMP(OpName, ...) \ op.set_start_time_us(Env::Default()->NowMicros()); \ Status status = session->OpName(__VA_ARGS__); \ op.set_end_time_us(Env::Default()->NowMicros()); \ if (!status.ok()) { \ Flush(op).IgnoreError(); \ return status; \ } // Records requests (and optionally responses) performed against a session. // The resulting replay log can be used with the `tf_replay` tool to replicate // the operations against a simulated environment, without requiring the // original code or cluster setup. // // Session logging by setting the TF_REPLAY_LOG_FILE environment variable. class SessionLogger { public: SessionLogger() { std::string log_name = getenv("TF_REPLAY_LOG_FILE"); LOG(INFO) << "Constructing new session logger for " << log_name; TF_CHECK_OK( Env::Default()->RecursivelyCreateDir(string(io::Dirname(log_name)))); Env::Default()->DeleteFile(log_name).IgnoreError(); TF_CHECK_OK(Env::Default()->NewWritableFile(log_name, &log_file_)); log_writer_ = absl::make_unique(log_file_.get()); } ~SessionLogger() { log_writer_->Close().IgnoreError(); log_writer_.release(); log_file_->Close().IgnoreError(); } Status RecordNewSession(Session* session) { LOG(INFO) << "New session discovered. Capturing devices..."; ReplayOp op; NewReplaySession* req = op.mutable_new_replay_session(); std::vector devices; Status status = session->ListDevices(&devices); if (status.ok()) { LOG(INFO) << "Found: " << devices.size() << " devices."; for (const DeviceAttributes& dev : devices) { *req->mutable_devices()->add_local_device() = dev; } } else { LOG(WARNING) << "Failed to list devices on session. Continuing."; } req->set_session_handle(SessionToHandle(session)); return Flush(op); } Status RecordRun(Session* session, const std::vector >& inputs, const std::vector& output_tensor_names, const std::vector& target_node_names, std::vector* outputs) { return RecordRun(session, *kEmptyRunOptions(), inputs, output_tensor_names, target_node_names, outputs, nullptr); } Status RecordRun(Session* session, const RunOptions& run_options, const std::vector >& inputs, const std::vector& output_tensor_names, const std::vector& target_node_names, std::vector* outputs, RunMetadata* run_metadata) { ReplayOp op; RunStepRequest* req = op.mutable_run_step(); RunStepResponse* resp = op.mutable_run_step_response(); req->set_session_handle(SessionToHandle(session)); *req->mutable_options() = run_options; for (const auto& it : inputs) { NamedTensorProto* feed = req->add_feed(); feed->set_name(it.first); it.second.AsProtoField(feed->mutable_tensor()); } // Build an index from fetch tensor name to first index in // output_tensor_names. std::unordered_map output_name_to_offset; for (int i = 0; i < output_tensor_names.size(); ++i) { const string& name = output_tensor_names[i]; if (output_name_to_offset.insert(std::make_pair(name, i)).second) { req->add_fetch(name); } } for (const string& target : target_node_names) { req->add_target(target); } if (&run_options == kEmptyRunOptions()) { RUN_WITH_TIMESTAMP(Run, inputs, output_tensor_names, target_node_names, outputs); } else { RUN_WITH_TIMESTAMP(Run, run_options, inputs, output_tensor_names, target_node_names, outputs, run_metadata); } for (size_t i = 0; i < outputs->size(); ++i) { const Tensor& tensor = (*outputs)[i]; NamedTensorProto* tproto = resp->add_tensor(); tensor.AsProtoField(tproto->mutable_tensor()); tproto->set_name(output_tensor_names[i]); } if (run_metadata) { *resp->mutable_metadata() = *run_metadata; } return Flush(op); } Status RecordCreate(Session* session, const GraphDef& graph) { return RecordCreate(session, *kEmptyRunOptions(), graph); } // N.B. RunOptions is not stored (it has no entry in CreateRequest) Status RecordCreate(Session* session, const RunOptions& run_options, const GraphDef& graph) { ReplayOp op; CreateSessionRequest* req = op.mutable_create_session(); *req->mutable_graph_def() = graph; CreateSessionResponse* resp = op.mutable_create_session_response(); if (&run_options == kEmptyRunOptions()) { RUN_WITH_TIMESTAMP(Create, graph); } else { RUN_WITH_TIMESTAMP(Create, run_options, graph); } resp->set_session_handle(SessionToHandle(session)); return Flush(op); } Status RecordExtend(Session* session, const GraphDef& graph) { return RecordExtend(session, *kEmptyRunOptions(), graph); } // N.B. RunOptions is not stored (it has no entry in ExtendRequest) Status RecordExtend(Session* session, const RunOptions& run_options, const GraphDef& graph) { ReplayOp op; ExtendSessionRequest* req = op.mutable_extend_session(); op.mutable_extend_session_response(); req->set_session_handle(SessionToHandle(session)); *req->mutable_graph_def() = graph; if (&run_options == kEmptyRunOptions()) { RUN_WITH_TIMESTAMP(Extend, graph); } else { RUN_WITH_TIMESTAMP(Extend, run_options, graph); } return Flush(op); } Status RecordClose(Session* session) { return RecordClose(session, *kEmptyRunOptions()); } // N.B. RunOptions is not stored (it has no entry in CloseRequest) Status RecordClose(Session* session, const RunOptions& run_options) { ReplayOp op; CloseSessionRequest* req = op.mutable_close_session(); req->set_session_handle(SessionToHandle(session)); op.mutable_close_session_response(); if (&run_options == kEmptyRunOptions()) { RUN_WITH_TIMESTAMP(Close); } else { RUN_WITH_TIMESTAMP(Close, run_options); } return Flush(op); } Status RecordListDevices(Session* session, std::vector* response) { ReplayOp op; ListDevicesRequest* req = op.mutable_list_devices(); ListDevicesResponse* resp = op.mutable_list_devices_response(); req->set_session_handle(SessionToHandle(session)); RUN_WITH_TIMESTAMP(ListDevices, response); // TODO(power) -- local vs remote device distinction is lost here! *resp->mutable_local_device() = {response->begin(), response->end()}; return Flush(op); } Status RecordPRunSetup(Session* session, const std::vector& input_names, const std::vector& output_names, const std::vector& target_nodes, string* handle) { ReplayOp op; PartialRunSetupRequest* req = op.mutable_partial_run_setup(); req->set_session_handle(SessionToHandle(session)); for (auto& input : input_names) { req->add_feed(input); } for (auto& output : output_names) { req->add_fetch(output); } for (auto& target : target_nodes) { req->add_target(target); } RUN_WITH_TIMESTAMP(PRunSetup, input_names, output_names, target_nodes, handle); op.mutable_partial_run_setup_response()->set_partial_run_handle(*handle); return Flush(op); } Status RecordPRun(Session* session, const string& handle, const std::vector >& inputs, const std::vector& output_names, std::vector* outputs) { ReplayOp op; RunStepRequest* req = op.mutable_run_step(); RunStepResponse* resp = op.mutable_run_step_response(); req->set_session_handle(SessionToHandle(session)); // Mark this step as a partial run for replay. req->set_partial_run_handle(handle); for (auto& input : inputs) { auto* feed = req->add_feed(); feed->set_name(input.first); input.second.AsProtoField(feed->mutable_tensor()); } for (auto& output : output_names) { req->add_fetch(output); } RUN_WITH_TIMESTAMP(PRun, handle, inputs, output_names, outputs); for (size_t i = 0; i < outputs->size(); ++i) { const Tensor& tensor = (*outputs)[i]; NamedTensorProto* tproto = resp->add_tensor(); tensor.AsProtoField(tproto->mutable_tensor()); tproto->set_name(output_names[i]); } return Flush(op); } Status RecordMakeCallable(Session* session, const CallableOptions& callable_options, Session::CallableHandle* handle) { ReplayOp op; MakeCallableRequest* req = op.mutable_make_callable(); req->set_session_handle(SessionToHandle(session)); *req->mutable_options() = callable_options; RUN_WITH_TIMESTAMP(MakeCallable, callable_options, handle); MakeCallableResponse* resp = op.mutable_make_callable_response(); resp->set_handle(*handle); return Flush(op); } Status RecordRunCallable(Session* session, Session::CallableHandle handle, const std::vector& feed_tensors, std::vector* fetch_tensors, RunMetadata* run_metadata) { ReplayOp op; RunCallableRequest* req = op.mutable_run_callable(); req->set_session_handle(SessionToHandle(session)); req->set_handle(handle); for (auto& tensor : feed_tensors) { tensor.AsProtoField(req->add_feed()); } RUN_WITH_TIMESTAMP(RunCallable, handle, feed_tensors, fetch_tensors, run_metadata); RunCallableResponse* resp = op.mutable_run_callable_response(); if (run_metadata) { *resp->mutable_metadata() = *run_metadata; } for (const Tensor& tensor : *fetch_tensors) { tensor.AsProtoTensorContent(resp->add_fetch()); } return Flush(op); } Status RecordReleaseCallable(Session* session, Session::CallableHandle handle) { ReplayOp op; ReleaseCallableRequest* req = op.mutable_release_callable(); req->set_session_handle(SessionToHandle(session)); req->set_handle(handle); RUN_WITH_TIMESTAMP(ReleaseCallable, handle); return Flush(op); } private: Status Flush(const ReplayOp& op) { mutex_lock l(log_mutex_); string buf; op.SerializeToString(&buf); TF_RETURN_IF_ERROR(log_writer_->WriteRecord(buf)); // TODO(b/116624106): Not all file-systems respect calls to `Sync()` return log_file_->Sync(); } std::unique_ptr log_file_; std::unique_ptr log_writer_; mutex log_mutex_; }; static SessionLogger* global_session_logger() { static SessionLogger* logger = new SessionLogger(); return logger; } SessionRef::SessionRef(Session* session) : session_(session) { if (getenv("TF_REPLAY_LOG_FILE") != nullptr) { logger_ = global_session_logger(); logger_->RecordNewSession(this->session_.get()).IgnoreError(); } else { logger_ = nullptr; } } SessionRef::~SessionRef() = default; Status SessionRef::CheckNotClosed() { mutex_lock l(run_lock_); if (session_ == nullptr) return errors::Cancelled("Session has been closed."); return ::tensorflow::Status::OK(); } // If logging is active, log the start and end time of the operation along with // the request and response. #define LOG_AND_RUN_OPERATION(OpName, ...) \ TF_RETURN_IF_ERROR(CheckNotClosed()); \ RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); \ if (!logger_) { \ return rc.session->OpName(__VA_ARGS__); \ } \ return logger_->Record##OpName(rc.session.get(), __VA_ARGS__); Status SessionRef::Run(const RunOptions& run_options, const std::vector >& inputs, const std::vector& output_tensor_names, const std::vector& target_node_names, std::vector* outputs, RunMetadata* run_metadata) { LOG_AND_RUN_OPERATION(Run, run_options, inputs, output_tensor_names, target_node_names, outputs, run_metadata); } Status SessionRef::Run(const std::vector >& inputs, const std::vector& output_tensor_names, const std::vector& target_node_names, std::vector* outputs) { LOG_AND_RUN_OPERATION(Run, inputs, output_tensor_names, target_node_names, outputs); } Status SessionRef::Create(const GraphDef& graph) { LOG_AND_RUN_OPERATION(Create, graph); } Status SessionRef::Create(const RunOptions& run_options, const GraphDef& graph) { LOG_AND_RUN_OPERATION(Create, run_options, graph); } Status SessionRef::Extend(const RunOptions& run_options, const GraphDef& graph) { LOG_AND_RUN_OPERATION(Extend, run_options, graph); } Status SessionRef::Extend(const GraphDef& graph) { LOG_AND_RUN_OPERATION(Extend, graph); } Status SessionRef::ListDevices(std::vector* response) { LOG_AND_RUN_OPERATION(ListDevices, response); } Status SessionRef::PRunSetup(const std::vector& input_names, const std::vector& output_names, const std::vector& target_nodes, string* handle) { LOG_AND_RUN_OPERATION(PRunSetup, input_names, output_names, target_nodes, handle); } Status SessionRef::PRun(const string& handle, const std::vector >& inputs, const std::vector& output_names, std::vector* outputs) { LOG_AND_RUN_OPERATION(PRun, handle, inputs, output_names, outputs); } Status SessionRef::MakeCallable(const CallableOptions& callable_options, CallableHandle* out_handle) { LOG_AND_RUN_OPERATION(MakeCallable, callable_options, out_handle); } Status SessionRef::RunCallable(CallableHandle handle, const std::vector& feed_tensors, std::vector* fetch_tensors, RunMetadata* run_metadata) { LOG_AND_RUN_OPERATION(RunCallable, handle, feed_tensors, fetch_tensors, run_metadata); } Status SessionRef::ReleaseCallable(CallableHandle handle) { LOG_AND_RUN_OPERATION(ReleaseCallable, handle); } Status SessionRef::Close(const RunOptions& run_options) { TF_RETURN_IF_ERROR(CheckNotClosed()); mutex_lock l(run_lock_); Status status; if (logger_) { status = logger_->RecordClose(session_.get(), run_options); } else { status = session_->Close(run_options); } session_.reset(); while (run_count_ > 0) { run_finished_.wait(l); } return status; } Status SessionRef::Close() { TF_RETURN_IF_ERROR(CheckNotClosed()); mutex_lock l(run_lock_); Status status; if (logger_) { status = logger_->RecordClose(session_.get()); } else { status = session_->Close(); } session_.reset(); while (run_count_ > 0) { run_finished_.wait(l); } return status; } } // namespace tensorflow