aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c
diff options
context:
space:
mode:
authorGravatar Igor Ganichev <iga@google.com>2017-09-28 12:14:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-28 12:19:18 -0700
commitbdab2691068757ee4872167898bc8768a7303ae9 (patch)
tree9d91c5b5f8e73860eab61fc8d5adb3a5788d43b6 /tensorflow/c
parent860b30b2d42d0a21a86f59ef392e5fd9962a1d7c (diff)
Add append_hash_to_fn_name arg to TF_GraphToFunction
PiperOrigin-RevId: 170379490
Diffstat (limited to 'tensorflow/c')
-rw-r--r--tensorflow/c/BUILD1
-rw-r--r--tensorflow/c/c_api.h17
-rw-r--r--tensorflow/c/c_api_function.cc29
-rw-r--r--tensorflow/c/c_api_function_test.cc23
4 files changed, 56 insertions, 14 deletions
diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD
index aead7154ee..077fb053fb 100644
--- a/tensorflow/c/BUILD
+++ b/tensorflow/c/BUILD
@@ -72,6 +72,7 @@ tf_cuda_library(
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
],
}),
)
diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h
index a17c877804..33fd1794cf 100644
--- a/tensorflow/c/c_api.h
+++ b/tensorflow/c/c_api.h
@@ -1039,12 +1039,14 @@ TF_CAPI_EXPORT void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny,
// fn_body - the graph whose operations (or subset of whose operations) will be
// converted to TF_Function.
// fn_name - the name of the new TF_Function. Should match the operation
-// name (OpDef.name) regexp [A-Z][A-Za-z0-9_.\\-/]* and be distinct
-// from other operation names (at least those registered in graphs
-// where this function will be used).
-// TODO(iga): Allow null in here and have C API come up with
-// a unique name with high probability (similarly to
-// _create_hash_str in function.py)
+// name (OpDef.name) regexp [A-Z][A-Za-z0-9_.\\-/]*.
+// If `append_hash_to_fn_name` is false, `fn_name` must be distinct
+// from other function and operation names (at least those
+// registered in graphs where this function will be used).
+// append_hash_to_fn_name - Must be 0 or 1. If set to 1, the actual name
+// of the function will be `fn_name` appended with
+// '_<hash_of_this_function's_definition>'.
+// If set to 0, the function's name will be `fn_name`.
// num_opers - `num_opers` contains the number of elements in the `opers` array
// or a special value of -1 meaning that no array is given.
// The distinction between an empty array of operations and no
@@ -1114,7 +1116,8 @@ TF_CAPI_EXPORT void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny,
//
// On failure, null.
TF_CAPI_EXPORT extern TF_Function* TF_GraphToFunction(
- const TF_Graph* fn_body, const char* fn_name, int num_opers,
+ const TF_Graph* fn_body, const char* fn_name,
+ unsigned char append_hash_to_fn_name, int num_opers,
const TF_Operation* const* opers, int ninputs, const TF_Output* inputs,
int noutputs, const TF_Output* outputs, const char* const* output_names,
const TF_FunctionOptions* opts, const char* description, TF_Status* status);
diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc
index 61484fd8ea..7924c31a5f 100644
--- a/tensorflow/c/c_api_function.cc
+++ b/tensorflow/c/c_api_function.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/lib/strings/base64.h"
#include "tensorflow/core/lib/strings/strcat.h"
using tensorflow::errors::InvalidArgument;
@@ -232,6 +233,7 @@ Status FillFunctionBody(
// Graph to FunctionDef conversion. This code is closely modeled on the Python
// code in third_party/tensorflow/python/framework/function.py.
Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name,
+ bool append_hash_to_fn_name,
const std::vector<const Node*>& body_nodes,
const std::vector<OutputTensor>& inputs,
const std::vector<OutputTensor>& outputs,
@@ -241,7 +243,6 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name,
DCHECK_EQ(output_names.size(), outputs.size());
}
- fdef->mutable_signature()->set_name(fn_name);
if (description != nullptr) {
fdef->mutable_signature()->set_description(description);
}
@@ -328,7 +329,6 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name,
// Remap return values.
for (int r = 0; r < fdef->signature().output_arg_size(); ++r) {
const string& ret_name = fdef->signature().output_arg(r).name();
-
// We convert this flat tensor name to the nested value
// (e.g. `add:z:1`) that we stored in tensor_renaming.
const string& return_value =
@@ -343,6 +343,24 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name,
(*fdef->mutable_ret())[ret_name] = iter->second;
}
+ if (append_hash_to_fn_name) {
+ const uint64 hash = FunctionDefHash(*fdef);
+ string encoded;
+ TF_RETURN_IF_ERROR(Base64Encode(
+ StringPiece(reinterpret_cast<const char*>(&hash), sizeof(hash)),
+ &encoded));
+ // Besides letters and digits our Base64 encoding uses '_' and '-'.
+ // Dash is invalid in operation names and multiple underscores in random
+ // places look strange. Since we never need to decode the hash back,
+ // replace these chars with with 'a' and 'A'. Replacing with different
+ // letters keeps more entropy.
+ std::replace(encoded.begin(), encoded.end(), '-', 'a');
+ std::replace(encoded.begin(), encoded.end(), '_', 'A');
+ fdef->mutable_signature()->set_name(strings::StrCat(fn_name, "_", encoded));
+ } else {
+ fdef->mutable_signature()->set_name(fn_name);
+ }
+
return Status::OK();
}
@@ -451,6 +469,7 @@ using tensorflow::Node;
using tensorflow::string;
TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name,
+ unsigned char append_hash_to_fn_name,
int num_opers, const TF_Operation* const* opers,
int ninputs, const TF_Output* inputs,
int noutputs, const TF_Output* outputs,
@@ -489,9 +508,11 @@ TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name,
// Do the actual function creation.
TF_Function* tf_function = new TF_Function();
+ DCHECK(append_hash_to_fn_name <= 1);
status->status = tensorflow::GraphToFunctionDef(
- fn_body->graph, fn_name, body_nodes, input_tensors, output_tensors,
- output_names_vec, description, &tf_function->fdef);
+ fn_body->graph, fn_name, append_hash_to_fn_name != 0, body_nodes,
+ input_tensors, output_tensors, output_names_vec, description,
+ &tf_function->fdef);
if (!status->status.ok()) {
TF_DeleteFunction(tf_function);
return nullptr;
diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc
index a5a66d9385..f76273e93b 100644
--- a/tensorflow/c/c_api_function_test.cc
+++ b/tensorflow/c/c_api_function_test.cc
@@ -179,7 +179,7 @@ class CApiFunctionTest : public ::testing::Test {
bool expect_failure = false) {
ASSERT_EQ(func_, nullptr);
const char** output_names_ptr = ToArray(output_names);
- func_ = TF_GraphToFunction(func_graph_, func_name_, num_opers,
+ func_ = TF_GraphToFunction(func_graph_, func_name_, false, num_opers,
num_opers == -1 ? nullptr : opers.data(),
inputs.size(), inputs.data(), outputs.size(),
outputs.data(), output_names_ptr,
@@ -1200,7 +1200,8 @@ TEST_F(CApiFunctionTest, OutputOpNotInBody) {
}
void DefineFunction(const char* name, TF_Function** func,
- const char* description = nullptr) {
+ const char* description = nullptr,
+ bool append_hash = false) {
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> func_graph(
TF_NewGraph(), TF_DeleteGraph);
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> s(TF_NewStatus(),
@@ -1211,7 +1212,7 @@ void DefineFunction(const char* name, TF_Function** func,
TF_Output inputs[] = {{feed, 0}};
TF_Output outputs[] = {{neg, 0}};
- *func = TF_GraphToFunction(func_graph.get(), name, -1,
+ *func = TF_GraphToFunction(func_graph.get(), name, append_hash, -1,
/*opers=*/nullptr, 1, inputs, 1, outputs,
/*output_names=*/nullptr,
/*opts=*/nullptr, description, s.get());
@@ -1453,5 +1454,21 @@ TEST_F(CApiFunctionTest, Description) {
ASSERT_EQ(string("Return something"), fdef.signature().description());
}
+TEST_F(CApiFunctionTest, Name) {
+ DefineFunction("long_func_name", &func_, "Return something",
+ /*append_hash=*/false);
+ tensorflow::FunctionDef fdef;
+ ASSERT_TRUE(GetFunctionDef(func_, &fdef));
+ ASSERT_EQ(string("long_func_name"), fdef.signature().name());
+}
+
+TEST_F(CApiFunctionTest, AppendHash) {
+ DefineFunction("func_name_base", &func_, "Return something",
+ /*append_hash=*/true);
+ tensorflow::FunctionDef fdef;
+ ASSERT_TRUE(GetFunctionDef(func_, &fdef));
+ ASSERT_EQ(string("func_name_base_qaJ8jA8UmGY"), fdef.signature().name());
+}
+
} // namespace
} // namespace tensorflow