aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/session_bundle/session_bundle.cc34
-rw-r--r--tensorflow/contrib/session_bundle/session_bundle_test.cc190
-rw-r--r--tensorflow/contrib/session_bundle/signature.cc17
-rw-r--r--tensorflow/contrib/session_bundle/signature_test.cc41
4 files changed, 266 insertions, 16 deletions
diff --git a/tensorflow/contrib/session_bundle/session_bundle.cc b/tensorflow/contrib/session_bundle/session_bundle.cc
index 5029dcd7fe..6895c094be 100644
--- a/tensorflow/contrib/session_bundle/session_bundle.cc
+++ b/tensorflow/contrib/session_bundle/session_bundle.cc
@@ -91,7 +91,7 @@ string GetVariablesFilename(const StringPiece export_dir) {
const char kVariablesFilename[] = "export";
const char kVariablesFilenamePattern[] = "export-\?\?\?\?\?-of-\?\?\?\?\?";
if (Env::Default()->FileExists(
- tensorflow::io::JoinPath(export_dir, kVariablesFilename))) {
+ tensorflow::io::JoinPath(export_dir, kVariablesFilename))) {
return tensorflow::io::JoinPath(export_dir, kVariablesFilename);
} else {
return tensorflow::io::JoinPath(export_dir, kVariablesFilenamePattern);
@@ -104,8 +104,8 @@ Status RunRestoreOp(const StringPiece export_dir,
const StringPiece variables_filename_const_op_name,
tensorflow::Session* session) {
LOG(INFO) << "Running restore op for SessionBundle";
- Tensor variables_tensor = CreateStringTensor(
- GetVariablesFilename(export_dir));
+ Tensor variables_tensor =
+ CreateStringTensor(GetVariablesFilename(export_dir));
std::vector<std::pair<string, Tensor>> inputs = {
{variables_filename_const_op_name.ToString(), variables_tensor}};
AddAssetsTensorsToInputs(export_dir, asset_files, &inputs);
@@ -137,11 +137,21 @@ tensorflow::Status LoadSessionBundleFromPath(
// Use serving graph_def in MetaGraphDef collection_def.
if (graph_collection_def.any_list().value_size() != 1) {
return errors::FailedPrecondition(
- strings::StrCat("Expected exactly one serving GraphDef in : ",
- bundle->meta_graph_def.DebugString()));
+ "Expected exactly one serving GraphDef in : ",
+ bundle->meta_graph_def.DebugString());
+ }
+ const auto& any = graph_collection_def.any_list().value(0);
+ if (!any.Is<GraphDef>()) {
+ return errors::FailedPrecondition(
+ "Expected Any type_url for: ",
+ tensorflow::GraphDef::default_instance().descriptor()->full_name(),
+ ". Got: ", string(any.type_url().data(), any.type_url().size()), ".");
}
tensorflow::GraphDef graph_def;
- graph_collection_def.any_list().value(0).UnpackTo(&graph_def);
+ if (!any.UnpackTo(&graph_def)) {
+ return errors::FailedPrecondition("Failed to unpack: ",
+ any.DebugString());
+ }
TF_RETURN_IF_ERROR(
CreateSessionFromGraphDef(options, graph_def, &bundle->session));
} else {
@@ -157,7 +167,17 @@ tensorflow::Status LoadSessionBundleFromPath(
const auto& any_assets = assets_it->second.any_list().value();
for (const auto& any_asset : any_assets) {
AssetFile asset_file;
- any_asset.UnpackTo(&asset_file);
+ if (!any_asset.Is<AssetFile>()) {
+ return errors::FailedPrecondition(
+ "Expected asset Any type_url for: ",
+ asset_file.descriptor()->full_name(), ". Got: ",
+ string(any_asset.type_url().data(), any_asset.type_url().size()),
+ ".");
+ }
+ if (!any_asset.UnpackTo(&asset_file)) {
+ return errors::FailedPrecondition("Failed to unpack: ",
+ any_asset.DebugString());
+ }
asset_files.push_back(asset_file);
}
}
diff --git a/tensorflow/contrib/session_bundle/session_bundle_test.cc b/tensorflow/contrib/session_bundle/session_bundle_test.cc
index 5aeb344524..5ecbb32f88 100644
--- a/tensorflow/contrib/session_bundle/session_bundle_test.cc
+++ b/tensorflow/contrib/session_bundle/session_bundle_test.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include <vector>
#include "google/protobuf/any.pb.h"
+#include "tensorflow/contrib/session_bundle/signature.h"
#include "tensorflow/contrib/session_bundle/test_util.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
@@ -36,9 +37,48 @@ namespace tensorflow {
namespace serving {
namespace {
-TEST(LoadSessionBundleFromPath, Basic) {
- const string export_path = test_util::TestSrcDirPath(
+// Constants for the export file names.
+const char kVariablesFilename[] = "export-00000-of-00001";
+const char kMetaGraphDefFilename[] = "export.meta";
+
+// Function used to rewrite a MetaGraphDef.
+using MetaGraphDefTwiddler = std::function<void(MetaGraphDef*)>;
+
+// Copy the base half_plus_two to `export_path`.
+// Outputs the files using the passed names (typically the constants above).
+// The Twiddler can be used to update the MetaGraphDef before output.
+Status CopyExport(const string& export_path, const string& variables_filename,
+ const string& meta_graph_def_filename,
+ const MetaGraphDefTwiddler& twiddler) {
+ TF_RETURN_IF_ERROR(Env::Default()->CreateDir(export_path));
+ const string orig_path = test_util::TestSrcDirPath(
"session_bundle/example/half_plus_two/00000123");
+ {
+ const string source =
+ tensorflow::io::JoinPath(orig_path, kVariablesFilename);
+ const string sink =
+ tensorflow::io::JoinPath(export_path, variables_filename);
+
+ string data;
+ TF_RETURN_IF_ERROR(ReadFileToString(Env::Default(), source, &data));
+ TF_RETURN_IF_ERROR(WriteStringToFile(Env::Default(), sink, data));
+ }
+ {
+ const string source =
+ tensorflow::io::JoinPath(orig_path, kMetaGraphDefFilename);
+ const string sink =
+ tensorflow::io::JoinPath(export_path, meta_graph_def_filename);
+
+ tensorflow::MetaGraphDef graph_def;
+ TF_RETURN_IF_ERROR(ReadBinaryProto(Env::Default(), source, &graph_def));
+ twiddler(&graph_def);
+ TF_RETURN_IF_ERROR(
+ WriteStringToFile(Env::Default(), sink, graph_def.SerializeAsString()));
+ }
+ return Status::OK();
+}
+
+void BasicTest(const string& export_path) {
tensorflow::SessionOptions options;
SessionBundle bundle;
TF_ASSERT_OK(LoadSessionBundleFromPath(options, export_path, &bundle));
@@ -67,10 +107,8 @@ TEST(LoadSessionBundleFromPath, Basic) {
Tensor input = test::AsTensor<float>({0, 1, 2, 3}, TensorShape({4, 1}));
// Recover the Tensor names of our inputs and outputs.
- auto collection_def = bundle.meta_graph_def.collection_def();
Signatures signatures;
- ASSERT_EQ(1, collection_def[kSignaturesKey].any_list().value_size());
- collection_def[kSignaturesKey].any_list().value(0).UnpackTo(&signatures);
+ TF_ASSERT_OK(GetSignatures(bundle.meta_graph_def, &signatures));
ASSERT_TRUE(signatures.default_signature().has_regression_signature());
const tensorflow::serving::RegressionSignature regression_signature =
signatures.default_signature().regression_signature();
@@ -86,6 +124,12 @@ TEST(LoadSessionBundleFromPath, Basic) {
outputs[0], test::AsTensor<float>({2, 2.5, 3, 3.5}, TensorShape({4, 1})));
}
+TEST(LoadSessionBundleFromPath, BasicTensorflowContrib) {
+ const string export_path = test_util::TestSrcDirPath(
+ "session_bundle/example/half_plus_two/00000123");
+ BasicTest(export_path);
+}
+
TEST(LoadSessionBundleFromPath, BadExportPath) {
const string export_path = test_util::TestSrcDirPath("/tmp/bigfoot");
tensorflow::SessionOptions options;
@@ -97,6 +141,142 @@ TEST(LoadSessionBundleFromPath, BadExportPath) {
EXPECT_TRUE(msg.find("Not found") != std::string::npos) << msg;
}
+class SessionBundleTest : public ::testing::Test {
+ protected:
+ // Copy the half_plus_two graph and apply the twiddler to rewrite the
+ // MetaGraphDef.
+ // Returns the path of the export.
+ // ** Should only be called once per test **
+ string SetupExport(MetaGraphDefTwiddler twiddler) {
+ return SetupExport(twiddler, kVariablesFilename, kMetaGraphDefFilename);
+ }
+ // SetupExport that allows for the variables and meta_graph_def filenames
+ // to be overridden.
+ string SetupExport(MetaGraphDefTwiddler twiddler,
+ const string& variables_filename,
+ const string& meta_graph_def_filename) {
+ // Construct a unique path name based on the test name.
+ const ::testing::TestInfo* const test_info =
+ ::testing::UnitTest::GetInstance()->current_test_info();
+ const string export_path = tensorflow::io::JoinPath(
+ testing::TmpDir(),
+ strings::StrCat(test_info->test_case_name(), test_info->name()));
+ TF_CHECK_OK(CopyExport(export_path, variables_filename,
+ meta_graph_def_filename, twiddler));
+ return export_path;
+ }
+
+ tensorflow::SessionOptions options_;
+ SessionBundle bundle_;
+ Status status_;
+};
+
+TEST_F(SessionBundleTest, Basic) {
+ const string export_path = SetupExport([](MetaGraphDef*) {});
+ BasicTest(export_path);
+}
+
+TEST_F(SessionBundleTest, UnshardedVariableFile) {
+ // Test that we can properly read the variables when exported
+ // without sharding.
+ const string export_path =
+ SetupExport([](MetaGraphDef*) {}, "export", kMetaGraphDefFilename);
+ BasicTest(export_path);
+}
+
+TEST_F(SessionBundleTest, ServingGraph_Empty) {
+ const string path = SetupExport([](MetaGraphDef* def) {
+ (*def->mutable_collection_def())[kGraphKey].clear_any_list();
+ });
+ status_ = LoadSessionBundleFromPath(options_, path, &bundle_);
+ EXPECT_FALSE(status_.ok());
+ EXPECT_TRUE(StringPiece(status_.error_message())
+ .contains("Expected exactly one serving GraphDef"))
+ << status_.error_message();
+}
+
+TEST_F(SessionBundleTest, ServingGraphAny_IncorrectType) {
+ const string path = SetupExport([](MetaGraphDef* def) {
+ // Pack an unexpected type in the GraphDef Any.
+ (*def->mutable_collection_def())[kGraphKey].clear_any_list();
+ auto* any = (*def->mutable_collection_def())[kGraphKey]
+ .mutable_any_list()
+ ->add_value();
+ any->PackFrom(AssetFile());
+ });
+ status_ = LoadSessionBundleFromPath(options_, path, &bundle_);
+ EXPECT_FALSE(status_.ok());
+ EXPECT_TRUE(StringPiece(status_.error_message())
+ .contains("Expected Any type_url for: tensorflow.GraphDef"))
+ << status_.error_message();
+}
+
+TEST_F(SessionBundleTest, ServingGraphAnyValue_Corrupted) {
+ const string path = SetupExport([](MetaGraphDef* def) {
+ // Pack an unexpected type in the GraphDef Any.
+ (*def->mutable_collection_def())[kGraphKey].clear_any_list();
+ auto* any = (*def->mutable_collection_def())[kGraphKey]
+ .mutable_any_list()
+ ->add_value();
+ any->PackFrom(GraphDef());
+ any->set_value("junk junk");
+ });
+ status_ = LoadSessionBundleFromPath(options_, path, &bundle_);
+ EXPECT_FALSE(status_.ok());
+ EXPECT_TRUE(StringPiece(status_.error_message()).contains("Failed to unpack"))
+ << status_.error_message();
+}
+
+TEST_F(SessionBundleTest, AssetFileAny_IncorrectType) {
+ const string path = SetupExport([](MetaGraphDef* def) {
+ // Pack an unexpected type in the AssetFile Any.
+ (*def->mutable_collection_def())[kAssetsKey].clear_any_list();
+ auto* any = (*def->mutable_collection_def())[kAssetsKey]
+ .mutable_any_list()
+ ->add_value();
+ any->PackFrom(GraphDef());
+ });
+ status_ = LoadSessionBundleFromPath(options_, path, &bundle_);
+ EXPECT_FALSE(status_.ok());
+ EXPECT_TRUE(
+ StringPiece(status_.error_message())
+ .contains(
+ "Expected asset Any type_url for: tensorflow.serving.AssetFile"))
+ << status_.error_message();
+}
+
+TEST_F(SessionBundleTest, AssetFileAny_ValueCorrupted) {
+ const string path = SetupExport([](MetaGraphDef* def) {
+ // Pack an unexpected type in the AssetFile Any.
+ (*def->mutable_collection_def())[kAssetsKey].clear_any_list();
+ auto* any = (*def->mutable_collection_def())[kAssetsKey]
+ .mutable_any_list()
+ ->add_value();
+ any->PackFrom(AssetFile());
+ any->set_value("junk junk");
+ });
+ status_ = LoadSessionBundleFromPath(options_, path, &bundle_);
+ EXPECT_FALSE(status_.ok());
+ EXPECT_TRUE(StringPiece(status_.error_message()).contains("Failed to unpack"))
+ << status_.error_message();
+}
+
+TEST_F(SessionBundleTest, InitOp_TooManyValues) {
+ const string path = SetupExport([](MetaGraphDef* def) {
+ // Pack multiple init ops in to the collection.
+ (*def->mutable_collection_def())[kInitOpKey].clear_node_list();
+ auto* node_list =
+ (*def->mutable_collection_def())[kInitOpKey].mutable_node_list();
+ node_list->add_value("foo");
+ node_list->add_value("bar");
+ });
+ status_ = LoadSessionBundleFromPath(options_, path, &bundle_);
+ EXPECT_FALSE(status_.ok());
+ EXPECT_TRUE(StringPiece(status_.error_message())
+ .contains("Expected exactly one serving init op"))
+ << status_.error_message();
+}
+
} // namespace
} // namespace serving
} // namespace tensorflow
diff --git a/tensorflow/contrib/session_bundle/signature.cc b/tensorflow/contrib/session_bundle/signature.cc
index 50fea99ef2..04fb479523 100644
--- a/tensorflow/contrib/session_bundle/signature.cc
+++ b/tensorflow/contrib/session_bundle/signature.cc
@@ -48,14 +48,23 @@ Status BatchSizesMatch(const Tensor& input, const Tensor& output) {
Status GetSignatures(const tensorflow::MetaGraphDef& meta_graph_def,
Signatures* signatures) {
- auto collection_def = meta_graph_def.collection_def();
- auto any_list = collection_def[kSignaturesKey].any_list();
- if (any_list.value_size() != 1) {
+ const auto& collection_def = meta_graph_def.collection_def();
+ const auto it = collection_def.find(kSignaturesKey);
+ if (it == collection_def.end() || it->second.any_list().value_size() != 1) {
return errors::FailedPrecondition(
strings::StrCat("Expected exactly one signatures proto in : ",
meta_graph_def.DebugString()));
}
- any_list.value(0).UnpackTo(signatures);
+ const auto& any = it->second.any_list().value(0);
+ if (!any.Is<Signatures>()) {
+ return errors::FailedPrecondition(
+ "Expected signature Any type_url for: ",
+ signatures->descriptor()->full_name(), ". Got: ",
+ string(any.type_url().data(), any.type_url().size()), ".");
+ }
+ if (!any.UnpackTo(signatures)) {
+ return errors::FailedPrecondition("Failed to unpack: ", any.DebugString());
+ }
return Status::OK();
}
diff --git a/tensorflow/contrib/session_bundle/signature_test.cc b/tensorflow/contrib/session_bundle/signature_test.cc
index c890551217..10908cc649 100644
--- a/tensorflow/contrib/session_bundle/signature_test.cc
+++ b/tensorflow/contrib/session_bundle/signature_test.cc
@@ -478,6 +478,47 @@ TEST(SetAndGetSignatures, RoundTrip) {
.tensor_name());
}
+TEST(GetSignatures, MissingSignature) {
+ tensorflow::MetaGraphDef meta_graph_def;
+ Signatures read_signatures;
+ const auto status = GetSignatures(meta_graph_def, &read_signatures);
+ EXPECT_FALSE(status.ok());
+ EXPECT_TRUE(
+ StringPiece(status.error_message()).contains("Expected exactly one"))
+ << status.error_message();
+}
+
+TEST(GetSignatures, WrongProtoInAny) {
+ tensorflow::MetaGraphDef meta_graph_def;
+ auto& collection_def = *(meta_graph_def.mutable_collection_def());
+ auto* any =
+ collection_def[kSignaturesKey].mutable_any_list()->mutable_value()->Add();
+ // Put an unexpected type into the Signatures Any.
+ any->PackFrom(TensorBinding());
+ Signatures read_signatures;
+ const auto status = GetSignatures(meta_graph_def, &read_signatures);
+ EXPECT_FALSE(status.ok());
+ EXPECT_TRUE(StringPiece(status.error_message())
+ .contains("Expected signature Any type_url for: "
+ "tensorflow.serving.Signatures"))
+ << status.error_message();
+}
+
+TEST(GetSignatures, JunkInAny) {
+ tensorflow::MetaGraphDef meta_graph_def;
+ auto& collection_def = *(meta_graph_def.mutable_collection_def());
+ auto* any =
+ collection_def[kSignaturesKey].mutable_any_list()->mutable_value()->Add();
+ // Create a valid Any then corrupt it.
+ any->PackFrom(Signatures());
+ any->set_value("junk junk");
+ Signatures read_signatures;
+ const auto status = GetSignatures(meta_graph_def, &read_signatures);
+ EXPECT_FALSE(status.ok());
+ EXPECT_TRUE(StringPiece(status.error_message()).contains("Failed to unpack"))
+ << status.error_message();
+}
+
// GenericSignature test fixture that contains a signature initialized with two
// bound Tensors.
class GenericSignatureTest : public ::testing::Test {