diff options
author | Mingsheng Hong <hongm@google.com> | 2018-05-05 07:10:58 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-07 15:42:04 -0700 |
commit | 067be1aa75d9065b6cc57ba6316fc17544a9fdf1 (patch) | |
tree | cc952cc49d777f7fc1ac828a1c42314aefb8dace /tensorflow/c/c_api_experimental.cc | |
parent | 150089e6e67e4492f098cdd8f9f2f48dc9f9cc56 (diff) |
Part 2 of Swift<->TF sends/recvs: support receiving tensors in TF from
Swift via direct session.
The changes are:
1. Added a TF experimental C API for Swift host to enqueue a tensor for sending
to TF. Again, the C APIs can be removed once the Fifo-queue based design
proves stable later.
2. TFLowerGraph is extended to generate Fifo related nodes for TF to receive
tensors. This is similar to the extension for TF to send tensors.
3. TFPartition is extended to support host send (createHostSend()), which does
the tensor send via a new protocol method TensorSendableReceivable.sendToDevice().
The main complexity is in sending a scalar, where a new protocol method
TensorSendableReceivable.createScalarTensor() is called to first create a tensor
out of it, and then send it over to TF.
Also removed code for protocol conformance on AccelerableByTensorFlow. Instead
have compiler look up that conformance from the SILFunction on sending/receiving
tensors.
AccelerableByTensorFlow could be removed from the compiler-known protocol list
now, but we'll defer that till things can stabilized more (in the past this
protocol has been added to and removed from the list at different times).
PiperOrigin-RevId: 195539436
Diffstat (limited to 'tensorflow/c/c_api_experimental.cc')
-rw-r--r-- | tensorflow/c/c_api_experimental.cc | 48 |
1 files changed, 48 insertions, 0 deletions
diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index 82dbd3cdbc..95b04f9058 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -8407,3 +8407,51 @@ TF_Tensor* TF_DequeueNamedTensor(TF_Session* session, int tensor_id, } return ret; } + +void TF_EnqueueNamedTensor(TF_Session* session, int tensor_id, + TF_Tensor* tensor, TF_Status* status) { + assert(session); + { + tensorflow::mutex_lock c(session->graph->mu); + if (VLOG_IS_ON(1)) { + VLOG(1) << "Enqueuing named tensor with id " << tensor_id + << ", with input graph: " + << session->graph->graph.ToGraphDefDebug().DebugString(); + tensorflow::Tensor internal_tensor; + if (tensorflow::TF_TensorToTensor(tensor, &internal_tensor).ok()) { + VLOG(1) << "Enqueu'ing tensor content: " + << internal_tensor.DebugString(); + } + } + } + + TF_Operation* enqueue_op = TF_GraphOperationByName( + session->graph, + tensorflow::strings::StrCat("fifo_queue_enqueue_", tensor_id).c_str()); + if (enqueue_op == nullptr) { + status->status = tensorflow::errors::Internal( + "Unable to find the enqueue node in the TF graph."); + return; + } + + TF_Operation* placeholder_op = TF_GraphOperationByName( + session->graph, + tensorflow::strings::StrCat("arg_tensor_enqueue_", tensor_id).c_str()); + if (placeholder_op == nullptr) { + status->status = tensorflow::errors::Internal( + "Unable to find the placeholder node as input to enqueue in the TF " + "graph."); + return; + } + + VLOG(1) << "Running the enqueue op"; + TF_Output input{placeholder_op, 0}; + TF_SessionRun(session, /*run_options*/ nullptr, + // input related parameters + /*inputs*/ &input, /*input_values*/ &tensor, /*ninputs*/ 1, + // output related parameters + /*outputs*/ nullptr, /*output_values*/ nullptr, /*noutputs*/ 0, + /*targets*/ &enqueue_op, /*ntargets*/ 1, + /*run_metadata*/ nullptr, status); + VLOG(1) << "Enqueuing is done."; +} |