aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/direct_session.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/common_runtime/direct_session.cc')
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc54
1 files changed, 30 insertions, 24 deletions
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 87ba609dd7..d1fd930d25 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -146,18 +146,15 @@ class DirectSessionFactory : public SessionFactory {
return options.target.empty();
}
- Session* NewSession(const SessionOptions& options) override {
+ Status NewSession(const SessionOptions& options,
+ Session** out_session) override {
// Must do this before the CPU allocator is created.
if (options.config.graph_options().build_cost_model() > 0) {
EnableCPUAllocatorFullStats(true);
}
std::vector<Device*> devices;
- const Status s = DeviceFactory::AddDevices(
- options, "/job:localhost/replica:0/task:0", &devices);
- if (!s.ok()) {
- LOG(ERROR) << s;
- return nullptr;
- }
+ TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(
+ options, "/job:localhost/replica:0/task:0", &devices));
DirectSession* session =
new DirectSession(options, new DeviceMgr(devices), this);
@@ -165,7 +162,8 @@ class DirectSessionFactory : public SessionFactory {
mutex_lock l(sessions_lock_);
sessions_.push_back(session);
}
- return session;
+ *out_session = session;
+ return Status::OK();
}
Status Reset(const SessionOptions& options,
@@ -237,7 +235,11 @@ void DirectSession::SchedClosure(thread::ThreadPool* pool,
// safe given the reasoning above.
c();
#else
- pool->Schedule(std::move(c));
+ if (pool != nullptr) {
+ pool->Schedule(std::move(c));
+ } else {
+ c();
+ }
#endif // __ANDROID__
}
@@ -524,8 +526,9 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
}
}
- if (run_options.inter_op_thread_pool() < 0 ||
- run_options.inter_op_thread_pool() >= thread_pools_.size()) {
+ if (run_options.inter_op_thread_pool() < -1 ||
+ run_options.inter_op_thread_pool() >=
+ static_cast<int32>(thread_pools_.size())) {
run_state.executors_done.Notify();
delete barrier;
return errors::InvalidArgument("Invalid inter_op_thread_pool: ",
@@ -550,7 +553,19 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
}
thread::ThreadPool* pool =
- thread_pools_[run_options.inter_op_thread_pool()].first;
+ run_options.inter_op_thread_pool() >= 0
+ ? thread_pools_[run_options.inter_op_thread_pool()].first
+ : nullptr;
+
+ if (pool == nullptr) {
+ // We allow using the caller thread only when having a single executor
+ // specified.
+ if (executors_and_keys->items.size() > 1) {
+ pool = thread_pools_[0].first;
+ } else {
+ VLOG(1) << "Executing Session::Run() synchronously!";
+ }
+ }
Executor::Args::Runner default_runner = [this,
pool](Executor::Args::Closure c) {
@@ -702,7 +717,8 @@ Status DirectSession::Run(const RunOptions& run_options,
// Receive outputs.
if (outputs) {
std::vector<Tensor> sorted_outputs;
- const Status s = call_frame.ConsumeRetvals(&sorted_outputs);
+ const Status s = call_frame.ConsumeRetvals(
+ &sorted_outputs, /* allow_dead_tensors = */ false);
if (errors::IsInternal(s)) {
return errors::InvalidArgument(s.error_message());
} else if (!s.ok()) {
@@ -1188,12 +1204,11 @@ Status DirectSession::CreateExecutors(
delete kernel;
}
};
- params.node_outputs_cb = node_outputs_callback_;
optimizer.Optimize(lib, options_.env, device, &iter->second,
/*shape_map=*/nullptr);
- // EXPERIMENTAL: tfdbg inserts debug nodes in the graph.
+ // TensorFlow Debugger (tfdbg) inserts debug nodes in the graph.
const DebugOptions& debug_options =
options.callable_options.run_options().debug_options();
if (!debug_options.debug_tensor_watch_opts().empty()) {
@@ -1626,15 +1641,6 @@ Status DirectSession::MakeCallable(const CallableOptions& callable_options,
TF_RETURN_IF_ERROR(CheckNotClosed());
TF_RETURN_IF_ERROR(CheckGraphCreated("MakeCallable()"));
- if (!callable_options.run_options()
- .debug_options()
- .debug_tensor_watch_opts()
- .empty()) {
- return errors::Unimplemented(
- "Debug options are not currently supported via the C++ MakeCallable "
- "interface.");
- }
-
std::unique_ptr<ExecutorsAndKeys> ek;
std::unique_ptr<FunctionInfo> func_info;
RunStateArgs run_state_args(callable_options.run_options().debug_options());