aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Russell Power <power@google.com>2018-09-20 13:48:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-20 13:53:34 -0700
commitd388770922ad1afa95e55597a33836fe74035c75 (patch)
tree3c15e17c357645e3e872ea46daaea2c91b00e9c1 /tensorflow
parent1f1e5ac6154583d5f87c846d1d7c9c59a77d6e0c (diff)
Implement TF graph capture.
PiperOrigin-RevId: 213875284
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/core/BUILD16
-rw-r--r--tensorflow/core/common_runtime/session_ref.cc170
-rw-r--r--tensorflow/core/protobuf/replay_log.proto47
-rw-r--r--tensorflow/python/BUILD15
-rw-r--r--tensorflow/python/client/session_ref.cc515
-rw-r--r--tensorflow/python/client/session_ref.h (renamed from tensorflow/core/common_runtime/session_ref.h)15
-rw-r--r--tensorflow/python/client/tf_session_helper.cc2
7 files changed, 597 insertions, 183 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index ed1818f834..85b6d4ff68 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -2988,12 +2988,16 @@ tf_cuda_library(
] + tf_additional_device_tracer_deps(),
)
-cc_library(
- name = "session_ref",
- srcs = ["common_runtime/session_ref.cc"],
- hdrs = ["common_runtime/session_ref.h"],
- copts = tf_copts(),
- deps = [":core_cpu_base"],
+tf_proto_library_cc(
+ name = "replay_log_proto",
+ srcs = ["protobuf/replay_log.proto"],
+ cc_api_version = 2,
+ protodeps = [
+ ":master_proto",
+ ] + tf_additional_all_protos(),
+ visibility = [
+ "//tensorflow:internal",
+ ],
)
cc_library(
diff --git a/tensorflow/core/common_runtime/session_ref.cc b/tensorflow/core/common_runtime/session_ref.cc
deleted file mode 100644
index b931ef4229..0000000000
--- a/tensorflow/core/common_runtime/session_ref.cc
+++ /dev/null
@@ -1,170 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "tensorflow/core/common_runtime/session_ref.h"
-
-#include <utility>
-
-namespace tensorflow {
-
-namespace {
-
-// Scope helper to track active calls and manage session lifetime.
-struct RunCounter {
- std::shared_ptr<Session> session;
- uint64* value;
- mutex* m;
- condition_variable* cv;
-
- explicit RunCounter(std::shared_ptr<Session> s, uint64* v, mutex* m,
- condition_variable* cv)
- : session(std::move(s)), value(v), m(m), cv(cv) {
- mutex_lock l(*m);
- ++*value;
- }
-
- ~RunCounter() {
- mutex_lock l(*m);
- if (--*value == 0) {
- cv->notify_all();
- }
- }
-};
-
-} // namespace
-
-Status SessionRef::CheckNotClosed() {
- mutex_lock l(run_lock_);
- if (session_ == nullptr) return errors::Cancelled("Session has been closed.");
- return ::tensorflow::Status::OK();
-}
-
-Status SessionRef::Run(const RunOptions& run_options,
- const std::vector<std::pair<string, Tensor> >& inputs,
- const std::vector<string>& output_tensor_names,
- const std::vector<string>& target_node_names,
- std::vector<Tensor>* outputs,
- RunMetadata* run_metadata) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->Run(run_options, inputs, output_tensor_names,
- target_node_names, outputs, run_metadata);
-}
-
-Status SessionRef::Create(const GraphDef& graph) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->Create(graph);
-}
-
-Status SessionRef::Create(const RunOptions& run_options,
- const GraphDef& graph) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->Create(run_options, graph);
-}
-
-Status SessionRef::Extend(const RunOptions& run_options,
- const GraphDef& graph) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->Extend(run_options, graph);
-}
-
-Status SessionRef::Extend(const GraphDef& graph) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->Extend(graph);
-}
-
-Status SessionRef::Close(const RunOptions& run_options) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- mutex_lock l(run_lock_);
- Status status = session_->Close(run_options);
- session_.reset();
- while (run_count_ > 0) {
- run_finished_.wait(l);
- }
- return status;
-}
-
-Status SessionRef::Close() {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- mutex_lock l(run_lock_);
- Status status = session_->Close();
- session_.reset();
- while (run_count_ > 0) {
- run_finished_.wait(l);
- }
- return status;
-}
-
-Status SessionRef::Run(const std::vector<std::pair<string, Tensor> >& inputs,
- const std::vector<string>& output_tensor_names,
- const std::vector<string>& target_node_names,
- std::vector<Tensor>* outputs) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->Run(inputs, output_tensor_names, target_node_names,
- outputs);
-}
-
-Status SessionRef::ListDevices(std::vector<DeviceAttributes>* response) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->ListDevices(response);
-}
-
-Status SessionRef::PRunSetup(const std::vector<string>& input_names,
- const std::vector<string>& output_names,
- const std::vector<string>& target_nodes,
- string* handle) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->PRunSetup(input_names, output_names, target_nodes, handle);
-}
-
-Status SessionRef::PRun(const string& handle,
- const std::vector<std::pair<string, Tensor> >& inputs,
- const std::vector<string>& output_names,
- std::vector<Tensor>* outputs) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->PRun(handle, inputs, output_names, outputs);
-}
-
-Status SessionRef::MakeCallable(const CallableOptions& callable_options,
- CallableHandle* out_handle) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->MakeCallable(callable_options, out_handle);
-}
-
-Status SessionRef::RunCallable(CallableHandle handle,
- const std::vector<Tensor>& feed_tensors,
- std::vector<Tensor>* fetch_tensors,
- RunMetadata* run_metadata) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->RunCallable(handle, feed_tensors, fetch_tensors,
- run_metadata);
-}
-
-Status SessionRef::ReleaseCallable(CallableHandle handle) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->ReleaseCallable(handle);
-}
-
-} // namespace tensorflow
diff --git a/tensorflow/core/protobuf/replay_log.proto b/tensorflow/core/protobuf/replay_log.proto
new file mode 100644
index 0000000000..7644314fc9
--- /dev/null
+++ b/tensorflow/core/protobuf/replay_log.proto
@@ -0,0 +1,47 @@
+syntax = "proto3";
+
+option cc_enable_arenas = true;
+package tensorflow;
+
+import "tensorflow/core/framework/graph.proto";
+import "tensorflow/core/protobuf/cluster.proto";
+import "tensorflow/core/protobuf/master.proto";
+
+// Records the creation of a new replay session. We record the device listing
+// here to capture the state of the cluster.
+message NewReplaySession {
+ ListDevicesResponse devices = 1;
+ string session_handle = 2;
+}
+
+message ReplayOp {
+ double start_time_us = 31;
+ double end_time_us = 32;
+
+ oneof op {
+ CreateSessionRequest create_session = 1;
+ ExtendSessionRequest extend_session = 2;
+ PartialRunSetupRequest partial_run_setup = 3;
+ RunStepRequest run_step = 4;
+ CloseSessionRequest close_session = 5;
+ ListDevicesRequest list_devices = 6;
+ ResetRequest reset_request = 7;
+ MakeCallableRequest make_callable = 8;
+ RunCallableRequest run_callable = 9;
+ ReleaseCallableRequest release_callable = 10;
+ NewReplaySession new_replay_session = 11;
+ }
+
+ oneof response {
+ CreateSessionResponse create_session_response = 21;
+ ExtendSessionResponse extend_session_response = 22;
+ PartialRunSetupResponse partial_run_setup_response = 23;
+ RunStepResponse run_step_response = 24;
+ CloseSessionResponse close_session_response = 25;
+ ListDevicesResponse list_devices_response = 26;
+ ResetResponse reset_request_response = 27;
+ MakeCallableResponse make_callable_response = 28;
+ RunCallableResponse run_callable_response = 29;
+ ReleaseCallableResponse release_callable_response = 30;
+ }
+}
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 9730e9933a..79f14466e6 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -3763,6 +3763,19 @@ cuda_py_tests(
],
)
+cc_library(
+ name = "session_ref",
+ srcs = ["client/session_ref.cc"],
+ hdrs = ["client/session_ref.h"],
+ deps = [
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:master_proto_cc",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:replay_log_proto_cc",
+ ],
+)
+
tf_cuda_library(
name = "tf_session_helper",
srcs = ["client/tf_session_helper.cc"],
@@ -3773,6 +3786,7 @@ tf_cuda_library(
":ndarray_tensor_bridge",
":numpy_lib",
":safe_ptr",
+ ":session_ref",
":test_ops_kernels",
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_internal",
@@ -3785,7 +3799,6 @@ tf_cuda_library(
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
- "//tensorflow/core:session_ref",
"//third_party/py/numpy:headers",
"//third_party/python_runtime:headers",
],
diff --git a/tensorflow/python/client/session_ref.cc b/tensorflow/python/client/session_ref.cc
new file mode 100644
index 0000000000..b2300df0b6
--- /dev/null
+++ b/tensorflow/python/client/session_ref.cc
@@ -0,0 +1,515 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/python/client/session_ref.h"
+
+#include <stdlib.h>
+#include <memory>
+#include <utility>
+
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/io/record_writer.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/protobuf/master.pb.h"
+#include "tensorflow/core/protobuf/named_tensor.pb.h"
+#include "tensorflow/core/protobuf/replay_log.pb.h"
+
+namespace tensorflow {
+
+namespace {
+
+// Scope helper to track active calls and manage session lifetime.
+// SessionRef blocks closing until all active calls complete or are cancelled.
+struct RunCounter {
+ std::shared_ptr<Session> session;
+ uint64* value;
+ mutex* m;
+ condition_variable* cv;
+
+ explicit RunCounter(std::shared_ptr<Session> s, uint64* v, mutex* m,
+ condition_variable* cv)
+ : session(std::move(s)), value(v), m(m), cv(cv) {
+ mutex_lock l(*m);
+ ++*value;
+ }
+
+ ~RunCounter() {
+ mutex_lock l(*m);
+ if (--*value == 0) {
+ cv->notify_all();
+ }
+ }
+};
+
+std::string SessionToHandle(Session* session) {
+ return strings::Printf("%llu", reinterpret_cast<uint64>(session));
+}
+
+// The Session interface has many methods of the form:
+//
+// X(a, b);
+// X(RunOptions, a, b);
+//
+// Not all sessions support the second case (with an empty RunOptions()).
+// We use this variable as a sentinel to dispatch to the correct call.
+RunOptions* kEmptyRunOptions() {
+ static RunOptions* options = new RunOptions();
+ return options;
+}
+
+} // namespace
+
+// Run the given session operation, recording start and end timestamps.
+// If the operation returns a bad status, return after flushing the current
+// log request. This should be run _after_ all request information has been
+// added to the current op.
+#define RUN_WITH_TIMESTAMP(OpName, ...) \
+ op.set_start_time_us(Env::Default()->NowMicros()); \
+ Status status = session->OpName(__VA_ARGS__); \
+ op.set_end_time_us(Env::Default()->NowMicros()); \
+ if (!status.ok()) { \
+ Flush(op).IgnoreError(); \
+ return status; \
+ }
+
+// Records requests (and optionally responses) performed against a session.
+// The resulting replay log can be used with the `tf_replay` tool to replicate
+// the operations against a simulated environment, without requiring the
+// original code or cluster setup.
+//
+// Session logging by setting the TF_REPLAY_LOG_FILE environment variable.
+class SessionLogger {
+ public:
+ SessionLogger() {
+ std::string log_name = getenv("TF_REPLAY_LOG_FILE");
+ TF_CHECK_OK(
+ Env::Default()->RecursivelyCreateDir(string(io::Dirname(log_name))));
+ Env::Default()->DeleteFile(log_name).IgnoreError();
+ TF_CHECK_OK(Env::Default()->NewWritableFile(log_name, &log_file_));
+
+ log_writer_ = absl::make_unique<io::RecordWriter>(log_file_.get());
+ }
+
+ Status RecordCreateSession(Session* session) {
+ LOG(INFO) << "Capturing devices for session.";
+ ReplayOp op;
+ NewReplaySession* req = op.mutable_new_replay_session();
+
+ std::vector<DeviceAttributes> devices;
+ TF_CHECK_OK(session->ListDevices(&devices));
+ for (const DeviceAttributes& dev : devices) {
+ *req->mutable_devices()->add_local_device() = dev;
+ }
+
+ req->set_session_handle(SessionToHandle(session));
+ return Flush(op);
+ }
+
+ Status RecordRun(Session* session,
+ const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_tensor_names,
+ const std::vector<string>& target_node_names,
+ std::vector<Tensor>* outputs) {
+ return RecordRun(session, *kEmptyRunOptions(), inputs, output_tensor_names,
+ target_node_names, outputs, nullptr);
+ }
+
+ Status RecordRun(Session* session, const RunOptions& run_options,
+ const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_tensor_names,
+ const std::vector<string>& target_node_names,
+ std::vector<Tensor>* outputs, RunMetadata* run_metadata) {
+ ReplayOp op;
+ RunStepRequest* req = op.mutable_run_step();
+ RunStepResponse* resp = op.mutable_run_step_response();
+
+ req->set_session_handle(SessionToHandle(session));
+ *req->mutable_options() = run_options;
+
+ for (const auto& it : inputs) {
+ NamedTensorProto* feed = req->add_feed();
+ feed->set_name(it.first);
+ it.second.AsProtoField(feed->mutable_tensor());
+ }
+
+ // Build an index from fetch tensor name to first index in
+ // output_tensor_names.
+ std::unordered_map<string, int> output_name_to_offset;
+ for (int i = 0; i < output_tensor_names.size(); ++i) {
+ const string& name = output_tensor_names[i];
+ if (output_name_to_offset.insert(std::make_pair(name, i)).second) {
+ req->add_fetch(name);
+ }
+ }
+ for (const string& target : target_node_names) {
+ req->add_target(target);
+ }
+
+ if (&run_options == kEmptyRunOptions()) {
+ RUN_WITH_TIMESTAMP(Run, inputs, output_tensor_names, target_node_names,
+ outputs);
+ } else {
+ RUN_WITH_TIMESTAMP(Run, run_options, inputs, output_tensor_names,
+ target_node_names, outputs, run_metadata);
+ }
+
+ for (size_t i = 0; i < outputs->size(); ++i) {
+ const Tensor& tensor = (*outputs)[i];
+ NamedTensorProto* tproto = resp->add_tensor();
+ tensor.AsProtoField(tproto->mutable_tensor());
+ tproto->set_name(output_tensor_names[i]);
+ }
+
+ if (run_metadata) {
+ *resp->mutable_metadata() = *run_metadata;
+ }
+
+ return Flush(op);
+ }
+
+ Status RecordCreate(Session* session, const GraphDef& graph) {
+ return RecordCreate(session, *kEmptyRunOptions(), graph);
+ }
+
+ // N.B. RunOptions is not stored (it has no entry in CreateRequest)
+ Status RecordCreate(Session* session, const RunOptions& run_options,
+ const GraphDef& graph) {
+ ReplayOp op;
+ CreateSessionRequest* req = op.mutable_create_session();
+ *req->mutable_graph_def() = graph;
+
+ CreateSessionResponse* resp = op.mutable_create_session_response();
+ if (&run_options == kEmptyRunOptions()) {
+ RUN_WITH_TIMESTAMP(Create, graph);
+ } else {
+ RUN_WITH_TIMESTAMP(Create, run_options, graph);
+ }
+ resp->set_session_handle(SessionToHandle(session));
+ return Flush(op);
+ }
+
+ Status RecordExtend(Session* session, const GraphDef& graph) {
+ return RecordExtend(session, *kEmptyRunOptions(), graph);
+ }
+
+ // N.B. RunOptions is not stored (it has no entry in ExtendRequest)
+ Status RecordExtend(Session* session, const RunOptions& run_options,
+ const GraphDef& graph) {
+ ReplayOp op;
+ ExtendSessionRequest* req = op.mutable_extend_session();
+ op.mutable_extend_session_response();
+ req->set_session_handle(SessionToHandle(session));
+ *req->mutable_graph_def() = graph;
+ if (&run_options == kEmptyRunOptions()) {
+ RUN_WITH_TIMESTAMP(Extend, graph);
+ } else {
+ RUN_WITH_TIMESTAMP(Extend, run_options, graph);
+ }
+
+ return Flush(op);
+ }
+
+ Status RecordClose(Session* session) {
+ return RecordClose(session, *kEmptyRunOptions());
+ }
+
+ // N.B. RunOptions is not stored (it has no entry in CloseRequest)
+ Status RecordClose(Session* session, const RunOptions& run_options) {
+ mutex_lock l(log_mutex_);
+ ReplayOp op;
+ CloseSessionRequest* req = op.mutable_close_session();
+ req->set_session_handle(SessionToHandle(session));
+ op.mutable_close_session_response();
+ if (&run_options == kEmptyRunOptions()) {
+ RUN_WITH_TIMESTAMP(Close);
+ } else {
+ RUN_WITH_TIMESTAMP(Close, run_options);
+ }
+ return Flush(op);
+ }
+
+ Status RecordListDevices(Session* session,
+ std::vector<DeviceAttributes>* response) {
+ mutex_lock l(log_mutex_);
+ ReplayOp op;
+ ListDevicesRequest* req = op.mutable_list_devices();
+ ListDevicesResponse* resp = op.mutable_list_devices_response();
+ req->set_session_handle(SessionToHandle(session));
+ RUN_WITH_TIMESTAMP(ListDevices, response);
+
+ // TODO(power) -- local vs remote device distinction is lost here!
+ *resp->mutable_local_device() = {response->begin(), response->end()};
+ return Flush(op);
+ }
+
+ Status RecordPRunSetup(Session* session,
+ const std::vector<string>& input_names,
+ const std::vector<string>& output_names,
+ const std::vector<string>& target_nodes,
+ string* handle) {
+ mutex_lock l(log_mutex_);
+ ReplayOp op;
+ PartialRunSetupRequest* req = op.mutable_partial_run_setup();
+ req->set_session_handle(SessionToHandle(session));
+ for (auto& input : input_names) {
+ req->add_feed(input);
+ }
+ for (auto& output : output_names) {
+ req->add_fetch(output);
+ }
+ for (auto& target : target_nodes) {
+ req->add_target(target);
+ }
+ RUN_WITH_TIMESTAMP(PRunSetup, input_names, output_names, target_nodes,
+ handle);
+ op.mutable_partial_run_setup_response()->set_partial_run_handle(*handle);
+ return Flush(op);
+ }
+
+ Status RecordPRun(Session* session, const string& handle,
+ const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_names,
+ std::vector<Tensor>* outputs) {
+ ReplayOp op;
+ RunStepRequest* req = op.mutable_run_step();
+ RunStepResponse* resp = op.mutable_run_step_response();
+ req->set_session_handle(SessionToHandle(session));
+
+ // Mark this step as a partial run for replay.
+ req->set_partial_run_handle(handle);
+ for (auto& input : inputs) {
+ auto* feed = req->add_feed();
+ feed->set_name(input.first);
+ input.second.AsProtoField(feed->mutable_tensor());
+ }
+
+ for (auto& output : output_names) {
+ req->add_fetch(output);
+ }
+
+ RUN_WITH_TIMESTAMP(PRun, handle, inputs, output_names, outputs);
+
+ for (size_t i = 0; i < outputs->size(); ++i) {
+ const Tensor& tensor = (*outputs)[i];
+ NamedTensorProto* tproto = resp->add_tensor();
+ tensor.AsProtoField(tproto->mutable_tensor());
+ tproto->set_name(output_names[i]);
+ }
+
+ return Flush(op);
+ }
+
+ Status RecordMakeCallable(Session* session,
+ const CallableOptions& callable_options,
+ Session::CallableHandle* handle) {
+ ReplayOp op;
+ MakeCallableRequest* req = op.mutable_make_callable();
+ req->set_session_handle(SessionToHandle(session));
+ *req->mutable_options() = callable_options;
+
+ RUN_WITH_TIMESTAMP(MakeCallable, callable_options, handle);
+
+ MakeCallableResponse* resp = op.mutable_make_callable_response();
+ resp->set_handle(*handle);
+
+ return Flush(op);
+ }
+
+ Status RecordRunCallable(Session* session, Session::CallableHandle handle,
+ const std::vector<Tensor>& feed_tensors,
+ std::vector<Tensor>* fetch_tensors,
+ RunMetadata* run_metadata) {
+ ReplayOp op;
+ RunCallableRequest* req = op.mutable_run_callable();
+ req->set_session_handle(SessionToHandle(session));
+ req->set_handle(handle);
+ for (auto& tensor : feed_tensors) {
+ tensor.AsProtoField(req->add_feed());
+ }
+ RUN_WITH_TIMESTAMP(RunCallable, handle, feed_tensors, fetch_tensors,
+ run_metadata);
+
+ RunCallableResponse* resp = op.mutable_run_callable_response();
+ if (run_metadata) {
+ *resp->mutable_metadata() = *run_metadata;
+ }
+ for (const Tensor& tensor : *fetch_tensors) {
+ tensor.AsProtoTensorContent(resp->add_fetch());
+ }
+ return Flush(op);
+ }
+
+ Status RecordReleaseCallable(Session* session,
+ Session::CallableHandle handle) {
+ ReplayOp op;
+ ReleaseCallableRequest* req = op.mutable_release_callable();
+ req->set_session_handle(SessionToHandle(session));
+ req->set_handle(handle);
+ RUN_WITH_TIMESTAMP(ReleaseCallable, handle);
+ return Flush(op);
+ }
+
+ private:
+ Status Flush(const ReplayOp& op) {
+ string buf;
+ op.SerializeToString(&buf);
+ TF_RETURN_IF_ERROR(log_writer_->WriteRecord(buf));
+
+ // Flushing the RecordWriter _does not_ flush the underlying file.
+ TF_RETURN_IF_ERROR(log_writer_->Flush());
+ return log_file_->Flush();
+ }
+
+ mutex log_mutex_;
+ std::unique_ptr<io::RecordWriter> log_writer_;
+ std::unique_ptr<WritableFile> log_file_;
+};
+
+static SessionLogger* global_session_logger() {
+ static SessionLogger* logger = new SessionLogger();
+ return logger;
+}
+
+SessionRef::SessionRef(Session* session) : session_(session) {
+ if (getenv("TF_REPLAY_LOG_FILE") != nullptr) {
+ logger_ = global_session_logger();
+ logger_->RecordCreateSession(this->session_.get()).IgnoreError();
+ } else {
+ logger_ = nullptr;
+ }
+}
+
+SessionRef::~SessionRef() = default;
+
+Status SessionRef::CheckNotClosed() {
+ mutex_lock l(run_lock_);
+ if (session_ == nullptr) return errors::Cancelled("Session has been closed.");
+ return ::tensorflow::Status::OK();
+}
+
+// If logging is active, log the start and end time of the operation along with
+// the request and response.
+#define LOG_AND_RUN_OPERATION(OpName, ...) \
+ TF_RETURN_IF_ERROR(CheckNotClosed()); \
+ RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); \
+ if (!logger_) { \
+ return rc.session->OpName(__VA_ARGS__); \
+ } \
+ return logger_->Record##OpName(rc.session.get(), __VA_ARGS__);
+
+Status SessionRef::Run(const RunOptions& run_options,
+ const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_tensor_names,
+ const std::vector<string>& target_node_names,
+ std::vector<Tensor>* outputs,
+ RunMetadata* run_metadata) {
+ LOG_AND_RUN_OPERATION(Run, run_options, inputs, output_tensor_names,
+ target_node_names, outputs, run_metadata);
+}
+
+Status SessionRef::Run(const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_tensor_names,
+ const std::vector<string>& target_node_names,
+ std::vector<Tensor>* outputs) {
+ LOG_AND_RUN_OPERATION(Run, inputs, output_tensor_names, target_node_names,
+ outputs);
+}
+
+Status SessionRef::Create(const GraphDef& graph) {
+ LOG_AND_RUN_OPERATION(Create, graph);
+}
+
+Status SessionRef::Create(const RunOptions& run_options,
+ const GraphDef& graph) {
+ LOG_AND_RUN_OPERATION(Create, run_options, graph);
+}
+
+Status SessionRef::Extend(const RunOptions& run_options,
+ const GraphDef& graph) {
+ LOG_AND_RUN_OPERATION(Extend, run_options, graph);
+}
+
+Status SessionRef::Extend(const GraphDef& graph) {
+ LOG_AND_RUN_OPERATION(Extend, graph);
+}
+
+Status SessionRef::ListDevices(std::vector<DeviceAttributes>* response) {
+ LOG_AND_RUN_OPERATION(ListDevices, response);
+}
+
+Status SessionRef::PRunSetup(const std::vector<string>& input_names,
+ const std::vector<string>& output_names,
+ const std::vector<string>& target_nodes,
+ string* handle) {
+ LOG_AND_RUN_OPERATION(PRunSetup, input_names, output_names, target_nodes,
+ handle);
+}
+
+Status SessionRef::PRun(const string& handle,
+ const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_names,
+ std::vector<Tensor>* outputs) {
+ LOG_AND_RUN_OPERATION(PRun, handle, inputs, output_names, outputs);
+}
+
+Status SessionRef::MakeCallable(const CallableOptions& callable_options,
+ CallableHandle* out_handle) {
+ LOG_AND_RUN_OPERATION(MakeCallable, callable_options, out_handle);
+}
+
+Status SessionRef::RunCallable(CallableHandle handle,
+ const std::vector<Tensor>& feed_tensors,
+ std::vector<Tensor>* fetch_tensors,
+ RunMetadata* run_metadata) {
+ LOG_AND_RUN_OPERATION(RunCallable, handle, feed_tensors, fetch_tensors,
+ run_metadata);
+}
+
+Status SessionRef::ReleaseCallable(CallableHandle handle) {
+ LOG_AND_RUN_OPERATION(ReleaseCallable, handle);
+}
+
+Status SessionRef::Close(const RunOptions& run_options) {
+ TF_RETURN_IF_ERROR(CheckNotClosed());
+ mutex_lock l(run_lock_);
+ Status status;
+ if (logger_) {
+ status = logger_->RecordClose(session_.get(), run_options);
+ } else {
+ status = session_->Close(run_options);
+ }
+ session_.reset();
+ while (run_count_ > 0) {
+ run_finished_.wait(l);
+ }
+ return status;
+}
+
+Status SessionRef::Close() {
+ TF_RETURN_IF_ERROR(CheckNotClosed());
+ mutex_lock l(run_lock_);
+ Status status;
+ if (logger_) {
+ status = logger_->RecordClose(session_.get());
+ } else {
+ status = session_->Close();
+ }
+ session_.reset();
+ while (run_count_ > 0) {
+ run_finished_.wait(l);
+ }
+ return status;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/session_ref.h b/tensorflow/python/client/session_ref.h
index 9459e7edbe..b0fb12b189 100644
--- a/tensorflow/core/common_runtime/session_ref.h
+++ b/tensorflow/python/client/session_ref.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SESSION_REF_H_
-#define TENSORFLOW_CORE_COMMON_RUNTIME_SESSION_REF_H_
+#ifndef TENSORFLOW_PYTHON_CLIENT_SESSION_REF_H_
+#define TENSORFLOW_PYTHON_CLIENT_SESSION_REF_H_
#include <memory>
@@ -22,6 +22,8 @@ limitations under the License.
namespace tensorflow {
+class SessionLogger;
+
// A `SessionRef` manages the lifetime of a wrapped `Session` pointer.
//
// SessionRef blocks the return of Close() until all pending operations have
@@ -29,8 +31,8 @@ namespace tensorflow {
// subsequent operations on the SessionRef object will return errors::Cancelled.
class SessionRef : public Session {
public:
- SessionRef(Session* session) : session_(session) {}
- virtual ~SessionRef() {}
+ explicit SessionRef(Session* session);
+ ~SessionRef() override;
Status Create(const GraphDef& graph) override;
Status Extend(const GraphDef& graph) override;
@@ -78,9 +80,12 @@ class SessionRef : public Session {
uint64 run_count_ GUARDED_BY(run_lock_) = {0};
std::shared_ptr<Session> session_;
+ // Borrowed reference to global session logger.
+ SessionLogger* logger_;
+
Status CheckNotClosed();
};
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SESSION_REF_H_
+#endif // TENSORFLOW_PYTHON_CLIENT_SESSION_REF_H_
diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc
index bcd4af2912..dc0c10bab7 100644
--- a/tensorflow/python/client/tf_session_helper.cc
+++ b/tensorflow/python/client/tf_session_helper.cc
@@ -20,7 +20,6 @@ limitations under the License.
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/tf_status_helper.h"
-#include "tensorflow/core/common_runtime/session_ref.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/attr_value_util.h"
@@ -31,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/equal_graph_def.h"
+#include "tensorflow/python/client/session_ref.h"
#include "tensorflow/python/lib/core/ndarray_tensor.h"
#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
#include "tensorflow/python/lib/core/safe_ptr.h"