aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2017-11-20 17:34:19 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-20 17:37:34 -0800
commit55672b52559c73b5bf12c4827277959bda765e59 (patch)
treeeabf35171fff9d33ddc813e4d04d7e03734a6f87
parentfd92829df41984de014fd5f6807ad061fa45090a (diff)
TFE_ContextAddFunction to interface with the TFE_Function* API
PiperOrigin-RevId: 176443014
-rw-r--r--tensorflow/c/eager/c_api.cc6
-rw-r--r--tensorflow/c/eager/c_api.h7
-rw-r--r--tensorflow/c/eager/c_api_test.cc60
3 files changed, 73 insertions, 0 deletions
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 8359de62b7..706c89536d 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -571,6 +571,12 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx,
status->status = ctx->func_lib_def.AddFunctionDef(function_def);
}
+void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
+ TF_Status* status) {
+ tensorflow::mutex_lock l(ctx->functions_mu);
+ status->status = ctx->func_lib_def.AddFunctionDef(function->fdef);
+}
+
} // extern "C"
TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t) {
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
index 865580c5f3..ca105962df 100644
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -200,6 +200,13 @@ TF_CAPI_EXPORT extern void TFE_ContextAddFunctionDef(TFE_Context* ctx,
const char* serialized_function_def,
size_t size, TF_Status* status);
+// Adds a function (created from TF_GraphToFunction or
+// TF_FunctionImportFunctionDef) to the context, allowing it to be executed with
+// TFE_Execute by creating an op with the same name as the function.
+TF_CAPI_EXPORT extern void TFE_ContextAddFunction(TFE_Context* ctx,
+ TF_Function* function,
+ TF_Status* status);
+
#ifdef __cplusplus
} /* end extern "C" */
#endif
diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc
index 4af91b8853..03843fa913 100644
--- a/tensorflow/c/eager/c_api_test.cc
+++ b/tensorflow/c/eager/c_api_test.cc
@@ -295,6 +295,66 @@ TEST(CAPI, Execute) {
TF_DeleteStatus(status);
}
+TEST(CAPI, Function) {
+ // First create a simple identity function.
+ TF_Graph* function_graph = TF_NewGraph();
+ TF_OperationDescription* arg_descr =
+ TF_NewOperation(function_graph, "Placeholder", "arg");
+ TF_SetAttrType(arg_descr, "dtype", TF_INT32);
+ TF_Status* status = TF_NewStatus();
+ TF_Operation* arg = TF_FinishOperation(arg_descr, status);
+ ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+ TF_OperationDescription* id_descr =
+ TF_NewOperation(function_graph, "Identity", "id");
+ TF_SetAttrType(id_descr, "T", TF_INT32);
+ TF_AddInput(id_descr, {arg, 0});
+ TF_Operation* id = TF_FinishOperation(id_descr, status);
+ ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+ TF_Output input{arg, 0};
+ TF_Output output{id, 0};
+ TF_Function* fn =
+ TF_GraphToFunction(function_graph, "ident", 0, 1, &id, 1, &input, 1,
+ &output, nullptr, nullptr, "test", status);
+ ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+ TF_DeleteGraph(function_graph);
+ TFE_ContextOptions* opts = TFE_NewContextOptions();
+ TFE_Context* ctx = TFE_NewContext(opts, status);
+ ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+ TFE_DeleteContextOptions(opts);
+ TFE_ContextAddFunction(ctx, fn, status);
+ ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+ TF_DeleteFunction(fn);
+
+ TF_Tensor* t = TF_AllocateTensor(TF_INT32, nullptr, 0, 1);
+ *reinterpret_cast<tensorflow::int32*>(TF_TensorData(t)) = 42;
+ TFE_TensorHandle* h = TFE_NewTensorHandle(t, status);
+ ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+ TF_DeleteTensor(t);
+
+ TFE_Op* op = TFE_NewOp(ctx, "ident", status);
+ ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+ TFE_OpAddInput(op, h, status);
+ ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+
+ std::vector<TFE_TensorHandle*> result;
+ result.push_back(nullptr);
+ int num_retvals = 1;
+ TFE_Execute(op, result.data(), &num_retvals, status);
+ TFE_DeleteOp(op);
+ ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+ ASSERT_EQ(num_retvals, 1);
+
+ TF_Tensor* r = TFE_TensorHandleResolve(result[0], status);
+ ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+ EXPECT_EQ(*reinterpret_cast<tensorflow::int32*>(TF_TensorData(r)), 42);
+ TFE_DeleteTensorHandle(h);
+ TF_DeleteTensor(r);
+ TFE_DeleteTensorHandle(result[0]);
+ TFE_DeleteContext(ctx, status);
+ ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+ TF_DeleteStatus(status);
+}
+
string MatMulFunction() {
tensorflow::FunctionDef def;
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(