aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_api_function.cc
diff options
context:
space:
mode:
authorGravatar Igor Ganichev <iga@google.com>2017-09-28 12:14:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-28 12:19:18 -0700
commitbdab2691068757ee4872167898bc8768a7303ae9 (patch)
tree9d91c5b5f8e73860eab61fc8d5adb3a5788d43b6 /tensorflow/c/c_api_function.cc
parent860b30b2d42d0a21a86f59ef392e5fd9962a1d7c (diff)
Add append_hash_to_fn_name arg to TF_GraphToFunction
PiperOrigin-RevId: 170379490
Diffstat (limited to 'tensorflow/c/c_api_function.cc')
-rw-r--r--tensorflow/c/c_api_function.cc29
1 files changed, 25 insertions, 4 deletions
diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc
index 61484fd8ea..7924c31a5f 100644
--- a/tensorflow/c/c_api_function.cc
+++ b/tensorflow/c/c_api_function.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/lib/strings/base64.h"
#include "tensorflow/core/lib/strings/strcat.h"
using tensorflow::errors::InvalidArgument;
@@ -232,6 +233,7 @@ Status FillFunctionBody(
// Graph to FunctionDef conversion. This code is closely modeled on the Python
// code in third_party/tensorflow/python/framework/function.py.
Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name,
+ bool append_hash_to_fn_name,
const std::vector<const Node*>& body_nodes,
const std::vector<OutputTensor>& inputs,
const std::vector<OutputTensor>& outputs,
@@ -241,7 +243,6 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name,
DCHECK_EQ(output_names.size(), outputs.size());
}
- fdef->mutable_signature()->set_name(fn_name);
if (description != nullptr) {
fdef->mutable_signature()->set_description(description);
}
@@ -328,7 +329,6 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name,
// Remap return values.
for (int r = 0; r < fdef->signature().output_arg_size(); ++r) {
const string& ret_name = fdef->signature().output_arg(r).name();
-
// We convert this flat tensor name to the nested value
// (e.g. `add:z:1`) that we stored in tensor_renaming.
const string& return_value =
@@ -343,6 +343,24 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name,
(*fdef->mutable_ret())[ret_name] = iter->second;
}
+ if (append_hash_to_fn_name) {
+ const uint64 hash = FunctionDefHash(*fdef);
+ string encoded;
+ TF_RETURN_IF_ERROR(Base64Encode(
+ StringPiece(reinterpret_cast<const char*>(&hash), sizeof(hash)),
+ &encoded));
+ // Besides letters and digits our Base64 encoding uses '_' and '-'.
+ // Dash is invalid in operation names and multiple underscores in random
+ // places look strange. Since we never need to decode the hash back,
+ // replace these chars with with 'a' and 'A'. Replacing with different
+ // letters keeps more entropy.
+ std::replace(encoded.begin(), encoded.end(), '-', 'a');
+ std::replace(encoded.begin(), encoded.end(), '_', 'A');
+ fdef->mutable_signature()->set_name(strings::StrCat(fn_name, "_", encoded));
+ } else {
+ fdef->mutable_signature()->set_name(fn_name);
+ }
+
return Status::OK();
}
@@ -451,6 +469,7 @@ using tensorflow::Node;
using tensorflow::string;
TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name,
+ unsigned char append_hash_to_fn_name,
int num_opers, const TF_Operation* const* opers,
int ninputs, const TF_Output* inputs,
int noutputs, const TF_Output* outputs,
@@ -489,9 +508,11 @@ TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name,
// Do the actual function creation.
TF_Function* tf_function = new TF_Function();
+ DCHECK(append_hash_to_fn_name <= 1);
status->status = tensorflow::GraphToFunctionDef(
- fn_body->graph, fn_name, body_nodes, input_tensors, output_tensors,
- output_names_vec, description, &tf_function->fdef);
+ fn_body->graph, fn_name, append_hash_to_fn_name != 0, body_nodes,
+ input_tensors, output_tensors, output_names_vec, description,
+ &tf_function->fdef);
if (!status->status.ok()) {
TF_DeleteFunction(tf_function);
return nullptr;