diff options
Diffstat (limited to 'tensorflow/core/common_runtime/direct_session.cc')
-rw-r--r-- | tensorflow/core/common_runtime/direct_session.cc | 54 |
1 files changed, 30 insertions, 24 deletions
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 87ba609dd7..d1fd930d25 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -146,18 +146,15 @@ class DirectSessionFactory : public SessionFactory { return options.target.empty(); } - Session* NewSession(const SessionOptions& options) override { + Status NewSession(const SessionOptions& options, + Session** out_session) override { // Must do this before the CPU allocator is created. if (options.config.graph_options().build_cost_model() > 0) { EnableCPUAllocatorFullStats(true); } std::vector<Device*> devices; - const Status s = DeviceFactory::AddDevices( - options, "/job:localhost/replica:0/task:0", &devices); - if (!s.ok()) { - LOG(ERROR) << s; - return nullptr; - } + TF_RETURN_IF_ERROR(DeviceFactory::AddDevices( + options, "/job:localhost/replica:0/task:0", &devices)); DirectSession* session = new DirectSession(options, new DeviceMgr(devices), this); @@ -165,7 +162,8 @@ class DirectSessionFactory : public SessionFactory { mutex_lock l(sessions_lock_); sessions_.push_back(session); } - return session; + *out_session = session; + return Status::OK(); } Status Reset(const SessionOptions& options, @@ -237,7 +235,11 @@ void DirectSession::SchedClosure(thread::ThreadPool* pool, // safe given the reasoning above. c(); #else - pool->Schedule(std::move(c)); + if (pool != nullptr) { + pool->Schedule(std::move(c)); + } else { + c(); + } #endif // __ANDROID__ } @@ -524,8 +526,9 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options, } } - if (run_options.inter_op_thread_pool() < 0 || - run_options.inter_op_thread_pool() >= thread_pools_.size()) { + if (run_options.inter_op_thread_pool() < -1 || + run_options.inter_op_thread_pool() >= + static_cast<int32>(thread_pools_.size())) { run_state.executors_done.Notify(); delete barrier; return errors::InvalidArgument("Invalid inter_op_thread_pool: ", @@ -550,7 +553,19 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options, } thread::ThreadPool* pool = - thread_pools_[run_options.inter_op_thread_pool()].first; + run_options.inter_op_thread_pool() >= 0 + ? thread_pools_[run_options.inter_op_thread_pool()].first + : nullptr; + + if (pool == nullptr) { + // We allow using the caller thread only when having a single executor + // specified. + if (executors_and_keys->items.size() > 1) { + pool = thread_pools_[0].first; + } else { + VLOG(1) << "Executing Session::Run() synchronously!"; + } + } Executor::Args::Runner default_runner = [this, pool](Executor::Args::Closure c) { @@ -702,7 +717,8 @@ Status DirectSession::Run(const RunOptions& run_options, // Receive outputs. if (outputs) { std::vector<Tensor> sorted_outputs; - const Status s = call_frame.ConsumeRetvals(&sorted_outputs); + const Status s = call_frame.ConsumeRetvals( + &sorted_outputs, /* allow_dead_tensors = */ false); if (errors::IsInternal(s)) { return errors::InvalidArgument(s.error_message()); } else if (!s.ok()) { @@ -1188,12 +1204,11 @@ Status DirectSession::CreateExecutors( delete kernel; } }; - params.node_outputs_cb = node_outputs_callback_; optimizer.Optimize(lib, options_.env, device, &iter->second, /*shape_map=*/nullptr); - // EXPERIMENTAL: tfdbg inserts debug nodes in the graph. + // TensorFlow Debugger (tfdbg) inserts debug nodes in the graph. const DebugOptions& debug_options = options.callable_options.run_options().debug_options(); if (!debug_options.debug_tensor_watch_opts().empty()) { @@ -1626,15 +1641,6 @@ Status DirectSession::MakeCallable(const CallableOptions& callable_options, TF_RETURN_IF_ERROR(CheckNotClosed()); TF_RETURN_IF_ERROR(CheckGraphCreated("MakeCallable()")); - if (!callable_options.run_options() - .debug_options() - .debug_tensor_watch_opts() - .empty()) { - return errors::Unimplemented( - "Debug options are not currently supported via the C++ MakeCallable " - "interface."); - } - std::unique_ptr<ExecutorsAndKeys> ek; std::unique_ptr<FunctionInfo> func_info; RunStateArgs run_state_args(callable_options.run_options().debug_options()); |