aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/client
diff options
context:
space:
mode:
authorGravatar Russell Power <power@google.com>2018-09-20 13:48:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-20 13:53:34 -0700
commitd388770922ad1afa95e55597a33836fe74035c75 (patch)
tree3c15e17c357645e3e872ea46daaea2c91b00e9c1 /tensorflow/python/client
parent1f1e5ac6154583d5f87c846d1d7c9c59a77d6e0c (diff)
Implement TF graph capture.
PiperOrigin-RevId: 213875284
Diffstat (limited to 'tensorflow/python/client')
-rw-r--r--tensorflow/python/client/session_ref.cc515
-rw-r--r--tensorflow/python/client/session_ref.h91
-rw-r--r--tensorflow/python/client/tf_session_helper.cc2
3 files changed, 607 insertions, 1 deletions
diff --git a/tensorflow/python/client/session_ref.cc b/tensorflow/python/client/session_ref.cc
new file mode 100644
index 0000000000..b2300df0b6
--- /dev/null
+++ b/tensorflow/python/client/session_ref.cc
@@ -0,0 +1,515 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/python/client/session_ref.h"
+
+#include <stdlib.h>
+#include <memory>
+#include <utility>
+
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/io/record_writer.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/protobuf/master.pb.h"
+#include "tensorflow/core/protobuf/named_tensor.pb.h"
+#include "tensorflow/core/protobuf/replay_log.pb.h"
+
+namespace tensorflow {
+
+namespace {
+
+// Scope helper to track active calls and manage session lifetime.
+// SessionRef blocks closing until all active calls complete or are cancelled.
+struct RunCounter {
+ std::shared_ptr<Session> session;
+ uint64* value;
+ mutex* m;
+ condition_variable* cv;
+
+ explicit RunCounter(std::shared_ptr<Session> s, uint64* v, mutex* m,
+ condition_variable* cv)
+ : session(std::move(s)), value(v), m(m), cv(cv) {
+ mutex_lock l(*m);
+ ++*value;
+ }
+
+ ~RunCounter() {
+ mutex_lock l(*m);
+ if (--*value == 0) {
+ cv->notify_all();
+ }
+ }
+};
+
+std::string SessionToHandle(Session* session) {
+ return strings::Printf("%llu", reinterpret_cast<uint64>(session));
+}
+
+// The Session interface has many methods of the form:
+//
+// X(a, b);
+// X(RunOptions, a, b);
+//
+// Not all sessions support the second case (with an empty RunOptions()).
+// We use this variable as a sentinel to dispatch to the correct call.
+RunOptions* kEmptyRunOptions() {
+ static RunOptions* options = new RunOptions();
+ return options;
+}
+
+} // namespace
+
+// Run the given session operation, recording start and end timestamps.
+// If the operation returns a bad status, return after flushing the current
+// log request. This should be run _after_ all request information has been
+// added to the current op.
+#define RUN_WITH_TIMESTAMP(OpName, ...) \
+ op.set_start_time_us(Env::Default()->NowMicros()); \
+ Status status = session->OpName(__VA_ARGS__); \
+ op.set_end_time_us(Env::Default()->NowMicros()); \
+ if (!status.ok()) { \
+ Flush(op).IgnoreError(); \
+ return status; \
+ }
+
+// Records requests (and optionally responses) performed against a session.
+// The resulting replay log can be used with the `tf_replay` tool to replicate
+// the operations against a simulated environment, without requiring the
+// original code or cluster setup.
+//
+// Session logging by setting the TF_REPLAY_LOG_FILE environment variable.
+class SessionLogger {
+ public:
+ SessionLogger() {
+ std::string log_name = getenv("TF_REPLAY_LOG_FILE");
+ TF_CHECK_OK(
+ Env::Default()->RecursivelyCreateDir(string(io::Dirname(log_name))));
+ Env::Default()->DeleteFile(log_name).IgnoreError();
+ TF_CHECK_OK(Env::Default()->NewWritableFile(log_name, &log_file_));
+
+ log_writer_ = absl::make_unique<io::RecordWriter>(log_file_.get());
+ }
+
+ Status RecordCreateSession(Session* session) {
+ LOG(INFO) << "Capturing devices for session.";
+ ReplayOp op;
+ NewReplaySession* req = op.mutable_new_replay_session();
+
+ std::vector<DeviceAttributes> devices;
+ TF_CHECK_OK(session->ListDevices(&devices));
+ for (const DeviceAttributes& dev : devices) {
+ *req->mutable_devices()->add_local_device() = dev;
+ }
+
+ req->set_session_handle(SessionToHandle(session));
+ return Flush(op);
+ }
+
+ Status RecordRun(Session* session,
+ const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_tensor_names,
+ const std::vector<string>& target_node_names,
+ std::vector<Tensor>* outputs) {
+ return RecordRun(session, *kEmptyRunOptions(), inputs, output_tensor_names,
+ target_node_names, outputs, nullptr);
+ }
+
+ Status RecordRun(Session* session, const RunOptions& run_options,
+ const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_tensor_names,
+ const std::vector<string>& target_node_names,
+ std::vector<Tensor>* outputs, RunMetadata* run_metadata) {
+ ReplayOp op;
+ RunStepRequest* req = op.mutable_run_step();
+ RunStepResponse* resp = op.mutable_run_step_response();
+
+ req->set_session_handle(SessionToHandle(session));
+ *req->mutable_options() = run_options;
+
+ for (const auto& it : inputs) {
+ NamedTensorProto* feed = req->add_feed();
+ feed->set_name(it.first);
+ it.second.AsProtoField(feed->mutable_tensor());
+ }
+
+ // Build an index from fetch tensor name to first index in
+ // output_tensor_names.
+ std::unordered_map<string, int> output_name_to_offset;
+ for (int i = 0; i < output_tensor_names.size(); ++i) {
+ const string& name = output_tensor_names[i];
+ if (output_name_to_offset.insert(std::make_pair(name, i)).second) {
+ req->add_fetch(name);
+ }
+ }
+ for (const string& target : target_node_names) {
+ req->add_target(target);
+ }
+
+ if (&run_options == kEmptyRunOptions()) {
+ RUN_WITH_TIMESTAMP(Run, inputs, output_tensor_names, target_node_names,
+ outputs);
+ } else {
+ RUN_WITH_TIMESTAMP(Run, run_options, inputs, output_tensor_names,
+ target_node_names, outputs, run_metadata);
+ }
+
+ for (size_t i = 0; i < outputs->size(); ++i) {
+ const Tensor& tensor = (*outputs)[i];
+ NamedTensorProto* tproto = resp->add_tensor();
+ tensor.AsProtoField(tproto->mutable_tensor());
+ tproto->set_name(output_tensor_names[i]);
+ }
+
+ if (run_metadata) {
+ *resp->mutable_metadata() = *run_metadata;
+ }
+
+ return Flush(op);
+ }
+
+ Status RecordCreate(Session* session, const GraphDef& graph) {
+ return RecordCreate(session, *kEmptyRunOptions(), graph);
+ }
+
+ // N.B. RunOptions is not stored (it has no entry in CreateRequest)
+ Status RecordCreate(Session* session, const RunOptions& run_options,
+ const GraphDef& graph) {
+ ReplayOp op;
+ CreateSessionRequest* req = op.mutable_create_session();
+ *req->mutable_graph_def() = graph;
+
+ CreateSessionResponse* resp = op.mutable_create_session_response();
+ if (&run_options == kEmptyRunOptions()) {
+ RUN_WITH_TIMESTAMP(Create, graph);
+ } else {
+ RUN_WITH_TIMESTAMP(Create, run_options, graph);
+ }
+ resp->set_session_handle(SessionToHandle(session));
+ return Flush(op);
+ }
+
+ Status RecordExtend(Session* session, const GraphDef& graph) {
+ return RecordExtend(session, *kEmptyRunOptions(), graph);
+ }
+
+ // N.B. RunOptions is not stored (it has no entry in ExtendRequest)
+ Status RecordExtend(Session* session, const RunOptions& run_options,
+ const GraphDef& graph) {
+ ReplayOp op;
+ ExtendSessionRequest* req = op.mutable_extend_session();
+ op.mutable_extend_session_response();
+ req->set_session_handle(SessionToHandle(session));
+ *req->mutable_graph_def() = graph;
+ if (&run_options == kEmptyRunOptions()) {
+ RUN_WITH_TIMESTAMP(Extend, graph);
+ } else {
+ RUN_WITH_TIMESTAMP(Extend, run_options, graph);
+ }
+
+ return Flush(op);
+ }
+
+ Status RecordClose(Session* session) {
+ return RecordClose(session, *kEmptyRunOptions());
+ }
+
+ // N.B. RunOptions is not stored (it has no entry in CloseRequest)
+ Status RecordClose(Session* session, const RunOptions& run_options) {
+ mutex_lock l(log_mutex_);
+ ReplayOp op;
+ CloseSessionRequest* req = op.mutable_close_session();
+ req->set_session_handle(SessionToHandle(session));
+ op.mutable_close_session_response();
+ if (&run_options == kEmptyRunOptions()) {
+ RUN_WITH_TIMESTAMP(Close);
+ } else {
+ RUN_WITH_TIMESTAMP(Close, run_options);
+ }
+ return Flush(op);
+ }
+
+ Status RecordListDevices(Session* session,
+ std::vector<DeviceAttributes>* response) {
+ mutex_lock l(log_mutex_);
+ ReplayOp op;
+ ListDevicesRequest* req = op.mutable_list_devices();
+ ListDevicesResponse* resp = op.mutable_list_devices_response();
+ req->set_session_handle(SessionToHandle(session));
+ RUN_WITH_TIMESTAMP(ListDevices, response);
+
+ // TODO(power) -- local vs remote device distinction is lost here!
+ *resp->mutable_local_device() = {response->begin(), response->end()};
+ return Flush(op);
+ }
+
+ Status RecordPRunSetup(Session* session,
+ const std::vector<string>& input_names,
+ const std::vector<string>& output_names,
+ const std::vector<string>& target_nodes,
+ string* handle) {
+ mutex_lock l(log_mutex_);
+ ReplayOp op;
+ PartialRunSetupRequest* req = op.mutable_partial_run_setup();
+ req->set_session_handle(SessionToHandle(session));
+ for (auto& input : input_names) {
+ req->add_feed(input);
+ }
+ for (auto& output : output_names) {
+ req->add_fetch(output);
+ }
+ for (auto& target : target_nodes) {
+ req->add_target(target);
+ }
+ RUN_WITH_TIMESTAMP(PRunSetup, input_names, output_names, target_nodes,
+ handle);
+ op.mutable_partial_run_setup_response()->set_partial_run_handle(*handle);
+ return Flush(op);
+ }
+
+ Status RecordPRun(Session* session, const string& handle,
+ const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_names,
+ std::vector<Tensor>* outputs) {
+ ReplayOp op;
+ RunStepRequest* req = op.mutable_run_step();
+ RunStepResponse* resp = op.mutable_run_step_response();
+ req->set_session_handle(SessionToHandle(session));
+
+ // Mark this step as a partial run for replay.
+ req->set_partial_run_handle(handle);
+ for (auto& input : inputs) {
+ auto* feed = req->add_feed();
+ feed->set_name(input.first);
+ input.second.AsProtoField(feed->mutable_tensor());
+ }
+
+ for (auto& output : output_names) {
+ req->add_fetch(output);
+ }
+
+ RUN_WITH_TIMESTAMP(PRun, handle, inputs, output_names, outputs);
+
+ for (size_t i = 0; i < outputs->size(); ++i) {
+ const Tensor& tensor = (*outputs)[i];
+ NamedTensorProto* tproto = resp->add_tensor();
+ tensor.AsProtoField(tproto->mutable_tensor());
+ tproto->set_name(output_names[i]);
+ }
+
+ return Flush(op);
+ }
+
+ Status RecordMakeCallable(Session* session,
+ const CallableOptions& callable_options,
+ Session::CallableHandle* handle) {
+ ReplayOp op;
+ MakeCallableRequest* req = op.mutable_make_callable();
+ req->set_session_handle(SessionToHandle(session));
+ *req->mutable_options() = callable_options;
+
+ RUN_WITH_TIMESTAMP(MakeCallable, callable_options, handle);
+
+ MakeCallableResponse* resp = op.mutable_make_callable_response();
+ resp->set_handle(*handle);
+
+ return Flush(op);
+ }
+
+ Status RecordRunCallable(Session* session, Session::CallableHandle handle,
+ const std::vector<Tensor>& feed_tensors,
+ std::vector<Tensor>* fetch_tensors,
+ RunMetadata* run_metadata) {
+ ReplayOp op;
+ RunCallableRequest* req = op.mutable_run_callable();
+ req->set_session_handle(SessionToHandle(session));
+ req->set_handle(handle);
+ for (auto& tensor : feed_tensors) {
+ tensor.AsProtoField(req->add_feed());
+ }
+ RUN_WITH_TIMESTAMP(RunCallable, handle, feed_tensors, fetch_tensors,
+ run_metadata);
+
+ RunCallableResponse* resp = op.mutable_run_callable_response();
+ if (run_metadata) {
+ *resp->mutable_metadata() = *run_metadata;
+ }
+ for (const Tensor& tensor : *fetch_tensors) {
+ tensor.AsProtoTensorContent(resp->add_fetch());
+ }
+ return Flush(op);
+ }
+
+ Status RecordReleaseCallable(Session* session,
+ Session::CallableHandle handle) {
+ ReplayOp op;
+ ReleaseCallableRequest* req = op.mutable_release_callable();
+ req->set_session_handle(SessionToHandle(session));
+ req->set_handle(handle);
+ RUN_WITH_TIMESTAMP(ReleaseCallable, handle);
+ return Flush(op);
+ }
+
+ private:
+ Status Flush(const ReplayOp& op) {
+ string buf;
+ op.SerializeToString(&buf);
+ TF_RETURN_IF_ERROR(log_writer_->WriteRecord(buf));
+
+ // Flushing the RecordWriter _does not_ flush the underlying file.
+ TF_RETURN_IF_ERROR(log_writer_->Flush());
+ return log_file_->Flush();
+ }
+
+ mutex log_mutex_;
+ std::unique_ptr<io::RecordWriter> log_writer_;
+ std::unique_ptr<WritableFile> log_file_;
+};
+
+static SessionLogger* global_session_logger() {
+ static SessionLogger* logger = new SessionLogger();
+ return logger;
+}
+
+SessionRef::SessionRef(Session* session) : session_(session) {
+ if (getenv("TF_REPLAY_LOG_FILE") != nullptr) {
+ logger_ = global_session_logger();
+ logger_->RecordCreateSession(this->session_.get()).IgnoreError();
+ } else {
+ logger_ = nullptr;
+ }
+}
+
+SessionRef::~SessionRef() = default;
+
+Status SessionRef::CheckNotClosed() {
+ mutex_lock l(run_lock_);
+ if (session_ == nullptr) return errors::Cancelled("Session has been closed.");
+ return ::tensorflow::Status::OK();
+}
+
+// If logging is active, log the start and end time of the operation along with
+// the request and response.
+#define LOG_AND_RUN_OPERATION(OpName, ...) \
+ TF_RETURN_IF_ERROR(CheckNotClosed()); \
+ RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); \
+ if (!logger_) { \
+ return rc.session->OpName(__VA_ARGS__); \
+ } \
+ return logger_->Record##OpName(rc.session.get(), __VA_ARGS__);
+
+Status SessionRef::Run(const RunOptions& run_options,
+ const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_tensor_names,
+ const std::vector<string>& target_node_names,
+ std::vector<Tensor>* outputs,
+ RunMetadata* run_metadata) {
+ LOG_AND_RUN_OPERATION(Run, run_options, inputs, output_tensor_names,
+ target_node_names, outputs, run_metadata);
+}
+
+Status SessionRef::Run(const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_tensor_names,
+ const std::vector<string>& target_node_names,
+ std::vector<Tensor>* outputs) {
+ LOG_AND_RUN_OPERATION(Run, inputs, output_tensor_names, target_node_names,
+ outputs);
+}
+
+Status SessionRef::Create(const GraphDef& graph) {
+ LOG_AND_RUN_OPERATION(Create, graph);
+}
+
+Status SessionRef::Create(const RunOptions& run_options,
+ const GraphDef& graph) {
+ LOG_AND_RUN_OPERATION(Create, run_options, graph);
+}
+
+Status SessionRef::Extend(const RunOptions& run_options,
+ const GraphDef& graph) {
+ LOG_AND_RUN_OPERATION(Extend, run_options, graph);
+}
+
+Status SessionRef::Extend(const GraphDef& graph) {
+ LOG_AND_RUN_OPERATION(Extend, graph);
+}
+
+Status SessionRef::ListDevices(std::vector<DeviceAttributes>* response) {
+ LOG_AND_RUN_OPERATION(ListDevices, response);
+}
+
+Status SessionRef::PRunSetup(const std::vector<string>& input_names,
+ const std::vector<string>& output_names,
+ const std::vector<string>& target_nodes,
+ string* handle) {
+ LOG_AND_RUN_OPERATION(PRunSetup, input_names, output_names, target_nodes,
+ handle);
+}
+
+Status SessionRef::PRun(const string& handle,
+ const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_names,
+ std::vector<Tensor>* outputs) {
+ LOG_AND_RUN_OPERATION(PRun, handle, inputs, output_names, outputs);
+}
+
+Status SessionRef::MakeCallable(const CallableOptions& callable_options,
+ CallableHandle* out_handle) {
+ LOG_AND_RUN_OPERATION(MakeCallable, callable_options, out_handle);
+}
+
+Status SessionRef::RunCallable(CallableHandle handle,
+ const std::vector<Tensor>& feed_tensors,
+ std::vector<Tensor>* fetch_tensors,
+ RunMetadata* run_metadata) {
+ LOG_AND_RUN_OPERATION(RunCallable, handle, feed_tensors, fetch_tensors,
+ run_metadata);
+}
+
+Status SessionRef::ReleaseCallable(CallableHandle handle) {
+ LOG_AND_RUN_OPERATION(ReleaseCallable, handle);
+}
+
+Status SessionRef::Close(const RunOptions& run_options) {
+ TF_RETURN_IF_ERROR(CheckNotClosed());
+ mutex_lock l(run_lock_);
+ Status status;
+ if (logger_) {
+ status = logger_->RecordClose(session_.get(), run_options);
+ } else {
+ status = session_->Close(run_options);
+ }
+ session_.reset();
+ while (run_count_ > 0) {
+ run_finished_.wait(l);
+ }
+ return status;
+}
+
+Status SessionRef::Close() {
+ TF_RETURN_IF_ERROR(CheckNotClosed());
+ mutex_lock l(run_lock_);
+ Status status;
+ if (logger_) {
+ status = logger_->RecordClose(session_.get());
+ } else {
+ status = session_->Close();
+ }
+ session_.reset();
+ while (run_count_ > 0) {
+ run_finished_.wait(l);
+ }
+ return status;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/python/client/session_ref.h b/tensorflow/python/client/session_ref.h
new file mode 100644
index 0000000000..b0fb12b189
--- /dev/null
+++ b/tensorflow/python/client/session_ref.h
@@ -0,0 +1,91 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_PYTHON_CLIENT_SESSION_REF_H_
+#define TENSORFLOW_PYTHON_CLIENT_SESSION_REF_H_
+
+#include <memory>
+
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/public/session.h"
+
+namespace tensorflow {
+
+class SessionLogger;
+
+// A `SessionRef` manages the lifetime of a wrapped `Session` pointer.
+//
+// SessionRef blocks the return of Close() until all pending operations have
+// been completed or cancelled and underlying session has been freed. Any
+// subsequent operations on the SessionRef object will return errors::Cancelled.
+class SessionRef : public Session {
+ public:
+ explicit SessionRef(Session* session);
+ ~SessionRef() override;
+
+ Status Create(const GraphDef& graph) override;
+ Status Extend(const GraphDef& graph) override;
+ Status Create(const RunOptions& run_options, const GraphDef& graph) override;
+ Status Extend(const RunOptions& run_options, const GraphDef& graph) override;
+ Status Run(const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_tensor_names,
+ const std::vector<string>& target_node_names,
+ std::vector<Tensor>* outputs) override;
+
+ Status ListDevices(std::vector<DeviceAttributes>* response) override;
+
+ Status Close() override;
+ Status Close(const RunOptions& run_options) override;
+
+ Status Run(const RunOptions& run_options,
+ const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_tensor_names,
+ const std::vector<string>& target_node_names,
+ std::vector<Tensor>* outputs, RunMetadata* run_metadata) override;
+
+ Status PRunSetup(const std::vector<string>& input_names,
+ const std::vector<string>& output_names,
+ const std::vector<string>& target_nodes,
+ string* handle) override;
+
+ Status PRun(const string& handle,
+ const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_names,
+ std::vector<Tensor>* outputs) override;
+
+ Status MakeCallable(const CallableOptions& callable_options,
+ CallableHandle* out_handle) override;
+
+ Status RunCallable(CallableHandle handle,
+ const std::vector<Tensor>& feed_tensors,
+ std::vector<Tensor>* fetch_tensors,
+ RunMetadata* run_metadata) override;
+
+ Status ReleaseCallable(CallableHandle handle) override;
+
+ private:
+ mutex run_lock_;
+ condition_variable run_finished_;
+ uint64 run_count_ GUARDED_BY(run_lock_) = {0};
+ std::shared_ptr<Session> session_;
+
+ // Borrowed reference to global session logger.
+ SessionLogger* logger_;
+
+ Status CheckNotClosed();
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_PYTHON_CLIENT_SESSION_REF_H_
diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc
index bcd4af2912..dc0c10bab7 100644
--- a/tensorflow/python/client/tf_session_helper.cc
+++ b/tensorflow/python/client/tf_session_helper.cc
@@ -20,7 +20,6 @@ limitations under the License.
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/tf_status_helper.h"
-#include "tensorflow/core/common_runtime/session_ref.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/attr_value_util.h"
@@ -31,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/equal_graph_def.h"
+#include "tensorflow/python/client/session_ref.h"
#include "tensorflow/python/lib/core/ndarray_tensor.h"
#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
#include "tensorflow/python/lib/core/safe_ptr.h"