aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c
diff options
context:
space:
mode:
authorGravatar Mingsheng Hong <hongm@google.com>2018-09-14 14:15:05 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-14 14:19:20 -0700
commitb5594e6121e902f8dd2d5127653a1ec5f97daccd (patch)
tree5c17dc342ace8fb5c30ff44adafcb4fcb81ba718 /tensorflow/c
parent19d66a950e2091bb598c6a2d375e14208f5773b2 (diff)
Added TFE_OpSetAttrTensor() to eager C API.
Also added some experimental C APIs for facilitate the use of eager C APIs in S4TF compiler. PiperOrigin-RevId: 213041780
Diffstat (limited to 'tensorflow/c')
-rw-r--r--tensorflow/c/c_api_experimental.cc50
-rw-r--r--tensorflow/c/c_api_experimental.h9
-rwxr-xr-xtensorflow/c/eager/c_api.cc7
-rwxr-xr-xtensorflow/c/eager/c_api.h5
4 files changed, 71 insertions, 0 deletions
diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc
index c195c9e01c..3bcc62cf2d 100644
--- a/tensorflow/c/c_api_experimental.cc
+++ b/tensorflow/c/c_api_experimental.cc
@@ -8705,3 +8705,53 @@ TFE_TensorHandle* TFE_DequeueVariantTensor(TF_Session* session, int tensor_id,
return createTFEDequeue(ctx, TF_VARIANT, queue, status);
}
+
+static void CheckOk(TF_Status* status) {
+ CHECK_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
+}
+
+void TFE_TensorHandlePrintDebugString(TFE_TensorHandle* handle) {
+ auto* status = TF_NewStatus();
+ TF_Tensor* t = TFE_TensorHandleResolve(handle, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ tensorflow::Tensor dst;
+ TF_CHECK_OK(TF_TensorToTensor(t, &dst));
+ LOG(INFO) << dst.DebugString();
+
+ TF_DeleteTensor(t);
+ TF_DeleteStatus(status);
+}
+
+TFE_TensorHandle* TFE_RunConstOp(TFE_Context* ctx) {
+ // Intentionally LOG into INFO below for ease of debugging.
+ VLOG(1) << "TFE_RunConstOp called";
+
+ auto* status = TF_NewStatus();
+ auto* op = TFE_NewOp(ctx, "Const", status);
+ CheckOk(status);
+ TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
+
+ auto* tensor =
+ TF_AllocateTensor(TF_FLOAT, /*shape.data()*/ nullptr, /*shape.size()*/ 0,
+ TF_DataTypeSize(TF_FLOAT) * 1);
+ auto* ptr = reinterpret_cast<char*>(TF_TensorData(tensor));
+ *reinterpret_cast<float*>(ptr) = 17.0;
+
+ TFE_OpSetAttrTensor(op, "value", tensor, status);
+ CheckOk(status);
+ TF_DeleteTensor(tensor);
+ VLOG(1) << "New op created";
+
+ TFE_TensorHandle* retval;
+ int num_retvals = 1;
+ TFE_Execute(op, &retval, &num_retvals, status);
+ CheckOk(status);
+ CHECK_EQ(num_retvals, 1);
+ VLOG(1) << "Op executed";
+
+ TFE_DeleteOp(op);
+ TF_DeleteStatus(status);
+
+ return retval;
+}
diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h
index 522c91f67e..a3ca847d96 100644
--- a/tensorflow/c/c_api_experimental.h
+++ b/tensorflow/c/c_api_experimental.h
@@ -174,6 +174,15 @@ TF_CAPI_EXPORT extern void TFE_EnqueueVariantTensor(TF_Session* session,
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueVariantTensor(
TF_Session* session, int tensor_id, TF_Status* status);
+// Prints `handle` in a human readable format to standard output for debugging.
+TF_CAPI_EXPORT extern void TFE_TensorHandlePrintDebugString(
+ TFE_TensorHandle* handle);
+
+// Returns a const scalar tensor.
+// Caller owns both the input and the output tensor handles.
+// TODO: Remove this API with hard-coded tensor computation.
+TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_RunConstOp(TFE_Context* ctx);
+
#ifdef __cplusplus
} /* end extern "C" */
#endif
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 349d9bcd7c..6f86ea80e5 100755
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -567,6 +567,13 @@ void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
op->operation.MutableAttrs()->Set(attr_name, attr_value);
}
+void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor,
+ TF_Status* status) {
+ tensorflow::Tensor t;
+ status->status = TF_TensorToTensor(tensor, &t);
+ if (status->status.ok()) op->operation.MutableAttrs()->Set(attr_name, t);
+}
+
void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
const void* const* values, const size_t* lengths,
int num_values) {
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
index 337447eec9..a87d73ec8e 100755
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -311,6 +311,11 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrFunction(TFE_Op* op,
const char* attr_name,
const TFE_Op* value);
+TF_CAPI_EXPORT extern void TFE_OpSetAttrTensor(TFE_Op* op,
+ const char* attr_name,
+ TF_Tensor* tensor,
+ TF_Status* status);
+
TF_CAPI_EXPORT extern void TFE_OpSetAttrStringList(TFE_Op* op,
const char* attr_name,
const void* const* values,