aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c
diff options
context:
space:
mode:
authorGravatar Mingsheng Hong <hongm@google.com>2018-09-04 09:41:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-04 09:44:58 -0700
commit102e0de242eccb2ac4664761183a7771b0a7c7af (patch)
tree3a036b7c44fdaa4744ee074688774cb71cb3cb38 /tensorflow/c
parent8f0487189a0310c0ce2d14077d0755fffe74a83c (diff)
Added a new eager C API TFE_NewContextFromSession(), where TFE_NewContext will
get an owned device mgr from the input session. One use case is in S4TF, we run a graph session to enqueue a tensor into a fifo queue, and then call TFE_Execute() on a dequeue op over the same queue, as a way to transfer a tensor from TF to host (tensor tranfer in the other direction also works). To make this work, we need TFE_Context and the the TF_Session to use the same ResourceMgr object (attached to a Device, which is in turn owned by DeviceMgr), so that both can access the fifo queue resource op. PiperOrigin-RevId: 211471075
Diffstat (limited to 'tensorflow/c')
-rw-r--r--tensorflow/c/BUILD1
-rw-r--r--tensorflow/c/c_api_experimental.h4
-rwxr-xr-xtensorflow/c/eager/c_api.cc15
-rw-r--r--tensorflow/c/eager/c_api_internal.h11
4 files changed, 24 insertions, 7 deletions
diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD
index 2c3a877edf..109b3b37aa 100644
--- a/tensorflow/c/BUILD
+++ b/tensorflow/c/BUILD
@@ -117,6 +117,7 @@ tf_cuda_library(
deps = [
":c_api",
":c_api_internal",
+ "//tensorflow/c/eager:c_api",
"//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags",
"//tensorflow/contrib/tpu:all_ops",
"//tensorflow/core:core_cpu",
diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h
index 6617c5a572..09d482d6df 100644
--- a/tensorflow/c/c_api_experimental.h
+++ b/tensorflow/c/c_api_experimental.h
@@ -20,6 +20,7 @@ limitations under the License.
#include <stdint.h>
#include "tensorflow/c/c_api.h"
+#include "tensorflow/c/eager/c_api.h"
// --------------------------------------------------------------------------
// Experimental C API for TensorFlow.
@@ -131,6 +132,9 @@ TF_CAPI_EXPORT extern void TF_EnqueueNamedTensor(TF_Session* session,
TF_Tensor* tensor,
TF_Status* status);
+TF_CAPI_EXPORT extern TFE_Context* TFE_NewContextFromSession(
+ const TFE_ContextOptions* opts, TF_Session* sess, TF_Status* status);
+
#ifdef __cplusplus
} /* end extern "C" */
#endif
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 1ccae3f138..77e3878a94 100755
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -273,7 +273,20 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
new tensorflow::IntraProcessRendezvous(device_mgr.get());
return new TFE_Context(opts->session_options.options, opts->policy,
- opts->async, std::move(device_mgr), r);
+ opts->async, device_mgr.release(),
+ /*device_mgr_owned*/ true, r);
+}
+
+TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
+ TF_Session* sess, TF_Status* status) {
+ const tensorflow::DeviceMgr* device_mgr = nullptr;
+ status->status = sess->session->LocalDeviceManager(&device_mgr);
+ if (!status->status.ok()) return nullptr;
+ tensorflow::Rendezvous* r =
+ new tensorflow::IntraProcessRendezvous(device_mgr);
+ return new TFE_Context(opts->session_options.options, opts->policy,
+ opts->async, device_mgr, /*device_mgr_owned*/ false,
+ r);
}
void TFE_DeleteContext(TFE_Context* ctx) { delete ctx; }
diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h
index a5c0681e2e..104d52430c 100644
--- a/tensorflow/c/eager/c_api_internal.h
+++ b/tensorflow/c/eager/c_api_internal.h
@@ -62,15 +62,14 @@ struct TFE_ContextOptions {
};
struct TFE_Context {
- explicit TFE_Context(const tensorflow::SessionOptions& opts,
- TFE_ContextDevicePlacementPolicy default_policy,
- bool async,
- std::unique_ptr<tensorflow::DeviceMgr> device_mgr,
- tensorflow::Rendezvous* rendezvous)
+ TFE_Context(const tensorflow::SessionOptions& opts,
+ TFE_ContextDevicePlacementPolicy default_policy, bool async,
+ const tensorflow::DeviceMgr* device_mgr, bool device_mgr_owned,
+ tensorflow::Rendezvous* rendezvous)
: context(opts,
static_cast<tensorflow::ContextDevicePlacementPolicy>(
default_policy),
- async, std::move(device_mgr), rendezvous) {}
+ async, device_mgr, device_mgr_owned, rendezvous) {}
tensorflow::EagerContext context;
};