From b5594e6121e902f8dd2d5127653a1ec5f97daccd Mon Sep 17 00:00:00 2001 From: Mingsheng Hong Date: Fri, 14 Sep 2018 14:15:05 -0700 Subject: Added TFE_OpSetAttrTensor() to eager C API. Also added some experimental C APIs for facilitate the use of eager C APIs in S4TF compiler. PiperOrigin-RevId: 213041780 --- tensorflow/c/c_api_experimental.cc | 50 ++++++++++++++++++++++++++++++++++++++ tensorflow/c/c_api_experimental.h | 9 +++++++ tensorflow/c/eager/c_api.cc | 7 ++++++ tensorflow/c/eager/c_api.h | 5 ++++ 4 files changed, 71 insertions(+) (limited to 'tensorflow/c') diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index c195c9e01c..3bcc62cf2d 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -8705,3 +8705,53 @@ TFE_TensorHandle* TFE_DequeueVariantTensor(TF_Session* session, int tensor_id, return createTFEDequeue(ctx, TF_VARIANT, queue, status); } + +static void CheckOk(TF_Status* status) { + CHECK_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); +} + +void TFE_TensorHandlePrintDebugString(TFE_TensorHandle* handle) { + auto* status = TF_NewStatus(); + TF_Tensor* t = TFE_TensorHandleResolve(handle, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + tensorflow::Tensor dst; + TF_CHECK_OK(TF_TensorToTensor(t, &dst)); + LOG(INFO) << dst.DebugString(); + + TF_DeleteTensor(t); + TF_DeleteStatus(status); +} + +TFE_TensorHandle* TFE_RunConstOp(TFE_Context* ctx) { + // Intentionally LOG into INFO below for ease of debugging. + VLOG(1) << "TFE_RunConstOp called"; + + auto* status = TF_NewStatus(); + auto* op = TFE_NewOp(ctx, "Const", status); + CheckOk(status); + TFE_OpSetAttrType(op, "dtype", TF_FLOAT); + + auto* tensor = + TF_AllocateTensor(TF_FLOAT, /*shape.data()*/ nullptr, /*shape.size()*/ 0, + TF_DataTypeSize(TF_FLOAT) * 1); + auto* ptr = reinterpret_cast(TF_TensorData(tensor)); + *reinterpret_cast(ptr) = 17.0; + + TFE_OpSetAttrTensor(op, "value", tensor, status); + CheckOk(status); + TF_DeleteTensor(tensor); + VLOG(1) << "New op created"; + + TFE_TensorHandle* retval; + int num_retvals = 1; + TFE_Execute(op, &retval, &num_retvals, status); + CheckOk(status); + CHECK_EQ(num_retvals, 1); + VLOG(1) << "Op executed"; + + TFE_DeleteOp(op); + TF_DeleteStatus(status); + + return retval; +} diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index 522c91f67e..a3ca847d96 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -174,6 +174,15 @@ TF_CAPI_EXPORT extern void TFE_EnqueueVariantTensor(TF_Session* session, TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueVariantTensor( TF_Session* session, int tensor_id, TF_Status* status); +// Prints `handle` in a human readable format to standard output for debugging. +TF_CAPI_EXPORT extern void TFE_TensorHandlePrintDebugString( + TFE_TensorHandle* handle); + +// Returns a const scalar tensor. +// Caller owns both the input and the output tensor handles. +// TODO: Remove this API with hard-coded tensor computation. +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_RunConstOp(TFE_Context* ctx); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 349d9bcd7c..6f86ea80e5 100755 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -567,6 +567,13 @@ void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name, op->operation.MutableAttrs()->Set(attr_name, attr_value); } +void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor, + TF_Status* status) { + tensorflow::Tensor t; + status->status = TF_TensorToTensor(tensor, &t); + if (status->status.ok()) op->operation.MutableAttrs()->Set(attr_name, t); +} + void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name, const void* const* values, const size_t* lengths, int num_values) { diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 337447eec9..a87d73ec8e 100755 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -311,6 +311,11 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name, const TFE_Op* value); +TF_CAPI_EXPORT extern void TFE_OpSetAttrTensor(TFE_Op* op, + const char* attr_name, + TF_Tensor* tensor, + TF_Status* status); + TF_CAPI_EXPORT extern void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name, const void* const* values, -- cgit v1.2.3