diff options
author | 2018-09-20 13:48:43 -0700 | |
---|---|---|
committer | 2018-09-20 13:53:34 -0700 | |
commit | d388770922ad1afa95e55597a33836fe74035c75 (patch) | |
tree | 3c15e17c357645e3e872ea46daaea2c91b00e9c1 /tensorflow | |
parent | 1f1e5ac6154583d5f87c846d1d7c9c59a77d6e0c (diff) |
Implement TF graph capture.
PiperOrigin-RevId: 213875284
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/core/BUILD | 16 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/session_ref.cc | 170 | ||||
-rw-r--r-- | tensorflow/core/protobuf/replay_log.proto | 47 | ||||
-rw-r--r-- | tensorflow/python/BUILD | 15 | ||||
-rw-r--r-- | tensorflow/python/client/session_ref.cc | 515 | ||||
-rw-r--r-- | tensorflow/python/client/session_ref.h (renamed from tensorflow/core/common_runtime/session_ref.h) | 15 | ||||
-rw-r--r-- | tensorflow/python/client/tf_session_helper.cc | 2 |
7 files changed, 597 insertions, 183 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index ed1818f834..85b6d4ff68 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -2988,12 +2988,16 @@ tf_cuda_library( ] + tf_additional_device_tracer_deps(), ) -cc_library( - name = "session_ref", - srcs = ["common_runtime/session_ref.cc"], - hdrs = ["common_runtime/session_ref.h"], - copts = tf_copts(), - deps = [":core_cpu_base"], +tf_proto_library_cc( + name = "replay_log_proto", + srcs = ["protobuf/replay_log.proto"], + cc_api_version = 2, + protodeps = [ + ":master_proto", + ] + tf_additional_all_protos(), + visibility = [ + "//tensorflow:internal", + ], ) cc_library( diff --git a/tensorflow/core/common_runtime/session_ref.cc b/tensorflow/core/common_runtime/session_ref.cc deleted file mode 100644 index b931ef4229..0000000000 --- a/tensorflow/core/common_runtime/session_ref.cc +++ /dev/null @@ -1,170 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/common_runtime/session_ref.h" - -#include <utility> - -namespace tensorflow { - -namespace { - -// Scope helper to track active calls and manage session lifetime. -struct RunCounter { - std::shared_ptr<Session> session; - uint64* value; - mutex* m; - condition_variable* cv; - - explicit RunCounter(std::shared_ptr<Session> s, uint64* v, mutex* m, - condition_variable* cv) - : session(std::move(s)), value(v), m(m), cv(cv) { - mutex_lock l(*m); - ++*value; - } - - ~RunCounter() { - mutex_lock l(*m); - if (--*value == 0) { - cv->notify_all(); - } - } -}; - -} // namespace - -Status SessionRef::CheckNotClosed() { - mutex_lock l(run_lock_); - if (session_ == nullptr) return errors::Cancelled("Session has been closed."); - return ::tensorflow::Status::OK(); -} - -Status SessionRef::Run(const RunOptions& run_options, - const std::vector<std::pair<string, Tensor> >& inputs, - const std::vector<string>& output_tensor_names, - const std::vector<string>& target_node_names, - std::vector<Tensor>* outputs, - RunMetadata* run_metadata) { - TF_RETURN_IF_ERROR(CheckNotClosed()); - RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); - return rc.session->Run(run_options, inputs, output_tensor_names, - target_node_names, outputs, run_metadata); -} - -Status SessionRef::Create(const GraphDef& graph) { - TF_RETURN_IF_ERROR(CheckNotClosed()); - RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); - return rc.session->Create(graph); -} - -Status SessionRef::Create(const RunOptions& run_options, - const GraphDef& graph) { - TF_RETURN_IF_ERROR(CheckNotClosed()); - RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); - return rc.session->Create(run_options, graph); -} - -Status SessionRef::Extend(const RunOptions& run_options, - const GraphDef& graph) { - TF_RETURN_IF_ERROR(CheckNotClosed()); - RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); - return rc.session->Extend(run_options, graph); -} - -Status SessionRef::Extend(const GraphDef& graph) { - TF_RETURN_IF_ERROR(CheckNotClosed()); - RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); - return rc.session->Extend(graph); -} - -Status SessionRef::Close(const RunOptions& run_options) { - TF_RETURN_IF_ERROR(CheckNotClosed()); - mutex_lock l(run_lock_); - Status status = session_->Close(run_options); - session_.reset(); - while (run_count_ > 0) { - run_finished_.wait(l); - } - return status; -} - -Status SessionRef::Close() { - TF_RETURN_IF_ERROR(CheckNotClosed()); - mutex_lock l(run_lock_); - Status status = session_->Close(); - session_.reset(); - while (run_count_ > 0) { - run_finished_.wait(l); - } - return status; -} - -Status SessionRef::Run(const std::vector<std::pair<string, Tensor> >& inputs, - const std::vector<string>& output_tensor_names, - const std::vector<string>& target_node_names, - std::vector<Tensor>* outputs) { - TF_RETURN_IF_ERROR(CheckNotClosed()); - RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); - return rc.session->Run(inputs, output_tensor_names, target_node_names, - outputs); -} - -Status SessionRef::ListDevices(std::vector<DeviceAttributes>* response) { - TF_RETURN_IF_ERROR(CheckNotClosed()); - RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); - return rc.session->ListDevices(response); -} - -Status SessionRef::PRunSetup(const std::vector<string>& input_names, - const std::vector<string>& output_names, - const std::vector<string>& target_nodes, - string* handle) { - TF_RETURN_IF_ERROR(CheckNotClosed()); - RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); - return rc.session->PRunSetup(input_names, output_names, target_nodes, handle); -} - -Status SessionRef::PRun(const string& handle, - const std::vector<std::pair<string, Tensor> >& inputs, - const std::vector<string>& output_names, - std::vector<Tensor>* outputs) { - TF_RETURN_IF_ERROR(CheckNotClosed()); - RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); - return rc.session->PRun(handle, inputs, output_names, outputs); -} - -Status SessionRef::MakeCallable(const CallableOptions& callable_options, - CallableHandle* out_handle) { - TF_RETURN_IF_ERROR(CheckNotClosed()); - RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); - return rc.session->MakeCallable(callable_options, out_handle); -} - -Status SessionRef::RunCallable(CallableHandle handle, - const std::vector<Tensor>& feed_tensors, - std::vector<Tensor>* fetch_tensors, - RunMetadata* run_metadata) { - TF_RETURN_IF_ERROR(CheckNotClosed()); - RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); - return rc.session->RunCallable(handle, feed_tensors, fetch_tensors, - run_metadata); -} - -Status SessionRef::ReleaseCallable(CallableHandle handle) { - TF_RETURN_IF_ERROR(CheckNotClosed()); - RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); - return rc.session->ReleaseCallable(handle); -} - -} // namespace tensorflow diff --git a/tensorflow/core/protobuf/replay_log.proto b/tensorflow/core/protobuf/replay_log.proto new file mode 100644 index 0000000000..7644314fc9 --- /dev/null +++ b/tensorflow/core/protobuf/replay_log.proto @@ -0,0 +1,47 @@ +syntax = "proto3"; + +option cc_enable_arenas = true; +package tensorflow; + +import "tensorflow/core/framework/graph.proto"; +import "tensorflow/core/protobuf/cluster.proto"; +import "tensorflow/core/protobuf/master.proto"; + +// Records the creation of a new replay session. We record the device listing +// here to capture the state of the cluster. +message NewReplaySession { + ListDevicesResponse devices = 1; + string session_handle = 2; +} + +message ReplayOp { + double start_time_us = 31; + double end_time_us = 32; + + oneof op { + CreateSessionRequest create_session = 1; + ExtendSessionRequest extend_session = 2; + PartialRunSetupRequest partial_run_setup = 3; + RunStepRequest run_step = 4; + CloseSessionRequest close_session = 5; + ListDevicesRequest list_devices = 6; + ResetRequest reset_request = 7; + MakeCallableRequest make_callable = 8; + RunCallableRequest run_callable = 9; + ReleaseCallableRequest release_callable = 10; + NewReplaySession new_replay_session = 11; + } + + oneof response { + CreateSessionResponse create_session_response = 21; + ExtendSessionResponse extend_session_response = 22; + PartialRunSetupResponse partial_run_setup_response = 23; + RunStepResponse run_step_response = 24; + CloseSessionResponse close_session_response = 25; + ListDevicesResponse list_devices_response = 26; + ResetResponse reset_request_response = 27; + MakeCallableResponse make_callable_response = 28; + RunCallableResponse run_callable_response = 29; + ReleaseCallableResponse release_callable_response = 30; + } +} diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 9730e9933a..79f14466e6 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -3763,6 +3763,19 @@ cuda_py_tests( ], ) +cc_library( + name = "session_ref", + srcs = ["client/session_ref.cc"], + hdrs = ["client/session_ref.h"], + deps = [ + "//tensorflow/core:core_cpu", + "//tensorflow/core:lib", + "//tensorflow/core:master_proto_cc", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:replay_log_proto_cc", + ], +) + tf_cuda_library( name = "tf_session_helper", srcs = ["client/tf_session_helper.cc"], @@ -3773,6 +3786,7 @@ tf_cuda_library( ":ndarray_tensor_bridge", ":numpy_lib", ":safe_ptr", + ":session_ref", ":test_ops_kernels", "//tensorflow/c:c_api", "//tensorflow/c:c_api_internal", @@ -3785,7 +3799,6 @@ tf_cuda_library( "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:session_ref", "//third_party/py/numpy:headers", "//third_party/python_runtime:headers", ], diff --git a/tensorflow/python/client/session_ref.cc b/tensorflow/python/client/session_ref.cc new file mode 100644 index 0000000000..b2300df0b6 --- /dev/null +++ b/tensorflow/python/client/session_ref.cc @@ -0,0 +1,515 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/python/client/session_ref.h" + +#include <stdlib.h> +#include <memory> +#include <utility> + +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/io/record_writer.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/protobuf/master.pb.h" +#include "tensorflow/core/protobuf/named_tensor.pb.h" +#include "tensorflow/core/protobuf/replay_log.pb.h" + +namespace tensorflow { + +namespace { + +// Scope helper to track active calls and manage session lifetime. +// SessionRef blocks closing until all active calls complete or are cancelled. +struct RunCounter { + std::shared_ptr<Session> session; + uint64* value; + mutex* m; + condition_variable* cv; + + explicit RunCounter(std::shared_ptr<Session> s, uint64* v, mutex* m, + condition_variable* cv) + : session(std::move(s)), value(v), m(m), cv(cv) { + mutex_lock l(*m); + ++*value; + } + + ~RunCounter() { + mutex_lock l(*m); + if (--*value == 0) { + cv->notify_all(); + } + } +}; + +std::string SessionToHandle(Session* session) { + return strings::Printf("%llu", reinterpret_cast<uint64>(session)); +} + +// The Session interface has many methods of the form: +// +// X(a, b); +// X(RunOptions, a, b); +// +// Not all sessions support the second case (with an empty RunOptions()). +// We use this variable as a sentinel to dispatch to the correct call. +RunOptions* kEmptyRunOptions() { + static RunOptions* options = new RunOptions(); + return options; +} + +} // namespace + +// Run the given session operation, recording start and end timestamps. +// If the operation returns a bad status, return after flushing the current +// log request. This should be run _after_ all request information has been +// added to the current op. +#define RUN_WITH_TIMESTAMP(OpName, ...) \ + op.set_start_time_us(Env::Default()->NowMicros()); \ + Status status = session->OpName(__VA_ARGS__); \ + op.set_end_time_us(Env::Default()->NowMicros()); \ + if (!status.ok()) { \ + Flush(op).IgnoreError(); \ + return status; \ + } + +// Records requests (and optionally responses) performed against a session. +// The resulting replay log can be used with the `tf_replay` tool to replicate +// the operations against a simulated environment, without requiring the +// original code or cluster setup. +// +// Session logging by setting the TF_REPLAY_LOG_FILE environment variable. +class SessionLogger { + public: + SessionLogger() { + std::string log_name = getenv("TF_REPLAY_LOG_FILE"); + TF_CHECK_OK( + Env::Default()->RecursivelyCreateDir(string(io::Dirname(log_name)))); + Env::Default()->DeleteFile(log_name).IgnoreError(); + TF_CHECK_OK(Env::Default()->NewWritableFile(log_name, &log_file_)); + + log_writer_ = absl::make_unique<io::RecordWriter>(log_file_.get()); + } + + Status RecordCreateSession(Session* session) { + LOG(INFO) << "Capturing devices for session."; + ReplayOp op; + NewReplaySession* req = op.mutable_new_replay_session(); + + std::vector<DeviceAttributes> devices; + TF_CHECK_OK(session->ListDevices(&devices)); + for (const DeviceAttributes& dev : devices) { + *req->mutable_devices()->add_local_device() = dev; + } + + req->set_session_handle(SessionToHandle(session)); + return Flush(op); + } + + Status RecordRun(Session* session, + const std::vector<std::pair<string, Tensor> >& inputs, + const std::vector<string>& output_tensor_names, + const std::vector<string>& target_node_names, + std::vector<Tensor>* outputs) { + return RecordRun(session, *kEmptyRunOptions(), inputs, output_tensor_names, + target_node_names, outputs, nullptr); + } + + Status RecordRun(Session* session, const RunOptions& run_options, + const std::vector<std::pair<string, Tensor> >& inputs, + const std::vector<string>& output_tensor_names, + const std::vector<string>& target_node_names, + std::vector<Tensor>* outputs, RunMetadata* run_metadata) { + ReplayOp op; + RunStepRequest* req = op.mutable_run_step(); + RunStepResponse* resp = op.mutable_run_step_response(); + + req->set_session_handle(SessionToHandle(session)); + *req->mutable_options() = run_options; + + for (const auto& it : inputs) { + NamedTensorProto* feed = req->add_feed(); + feed->set_name(it.first); + it.second.AsProtoField(feed->mutable_tensor()); + } + + // Build an index from fetch tensor name to first index in + // output_tensor_names. + std::unordered_map<string, int> output_name_to_offset; + for (int i = 0; i < output_tensor_names.size(); ++i) { + const string& name = output_tensor_names[i]; + if (output_name_to_offset.insert(std::make_pair(name, i)).second) { + req->add_fetch(name); + } + } + for (const string& target : target_node_names) { + req->add_target(target); + } + + if (&run_options == kEmptyRunOptions()) { + RUN_WITH_TIMESTAMP(Run, inputs, output_tensor_names, target_node_names, + outputs); + } else { + RUN_WITH_TIMESTAMP(Run, run_options, inputs, output_tensor_names, + target_node_names, outputs, run_metadata); + } + + for (size_t i = 0; i < outputs->size(); ++i) { + const Tensor& tensor = (*outputs)[i]; + NamedTensorProto* tproto = resp->add_tensor(); + tensor.AsProtoField(tproto->mutable_tensor()); + tproto->set_name(output_tensor_names[i]); + } + + if (run_metadata) { + *resp->mutable_metadata() = *run_metadata; + } + + return Flush(op); + } + + Status RecordCreate(Session* session, const GraphDef& graph) { + return RecordCreate(session, *kEmptyRunOptions(), graph); + } + + // N.B. RunOptions is not stored (it has no entry in CreateRequest) + Status RecordCreate(Session* session, const RunOptions& run_options, + const GraphDef& graph) { + ReplayOp op; + CreateSessionRequest* req = op.mutable_create_session(); + *req->mutable_graph_def() = graph; + + CreateSessionResponse* resp = op.mutable_create_session_response(); + if (&run_options == kEmptyRunOptions()) { + RUN_WITH_TIMESTAMP(Create, graph); + } else { + RUN_WITH_TIMESTAMP(Create, run_options, graph); + } + resp->set_session_handle(SessionToHandle(session)); + return Flush(op); + } + + Status RecordExtend(Session* session, const GraphDef& graph) { + return RecordExtend(session, *kEmptyRunOptions(), graph); + } + + // N.B. RunOptions is not stored (it has no entry in ExtendRequest) + Status RecordExtend(Session* session, const RunOptions& run_options, + const GraphDef& graph) { + ReplayOp op; + ExtendSessionRequest* req = op.mutable_extend_session(); + op.mutable_extend_session_response(); + req->set_session_handle(SessionToHandle(session)); + *req->mutable_graph_def() = graph; + if (&run_options == kEmptyRunOptions()) { + RUN_WITH_TIMESTAMP(Extend, graph); + } else { + RUN_WITH_TIMESTAMP(Extend, run_options, graph); + } + + return Flush(op); + } + + Status RecordClose(Session* session) { + return RecordClose(session, *kEmptyRunOptions()); + } + + // N.B. RunOptions is not stored (it has no entry in CloseRequest) + Status RecordClose(Session* session, const RunOptions& run_options) { + mutex_lock l(log_mutex_); + ReplayOp op; + CloseSessionRequest* req = op.mutable_close_session(); + req->set_session_handle(SessionToHandle(session)); + op.mutable_close_session_response(); + if (&run_options == kEmptyRunOptions()) { + RUN_WITH_TIMESTAMP(Close); + } else { + RUN_WITH_TIMESTAMP(Close, run_options); + } + return Flush(op); + } + + Status RecordListDevices(Session* session, + std::vector<DeviceAttributes>* response) { + mutex_lock l(log_mutex_); + ReplayOp op; + ListDevicesRequest* req = op.mutable_list_devices(); + ListDevicesResponse* resp = op.mutable_list_devices_response(); + req->set_session_handle(SessionToHandle(session)); + RUN_WITH_TIMESTAMP(ListDevices, response); + + // TODO(power) -- local vs remote device distinction is lost here! + *resp->mutable_local_device() = {response->begin(), response->end()}; + return Flush(op); + } + + Status RecordPRunSetup(Session* session, + const std::vector<string>& input_names, + const std::vector<string>& output_names, + const std::vector<string>& target_nodes, + string* handle) { + mutex_lock l(log_mutex_); + ReplayOp op; + PartialRunSetupRequest* req = op.mutable_partial_run_setup(); + req->set_session_handle(SessionToHandle(session)); + for (auto& input : input_names) { + req->add_feed(input); + } + for (auto& output : output_names) { + req->add_fetch(output); + } + for (auto& target : target_nodes) { + req->add_target(target); + } + RUN_WITH_TIMESTAMP(PRunSetup, input_names, output_names, target_nodes, + handle); + op.mutable_partial_run_setup_response()->set_partial_run_handle(*handle); + return Flush(op); + } + + Status RecordPRun(Session* session, const string& handle, + const std::vector<std::pair<string, Tensor> >& inputs, + const std::vector<string>& output_names, + std::vector<Tensor>* outputs) { + ReplayOp op; + RunStepRequest* req = op.mutable_run_step(); + RunStepResponse* resp = op.mutable_run_step_response(); + req->set_session_handle(SessionToHandle(session)); + + // Mark this step as a partial run for replay. + req->set_partial_run_handle(handle); + for (auto& input : inputs) { + auto* feed = req->add_feed(); + feed->set_name(input.first); + input.second.AsProtoField(feed->mutable_tensor()); + } + + for (auto& output : output_names) { + req->add_fetch(output); + } + + RUN_WITH_TIMESTAMP(PRun, handle, inputs, output_names, outputs); + + for (size_t i = 0; i < outputs->size(); ++i) { + const Tensor& tensor = (*outputs)[i]; + NamedTensorProto* tproto = resp->add_tensor(); + tensor.AsProtoField(tproto->mutable_tensor()); + tproto->set_name(output_names[i]); + } + + return Flush(op); + } + + Status RecordMakeCallable(Session* session, + const CallableOptions& callable_options, + Session::CallableHandle* handle) { + ReplayOp op; + MakeCallableRequest* req = op.mutable_make_callable(); + req->set_session_handle(SessionToHandle(session)); + *req->mutable_options() = callable_options; + + RUN_WITH_TIMESTAMP(MakeCallable, callable_options, handle); + + MakeCallableResponse* resp = op.mutable_make_callable_response(); + resp->set_handle(*handle); + + return Flush(op); + } + + Status RecordRunCallable(Session* session, Session::CallableHandle handle, + const std::vector<Tensor>& feed_tensors, + std::vector<Tensor>* fetch_tensors, + RunMetadata* run_metadata) { + ReplayOp op; + RunCallableRequest* req = op.mutable_run_callable(); + req->set_session_handle(SessionToHandle(session)); + req->set_handle(handle); + for (auto& tensor : feed_tensors) { + tensor.AsProtoField(req->add_feed()); + } + RUN_WITH_TIMESTAMP(RunCallable, handle, feed_tensors, fetch_tensors, + run_metadata); + + RunCallableResponse* resp = op.mutable_run_callable_response(); + if (run_metadata) { + *resp->mutable_metadata() = *run_metadata; + } + for (const Tensor& tensor : *fetch_tensors) { + tensor.AsProtoTensorContent(resp->add_fetch()); + } + return Flush(op); + } + + Status RecordReleaseCallable(Session* session, + Session::CallableHandle handle) { + ReplayOp op; + ReleaseCallableRequest* req = op.mutable_release_callable(); + req->set_session_handle(SessionToHandle(session)); + req->set_handle(handle); + RUN_WITH_TIMESTAMP(ReleaseCallable, handle); + return Flush(op); + } + + private: + Status Flush(const ReplayOp& op) { + string buf; + op.SerializeToString(&buf); + TF_RETURN_IF_ERROR(log_writer_->WriteRecord(buf)); + + // Flushing the RecordWriter _does not_ flush the underlying file. + TF_RETURN_IF_ERROR(log_writer_->Flush()); + return log_file_->Flush(); + } + + mutex log_mutex_; + std::unique_ptr<io::RecordWriter> log_writer_; + std::unique_ptr<WritableFile> log_file_; +}; + +static SessionLogger* global_session_logger() { + static SessionLogger* logger = new SessionLogger(); + return logger; +} + +SessionRef::SessionRef(Session* session) : session_(session) { + if (getenv("TF_REPLAY_LOG_FILE") != nullptr) { + logger_ = global_session_logger(); + logger_->RecordCreateSession(this->session_.get()).IgnoreError(); + } else { + logger_ = nullptr; + } +} + +SessionRef::~SessionRef() = default; + +Status SessionRef::CheckNotClosed() { + mutex_lock l(run_lock_); + if (session_ == nullptr) return errors::Cancelled("Session has been closed."); + return ::tensorflow::Status::OK(); +} + +// If logging is active, log the start and end time of the operation along with +// the request and response. +#define LOG_AND_RUN_OPERATION(OpName, ...) \ + TF_RETURN_IF_ERROR(CheckNotClosed()); \ + RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); \ + if (!logger_) { \ + return rc.session->OpName(__VA_ARGS__); \ + } \ + return logger_->Record##OpName(rc.session.get(), __VA_ARGS__); + +Status SessionRef::Run(const RunOptions& run_options, + const std::vector<std::pair<string, Tensor> >& inputs, + const std::vector<string>& output_tensor_names, + const std::vector<string>& target_node_names, + std::vector<Tensor>* outputs, + RunMetadata* run_metadata) { + LOG_AND_RUN_OPERATION(Run, run_options, inputs, output_tensor_names, + target_node_names, outputs, run_metadata); +} + +Status SessionRef::Run(const std::vector<std::pair<string, Tensor> >& inputs, + const std::vector<string>& output_tensor_names, + const std::vector<string>& target_node_names, + std::vector<Tensor>* outputs) { + LOG_AND_RUN_OPERATION(Run, inputs, output_tensor_names, target_node_names, + outputs); +} + +Status SessionRef::Create(const GraphDef& graph) { + LOG_AND_RUN_OPERATION(Create, graph); +} + +Status SessionRef::Create(const RunOptions& run_options, + const GraphDef& graph) { + LOG_AND_RUN_OPERATION(Create, run_options, graph); +} + +Status SessionRef::Extend(const RunOptions& run_options, + const GraphDef& graph) { + LOG_AND_RUN_OPERATION(Extend, run_options, graph); +} + +Status SessionRef::Extend(const GraphDef& graph) { + LOG_AND_RUN_OPERATION(Extend, graph); +} + +Status SessionRef::ListDevices(std::vector<DeviceAttributes>* response) { + LOG_AND_RUN_OPERATION(ListDevices, response); +} + +Status SessionRef::PRunSetup(const std::vector<string>& input_names, + const std::vector<string>& output_names, + const std::vector<string>& target_nodes, + string* handle) { + LOG_AND_RUN_OPERATION(PRunSetup, input_names, output_names, target_nodes, + handle); +} + +Status SessionRef::PRun(const string& handle, + const std::vector<std::pair<string, Tensor> >& inputs, + const std::vector<string>& output_names, + std::vector<Tensor>* outputs) { + LOG_AND_RUN_OPERATION(PRun, handle, inputs, output_names, outputs); +} + +Status SessionRef::MakeCallable(const CallableOptions& callable_options, + CallableHandle* out_handle) { + LOG_AND_RUN_OPERATION(MakeCallable, callable_options, out_handle); +} + +Status SessionRef::RunCallable(CallableHandle handle, + const std::vector<Tensor>& feed_tensors, + std::vector<Tensor>* fetch_tensors, + RunMetadata* run_metadata) { + LOG_AND_RUN_OPERATION(RunCallable, handle, feed_tensors, fetch_tensors, + run_metadata); +} + +Status SessionRef::ReleaseCallable(CallableHandle handle) { + LOG_AND_RUN_OPERATION(ReleaseCallable, handle); +} + +Status SessionRef::Close(const RunOptions& run_options) { + TF_RETURN_IF_ERROR(CheckNotClosed()); + mutex_lock l(run_lock_); + Status status; + if (logger_) { + status = logger_->RecordClose(session_.get(), run_options); + } else { + status = session_->Close(run_options); + } + session_.reset(); + while (run_count_ > 0) { + run_finished_.wait(l); + } + return status; +} + +Status SessionRef::Close() { + TF_RETURN_IF_ERROR(CheckNotClosed()); + mutex_lock l(run_lock_); + Status status; + if (logger_) { + status = logger_->RecordClose(session_.get()); + } else { + status = session_->Close(); + } + session_.reset(); + while (run_count_ > 0) { + run_finished_.wait(l); + } + return status; +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/session_ref.h b/tensorflow/python/client/session_ref.h index 9459e7edbe..b0fb12b189 100644 --- a/tensorflow/core/common_runtime/session_ref.h +++ b/tensorflow/python/client/session_ref.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SESSION_REF_H_ -#define TENSORFLOW_CORE_COMMON_RUNTIME_SESSION_REF_H_ +#ifndef TENSORFLOW_PYTHON_CLIENT_SESSION_REF_H_ +#define TENSORFLOW_PYTHON_CLIENT_SESSION_REF_H_ #include <memory> @@ -22,6 +22,8 @@ limitations under the License. namespace tensorflow { +class SessionLogger; + // A `SessionRef` manages the lifetime of a wrapped `Session` pointer. // // SessionRef blocks the return of Close() until all pending operations have @@ -29,8 +31,8 @@ namespace tensorflow { // subsequent operations on the SessionRef object will return errors::Cancelled. class SessionRef : public Session { public: - SessionRef(Session* session) : session_(session) {} - virtual ~SessionRef() {} + explicit SessionRef(Session* session); + ~SessionRef() override; Status Create(const GraphDef& graph) override; Status Extend(const GraphDef& graph) override; @@ -78,9 +80,12 @@ class SessionRef : public Session { uint64 run_count_ GUARDED_BY(run_lock_) = {0}; std::shared_ptr<Session> session_; + // Borrowed reference to global session logger. + SessionLogger* logger_; + Status CheckNotClosed(); }; } // namespace tensorflow -#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SESSION_REF_H_ +#endif // TENSORFLOW_PYTHON_CLIENT_SESSION_REF_H_ diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc index bcd4af2912..dc0c10bab7 100644 --- a/tensorflow/python/client/tf_session_helper.cc +++ b/tensorflow/python/client/tf_session_helper.cc @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/tf_status_helper.h" -#include "tensorflow/core/common_runtime/session_ref.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value_util.h" @@ -31,6 +30,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/equal_graph_def.h" +#include "tensorflow/python/client/session_ref.h" #include "tensorflow/python/lib/core/ndarray_tensor.h" #include "tensorflow/python/lib/core/ndarray_tensor_bridge.h" #include "tensorflow/python/lib/core/safe_ptr.h" |