aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_api_experimental.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/c/c_api_experimental.cc')
-rw-r--r--tensorflow/c/c_api_experimental.cc210
1 files changed, 210 insertions, 0 deletions
diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc
index 69b3ffe2a1..c046bd66cd 100644
--- a/tensorflow/c/c_api_experimental.cc
+++ b/tensorflow/c/c_api_experimental.cc
@@ -79,6 +79,18 @@ TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation,
auto* gpu_options = config.mutable_gpu_options();
gpu_options->set_allow_growth(gpu_memory_allow_growth);
+ // TODO(b/113217601): This is needed for EagerContext::runner_ to use a
+ // threadpool, so that we avoid the possibility of running the runner_ in the
+ // threadpool of GPU event mgr, as that can trigger more callbacks to be
+ // scheduled on that same threadpool, causing a deadlock in cases where the
+ // caller of event_mgr->ThenExecute() blocks on the completion of the callback
+ // (as in the case of ConstOp kernel creation on GPU, which involves copying a
+ // CPU tensor to GPU).
+ // Setting a larger thread pool does not help with the Swift caller, as we use
+ // a different TFE context for each thread of execution (for running graph
+ // functions, and their send/recvs corountines).
+ config.set_inter_op_parallelism_threads(1);
+
TF_Buffer* ret = TF_NewBuffer();
TF_CHECK_OK(MessageToBuffer(config, ret));
return ret;
@@ -8494,3 +8506,201 @@ void TF_EnqueueNamedTensor(TF_Session* session, int tensor_id,
/*run_metadata*/ nullptr, status);
VLOG(1) << "Enqueuing is done.";
}
+
+TFE_Context* TFE_CreateContextFromSession(TF_Session* session,
+ TF_Status* status) {
+ auto* opts = TFE_NewContextOptions();
+
+ // Reduce GPU memory allocation, and set appropriate config options for TFE
+ // context.
+ auto* config =
+ TF_CreateConfig(/*xla*/ false, /* gpu_memory_allow_growth */ true);
+ TFE_ContextOptionsSetConfig(opts, config->data, config->length, status);
+ if (!status->status.ok()) {
+ CHECK(!config);
+ TFE_DeleteContextOptions(opts);
+ return nullptr;
+ }
+
+ auto* ctx = TFE_NewContextFromSession(opts, session, status);
+ TF_DeleteBuffer(config);
+ TFE_DeleteContextOptions(opts);
+ return ctx;
+}
+
+// TODO: retrieve the device string via TFE_ContextListDevices()
+static const char DEFAULT_CPU_DEVICE[] =
+ "/job:localhost/replica:0/task:0/device:CPU:0";
+
+static TFE_TensorHandle* createTFEQueue(TFE_Context* ctx, TF_DataType inputType,
+ int tensor_id, TF_Status* status) {
+ std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> queueOp(
+ TFE_NewOp(ctx, "FIFOQueueV2", status), TFE_DeleteOp);
+ TFE_OpSetDevice(queueOp.get(), DEFAULT_CPU_DEVICE, status);
+ if (!status->status.ok()) return nullptr;
+ // TODO: use NAMED_TENSOR_QUEUE_CAPACITY in S4TF compiler.
+ TFE_OpSetAttrInt(queueOp.get(), "capacity", 1);
+ TFE_OpSetAttrTypeList(queueOp.get(), "component_types", &inputType, 1);
+ auto shared_name = tensorflow::strings::StrCat("fifo_queue_", tensor_id);
+ TFE_OpSetAttrString(queueOp.get(), "shared_name", shared_name.data(),
+ shared_name.size());
+ TFE_OpSetAttrString(queueOp.get(), "container", "", 0);
+
+ // TODO: consider making this an unknown shape.
+ const int64_t* dims_ptr = nullptr;
+ int num_dims = 0;
+ TFE_OpSetAttrShapeList(queueOp.get(), "shapes", &dims_ptr, &num_dims,
+ /*num_values*/ 0, status);
+ if (!status->status.ok()) return nullptr;
+
+ int num_retvals = 1;
+ TFE_TensorHandle* queue = nullptr;
+ TFE_Execute(queueOp.get(), &queue, &num_retvals, status);
+ if (!status->status.ok()) return nullptr;
+ CHECK_EQ(num_retvals, 1);
+
+ return queue;
+}
+
+static void createTFEEnqueue(TFE_Context* ctx, TF_DataType inputType,
+ TFE_TensorHandle* queue, TFE_TensorHandle* tensor,
+ TF_Status* status) {
+ TFE_Op* op = TFE_NewOp(ctx, "QueueEnqueueV2", status);
+ if (!status->status.ok()) return;
+ std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op_deleter(op, TFE_DeleteOp);
+ TFE_OpSetDevice(op, DEFAULT_CPU_DEVICE, status);
+ if (!status->status.ok()) return;
+ TFE_OpAddInput(op, queue, status);
+ if (!status->status.ok()) return;
+ TFE_OpAddInput(op, tensor, status);
+ if (!status->status.ok()) return;
+ TFE_OpSetAttrTypeList(op, "Tcomponents", &inputType, 1);
+ TFE_OpSetAttrInt(op, "timeout_ms", -1);
+
+ int num_retvals = 0;
+ TFE_Execute(op, nullptr /*retvals*/, &num_retvals, status);
+ if (!status->status.ok()) return;
+ CHECK_EQ(num_retvals, 0);
+}
+
+static TFE_TensorHandle* createTFEDequeue(TFE_Context* ctx,
+ TF_DataType inputType,
+ TFE_TensorHandle* queue,
+ TF_Status* status) {
+ TFE_Op* op = TFE_NewOp(ctx, "QueueDequeueV2", status);
+ if (!status->status.ok()) return nullptr;
+ std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op_deleter(op, TFE_DeleteOp);
+ TFE_OpSetDevice(op, DEFAULT_CPU_DEVICE, status);
+ if (!status->status.ok()) return nullptr;
+
+ TFE_OpAddInput(op, queue, status);
+ if (!status->status.ok()) return nullptr;
+ TFE_OpSetAttrTypeList(op, "component_types", &inputType, 1);
+ TFE_OpSetAttrInt(op, "timeout_ms", -1);
+ TFE_TensorHandle* ret;
+ int num_retvals = 1;
+ TFE_Execute(op, &ret, &num_retvals, status);
+ if (!status->status.ok()) return nullptr;
+ CHECK_EQ(num_retvals, 1);
+ return ret;
+}
+
+TFE_TensorHandle* TFE_DequeueNamedTensor(TF_Session* session, int tensor_id,
+ TF_DataType inputType,
+ TF_Status* status) {
+ assert(session);
+ VLOG(1) << "Dequeuing data tensor with id " << tensor_id;
+
+ auto ctx = TFE_CreateContextFromSession(session, status);
+ if (!status->status.ok()) return nullptr;
+ std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
+ ctx, TFE_DeleteContext);
+
+ TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
+ if (!status->status.ok()) return nullptr;
+ std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
+ queue_deleter(queue, TFE_DeleteTensorHandle);
+
+ auto* ret = createTFEDequeue(ctx, inputType, queue, status);
+ return ret;
+}
+
+TFE_TensorHandle* TFE_DequeueNamedTensorFromCtx(TFE_Context* ctx, int tensor_id,
+ TF_DataType inputType,
+ TF_Status* status) {
+ TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
+ if (!status->status.ok()) return nullptr;
+ std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
+ queue_deleter(queue, TFE_DeleteTensorHandle);
+
+ auto* ret = createTFEDequeue(ctx, inputType, queue, status);
+
+ return ret;
+}
+
+void TFE_EnqueueNamedTensor(TF_Session* session, int tensor_id,
+ TFE_TensorHandle* tensor, TF_Status* status) {
+ assert(session);
+ VLOG(1) << "Enqueuing data tensor with id " << tensor_id;
+
+ auto ctx = TFE_CreateContextFromSession(session, status);
+ if (!status->status.ok()) return;
+ std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
+ ctx, TFE_DeleteContext);
+
+ TF_DataType inputType = TFE_TensorHandleDataType(tensor);
+ TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
+ if (!status->status.ok()) return;
+ std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
+ queue_deleter(queue, TFE_DeleteTensorHandle);
+
+ createTFEEnqueue(ctx, inputType, queue, tensor, status);
+}
+
+void TFE_EnqueueNamedTensorFromCtx(TFE_Context* ctx, int tensor_id,
+ TFE_TensorHandle* tensor,
+ TF_Status* status) {
+ VLOG(1) << "Enqueuing data tensor with id " << tensor_id;
+
+ TF_DataType inputType = TFE_TensorHandleDataType(tensor);
+ TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
+ if (!status->status.ok()) return;
+ std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
+ queue_deleter(queue, TFE_DeleteTensorHandle);
+
+ createTFEEnqueue(ctx, inputType, queue, tensor, status);
+}
+
+void TFE_EnqueueVariantTensor(TF_Session* session, int tensor_id,
+ TFE_TensorHandle* tensor, TF_Status* status) {
+ VLOG(1) << "Enqueuing variant tensor with id " << tensor_id;
+
+ auto ctx = TFE_CreateContextFromSession(session, status);
+ if (!status->status.ok()) return;
+ std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
+ ctx, TFE_DeleteContext);
+
+ TFE_TensorHandle* queue = createTFEQueue(ctx, TF_VARIANT, tensor_id, status);
+ if (!status->status.ok()) return;
+ std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
+ queue_deleter(queue, TFE_DeleteTensorHandle);
+
+ createTFEEnqueue(ctx, TF_VARIANT, queue, tensor, status);
+}
+
+TFE_TensorHandle* TFE_DequeueVariantTensor(TF_Session* session, int tensor_id,
+ TF_Status* status) {
+ VLOG(1) << "Dequeuing variant tensor with id " << tensor_id;
+
+ auto ctx = TFE_CreateContextFromSession(session, status);
+ if (!status->status.ok()) return nullptr;
+ std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
+ ctx, TFE_DeleteContext);
+
+ TFE_TensorHandle* queue = createTFEQueue(ctx, TF_VARIANT, tensor_id, status);
+ if (!status->status.ok()) return nullptr;
+ std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
+ queue_deleter(queue, TFE_DeleteTensorHandle);
+
+ return createTFEDequeue(ctx, TF_VARIANT, queue, status);
+}