From 55672b52559c73b5bf12c4827277959bda765e59 Mon Sep 17 00:00:00 2001 From: Alexandre Passos Date: Mon, 20 Nov 2017 17:34:19 -0800 Subject: TFE_ContextAddFunction to interface with the TFE_Function* API PiperOrigin-RevId: 176443014 --- tensorflow/c/eager/c_api.cc | 6 ++++ tensorflow/c/eager/c_api.h | 7 +++++ tensorflow/c/eager/c_api_test.cc | 60 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 73 insertions(+) diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 8359de62b7..706c89536d 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -571,6 +571,12 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx, status->status = ctx->func_lib_def.AddFunctionDef(function_def); } +void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function, + TF_Status* status) { + tensorflow::mutex_lock l(ctx->functions_mu); + status->status = ctx->func_lib_def.AddFunctionDef(function->fdef); +} + } // extern "C" TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t) { diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 865580c5f3..ca105962df 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -200,6 +200,13 @@ TF_CAPI_EXPORT extern void TFE_ContextAddFunctionDef(TFE_Context* ctx, const char* serialized_function_def, size_t size, TF_Status* status); +// Adds a function (created from TF_GraphToFunction or +// TF_FunctionImportFunctionDef) to the context, allowing it to be executed with +// TFE_Execute by creating an op with the same name as the function. +TF_CAPI_EXPORT extern void TFE_ContextAddFunction(TFE_Context* ctx, + TF_Function* function, + TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 4af91b8853..03843fa913 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -295,6 +295,66 @@ TEST(CAPI, Execute) { TF_DeleteStatus(status); } +TEST(CAPI, Function) { + // First create a simple identity function. + TF_Graph* function_graph = TF_NewGraph(); + TF_OperationDescription* arg_descr = + TF_NewOperation(function_graph, "Placeholder", "arg"); + TF_SetAttrType(arg_descr, "dtype", TF_INT32); + TF_Status* status = TF_NewStatus(); + TF_Operation* arg = TF_FinishOperation(arg_descr, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_OperationDescription* id_descr = + TF_NewOperation(function_graph, "Identity", "id"); + TF_SetAttrType(id_descr, "T", TF_INT32); + TF_AddInput(id_descr, {arg, 0}); + TF_Operation* id = TF_FinishOperation(id_descr, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_Output input{arg, 0}; + TF_Output output{id, 0}; + TF_Function* fn = + TF_GraphToFunction(function_graph, "ident", 0, 1, &id, 1, &input, 1, + &output, nullptr, nullptr, "test", status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_DeleteGraph(function_graph); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TFE_DeleteContextOptions(opts); + TFE_ContextAddFunction(ctx, fn, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_DeleteFunction(fn); + + TF_Tensor* t = TF_AllocateTensor(TF_INT32, nullptr, 0, 1); + *reinterpret_cast(TF_TensorData(t)) = 42; + TFE_TensorHandle* h = TFE_NewTensorHandle(t, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_DeleteTensor(t); + + TFE_Op* op = TFE_NewOp(ctx, "ident", status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TFE_OpAddInput(op, h, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + + std::vector result; + result.push_back(nullptr); + int num_retvals = 1; + TFE_Execute(op, result.data(), &num_retvals, status); + TFE_DeleteOp(op); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + ASSERT_EQ(num_retvals, 1); + + TF_Tensor* r = TFE_TensorHandleResolve(result[0], status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + EXPECT_EQ(*reinterpret_cast(TF_TensorData(r)), 42); + TFE_DeleteTensorHandle(h); + TF_DeleteTensor(r); + TFE_DeleteTensorHandle(result[0]); + TFE_DeleteContext(ctx, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_DeleteStatus(status); +} + string MatMulFunction() { tensorflow::FunctionDef def; CHECK(tensorflow::protobuf::TextFormat::ParseFromString( -- cgit v1.2.3