aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Mingsheng Hong <hongm@google.com>2018-09-04 09:41:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-04 09:44:58 -0700
commit102e0de242eccb2ac4664761183a7771b0a7c7af (patch)
tree3a036b7c44fdaa4744ee074688774cb71cb3cb38
parent8f0487189a0310c0ce2d14077d0755fffe74a83c (diff)
Added a new eager C API TFE_NewContextFromSession(), where TFE_NewContext will
get an owned device mgr from the input session. One use case is in S4TF, we run a graph session to enqueue a tensor into a fifo queue, and then call TFE_Execute() on a dequeue op over the same queue, as a way to transfer a tensor from TF to host (tensor tranfer in the other direction also works). To make this work, we need TFE_Context and the the TF_Session to use the same ResourceMgr object (attached to a Device, which is in turn owned by DeviceMgr), so that both can access the fifo queue resource op. PiperOrigin-RevId: 211471075
-rw-r--r--tensorflow/c/BUILD1
-rw-r--r--tensorflow/c/c_api_experimental.h4
-rwxr-xr-xtensorflow/c/eager/c_api.cc15
-rw-r--r--tensorflow/c/eager/c_api_internal.h11
-rw-r--r--tensorflow/core/common_runtime/eager/context.cc24
-rw-r--r--tensorflow/core/common_runtime/eager/context.h19
6 files changed, 55 insertions, 19 deletions
diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD
index 2c3a877edf..109b3b37aa 100644
--- a/tensorflow/c/BUILD
+++ b/tensorflow/c/BUILD
@@ -117,6 +117,7 @@ tf_cuda_library(
deps = [
":c_api",
":c_api_internal",
+ "//tensorflow/c/eager:c_api",
"//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags",
"//tensorflow/contrib/tpu:all_ops",
"//tensorflow/core:core_cpu",
diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h
index 6617c5a572..09d482d6df 100644
--- a/tensorflow/c/c_api_experimental.h
+++ b/tensorflow/c/c_api_experimental.h
@@ -20,6 +20,7 @@ limitations under the License.
#include <stdint.h>
#include "tensorflow/c/c_api.h"
+#include "tensorflow/c/eager/c_api.h"
// --------------------------------------------------------------------------
// Experimental C API for TensorFlow.
@@ -131,6 +132,9 @@ TF_CAPI_EXPORT extern void TF_EnqueueNamedTensor(TF_Session* session,
TF_Tensor* tensor,
TF_Status* status);
+TF_CAPI_EXPORT extern TFE_Context* TFE_NewContextFromSession(
+ const TFE_ContextOptions* opts, TF_Session* sess, TF_Status* status);
+
#ifdef __cplusplus
} /* end extern "C" */
#endif
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 1ccae3f138..77e3878a94 100755
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -273,7 +273,20 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
new tensorflow::IntraProcessRendezvous(device_mgr.get());
return new TFE_Context(opts->session_options.options, opts->policy,
- opts->async, std::move(device_mgr), r);
+ opts->async, device_mgr.release(),
+ /*device_mgr_owned*/ true, r);
+}
+
+TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
+ TF_Session* sess, TF_Status* status) {
+ const tensorflow::DeviceMgr* device_mgr = nullptr;
+ status->status = sess->session->LocalDeviceManager(&device_mgr);
+ if (!status->status.ok()) return nullptr;
+ tensorflow::Rendezvous* r =
+ new tensorflow::IntraProcessRendezvous(device_mgr);
+ return new TFE_Context(opts->session_options.options, opts->policy,
+ opts->async, device_mgr, /*device_mgr_owned*/ false,
+ r);
}
void TFE_DeleteContext(TFE_Context* ctx) { delete ctx; }
diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h
index a5c0681e2e..104d52430c 100644
--- a/tensorflow/c/eager/c_api_internal.h
+++ b/tensorflow/c/eager/c_api_internal.h
@@ -62,15 +62,14 @@ struct TFE_ContextOptions {
};
struct TFE_Context {
- explicit TFE_Context(const tensorflow::SessionOptions& opts,
- TFE_ContextDevicePlacementPolicy default_policy,
- bool async,
- std::unique_ptr<tensorflow::DeviceMgr> device_mgr,
- tensorflow::Rendezvous* rendezvous)
+ TFE_Context(const tensorflow::SessionOptions& opts,
+ TFE_ContextDevicePlacementPolicy default_policy, bool async,
+ const tensorflow::DeviceMgr* device_mgr, bool device_mgr_owned,
+ tensorflow::Rendezvous* rendezvous)
: context(opts,
static_cast<tensorflow::ContextDevicePlacementPolicy>(
default_policy),
- async, std::move(device_mgr), rendezvous) {}
+ async, device_mgr, device_mgr_owned, rendezvous) {}
tensorflow::EagerContext context;
};
diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc
index 39a3b49cd1..879a794368 100644
--- a/tensorflow/core/common_runtime/eager/context.cc
+++ b/tensorflow/core/common_runtime/eager/context.cc
@@ -36,22 +36,34 @@ bool ReadBoolFromEnvVar(StringPiece env_var_name, bool default_val) {
EagerContext::EagerContext(const SessionOptions& opts,
ContextDevicePlacementPolicy default_policy,
- bool async, std::unique_ptr<DeviceMgr> device_mgr,
+ bool async,
+ std::unique_ptr<const DeviceMgr> device_mgr,
Rendezvous* rendezvous)
+ : EagerContext(opts, default_policy, async, device_mgr.release(),
+ /*device_mgr_owned*/ true, rendezvous) {}
+
+EagerContext::EagerContext(const SessionOptions& opts,
+ ContextDevicePlacementPolicy default_policy,
+ bool async, const DeviceMgr* device_mgr,
+ bool device_mgr_owned, Rendezvous* rendezvous)
: policy_(default_policy),
- local_device_manager_(std::move(device_mgr)),
- local_unowned_device_manager_(nullptr),
- devices_(local_device_manager_->ListDevices()),
+ devices_(device_mgr->ListDevices()),
rendezvous_(rendezvous),
thread_pool_(NewThreadPoolFromSessionOptions(opts)),
pflr_(new ProcessFunctionLibraryRuntime(
- local_device_manager_.get(), opts.env, TF_GRAPH_DEF_VERSION,
- &func_lib_def_, {}, thread_pool_.get())),
+ device_mgr, opts.env, TF_GRAPH_DEF_VERSION, &func_lib_def_, {},
+ thread_pool_.get())),
log_device_placement_(opts.config.log_device_placement()),
num_active_steps_(0),
async_default_(async),
env_(opts.env),
use_send_tensor_rpc_(false) {
+ if (device_mgr_owned) {
+ local_device_manager_.reset(device_mgr);
+ local_unowned_device_manager_ = nullptr;
+ } else {
+ local_unowned_device_manager_ = device_mgr;
+ }
InitDeviceMapAndAsync();
if (opts.config.inter_op_parallelism_threads() > 0) {
runner_ = [this](std::function<void()> closure) {
diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h
index 3c95ac590d..eb6eb0d55a 100644
--- a/tensorflow/core/common_runtime/eager/context.h
+++ b/tensorflow/core/common_runtime/eager/context.h
@@ -65,10 +65,17 @@ enum ContextDevicePlacementPolicy {
class EagerContext {
public:
- explicit EagerContext(const SessionOptions& opts,
- ContextDevicePlacementPolicy default_policy, bool async,
- std::unique_ptr<DeviceMgr> device_mgr,
- Rendezvous* rendezvous);
+ // TODO: remove this constructor once we migrate all callers to the next one.
+ EagerContext(const SessionOptions& opts,
+ ContextDevicePlacementPolicy default_policy, bool async,
+ std::unique_ptr<const DeviceMgr> device_mgr,
+ Rendezvous* rendezvous);
+
+ EagerContext(const SessionOptions& opts,
+ ContextDevicePlacementPolicy default_policy, bool async,
+ const DeviceMgr* device_mgr, bool device_mgr_owned,
+ Rendezvous* rendezvous);
+
~EagerContext();
// Returns the function library runtime for the given device.
@@ -207,8 +214,8 @@ class EagerContext {
thread_local_policies_ GUARDED_BY(policy_map_mu_);
// Only one of the below is set.
- std::unique_ptr<DeviceMgr> local_device_manager_;
- DeviceMgr* local_unowned_device_manager_;
+ std::unique_ptr<const DeviceMgr> local_device_manager_;
+ const DeviceMgr* local_unowned_device_manager_;
std::unique_ptr<DeviceMgr> remote_device_manager_;
// Devices owned by device_manager