aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Sukriti Ramesh <sukritiramesh@google.com>2017-02-09 10:56:58 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-09 11:18:15 -0800
commit4a75d35b1a8cc13d4c40c93773a90f3000daf289 (patch)
tree7bd9df616b4cffa26b0e1e9b7965a48c761c9df7
parent9683b095fce7b77df01d95ac3b07dcd17a083782 (diff)
Add functionality to populate dtype of TensorInfos in up-converted SessionBundles.
Change: 147054398
-rw-r--r--tensorflow/contrib/session_bundle/bundle_shim.cc145
-rw-r--r--tensorflow/contrib/session_bundle/bundle_shim.h22
-rw-r--r--tensorflow/contrib/session_bundle/bundle_shim_test.cc61
3 files changed, 167 insertions, 61 deletions
diff --git a/tensorflow/contrib/session_bundle/bundle_shim.cc b/tensorflow/contrib/session_bundle/bundle_shim.cc
index 9c7cdf192d..81b37b1cf2 100644
--- a/tensorflow/contrib/session_bundle/bundle_shim.cc
+++ b/tensorflow/contrib/session_bundle/bundle_shim.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/contrib/session_bundle/manifest.pb.h"
#include "tensorflow/contrib/session_bundle/session_bundle.h"
#include "tensorflow/contrib/session_bundle/signature.h"
+#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
@@ -46,33 +47,39 @@ bool IsRegressionSignature(const Signature& signature) {
// SignatureDefs.
SignatureDef BuildRegressionSignatureDef(
- const RegressionSignature& regression_signature) {
+ const RegressionSignature& regression_signature,
+ const std::unordered_map<string, DataType>& tensor_name_to_dtype) {
SignatureDef signature_def;
signature_def.set_method_name(kRegressMethodName);
internal::AddInputToSignatureDef(regression_signature.input().tensor_name(),
- kRegressInputs, &signature_def);
+ tensor_name_to_dtype, kRegressInputs,
+ &signature_def);
internal::AddOutputToSignatureDef(regression_signature.output().tensor_name(),
- kRegressOutputs, &signature_def);
+ tensor_name_to_dtype, kRegressOutputs,
+ &signature_def);
return signature_def;
}
SignatureDef BuildClassificationSignatureDef(
- const ClassificationSignature& classification_signature) {
+ const ClassificationSignature& classification_signature,
+ const std::unordered_map<string, DataType>& tensor_name_to_dtype) {
SignatureDef signature_def;
signature_def.set_method_name(kClassifyMethodName);
internal::AddInputToSignatureDef(
- classification_signature.input().tensor_name(), kClassifyInputs,
- &signature_def);
+ classification_signature.input().tensor_name(), tensor_name_to_dtype,
+ kClassifyInputs, &signature_def);
internal::AddOutputToSignatureDef(
- classification_signature.classes().tensor_name(), kClassifyOutputClasses,
- &signature_def);
+ classification_signature.classes().tensor_name(), tensor_name_to_dtype,
+ kClassifyOutputClasses, &signature_def);
internal::AddOutputToSignatureDef(
- classification_signature.scores().tensor_name(), kClassifyOutputScores,
- &signature_def);
+ classification_signature.scores().tensor_name(), tensor_name_to_dtype,
+ kClassifyOutputScores, &signature_def);
return signature_def;
}
-Status MaybeBuildPredictSignatureDef(MetaGraphDef* meta_graph_def) {
+Status MaybeBuildPredictSignatureDef(
+ const std::unordered_map<string, DataType>& tensor_name_to_dtype,
+ MetaGraphDef* meta_graph_def) {
Signature input_signature, output_signature;
// Ensure that named signatures corresponding to `inputs` and `outputs` keys
// exist.
@@ -97,13 +104,15 @@ Status MaybeBuildPredictSignatureDef(MetaGraphDef* meta_graph_def) {
// signature def.
for (const auto& map_entry : input_signature.generic_signature().map()) {
internal::AddInputToSignatureDef(map_entry.second.tensor_name(),
- map_entry.first, &signature_def);
+ tensor_name_to_dtype, map_entry.first,
+ &signature_def);
}
// Add map entries from the `outputs` generic signature to the output map in
// the signature def.
for (const auto& map_entry : output_signature.generic_signature().map()) {
internal::AddOutputToSignatureDef(map_entry.second.tensor_name(),
- map_entry.first, &signature_def);
+ tensor_name_to_dtype, map_entry.first,
+ &signature_def);
}
// Add the constructed signature def to the signature def map of the meta
// graph def. Use the default key if it isn't already in use.
@@ -148,8 +157,10 @@ Status LoadSavedModelFromLegacySessionBundlePath(
// Up-conversion of default signatures is supported for classification and
// regression.
-Status ConvertDefaultSignatureToSignatureDef(const Signatures& signatures,
- MetaGraphDef* meta_graph_def) {
+Status ConvertDefaultSignatureToSignatureDef(
+ const Signatures& signatures,
+ const std::unordered_map<string, DataType>& tensor_name_to_dtype,
+ MetaGraphDef* meta_graph_def) {
if (!signatures.has_default_signature()) {
return Status::OK();
}
@@ -165,10 +176,12 @@ Status ConvertDefaultSignatureToSignatureDef(const Signatures& signatures,
const Signature& signature = signatures.default_signature();
if (IsRegressionSignature(signature)) {
(*meta_graph_def->mutable_signature_def())[kDefaultServingSignatureDefKey] =
- BuildRegressionSignatureDef(signature.regression_signature());
+ BuildRegressionSignatureDef(signature.regression_signature(),
+ tensor_name_to_dtype);
} else if (IsClassificationSignature(signature)) {
(*meta_graph_def->mutable_signature_def())[kDefaultServingSignatureDefKey] =
- BuildClassificationSignatureDef(signature.classification_signature());
+ BuildClassificationSignatureDef(signature.classification_signature(),
+ tensor_name_to_dtype);
} else {
LOG(WARNING) << "Default signature up-conversion to SignatureDef is only "
"supported for `Classification` and `Regression`. Could "
@@ -180,14 +193,16 @@ Status ConvertDefaultSignatureToSignatureDef(const Signatures& signatures,
return Status::OK();
}
-Status ConvertNamedSignaturesToSignatureDef(const Signatures& signatures,
- MetaGraphDef* meta_graph_def) {
+Status ConvertNamedSignaturesToSignatureDef(
+ const Signatures& signatures,
+ const std::unordered_map<string, DataType>& tensor_name_to_dtype,
+ MetaGraphDef* meta_graph_def) {
if (signatures.named_signatures().empty()) {
return Status::OK();
}
// Check for a Predict signature for up-conversion.
Status predict_signature_def_status =
- MaybeBuildPredictSignatureDef(meta_graph_def);
+ MaybeBuildPredictSignatureDef(tensor_name_to_dtype, meta_graph_def);
for (const auto& it_named_signature : signatures.named_signatures()) {
const string key = it_named_signature.first;
// If a Predict SignatureDef was successfully constructed, skip the entries
@@ -200,10 +215,12 @@ Status ConvertNamedSignaturesToSignatureDef(const Signatures& signatures,
const Signature signature = it_named_signature.second;
if (IsRegressionSignature(signature)) {
(*meta_graph_def->mutable_signature_def())[key] =
- BuildRegressionSignatureDef(signature.regression_signature());
+ BuildRegressionSignatureDef(signature.regression_signature(),
+ tensor_name_to_dtype);
} else if (IsClassificationSignature(signature)) {
(*meta_graph_def->mutable_signature_def())[key] =
- BuildClassificationSignatureDef(signature.classification_signature());
+ BuildClassificationSignatureDef(signature.classification_signature(),
+ tensor_name_to_dtype);
} else {
LOG(WARNING)
<< "Named signature up-conversion to SignatureDef is only supported "
@@ -223,39 +240,97 @@ namespace internal {
// Helper functions to populate SignatureDef fields.
// Adds an entry to the `inputs` map of the supplied SignatureDef.
-void AddInputToSignatureDef(const string& tensor_name, const string& map_key,
- SignatureDef* signature_def) {
+void AddInputToSignatureDef(
+ const string& tensor_name,
+ const std::unordered_map<string, DataType>& tensor_name_to_dtype,
+ const string& input_key, SignatureDef* signature_def) {
if (tensor_name.empty()) {
+ LOG(WARNING) << "Tensor name not provided. Cannot add TensorInfo to "
+ "SignatureDef inputs.";
return;
}
- // TensorInfo messages used in the SignatureDefs are thinly populated with
- // name only.
+ // Extract the tensor-name in case the supplied string is a tensor-reference.
+ // Example: Extract "x" from "x:0".
+ std::size_t pos = tensor_name.find(":");
+ const string key =
+ (pos != std::string::npos) ? tensor_name.substr(0, pos) : tensor_name;
+ const auto it_tensor_info = tensor_name_to_dtype.find(key);
TensorInfo tensor_info;
tensor_info.set_name(tensor_name);
- (*signature_def->mutable_inputs())[map_key] = tensor_info;
+ if (it_tensor_info != tensor_name_to_dtype.end()) {
+ tensor_info.set_dtype(it_tensor_info->second);
+ } else {
+ LOG(WARNING)
+ << "No dtype found for tensor with name: " << tensor_name << ". "
+ << "Building TensorInfo with only name for SignatureDef inputs. "
+ << "Downstream functionality including validation may be "
+ << "impacted.";
+ }
+ (*signature_def->mutable_inputs())[input_key] = tensor_info;
}
// Adds an entry to the `outputs` map of the supplied SignatureDef.
-void AddOutputToSignatureDef(const string& tensor_name, const string& map_key,
- SignatureDef* signature_def) {
+void AddOutputToSignatureDef(
+ const string& tensor_name,
+ const std::unordered_map<string, DataType>& tensor_name_to_dtype,
+ const string& output_key, SignatureDef* signature_def) {
if (tensor_name.empty()) {
+ LOG(WARNING) << "Tensor name not provided. Cannot add TensorInfo to "
+ "SignatureDef outputs.";
return;
}
- // TensorInfo messages used in the SignatureDefs are thinly populated with
- // name only.
+ // Extract the tensor-name in case the supplied string is a tensor-reference.
+ // Example: Extract "x" from "x:0".
+ std::size_t pos = tensor_name.find(":");
+ const string key =
+ (pos != std::string::npos) ? tensor_name.substr(0, pos) : tensor_name;
+ const auto it_tensor_info = tensor_name_to_dtype.find(key);
TensorInfo tensor_info;
tensor_info.set_name(tensor_name);
- (*signature_def->mutable_outputs())[map_key] = tensor_info;
+ if (it_tensor_info != tensor_name_to_dtype.end()) {
+ tensor_info.set_dtype(it_tensor_info->second);
+ } else {
+ LOG(WARNING)
+ << "No dtype found for tensor with name: " << tensor_name << ". "
+ << "Building TensorInfo with only name for SignatureDef outputs."
+ << " Downstream functionality including validation may be "
+ << "impacted.";
+ }
+ (*signature_def->mutable_outputs())[output_key] = tensor_info;
+}
+
+// Builds a map from tensor name to the corresponding datatype, by parsing the
+// MetaGraphDef.
+Status BuildTensorNameToDtypeMap(
+ const MetaGraphDef& meta_graph_def,
+ std::unordered_map<string, DataType>* tensor_name_to_dtype) {
+ GraphConstructorOptions opts;
+ Graph graph(OpRegistry::Global());
+ TF_RETURN_IF_ERROR(
+ ConvertGraphDefToGraph(opts, meta_graph_def.graph_def(), &graph));
+ for (Node* node : graph.nodes()) {
+ for (auto dt : node->output_types()) {
+ tensor_name_to_dtype->insert(std::make_pair(node->name(), dt));
+ }
+ }
+ return Status::OK();
}
// Converts SessionBundle signatures to SavedModel signature-defs.
Status ConvertSignaturesToSignatureDefs(MetaGraphDef* meta_graph_def) {
Signatures signatures;
GetSignatures(*meta_graph_def, &signatures);
+
+ // Build a map of tensor-names to the corresponding tensor-info with `name`
+ // and `dtype` fields.
+ std::unordered_map<string, DataType> tensor_name_to_dtype;
TF_RETURN_IF_ERROR(
- ConvertDefaultSignatureToSignatureDef(signatures, meta_graph_def));
- TF_RETURN_IF_ERROR(
- ConvertNamedSignaturesToSignatureDef(signatures, meta_graph_def));
+ BuildTensorNameToDtypeMap(*meta_graph_def, &tensor_name_to_dtype));
+
+ TF_RETURN_IF_ERROR(ConvertDefaultSignatureToSignatureDef(
+ signatures, tensor_name_to_dtype, meta_graph_def));
+ TF_RETURN_IF_ERROR(ConvertNamedSignaturesToSignatureDef(
+ signatures, tensor_name_to_dtype, meta_graph_def));
return Status::OK();
}
diff --git a/tensorflow/contrib/session_bundle/bundle_shim.h b/tensorflow/contrib/session_bundle/bundle_shim.h
index 37c242c6ea..e24efa0de1 100644
--- a/tensorflow/contrib/session_bundle/bundle_shim.h
+++ b/tensorflow/contrib/session_bundle/bundle_shim.h
@@ -32,15 +32,21 @@ namespace tensorflow {
namespace serving {
namespace internal {
-// Adds an entry (key and value) to the input map of the signature def.
-void AddInputToSignatureDef(const string& tensor_name,
- const string& input_map_key,
- SignatureDef* signature_def);
+// Adds an entry (key and value) to the input map of the signature def. Builds
+// TensorInfos for the SignatureDefs by using the name and dtype information
+// from the supplied map.
+void AddInputToSignatureDef(
+ const string& tensor_name,
+ const std::unordered_map<string, DataType>& tensor_name_to_dtype,
+ const string& input_map_key, SignatureDef* signature_def);
-// Adds an entry (key and value) to the output map of the signature def.
-void AddOutputToSignatureDef(const string& tensor_name,
- const string& output_map_key,
- SignatureDef* signature_def);
+// Adds an entry (key and value) to the output map of the signature def. Builds
+// TensorInfos for the SignatureDefs by using the name and dtype information
+// from the supplied map.
+void AddOutputToSignatureDef(
+ const string& tensor_name,
+ const std::unordered_map<string, DataType>& tensor_name_to_dtype,
+ const string& output_map_key, SignatureDef* signature_def);
// Converts signatures in the MetaGraphDef into a SignatureDefs in the
// MetaGraphDef.
diff --git a/tensorflow/contrib/session_bundle/bundle_shim_test.cc b/tensorflow/contrib/session_bundle/bundle_shim_test.cc
index 0f4f0f6ee4..ac2acfe870 100644
--- a/tensorflow/contrib/session_bundle/bundle_shim_test.cc
+++ b/tensorflow/contrib/session_bundle/bundle_shim_test.cc
@@ -83,21 +83,43 @@ void LoadAndValidateSavedModelBundle(const string& export_dir,
TensorInfo input_tensor_info =
regression_signature_def.inputs().find(kRegressInputs)->second;
EXPECT_EQ(1, regression_signature_def.outputs_size());
+ // Ensure the TensorInfo has dtype populated.
+ EXPECT_EQ(DT_STRING, input_tensor_info.dtype());
ASSERT_FALSE(regression_signature_def.outputs().find(kRegressOutputs) ==
regression_signature_def.outputs().end());
TensorInfo output_tensor_info =
regression_signature_def.outputs().find(kRegressOutputs)->second;
+ // Ensure the TensorInfo has dtype populated.
+ EXPECT_EQ(DT_FLOAT, output_tensor_info.dtype());
ValidateHalfPlusTwo(saved_model_bundle, input_tensor_info.name(),
output_tensor_info.name());
}
+// Helper function to validate that the SignatureDef found in the MetaGraphDef
+// with the provided key has the expected string representation.
+void ValidateSignatureDef(const MetaGraphDef& meta_graph_def, const string& key,
+ const string& expected_string_signature_def) {
+ tensorflow::SignatureDef expected_signature;
+ CHECK(protobuf::TextFormat::ParseFromString(expected_string_signature_def,
+ &expected_signature));
+ auto iter = meta_graph_def.signature_def().find(key);
+ ASSERT_TRUE(iter != meta_graph_def.signature_def().end());
+ EXPECT_EQ(expected_signature.DebugString(), iter->second.DebugString());
+}
+
// Checks that the input map in a signature def is populated correctly.
TEST(BundleShimTest, AddInputToSignatureDef) {
SignatureDef signature_def;
const string tensor_name = "foo_tensor";
const string map_key = "foo_key";
- AddInputToSignatureDef(tensor_name, map_key, &signature_def);
+
+ // Build a map of tensor-name to dtype, for the unit-test.
+ std::unordered_map<string, DataType> tensor_name_to_dtype;
+ tensor_name_to_dtype[tensor_name] = tensorflow::DT_STRING;
+
+ AddInputToSignatureDef(tensor_name, tensor_name_to_dtype, map_key,
+ &signature_def);
EXPECT_EQ(1, signature_def.inputs_size());
EXPECT_EQ(tensor_name, signature_def.inputs().find(map_key)->second.name());
}
@@ -107,7 +129,13 @@ TEST(BundleShimTest, AddOutputToSignatureDef) {
SignatureDef signature_def;
const string tensor_name = "foo_tensor";
const string map_key = "foo_key";
- AddOutputToSignatureDef(tensor_name, map_key, &signature_def);
+
+ // Build a map of tensor-name to dtype, for the unit-test.
+ std::unordered_map<string, DataType> tensor_name_to_dtype;
+ tensor_name_to_dtype[tensor_name] = tensorflow::DT_STRING;
+
+ AddOutputToSignatureDef(tensor_name, tensor_name_to_dtype, map_key,
+ &signature_def);
EXPECT_EQ(1, signature_def.outputs_size());
EXPECT_EQ(tensor_name, signature_def.outputs().find(map_key)->second.name());
}
@@ -213,18 +241,6 @@ TEST(BundleShimTest, DefaultSignatureGeneric) {
EXPECT_EQ(0, meta_graph_def.signature_def_size());
}
-// Helper function to validate that the SignatureDef found in the MetaGraphDef
-// with the provided key has the expected string representation.
-void ValidateSignatureDef(const MetaGraphDef& meta_graph_def, const string& key,
- const string& expected_string_signature_def) {
- tensorflow::SignatureDef expected_signature;
- CHECK(protobuf::TextFormat::ParseFromString(expected_string_signature_def,
- &expected_signature));
- auto iter = meta_graph_def.signature_def().find(key);
- ASSERT_TRUE(iter != meta_graph_def.signature_def().end());
- EXPECT_EQ(expected_signature.DebugString(), iter->second.DebugString());
-}
-
TEST(BundleShimTest, NamedRegressionSignatures) {
Signatures signatures;
@@ -522,11 +538,20 @@ TEST(BundleShimTest, BasicExportSessionBundle) {
found_named_signature = true;
EXPECT_EQ(1, signature_def.inputs_size());
- EXPECT_FALSE(signature_def.inputs().find("x") ==
- signature_def.inputs().end());
+ const auto it_inputs_x = signature_def.inputs().find("x");
+ EXPECT_FALSE(it_inputs_x == signature_def.inputs().end());
+ // Ensure the TensorInfo has name and dtype populated.
+ const TensorInfo& tensor_info_x = it_inputs_x->second;
+ EXPECT_EQ("x:0", tensor_info_x.name());
+ EXPECT_EQ(DT_FLOAT, tensor_info_x.dtype());
+
EXPECT_EQ(1, signature_def.outputs_size());
- EXPECT_FALSE(signature_def.outputs().find("y") ==
- signature_def.outputs().end());
+ const auto it_outputs_y = signature_def.outputs().find("y");
+ EXPECT_FALSE(it_outputs_y == signature_def.outputs().end());
+ // Ensure the TensorInfo has name and dtype populated.
+ const TensorInfo& tensor_info_y = it_outputs_y->second;
+ EXPECT_EQ("y:0", tensor_info_y.name());
+ EXPECT_EQ(DT_FLOAT, tensor_info_y.dtype());
}
EXPECT_TRUE(found_named_signature);
}