diff options
author | Sukriti Ramesh <sukritiramesh@google.com> | 2017-02-09 10:56:58 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-02-09 11:18:15 -0800 |
commit | 4a75d35b1a8cc13d4c40c93773a90f3000daf289 (patch) | |
tree | 7bd9df616b4cffa26b0e1e9b7965a48c761c9df7 /tensorflow/contrib/session_bundle/bundle_shim.cc | |
parent | 9683b095fce7b77df01d95ac3b07dcd17a083782 (diff) |
Add functionality to populate dtype of TensorInfos in up-converted SessionBundles.
Change: 147054398
Diffstat (limited to 'tensorflow/contrib/session_bundle/bundle_shim.cc')
-rw-r--r-- | tensorflow/contrib/session_bundle/bundle_shim.cc | 145 |
1 files changed, 110 insertions, 35 deletions
diff --git a/tensorflow/contrib/session_bundle/bundle_shim.cc b/tensorflow/contrib/session_bundle/bundle_shim.cc index 9c7cdf192d..81b37b1cf2 100644 --- a/tensorflow/contrib/session_bundle/bundle_shim.cc +++ b/tensorflow/contrib/session_bundle/bundle_shim.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/contrib/session_bundle/manifest.pb.h" #include "tensorflow/contrib/session_bundle/session_bundle.h" #include "tensorflow/contrib/session_bundle/signature.h" +#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" @@ -46,33 +47,39 @@ bool IsRegressionSignature(const Signature& signature) { // SignatureDefs. SignatureDef BuildRegressionSignatureDef( - const RegressionSignature& regression_signature) { + const RegressionSignature& regression_signature, + const std::unordered_map<string, DataType>& tensor_name_to_dtype) { SignatureDef signature_def; signature_def.set_method_name(kRegressMethodName); internal::AddInputToSignatureDef(regression_signature.input().tensor_name(), - kRegressInputs, &signature_def); + tensor_name_to_dtype, kRegressInputs, + &signature_def); internal::AddOutputToSignatureDef(regression_signature.output().tensor_name(), - kRegressOutputs, &signature_def); + tensor_name_to_dtype, kRegressOutputs, + &signature_def); return signature_def; } SignatureDef BuildClassificationSignatureDef( - const ClassificationSignature& classification_signature) { + const ClassificationSignature& classification_signature, + const std::unordered_map<string, DataType>& tensor_name_to_dtype) { SignatureDef signature_def; signature_def.set_method_name(kClassifyMethodName); internal::AddInputToSignatureDef( - classification_signature.input().tensor_name(), kClassifyInputs, - &signature_def); + classification_signature.input().tensor_name(), tensor_name_to_dtype, + kClassifyInputs, &signature_def); internal::AddOutputToSignatureDef( - classification_signature.classes().tensor_name(), kClassifyOutputClasses, - &signature_def); + classification_signature.classes().tensor_name(), tensor_name_to_dtype, + kClassifyOutputClasses, &signature_def); internal::AddOutputToSignatureDef( - classification_signature.scores().tensor_name(), kClassifyOutputScores, - &signature_def); + classification_signature.scores().tensor_name(), tensor_name_to_dtype, + kClassifyOutputScores, &signature_def); return signature_def; } -Status MaybeBuildPredictSignatureDef(MetaGraphDef* meta_graph_def) { +Status MaybeBuildPredictSignatureDef( + const std::unordered_map<string, DataType>& tensor_name_to_dtype, + MetaGraphDef* meta_graph_def) { Signature input_signature, output_signature; // Ensure that named signatures corresponding to `inputs` and `outputs` keys // exist. @@ -97,13 +104,15 @@ Status MaybeBuildPredictSignatureDef(MetaGraphDef* meta_graph_def) { // signature def. for (const auto& map_entry : input_signature.generic_signature().map()) { internal::AddInputToSignatureDef(map_entry.second.tensor_name(), - map_entry.first, &signature_def); + tensor_name_to_dtype, map_entry.first, + &signature_def); } // Add map entries from the `outputs` generic signature to the output map in // the signature def. for (const auto& map_entry : output_signature.generic_signature().map()) { internal::AddOutputToSignatureDef(map_entry.second.tensor_name(), - map_entry.first, &signature_def); + tensor_name_to_dtype, map_entry.first, + &signature_def); } // Add the constructed signature def to the signature def map of the meta // graph def. Use the default key if it isn't already in use. @@ -148,8 +157,10 @@ Status LoadSavedModelFromLegacySessionBundlePath( // Up-conversion of default signatures is supported for classification and // regression. -Status ConvertDefaultSignatureToSignatureDef(const Signatures& signatures, - MetaGraphDef* meta_graph_def) { +Status ConvertDefaultSignatureToSignatureDef( + const Signatures& signatures, + const std::unordered_map<string, DataType>& tensor_name_to_dtype, + MetaGraphDef* meta_graph_def) { if (!signatures.has_default_signature()) { return Status::OK(); } @@ -165,10 +176,12 @@ Status ConvertDefaultSignatureToSignatureDef(const Signatures& signatures, const Signature& signature = signatures.default_signature(); if (IsRegressionSignature(signature)) { (*meta_graph_def->mutable_signature_def())[kDefaultServingSignatureDefKey] = - BuildRegressionSignatureDef(signature.regression_signature()); + BuildRegressionSignatureDef(signature.regression_signature(), + tensor_name_to_dtype); } else if (IsClassificationSignature(signature)) { (*meta_graph_def->mutable_signature_def())[kDefaultServingSignatureDefKey] = - BuildClassificationSignatureDef(signature.classification_signature()); + BuildClassificationSignatureDef(signature.classification_signature(), + tensor_name_to_dtype); } else { LOG(WARNING) << "Default signature up-conversion to SignatureDef is only " "supported for `Classification` and `Regression`. Could " @@ -180,14 +193,16 @@ Status ConvertDefaultSignatureToSignatureDef(const Signatures& signatures, return Status::OK(); } -Status ConvertNamedSignaturesToSignatureDef(const Signatures& signatures, - MetaGraphDef* meta_graph_def) { +Status ConvertNamedSignaturesToSignatureDef( + const Signatures& signatures, + const std::unordered_map<string, DataType>& tensor_name_to_dtype, + MetaGraphDef* meta_graph_def) { if (signatures.named_signatures().empty()) { return Status::OK(); } // Check for a Predict signature for up-conversion. Status predict_signature_def_status = - MaybeBuildPredictSignatureDef(meta_graph_def); + MaybeBuildPredictSignatureDef(tensor_name_to_dtype, meta_graph_def); for (const auto& it_named_signature : signatures.named_signatures()) { const string key = it_named_signature.first; // If a Predict SignatureDef was successfully constructed, skip the entries @@ -200,10 +215,12 @@ Status ConvertNamedSignaturesToSignatureDef(const Signatures& signatures, const Signature signature = it_named_signature.second; if (IsRegressionSignature(signature)) { (*meta_graph_def->mutable_signature_def())[key] = - BuildRegressionSignatureDef(signature.regression_signature()); + BuildRegressionSignatureDef(signature.regression_signature(), + tensor_name_to_dtype); } else if (IsClassificationSignature(signature)) { (*meta_graph_def->mutable_signature_def())[key] = - BuildClassificationSignatureDef(signature.classification_signature()); + BuildClassificationSignatureDef(signature.classification_signature(), + tensor_name_to_dtype); } else { LOG(WARNING) << "Named signature up-conversion to SignatureDef is only supported " @@ -223,39 +240,97 @@ namespace internal { // Helper functions to populate SignatureDef fields. // Adds an entry to the `inputs` map of the supplied SignatureDef. -void AddInputToSignatureDef(const string& tensor_name, const string& map_key, - SignatureDef* signature_def) { +void AddInputToSignatureDef( + const string& tensor_name, + const std::unordered_map<string, DataType>& tensor_name_to_dtype, + const string& input_key, SignatureDef* signature_def) { if (tensor_name.empty()) { + LOG(WARNING) << "Tensor name not provided. Cannot add TensorInfo to " + "SignatureDef inputs."; return; } - // TensorInfo messages used in the SignatureDefs are thinly populated with - // name only. + // Extract the tensor-name in case the supplied string is a tensor-reference. + // Example: Extract "x" from "x:0". + std::size_t pos = tensor_name.find(":"); + const string key = + (pos != std::string::npos) ? tensor_name.substr(0, pos) : tensor_name; + const auto it_tensor_info = tensor_name_to_dtype.find(key); TensorInfo tensor_info; tensor_info.set_name(tensor_name); - (*signature_def->mutable_inputs())[map_key] = tensor_info; + if (it_tensor_info != tensor_name_to_dtype.end()) { + tensor_info.set_dtype(it_tensor_info->second); + } else { + LOG(WARNING) + << "No dtype found for tensor with name: " << tensor_name << ". " + << "Building TensorInfo with only name for SignatureDef inputs. " + << "Downstream functionality including validation may be " + << "impacted."; + } + (*signature_def->mutable_inputs())[input_key] = tensor_info; } // Adds an entry to the `outputs` map of the supplied SignatureDef. -void AddOutputToSignatureDef(const string& tensor_name, const string& map_key, - SignatureDef* signature_def) { +void AddOutputToSignatureDef( + const string& tensor_name, + const std::unordered_map<string, DataType>& tensor_name_to_dtype, + const string& output_key, SignatureDef* signature_def) { if (tensor_name.empty()) { + LOG(WARNING) << "Tensor name not provided. Cannot add TensorInfo to " + "SignatureDef outputs."; return; } - // TensorInfo messages used in the SignatureDefs are thinly populated with - // name only. + // Extract the tensor-name in case the supplied string is a tensor-reference. + // Example: Extract "x" from "x:0". + std::size_t pos = tensor_name.find(":"); + const string key = + (pos != std::string::npos) ? tensor_name.substr(0, pos) : tensor_name; + const auto it_tensor_info = tensor_name_to_dtype.find(key); TensorInfo tensor_info; tensor_info.set_name(tensor_name); - (*signature_def->mutable_outputs())[map_key] = tensor_info; + if (it_tensor_info != tensor_name_to_dtype.end()) { + tensor_info.set_dtype(it_tensor_info->second); + } else { + LOG(WARNING) + << "No dtype found for tensor with name: " << tensor_name << ". " + << "Building TensorInfo with only name for SignatureDef outputs." + << " Downstream functionality including validation may be " + << "impacted."; + } + (*signature_def->mutable_outputs())[output_key] = tensor_info; +} + +// Builds a map from tensor name to the corresponding datatype, by parsing the +// MetaGraphDef. +Status BuildTensorNameToDtypeMap( + const MetaGraphDef& meta_graph_def, + std::unordered_map<string, DataType>* tensor_name_to_dtype) { + GraphConstructorOptions opts; + Graph graph(OpRegistry::Global()); + TF_RETURN_IF_ERROR( + ConvertGraphDefToGraph(opts, meta_graph_def.graph_def(), &graph)); + for (Node* node : graph.nodes()) { + for (auto dt : node->output_types()) { + tensor_name_to_dtype->insert(std::make_pair(node->name(), dt)); + } + } + return Status::OK(); } // Converts SessionBundle signatures to SavedModel signature-defs. Status ConvertSignaturesToSignatureDefs(MetaGraphDef* meta_graph_def) { Signatures signatures; GetSignatures(*meta_graph_def, &signatures); + + // Build a map of tensor-names to the corresponding tensor-info with `name` + // and `dtype` fields. + std::unordered_map<string, DataType> tensor_name_to_dtype; TF_RETURN_IF_ERROR( - ConvertDefaultSignatureToSignatureDef(signatures, meta_graph_def)); - TF_RETURN_IF_ERROR( - ConvertNamedSignaturesToSignatureDef(signatures, meta_graph_def)); + BuildTensorNameToDtypeMap(*meta_graph_def, &tensor_name_to_dtype)); + + TF_RETURN_IF_ERROR(ConvertDefaultSignatureToSignatureDef( + signatures, tensor_name_to_dtype, meta_graph_def)); + TF_RETURN_IF_ERROR(ConvertNamedSignaturesToSignatureDef( + signatures, tensor_name_to_dtype, meta_graph_def)); return Status::OK(); } |