aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_api_experimental.cc
diff options
context:
space:
mode:
authorGravatar Mingsheng Hong <hongm@google.com>2018-05-05 07:10:58 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-07 15:42:04 -0700
commit067be1aa75d9065b6cc57ba6316fc17544a9fdf1 (patch)
treecc952cc49d777f7fc1ac828a1c42314aefb8dace /tensorflow/c/c_api_experimental.cc
parent150089e6e67e4492f098cdd8f9f2f48dc9f9cc56 (diff)
Part 2 of Swift<->TF sends/recvs: support receiving tensors in TF from
Swift via direct session. The changes are: 1. Added a TF experimental C API for Swift host to enqueue a tensor for sending to TF. Again, the C APIs can be removed once the Fifo-queue based design proves stable later. 2. TFLowerGraph is extended to generate Fifo related nodes for TF to receive tensors. This is similar to the extension for TF to send tensors. 3. TFPartition is extended to support host send (createHostSend()), which does the tensor send via a new protocol method TensorSendableReceivable.sendToDevice(). The main complexity is in sending a scalar, where a new protocol method TensorSendableReceivable.createScalarTensor() is called to first create a tensor out of it, and then send it over to TF. Also removed code for protocol conformance on AccelerableByTensorFlow. Instead have compiler look up that conformance from the SILFunction on sending/receiving tensors. AccelerableByTensorFlow could be removed from the compiler-known protocol list now, but we'll defer that till things can stabilized more (in the past this protocol has been added to and removed from the list at different times). PiperOrigin-RevId: 195539436
Diffstat (limited to 'tensorflow/c/c_api_experimental.cc')
-rw-r--r--tensorflow/c/c_api_experimental.cc48
1 files changed, 48 insertions, 0 deletions
diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc
index 82dbd3cdbc..95b04f9058 100644
--- a/tensorflow/c/c_api_experimental.cc
+++ b/tensorflow/c/c_api_experimental.cc
@@ -8407,3 +8407,51 @@ TF_Tensor* TF_DequeueNamedTensor(TF_Session* session, int tensor_id,
}
return ret;
}
+
+void TF_EnqueueNamedTensor(TF_Session* session, int tensor_id,
+ TF_Tensor* tensor, TF_Status* status) {
+ assert(session);
+ {
+ tensorflow::mutex_lock c(session->graph->mu);
+ if (VLOG_IS_ON(1)) {
+ VLOG(1) << "Enqueuing named tensor with id " << tensor_id
+ << ", with input graph: "
+ << session->graph->graph.ToGraphDefDebug().DebugString();
+ tensorflow::Tensor internal_tensor;
+ if (tensorflow::TF_TensorToTensor(tensor, &internal_tensor).ok()) {
+ VLOG(1) << "Enqueu'ing tensor content: "
+ << internal_tensor.DebugString();
+ }
+ }
+ }
+
+ TF_Operation* enqueue_op = TF_GraphOperationByName(
+ session->graph,
+ tensorflow::strings::StrCat("fifo_queue_enqueue_", tensor_id).c_str());
+ if (enqueue_op == nullptr) {
+ status->status = tensorflow::errors::Internal(
+ "Unable to find the enqueue node in the TF graph.");
+ return;
+ }
+
+ TF_Operation* placeholder_op = TF_GraphOperationByName(
+ session->graph,
+ tensorflow::strings::StrCat("arg_tensor_enqueue_", tensor_id).c_str());
+ if (placeholder_op == nullptr) {
+ status->status = tensorflow::errors::Internal(
+ "Unable to find the placeholder node as input to enqueue in the TF "
+ "graph.");
+ return;
+ }
+
+ VLOG(1) << "Running the enqueue op";
+ TF_Output input{placeholder_op, 0};
+ TF_SessionRun(session, /*run_options*/ nullptr,
+ // input related parameters
+ /*inputs*/ &input, /*input_values*/ &tensor, /*ninputs*/ 1,
+ // output related parameters
+ /*outputs*/ nullptr, /*output_values*/ nullptr, /*noutputs*/ 0,
+ /*targets*/ &enqueue_op, /*ntargets*/ 1,
+ /*run_metadata*/ nullptr, status);
+ VLOG(1) << "Enqueuing is done.";
+}