diff options
Diffstat (limited to 'tensorflow/c/c_api_experimental.cc')
-rw-r--r-- | tensorflow/c/c_api_experimental.cc | 210 |
1 files changed, 210 insertions, 0 deletions
diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index 69b3ffe2a1..c046bd66cd 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -79,6 +79,18 @@ TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation, auto* gpu_options = config.mutable_gpu_options(); gpu_options->set_allow_growth(gpu_memory_allow_growth); + // TODO(b/113217601): This is needed for EagerContext::runner_ to use a + // threadpool, so that we avoid the possibility of running the runner_ in the + // threadpool of GPU event mgr, as that can trigger more callbacks to be + // scheduled on that same threadpool, causing a deadlock in cases where the + // caller of event_mgr->ThenExecute() blocks on the completion of the callback + // (as in the case of ConstOp kernel creation on GPU, which involves copying a + // CPU tensor to GPU). + // Setting a larger thread pool does not help with the Swift caller, as we use + // a different TFE context for each thread of execution (for running graph + // functions, and their send/recvs corountines). + config.set_inter_op_parallelism_threads(1); + TF_Buffer* ret = TF_NewBuffer(); TF_CHECK_OK(MessageToBuffer(config, ret)); return ret; @@ -8494,3 +8506,201 @@ void TF_EnqueueNamedTensor(TF_Session* session, int tensor_id, /*run_metadata*/ nullptr, status); VLOG(1) << "Enqueuing is done."; } + +TFE_Context* TFE_CreateContextFromSession(TF_Session* session, + TF_Status* status) { + auto* opts = TFE_NewContextOptions(); + + // Reduce GPU memory allocation, and set appropriate config options for TFE + // context. + auto* config = + TF_CreateConfig(/*xla*/ false, /* gpu_memory_allow_growth */ true); + TFE_ContextOptionsSetConfig(opts, config->data, config->length, status); + if (!status->status.ok()) { + CHECK(!config); + TFE_DeleteContextOptions(opts); + return nullptr; + } + + auto* ctx = TFE_NewContextFromSession(opts, session, status); + TF_DeleteBuffer(config); + TFE_DeleteContextOptions(opts); + return ctx; +} + +// TODO: retrieve the device string via TFE_ContextListDevices() +static const char DEFAULT_CPU_DEVICE[] = + "/job:localhost/replica:0/task:0/device:CPU:0"; + +static TFE_TensorHandle* createTFEQueue(TFE_Context* ctx, TF_DataType inputType, + int tensor_id, TF_Status* status) { + std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> queueOp( + TFE_NewOp(ctx, "FIFOQueueV2", status), TFE_DeleteOp); + TFE_OpSetDevice(queueOp.get(), DEFAULT_CPU_DEVICE, status); + if (!status->status.ok()) return nullptr; + // TODO: use NAMED_TENSOR_QUEUE_CAPACITY in S4TF compiler. + TFE_OpSetAttrInt(queueOp.get(), "capacity", 1); + TFE_OpSetAttrTypeList(queueOp.get(), "component_types", &inputType, 1); + auto shared_name = tensorflow::strings::StrCat("fifo_queue_", tensor_id); + TFE_OpSetAttrString(queueOp.get(), "shared_name", shared_name.data(), + shared_name.size()); + TFE_OpSetAttrString(queueOp.get(), "container", "", 0); + + // TODO: consider making this an unknown shape. + const int64_t* dims_ptr = nullptr; + int num_dims = 0; + TFE_OpSetAttrShapeList(queueOp.get(), "shapes", &dims_ptr, &num_dims, + /*num_values*/ 0, status); + if (!status->status.ok()) return nullptr; + + int num_retvals = 1; + TFE_TensorHandle* queue = nullptr; + TFE_Execute(queueOp.get(), &queue, &num_retvals, status); + if (!status->status.ok()) return nullptr; + CHECK_EQ(num_retvals, 1); + + return queue; +} + +static void createTFEEnqueue(TFE_Context* ctx, TF_DataType inputType, + TFE_TensorHandle* queue, TFE_TensorHandle* tensor, + TF_Status* status) { + TFE_Op* op = TFE_NewOp(ctx, "QueueEnqueueV2", status); + if (!status->status.ok()) return; + std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op_deleter(op, TFE_DeleteOp); + TFE_OpSetDevice(op, DEFAULT_CPU_DEVICE, status); + if (!status->status.ok()) return; + TFE_OpAddInput(op, queue, status); + if (!status->status.ok()) return; + TFE_OpAddInput(op, tensor, status); + if (!status->status.ok()) return; + TFE_OpSetAttrTypeList(op, "Tcomponents", &inputType, 1); + TFE_OpSetAttrInt(op, "timeout_ms", -1); + + int num_retvals = 0; + TFE_Execute(op, nullptr /*retvals*/, &num_retvals, status); + if (!status->status.ok()) return; + CHECK_EQ(num_retvals, 0); +} + +static TFE_TensorHandle* createTFEDequeue(TFE_Context* ctx, + TF_DataType inputType, + TFE_TensorHandle* queue, + TF_Status* status) { + TFE_Op* op = TFE_NewOp(ctx, "QueueDequeueV2", status); + if (!status->status.ok()) return nullptr; + std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op_deleter(op, TFE_DeleteOp); + TFE_OpSetDevice(op, DEFAULT_CPU_DEVICE, status); + if (!status->status.ok()) return nullptr; + + TFE_OpAddInput(op, queue, status); + if (!status->status.ok()) return nullptr; + TFE_OpSetAttrTypeList(op, "component_types", &inputType, 1); + TFE_OpSetAttrInt(op, "timeout_ms", -1); + TFE_TensorHandle* ret; + int num_retvals = 1; + TFE_Execute(op, &ret, &num_retvals, status); + if (!status->status.ok()) return nullptr; + CHECK_EQ(num_retvals, 1); + return ret; +} + +TFE_TensorHandle* TFE_DequeueNamedTensor(TF_Session* session, int tensor_id, + TF_DataType inputType, + TF_Status* status) { + assert(session); + VLOG(1) << "Dequeuing data tensor with id " << tensor_id; + + auto ctx = TFE_CreateContextFromSession(session, status); + if (!status->status.ok()) return nullptr; + std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter( + ctx, TFE_DeleteContext); + + TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status); + if (!status->status.ok()) return nullptr; + std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> + queue_deleter(queue, TFE_DeleteTensorHandle); + + auto* ret = createTFEDequeue(ctx, inputType, queue, status); + return ret; +} + +TFE_TensorHandle* TFE_DequeueNamedTensorFromCtx(TFE_Context* ctx, int tensor_id, + TF_DataType inputType, + TF_Status* status) { + TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status); + if (!status->status.ok()) return nullptr; + std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> + queue_deleter(queue, TFE_DeleteTensorHandle); + + auto* ret = createTFEDequeue(ctx, inputType, queue, status); + + return ret; +} + +void TFE_EnqueueNamedTensor(TF_Session* session, int tensor_id, + TFE_TensorHandle* tensor, TF_Status* status) { + assert(session); + VLOG(1) << "Enqueuing data tensor with id " << tensor_id; + + auto ctx = TFE_CreateContextFromSession(session, status); + if (!status->status.ok()) return; + std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter( + ctx, TFE_DeleteContext); + + TF_DataType inputType = TFE_TensorHandleDataType(tensor); + TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status); + if (!status->status.ok()) return; + std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> + queue_deleter(queue, TFE_DeleteTensorHandle); + + createTFEEnqueue(ctx, inputType, queue, tensor, status); +} + +void TFE_EnqueueNamedTensorFromCtx(TFE_Context* ctx, int tensor_id, + TFE_TensorHandle* tensor, + TF_Status* status) { + VLOG(1) << "Enqueuing data tensor with id " << tensor_id; + + TF_DataType inputType = TFE_TensorHandleDataType(tensor); + TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status); + if (!status->status.ok()) return; + std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> + queue_deleter(queue, TFE_DeleteTensorHandle); + + createTFEEnqueue(ctx, inputType, queue, tensor, status); +} + +void TFE_EnqueueVariantTensor(TF_Session* session, int tensor_id, + TFE_TensorHandle* tensor, TF_Status* status) { + VLOG(1) << "Enqueuing variant tensor with id " << tensor_id; + + auto ctx = TFE_CreateContextFromSession(session, status); + if (!status->status.ok()) return; + std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter( + ctx, TFE_DeleteContext); + + TFE_TensorHandle* queue = createTFEQueue(ctx, TF_VARIANT, tensor_id, status); + if (!status->status.ok()) return; + std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> + queue_deleter(queue, TFE_DeleteTensorHandle); + + createTFEEnqueue(ctx, TF_VARIANT, queue, tensor, status); +} + +TFE_TensorHandle* TFE_DequeueVariantTensor(TF_Session* session, int tensor_id, + TF_Status* status) { + VLOG(1) << "Dequeuing variant tensor with id " << tensor_id; + + auto ctx = TFE_CreateContextFromSession(session, status); + if (!status->status.ok()) return nullptr; + std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter( + ctx, TFE_DeleteContext); + + TFE_TensorHandle* queue = createTFEQueue(ctx, TF_VARIANT, tensor_id, status); + if (!status->status.ok()) return nullptr; + std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> + queue_deleter(queue, TFE_DeleteTensorHandle); + + return createTFEDequeue(ctx, TF_VARIANT, queue, status); +} |