aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/eager/c_api.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-12 12:44:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-12 12:51:15 -0700
commit315369aacd002d8c668b86a52f3cd88956a9b9a2 (patch)
treeb88e7f48b8142cb20c4e8b2fdaa59f5d3ea51b46 /tensorflow/c/eager/c_api.h
parent694a8101316107088efdbc33f7a5a60c7c8e7c8d (diff)
Extend TF Eager C API to allow asynchronous execution.
PiperOrigin-RevId: 188763442
Diffstat (limited to 'tensorflow/c/eager/c_api.h')
-rw-r--r--tensorflow/c/eager/c_api.h58
1 files changed, 53 insertions, 5 deletions
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
index 9610ca1b3b..316006bafb 100644
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -75,6 +75,11 @@ typedef enum TFE_ContextDevicePlacementPolicy {
TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3,
} TFE_ContextDevicePlacementPolicy;
+// Sets the default execution mode (sync/async). Note that this can be
+// overridden per thread using TFE_ContextSetAsyncForThread.
+TF_CAPI_EXPORT extern void TFE_ContextOptionsSetAsync(TFE_ContextOptions*,
+ unsigned char async);
+
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetDevicePlacementPolicy(
TFE_ContextOptions*, TFE_ContextDevicePlacementPolicy);
@@ -110,6 +115,30 @@ TF_CAPI_EXPORT extern void TFE_ContextSetThreadLocalDevicePlacementPolicy(
TF_CAPI_EXPORT extern TFE_ContextDevicePlacementPolicy
TFE_ContextGetDevicePlacementPolicy(TFE_Context*);
+// Overrides the execution mode (sync/async) for the current thread.
+TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context*,
+ unsigned char async,
+ TF_Status* status);
+
+// Causes the calling thread to block till all ops dispatched in async mode
+// have been executed. Note that "execution" here refers to kernel execution /
+// scheduling of copies, etc. Similar to sync execution, it doesn't guarantee
+// that lower level device queues (like GPU streams) have been flushed.
+//
+// This call may not block for execution of ops enqueued concurrently with this
+// call.
+TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context*,
+ TF_Status* status);
+
+// When an error happens, any pending operations are discarded and newly issued
+// ops return an error. This call clears the error state and re-enables
+// execution of newly issued ops.
+//
+// Note that outputs of discarded ops remain in a corrupt state and should not
+// be used for future calls.
+// TODO(agarwal): mark the affected handles and raise errors if they are used.
+TF_CAPI_EXPORT extern void TFE_ContextAsyncClearError(TFE_Context*);
+
// A handle to a tensor on a device.
//
// Like a TF_Tensor, a TFE_TensorHandle refers to a tensor with a value, shape,
@@ -119,15 +148,21 @@ typedef struct TFE_TensorHandle TFE_TensorHandle;
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t,
TF_Status* status);
+// Indicates that the caller will not be using `h` any more.
TF_CAPI_EXPORT extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h);
TF_CAPI_EXPORT extern TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h);
+// This function will block till the operation that produces `h` has completed.
TF_CAPI_EXPORT extern int TFE_TensorHandleNumDims(TFE_TensorHandle* h,
TF_Status* status);
+// This function will block till the operation that produces `h` has completed.
TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h,
int dim_index,
TF_Status* status);
+// This function will block till the operation that produces `h` has completed.
TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName(
TFE_TensorHandle* h, TF_Status* status);
+
+// This function will block till the operation that produces `h` has completed.
TF_CAPI_EXPORT extern TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h,
TF_Status* status);
@@ -137,6 +172,9 @@ TF_CAPI_EXPORT extern TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h,
// that shares the underlying buffer. Otherwise, it currently requires at least
// one of the source or destination devices to be CPU (i.e., for the source or
// destination tensor to be placed in host memory).
+// If async execution is enabled, the copy may be enqueued and the call will
+// return "non-ready" handle. Else, this function returns after the copy has
+// been done.
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopyToDevice(
TFE_TensorHandle* h, TFE_Context* ctx, const char* device_name,
TF_Status* status);
@@ -157,6 +195,7 @@ typedef struct TFE_Op TFE_Op;
TF_CAPI_EXPORT extern TFE_Op* TFE_NewOp(TFE_Context* ctx,
const char* op_or_function_name,
TF_Status* status);
+
TF_CAPI_EXPORT extern void TFE_DeleteOp(TFE_Op* op);
TF_CAPI_EXPORT extern void TFE_OpSetDevice(TFE_Op* op, const char* device_name,
@@ -242,13 +281,20 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrFunctionList(TFE_Op* op,
int num_values);
// Execute the operation defined by 'op' and return handles to computed
-// tensors in 'retvals'.
+// tensors in `retvals`.
+//
+// 'retvals' must point to a pre-allocated array of TFE_TensorHandle* and
+// '*num_retvals' should be set to the size of this array. It is an error if
+// the number of outputs is different from *num_retvals.
//
-// 'retvals' must point to a pre-allocated array of TFE_TensorHandle*
-// and '*num_retvals' should be set to the size of this array.
+// If async execution is enabled, the call may simply enqueue the execution
+// and return "non-ready" handles in `retvals`. Note that any handles contained
+// in 'op' should not be mutated till the kernel execution actually finishes.
//
-// On return, 'num_retvals' will be set to the actual number of outputs
-// returned by the operation.
+// For sync execution, if any of the inputs to `op` are not ready, this call
+// will block till they become ready and then return when the kernel execution
+// is done.
+// TODO(agarwal): change num_retvals to int from int*.
TF_CAPI_EXPORT extern void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals,
int* num_retvals, TF_Status* status);
@@ -274,6 +320,8 @@ TF_CAPI_EXPORT extern void TFE_ContextDisableRunMetadata(TFE_Context* ctx);
// Populates the passed-in buffer with a serialized RunMetadata protocol buffer
// containing any run metadata information accumulated so far and clears this
// information.
+// If async mode is enabled, this call blocks till all currently pending ops are
+// done.
TF_CAPI_EXPORT extern void TFE_ContextExportRunMetadata(TFE_Context* ctx,
TF_Buffer* buf,
TF_Status* status);