aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/session_bundle/bundle_shim.cc
diff options
context:
space:
mode:
authorGravatar Sukriti Ramesh <sukritiramesh@google.com>2017-02-09 10:56:58 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-09 11:18:15 -0800
commit4a75d35b1a8cc13d4c40c93773a90f3000daf289 (patch)
tree7bd9df616b4cffa26b0e1e9b7965a48c761c9df7 /tensorflow/contrib/session_bundle/bundle_shim.cc
parent9683b095fce7b77df01d95ac3b07dcd17a083782 (diff)
Add functionality to populate dtype of TensorInfos in up-converted SessionBundles.
Change: 147054398
Diffstat (limited to 'tensorflow/contrib/session_bundle/bundle_shim.cc')
-rw-r--r--tensorflow/contrib/session_bundle/bundle_shim.cc145
1 files changed, 110 insertions, 35 deletions
diff --git a/tensorflow/contrib/session_bundle/bundle_shim.cc b/tensorflow/contrib/session_bundle/bundle_shim.cc
index 9c7cdf192d..81b37b1cf2 100644
--- a/tensorflow/contrib/session_bundle/bundle_shim.cc
+++ b/tensorflow/contrib/session_bundle/bundle_shim.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/contrib/session_bundle/manifest.pb.h"
#include "tensorflow/contrib/session_bundle/session_bundle.h"
#include "tensorflow/contrib/session_bundle/signature.h"
+#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
@@ -46,33 +47,39 @@ bool IsRegressionSignature(const Signature& signature) {
// SignatureDefs.
SignatureDef BuildRegressionSignatureDef(
- const RegressionSignature& regression_signature) {
+ const RegressionSignature& regression_signature,
+ const std::unordered_map<string, DataType>& tensor_name_to_dtype) {
SignatureDef signature_def;
signature_def.set_method_name(kRegressMethodName);
internal::AddInputToSignatureDef(regression_signature.input().tensor_name(),
- kRegressInputs, &signature_def);
+ tensor_name_to_dtype, kRegressInputs,
+ &signature_def);
internal::AddOutputToSignatureDef(regression_signature.output().tensor_name(),
- kRegressOutputs, &signature_def);
+ tensor_name_to_dtype, kRegressOutputs,
+ &signature_def);
return signature_def;
}
SignatureDef BuildClassificationSignatureDef(
- const ClassificationSignature& classification_signature) {
+ const ClassificationSignature& classification_signature,
+ const std::unordered_map<string, DataType>& tensor_name_to_dtype) {
SignatureDef signature_def;
signature_def.set_method_name(kClassifyMethodName);
internal::AddInputToSignatureDef(
- classification_signature.input().tensor_name(), kClassifyInputs,
- &signature_def);
+ classification_signature.input().tensor_name(), tensor_name_to_dtype,
+ kClassifyInputs, &signature_def);
internal::AddOutputToSignatureDef(
- classification_signature.classes().tensor_name(), kClassifyOutputClasses,
- &signature_def);
+ classification_signature.classes().tensor_name(), tensor_name_to_dtype,
+ kClassifyOutputClasses, &signature_def);
internal::AddOutputToSignatureDef(
- classification_signature.scores().tensor_name(), kClassifyOutputScores,
- &signature_def);
+ classification_signature.scores().tensor_name(), tensor_name_to_dtype,
+ kClassifyOutputScores, &signature_def);
return signature_def;
}
-Status MaybeBuildPredictSignatureDef(MetaGraphDef* meta_graph_def) {
+Status MaybeBuildPredictSignatureDef(
+ const std::unordered_map<string, DataType>& tensor_name_to_dtype,
+ MetaGraphDef* meta_graph_def) {
Signature input_signature, output_signature;
// Ensure that named signatures corresponding to `inputs` and `outputs` keys
// exist.
@@ -97,13 +104,15 @@ Status MaybeBuildPredictSignatureDef(MetaGraphDef* meta_graph_def) {
// signature def.
for (const auto& map_entry : input_signature.generic_signature().map()) {
internal::AddInputToSignatureDef(map_entry.second.tensor_name(),
- map_entry.first, &signature_def);
+ tensor_name_to_dtype, map_entry.first,
+ &signature_def);
}
// Add map entries from the `outputs` generic signature to the output map in
// the signature def.
for (const auto& map_entry : output_signature.generic_signature().map()) {
internal::AddOutputToSignatureDef(map_entry.second.tensor_name(),
- map_entry.first, &signature_def);
+ tensor_name_to_dtype, map_entry.first,
+ &signature_def);
}
// Add the constructed signature def to the signature def map of the meta
// graph def. Use the default key if it isn't already in use.
@@ -148,8 +157,10 @@ Status LoadSavedModelFromLegacySessionBundlePath(
// Up-conversion of default signatures is supported for classification and
// regression.
-Status ConvertDefaultSignatureToSignatureDef(const Signatures& signatures,
- MetaGraphDef* meta_graph_def) {
+Status ConvertDefaultSignatureToSignatureDef(
+ const Signatures& signatures,
+ const std::unordered_map<string, DataType>& tensor_name_to_dtype,
+ MetaGraphDef* meta_graph_def) {
if (!signatures.has_default_signature()) {
return Status::OK();
}
@@ -165,10 +176,12 @@ Status ConvertDefaultSignatureToSignatureDef(const Signatures& signatures,
const Signature& signature = signatures.default_signature();
if (IsRegressionSignature(signature)) {
(*meta_graph_def->mutable_signature_def())[kDefaultServingSignatureDefKey] =
- BuildRegressionSignatureDef(signature.regression_signature());
+ BuildRegressionSignatureDef(signature.regression_signature(),
+ tensor_name_to_dtype);
} else if (IsClassificationSignature(signature)) {
(*meta_graph_def->mutable_signature_def())[kDefaultServingSignatureDefKey] =
- BuildClassificationSignatureDef(signature.classification_signature());
+ BuildClassificationSignatureDef(signature.classification_signature(),
+ tensor_name_to_dtype);
} else {
LOG(WARNING) << "Default signature up-conversion to SignatureDef is only "
"supported for `Classification` and `Regression`. Could "
@@ -180,14 +193,16 @@ Status ConvertDefaultSignatureToSignatureDef(const Signatures& signatures,
return Status::OK();
}
-Status ConvertNamedSignaturesToSignatureDef(const Signatures& signatures,
- MetaGraphDef* meta_graph_def) {
+Status ConvertNamedSignaturesToSignatureDef(
+ const Signatures& signatures,
+ const std::unordered_map<string, DataType>& tensor_name_to_dtype,
+ MetaGraphDef* meta_graph_def) {
if (signatures.named_signatures().empty()) {
return Status::OK();
}
// Check for a Predict signature for up-conversion.
Status predict_signature_def_status =
- MaybeBuildPredictSignatureDef(meta_graph_def);
+ MaybeBuildPredictSignatureDef(tensor_name_to_dtype, meta_graph_def);
for (const auto& it_named_signature : signatures.named_signatures()) {
const string key = it_named_signature.first;
// If a Predict SignatureDef was successfully constructed, skip the entries
@@ -200,10 +215,12 @@ Status ConvertNamedSignaturesToSignatureDef(const Signatures& signatures,
const Signature signature = it_named_signature.second;
if (IsRegressionSignature(signature)) {
(*meta_graph_def->mutable_signature_def())[key] =
- BuildRegressionSignatureDef(signature.regression_signature());
+ BuildRegressionSignatureDef(signature.regression_signature(),
+ tensor_name_to_dtype);
} else if (IsClassificationSignature(signature)) {
(*meta_graph_def->mutable_signature_def())[key] =
- BuildClassificationSignatureDef(signature.classification_signature());
+ BuildClassificationSignatureDef(signature.classification_signature(),
+ tensor_name_to_dtype);
} else {
LOG(WARNING)
<< "Named signature up-conversion to SignatureDef is only supported "
@@ -223,39 +240,97 @@ namespace internal {
// Helper functions to populate SignatureDef fields.
// Adds an entry to the `inputs` map of the supplied SignatureDef.
-void AddInputToSignatureDef(const string& tensor_name, const string& map_key,
- SignatureDef* signature_def) {
+void AddInputToSignatureDef(
+ const string& tensor_name,
+ const std::unordered_map<string, DataType>& tensor_name_to_dtype,
+ const string& input_key, SignatureDef* signature_def) {
if (tensor_name.empty()) {
+ LOG(WARNING) << "Tensor name not provided. Cannot add TensorInfo to "
+ "SignatureDef inputs.";
return;
}
- // TensorInfo messages used in the SignatureDefs are thinly populated with
- // name only.
+ // Extract the tensor-name in case the supplied string is a tensor-reference.
+ // Example: Extract "x" from "x:0".
+ std::size_t pos = tensor_name.find(":");
+ const string key =
+ (pos != std::string::npos) ? tensor_name.substr(0, pos) : tensor_name;
+ const auto it_tensor_info = tensor_name_to_dtype.find(key);
TensorInfo tensor_info;
tensor_info.set_name(tensor_name);
- (*signature_def->mutable_inputs())[map_key] = tensor_info;
+ if (it_tensor_info != tensor_name_to_dtype.end()) {
+ tensor_info.set_dtype(it_tensor_info->second);
+ } else {
+ LOG(WARNING)
+ << "No dtype found for tensor with name: " << tensor_name << ". "
+ << "Building TensorInfo with only name for SignatureDef inputs. "
+ << "Downstream functionality including validation may be "
+ << "impacted.";
+ }
+ (*signature_def->mutable_inputs())[input_key] = tensor_info;
}
// Adds an entry to the `outputs` map of the supplied SignatureDef.
-void AddOutputToSignatureDef(const string& tensor_name, const string& map_key,
- SignatureDef* signature_def) {
+void AddOutputToSignatureDef(
+ const string& tensor_name,
+ const std::unordered_map<string, DataType>& tensor_name_to_dtype,
+ const string& output_key, SignatureDef* signature_def) {
if (tensor_name.empty()) {
+ LOG(WARNING) << "Tensor name not provided. Cannot add TensorInfo to "
+ "SignatureDef outputs.";
return;
}
- // TensorInfo messages used in the SignatureDefs are thinly populated with
- // name only.
+ // Extract the tensor-name in case the supplied string is a tensor-reference.
+ // Example: Extract "x" from "x:0".
+ std::size_t pos = tensor_name.find(":");
+ const string key =
+ (pos != std::string::npos) ? tensor_name.substr(0, pos) : tensor_name;
+ const auto it_tensor_info = tensor_name_to_dtype.find(key);
TensorInfo tensor_info;
tensor_info.set_name(tensor_name);
- (*signature_def->mutable_outputs())[map_key] = tensor_info;
+ if (it_tensor_info != tensor_name_to_dtype.end()) {
+ tensor_info.set_dtype(it_tensor_info->second);
+ } else {
+ LOG(WARNING)
+ << "No dtype found for tensor with name: " << tensor_name << ". "
+ << "Building TensorInfo with only name for SignatureDef outputs."
+ << " Downstream functionality including validation may be "
+ << "impacted.";
+ }
+ (*signature_def->mutable_outputs())[output_key] = tensor_info;
+}
+
+// Builds a map from tensor name to the corresponding datatype, by parsing the
+// MetaGraphDef.
+Status BuildTensorNameToDtypeMap(
+ const MetaGraphDef& meta_graph_def,
+ std::unordered_map<string, DataType>* tensor_name_to_dtype) {
+ GraphConstructorOptions opts;
+ Graph graph(OpRegistry::Global());
+ TF_RETURN_IF_ERROR(
+ ConvertGraphDefToGraph(opts, meta_graph_def.graph_def(), &graph));
+ for (Node* node : graph.nodes()) {
+ for (auto dt : node->output_types()) {
+ tensor_name_to_dtype->insert(std::make_pair(node->name(), dt));
+ }
+ }
+ return Status::OK();
}
// Converts SessionBundle signatures to SavedModel signature-defs.
Status ConvertSignaturesToSignatureDefs(MetaGraphDef* meta_graph_def) {
Signatures signatures;
GetSignatures(*meta_graph_def, &signatures);
+
+ // Build a map of tensor-names to the corresponding tensor-info with `name`
+ // and `dtype` fields.
+ std::unordered_map<string, DataType> tensor_name_to_dtype;
TF_RETURN_IF_ERROR(
- ConvertDefaultSignatureToSignatureDef(signatures, meta_graph_def));
- TF_RETURN_IF_ERROR(
- ConvertNamedSignaturesToSignatureDef(signatures, meta_graph_def));
+ BuildTensorNameToDtypeMap(*meta_graph_def, &tensor_name_to_dtype));
+
+ TF_RETURN_IF_ERROR(ConvertDefaultSignatureToSignatureDef(
+ signatures, tensor_name_to_dtype, meta_graph_def));
+ TF_RETURN_IF_ERROR(ConvertNamedSignaturesToSignatureDef(
+ signatures, tensor_name_to_dtype, meta_graph_def));
return Status::OK();
}