diff options
author | 2017-09-28 12:14:42 -0700 | |
---|---|---|
committer | 2017-09-28 12:19:18 -0700 | |
commit | bdab2691068757ee4872167898bc8768a7303ae9 (patch) | |
tree | 9d91c5b5f8e73860eab61fc8d5adb3a5788d43b6 /tensorflow/c | |
parent | 860b30b2d42d0a21a86f59ef392e5fd9962a1d7c (diff) |
Add append_hash_to_fn_name arg to TF_GraphToFunction
PiperOrigin-RevId: 170379490
Diffstat (limited to 'tensorflow/c')
-rw-r--r-- | tensorflow/c/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/c/c_api.h | 17 | ||||
-rw-r--r-- | tensorflow/c/c_api_function.cc | 29 | ||||
-rw-r--r-- | tensorflow/c/c_api_function_test.cc | 23 |
4 files changed, 56 insertions, 14 deletions
diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index aead7154ee..077fb053fb 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -72,6 +72,7 @@ tf_cuda_library( "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", ], }), ) diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index a17c877804..33fd1794cf 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -1039,12 +1039,14 @@ TF_CAPI_EXPORT void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, // fn_body - the graph whose operations (or subset of whose operations) will be // converted to TF_Function. // fn_name - the name of the new TF_Function. Should match the operation -// name (OpDef.name) regexp [A-Z][A-Za-z0-9_.\\-/]* and be distinct -// from other operation names (at least those registered in graphs -// where this function will be used). -// TODO(iga): Allow null in here and have C API come up with -// a unique name with high probability (similarly to -// _create_hash_str in function.py) +// name (OpDef.name) regexp [A-Z][A-Za-z0-9_.\\-/]*. +// If `append_hash_to_fn_name` is false, `fn_name` must be distinct +// from other function and operation names (at least those +// registered in graphs where this function will be used). +// append_hash_to_fn_name - Must be 0 or 1. If set to 1, the actual name +// of the function will be `fn_name` appended with +// '_<hash_of_this_function's_definition>'. +// If set to 0, the function's name will be `fn_name`. // num_opers - `num_opers` contains the number of elements in the `opers` array // or a special value of -1 meaning that no array is given. // The distinction between an empty array of operations and no @@ -1114,7 +1116,8 @@ TF_CAPI_EXPORT void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, // // On failure, null. TF_CAPI_EXPORT extern TF_Function* TF_GraphToFunction( - const TF_Graph* fn_body, const char* fn_name, int num_opers, + const TF_Graph* fn_body, const char* fn_name, + unsigned char append_hash_to_fn_name, int num_opers, const TF_Operation* const* opers, int ninputs, const TF_Output* inputs, int noutputs, const TF_Output* outputs, const char* const* output_names, const TF_FunctionOptions* opts, const char* description, TF_Status* status); diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc index 61484fd8ea..7924c31a5f 100644 --- a/tensorflow/c/c_api_function.cc +++ b/tensorflow/c/c_api_function.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/strings/base64.h" #include "tensorflow/core/lib/strings/strcat.h" using tensorflow::errors::InvalidArgument; @@ -232,6 +233,7 @@ Status FillFunctionBody( // Graph to FunctionDef conversion. This code is closely modeled on the Python // code in third_party/tensorflow/python/framework/function.py. Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, + bool append_hash_to_fn_name, const std::vector<const Node*>& body_nodes, const std::vector<OutputTensor>& inputs, const std::vector<OutputTensor>& outputs, @@ -241,7 +243,6 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, DCHECK_EQ(output_names.size(), outputs.size()); } - fdef->mutable_signature()->set_name(fn_name); if (description != nullptr) { fdef->mutable_signature()->set_description(description); } @@ -328,7 +329,6 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, // Remap return values. for (int r = 0; r < fdef->signature().output_arg_size(); ++r) { const string& ret_name = fdef->signature().output_arg(r).name(); - // We convert this flat tensor name to the nested value // (e.g. `add:z:1`) that we stored in tensor_renaming. const string& return_value = @@ -343,6 +343,24 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, (*fdef->mutable_ret())[ret_name] = iter->second; } + if (append_hash_to_fn_name) { + const uint64 hash = FunctionDefHash(*fdef); + string encoded; + TF_RETURN_IF_ERROR(Base64Encode( + StringPiece(reinterpret_cast<const char*>(&hash), sizeof(hash)), + &encoded)); + // Besides letters and digits our Base64 encoding uses '_' and '-'. + // Dash is invalid in operation names and multiple underscores in random + // places look strange. Since we never need to decode the hash back, + // replace these chars with with 'a' and 'A'. Replacing with different + // letters keeps more entropy. + std::replace(encoded.begin(), encoded.end(), '-', 'a'); + std::replace(encoded.begin(), encoded.end(), '_', 'A'); + fdef->mutable_signature()->set_name(strings::StrCat(fn_name, "_", encoded)); + } else { + fdef->mutable_signature()->set_name(fn_name); + } + return Status::OK(); } @@ -451,6 +469,7 @@ using tensorflow::Node; using tensorflow::string; TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name, + unsigned char append_hash_to_fn_name, int num_opers, const TF_Operation* const* opers, int ninputs, const TF_Output* inputs, int noutputs, const TF_Output* outputs, @@ -489,9 +508,11 @@ TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name, // Do the actual function creation. TF_Function* tf_function = new TF_Function(); + DCHECK(append_hash_to_fn_name <= 1); status->status = tensorflow::GraphToFunctionDef( - fn_body->graph, fn_name, body_nodes, input_tensors, output_tensors, - output_names_vec, description, &tf_function->fdef); + fn_body->graph, fn_name, append_hash_to_fn_name != 0, body_nodes, + input_tensors, output_tensors, output_names_vec, description, + &tf_function->fdef); if (!status->status.ok()) { TF_DeleteFunction(tf_function); return nullptr; diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc index a5a66d9385..f76273e93b 100644 --- a/tensorflow/c/c_api_function_test.cc +++ b/tensorflow/c/c_api_function_test.cc @@ -179,7 +179,7 @@ class CApiFunctionTest : public ::testing::Test { bool expect_failure = false) { ASSERT_EQ(func_, nullptr); const char** output_names_ptr = ToArray(output_names); - func_ = TF_GraphToFunction(func_graph_, func_name_, num_opers, + func_ = TF_GraphToFunction(func_graph_, func_name_, false, num_opers, num_opers == -1 ? nullptr : opers.data(), inputs.size(), inputs.data(), outputs.size(), outputs.data(), output_names_ptr, @@ -1200,7 +1200,8 @@ TEST_F(CApiFunctionTest, OutputOpNotInBody) { } void DefineFunction(const char* name, TF_Function** func, - const char* description = nullptr) { + const char* description = nullptr, + bool append_hash = false) { std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> func_graph( TF_NewGraph(), TF_DeleteGraph); std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> s(TF_NewStatus(), @@ -1211,7 +1212,7 @@ void DefineFunction(const char* name, TF_Function** func, TF_Output inputs[] = {{feed, 0}}; TF_Output outputs[] = {{neg, 0}}; - *func = TF_GraphToFunction(func_graph.get(), name, -1, + *func = TF_GraphToFunction(func_graph.get(), name, append_hash, -1, /*opers=*/nullptr, 1, inputs, 1, outputs, /*output_names=*/nullptr, /*opts=*/nullptr, description, s.get()); @@ -1453,5 +1454,21 @@ TEST_F(CApiFunctionTest, Description) { ASSERT_EQ(string("Return something"), fdef.signature().description()); } +TEST_F(CApiFunctionTest, Name) { + DefineFunction("long_func_name", &func_, "Return something", + /*append_hash=*/false); + tensorflow::FunctionDef fdef; + ASSERT_TRUE(GetFunctionDef(func_, &fdef)); + ASSERT_EQ(string("long_func_name"), fdef.signature().name()); +} + +TEST_F(CApiFunctionTest, AppendHash) { + DefineFunction("func_name_base", &func_, "Return something", + /*append_hash=*/true); + tensorflow::FunctionDef fdef; + ASSERT_TRUE(GetFunctionDef(func_, &fdef)); + ASSERT_EQ(string("func_name_base_qaJ8jA8UmGY"), fdef.signature().name()); +} + } // namespace } // namespace tensorflow |