diff options
author | Mingsheng Hong <hongm@google.com> | 2018-09-04 09:41:42 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-04 09:44:58 -0700 |
commit | 102e0de242eccb2ac4664761183a7771b0a7c7af (patch) | |
tree | 3a036b7c44fdaa4744ee074688774cb71cb3cb38 | |
parent | 8f0487189a0310c0ce2d14077d0755fffe74a83c (diff) |
Added a new eager C API TFE_NewContextFromSession(), where TFE_NewContext will
get an owned device mgr from the input session.
One use case is in S4TF, we run a graph session to enqueue a tensor into a fifo
queue, and then call TFE_Execute() on a dequeue op over the same queue, as a way
to transfer a tensor from TF to host (tensor tranfer in the other direction also
works).
To make this work, we need TFE_Context and the the TF_Session to use the same
ResourceMgr object (attached to a Device, which is in turn owned by DeviceMgr),
so that both can access the fifo queue resource op.
PiperOrigin-RevId: 211471075
-rw-r--r-- | tensorflow/c/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/c/c_api_experimental.h | 4 | ||||
-rwxr-xr-x | tensorflow/c/eager/c_api.cc | 15 | ||||
-rw-r--r-- | tensorflow/c/eager/c_api_internal.h | 11 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/eager/context.cc | 24 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/eager/context.h | 19 |
6 files changed, 55 insertions, 19 deletions
diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 2c3a877edf..109b3b37aa 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -117,6 +117,7 @@ tf_cuda_library( deps = [ ":c_api", ":c_api_internal", + "//tensorflow/c/eager:c_api", "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", "//tensorflow/contrib/tpu:all_ops", "//tensorflow/core:core_cpu", diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index 6617c5a572..09d482d6df 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -20,6 +20,7 @@ limitations under the License. #include <stdint.h> #include "tensorflow/c/c_api.h" +#include "tensorflow/c/eager/c_api.h" // -------------------------------------------------------------------------- // Experimental C API for TensorFlow. @@ -131,6 +132,9 @@ TF_CAPI_EXPORT extern void TF_EnqueueNamedTensor(TF_Session* session, TF_Tensor* tensor, TF_Status* status); +TF_CAPI_EXPORT extern TFE_Context* TFE_NewContextFromSession( + const TFE_ContextOptions* opts, TF_Session* sess, TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 1ccae3f138..77e3878a94 100755 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -273,7 +273,20 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { new tensorflow::IntraProcessRendezvous(device_mgr.get()); return new TFE_Context(opts->session_options.options, opts->policy, - opts->async, std::move(device_mgr), r); + opts->async, device_mgr.release(), + /*device_mgr_owned*/ true, r); +} + +TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts, + TF_Session* sess, TF_Status* status) { + const tensorflow::DeviceMgr* device_mgr = nullptr; + status->status = sess->session->LocalDeviceManager(&device_mgr); + if (!status->status.ok()) return nullptr; + tensorflow::Rendezvous* r = + new tensorflow::IntraProcessRendezvous(device_mgr); + return new TFE_Context(opts->session_options.options, opts->policy, + opts->async, device_mgr, /*device_mgr_owned*/ false, + r); } void TFE_DeleteContext(TFE_Context* ctx) { delete ctx; } diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index a5c0681e2e..104d52430c 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -62,15 +62,14 @@ struct TFE_ContextOptions { }; struct TFE_Context { - explicit TFE_Context(const tensorflow::SessionOptions& opts, - TFE_ContextDevicePlacementPolicy default_policy, - bool async, - std::unique_ptr<tensorflow::DeviceMgr> device_mgr, - tensorflow::Rendezvous* rendezvous) + TFE_Context(const tensorflow::SessionOptions& opts, + TFE_ContextDevicePlacementPolicy default_policy, bool async, + const tensorflow::DeviceMgr* device_mgr, bool device_mgr_owned, + tensorflow::Rendezvous* rendezvous) : context(opts, static_cast<tensorflow::ContextDevicePlacementPolicy>( default_policy), - async, std::move(device_mgr), rendezvous) {} + async, device_mgr, device_mgr_owned, rendezvous) {} tensorflow::EagerContext context; }; diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index 39a3b49cd1..879a794368 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -36,22 +36,34 @@ bool ReadBoolFromEnvVar(StringPiece env_var_name, bool default_val) { EagerContext::EagerContext(const SessionOptions& opts, ContextDevicePlacementPolicy default_policy, - bool async, std::unique_ptr<DeviceMgr> device_mgr, + bool async, + std::unique_ptr<const DeviceMgr> device_mgr, Rendezvous* rendezvous) + : EagerContext(opts, default_policy, async, device_mgr.release(), + /*device_mgr_owned*/ true, rendezvous) {} + +EagerContext::EagerContext(const SessionOptions& opts, + ContextDevicePlacementPolicy default_policy, + bool async, const DeviceMgr* device_mgr, + bool device_mgr_owned, Rendezvous* rendezvous) : policy_(default_policy), - local_device_manager_(std::move(device_mgr)), - local_unowned_device_manager_(nullptr), - devices_(local_device_manager_->ListDevices()), + devices_(device_mgr->ListDevices()), rendezvous_(rendezvous), thread_pool_(NewThreadPoolFromSessionOptions(opts)), pflr_(new ProcessFunctionLibraryRuntime( - local_device_manager_.get(), opts.env, TF_GRAPH_DEF_VERSION, - &func_lib_def_, {}, thread_pool_.get())), + device_mgr, opts.env, TF_GRAPH_DEF_VERSION, &func_lib_def_, {}, + thread_pool_.get())), log_device_placement_(opts.config.log_device_placement()), num_active_steps_(0), async_default_(async), env_(opts.env), use_send_tensor_rpc_(false) { + if (device_mgr_owned) { + local_device_manager_.reset(device_mgr); + local_unowned_device_manager_ = nullptr; + } else { + local_unowned_device_manager_ = device_mgr; + } InitDeviceMapAndAsync(); if (opts.config.inter_op_parallelism_threads() > 0) { runner_ = [this](std::function<void()> closure) { diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index 3c95ac590d..eb6eb0d55a 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -65,10 +65,17 @@ enum ContextDevicePlacementPolicy { class EagerContext { public: - explicit EagerContext(const SessionOptions& opts, - ContextDevicePlacementPolicy default_policy, bool async, - std::unique_ptr<DeviceMgr> device_mgr, - Rendezvous* rendezvous); + // TODO: remove this constructor once we migrate all callers to the next one. + EagerContext(const SessionOptions& opts, + ContextDevicePlacementPolicy default_policy, bool async, + std::unique_ptr<const DeviceMgr> device_mgr, + Rendezvous* rendezvous); + + EagerContext(const SessionOptions& opts, + ContextDevicePlacementPolicy default_policy, bool async, + const DeviceMgr* device_mgr, bool device_mgr_owned, + Rendezvous* rendezvous); + ~EagerContext(); // Returns the function library runtime for the given device. @@ -207,8 +214,8 @@ class EagerContext { thread_local_policies_ GUARDED_BY(policy_map_mu_); // Only one of the below is set. - std::unique_ptr<DeviceMgr> local_device_manager_; - DeviceMgr* local_unowned_device_manager_; + std::unique_ptr<const DeviceMgr> local_device_manager_; + const DeviceMgr* local_unowned_device_manager_; std::unique_ptr<DeviceMgr> remote_device_manager_; // Devices owned by device_manager |