aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_api_experimental.cc
diff options
context:
space:
mode:
authorGravatar Mingsheng Hong <hongm@google.com>2018-04-28 08:55:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-28 08:57:32 -0700
commit4c5699582aa368edfbe058d770407a558729f305 (patch)
treea342069d125520d58eb147dc5ec3132e0f0ec959 /tensorflow/c/c_api_experimental.cc
parent102b0c87a024f95d619860b0ba492c93e4bd96c9 (diff)
This is Part 1 of Swift<->TF sends/recvs: support sending tensors from TF to
Swift via direct session. The changes are: 1. Added an experimental TF C API TF_DequeueNamedTensor() to consume the queued tensors from a dequeue op. One use case is for the Swift host program to consume tensors sent by TF, where the queue is a Fifo queue managed by TF. Enqueuing tensors are done by running an enqueue op in a graph. The queued tensors are not persisted, and will be lost if the process/machine dies. The queue has a bounded capacity, to prevent producer from being unboundedly ahead of consumer. while caller of TF_DequeueNamedTensor() could have run the Fifo dequeue op directly, the extra level of indirection provided by this API allows us to more easily switch the queuing impl to another mechanism. If and once we stabilize on the Fifo queue based impl, we can remove this API. 2. Added a new S4TF runtime API _TFCReceiveTensorHandle() that receives a tensor via TF_DequeueNamedTensor(). 3. To support tensor receives in host program, taught PartitionCloner in TFPartition to insert SIL code to call _TFCReceiveTensorHandle(). 4. To support tensor sends in accelerator program, taught TFGraphLowering in generate QueueEnqueueV2 nodes in the TF graphs, with appropriate control dependence to make sure these nodes get executed. a) The enqueue produces no output tensor, and is executed only for its side effect. To ensure it is executed properly, control dependence is wired up. The general design is: before a TF_Function (can be a top level function or the body function of a while op) produces an output tensor OT, make OT control dependent on the enqueue op, so that enqueue gets run before the function returns. b) If tensor send occurs in a while loop body, the body logic currently gets lowered in 3 places: the while op cond function, the while op body function, and the ops at the same level as the while op itself (for running the last loop iteration). In this case, the correct TFGraph lowering is to run the enqueue in the last 2 out of the 3 places above. After this CL, the dual versions of the above (dequeuing via an op, and enqueuing via C API) will be added. PiperOrigin-RevId: 194658511
Diffstat (limited to 'tensorflow/c/c_api_experimental.cc')
-rw-r--r--tensorflow/c/c_api_experimental.cc39
1 files changed, 39 insertions, 0 deletions
diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc
index d3916bc167..82dbd3cdbc 100644
--- a/tensorflow/c/c_api_experimental.cc
+++ b/tensorflow/c/c_api_experimental.cc
@@ -8368,3 +8368,42 @@ TF_Operation* TF_MakeFileBasedIteratorGetNextWithDatasets(
return getnext_node;
#endif
}
+
+TF_Tensor* TF_DequeueNamedTensor(TF_Session* session, int tensor_id,
+ TF_Status* status) {
+ assert(session);
+ {
+ tensorflow::mutex_lock c(session->graph->mu);
+ VLOG(1) << "Dequeuing named tensor with id " << tensor_id
+ << ", with input graph: "
+ << session->graph->graph.ToGraphDefDebug().DebugString();
+ }
+
+ TF_Operation* dequeue_op = TF_GraphOperationByName(
+ session->graph,
+ tensorflow::strings::StrCat("fifo_queue_dequeue_", tensor_id).c_str());
+ if (dequeue_op == nullptr) {
+ status->status = tensorflow::errors::Internal(
+ "Unable to find the dequeue node in the TF graph.");
+ return nullptr;
+ }
+
+ VLOG(1) << "Running the dequeue op";
+ TF_Output output{dequeue_op, 0};
+ TF_Tensor* ret;
+ TF_SessionRun(session, /*run_options*/ nullptr,
+ // input related parameters
+ /*inputs*/ nullptr, /*input_values*/ nullptr, /*ninputs*/ 0,
+ // output related parameters
+ /*outputs*/ &output, /*output_values*/ &ret,
+ /*noutputs*/ 1,
+ /*targets*/ nullptr, /*ntargets*/ 0,
+ /*run_metadata*/ nullptr, status);
+ if (VLOG_IS_ON(1) && status->status.ok()) {
+ tensorflow::Tensor tensor;
+ if (tensorflow::TF_TensorToTensor(ret, &tensor).ok()) {
+ VLOG(1) << "Dequeued tensor content: " << tensor.DebugString();
+ }
+ }
+ return ret;
+}