diff options
-rw-r--r-- | tensorflow/contrib/session_bundle/session_bundle.cc | 34 | ||||
-rw-r--r-- | tensorflow/contrib/session_bundle/session_bundle_test.cc | 190 | ||||
-rw-r--r-- | tensorflow/contrib/session_bundle/signature.cc | 17 | ||||
-rw-r--r-- | tensorflow/contrib/session_bundle/signature_test.cc | 41 |
4 files changed, 266 insertions, 16 deletions
diff --git a/tensorflow/contrib/session_bundle/session_bundle.cc b/tensorflow/contrib/session_bundle/session_bundle.cc index 5029dcd7fe..6895c094be 100644 --- a/tensorflow/contrib/session_bundle/session_bundle.cc +++ b/tensorflow/contrib/session_bundle/session_bundle.cc @@ -91,7 +91,7 @@ string GetVariablesFilename(const StringPiece export_dir) { const char kVariablesFilename[] = "export"; const char kVariablesFilenamePattern[] = "export-\?\?\?\?\?-of-\?\?\?\?\?"; if (Env::Default()->FileExists( - tensorflow::io::JoinPath(export_dir, kVariablesFilename))) { + tensorflow::io::JoinPath(export_dir, kVariablesFilename))) { return tensorflow::io::JoinPath(export_dir, kVariablesFilename); } else { return tensorflow::io::JoinPath(export_dir, kVariablesFilenamePattern); @@ -104,8 +104,8 @@ Status RunRestoreOp(const StringPiece export_dir, const StringPiece variables_filename_const_op_name, tensorflow::Session* session) { LOG(INFO) << "Running restore op for SessionBundle"; - Tensor variables_tensor = CreateStringTensor( - GetVariablesFilename(export_dir)); + Tensor variables_tensor = + CreateStringTensor(GetVariablesFilename(export_dir)); std::vector<std::pair<string, Tensor>> inputs = { {variables_filename_const_op_name.ToString(), variables_tensor}}; AddAssetsTensorsToInputs(export_dir, asset_files, &inputs); @@ -137,11 +137,21 @@ tensorflow::Status LoadSessionBundleFromPath( // Use serving graph_def in MetaGraphDef collection_def. if (graph_collection_def.any_list().value_size() != 1) { return errors::FailedPrecondition( - strings::StrCat("Expected exactly one serving GraphDef in : ", - bundle->meta_graph_def.DebugString())); + "Expected exactly one serving GraphDef in : ", + bundle->meta_graph_def.DebugString()); + } + const auto& any = graph_collection_def.any_list().value(0); + if (!any.Is<GraphDef>()) { + return errors::FailedPrecondition( + "Expected Any type_url for: ", + tensorflow::GraphDef::default_instance().descriptor()->full_name(), + ". Got: ", string(any.type_url().data(), any.type_url().size()), "."); } tensorflow::GraphDef graph_def; - graph_collection_def.any_list().value(0).UnpackTo(&graph_def); + if (!any.UnpackTo(&graph_def)) { + return errors::FailedPrecondition("Failed to unpack: ", + any.DebugString()); + } TF_RETURN_IF_ERROR( CreateSessionFromGraphDef(options, graph_def, &bundle->session)); } else { @@ -157,7 +167,17 @@ tensorflow::Status LoadSessionBundleFromPath( const auto& any_assets = assets_it->second.any_list().value(); for (const auto& any_asset : any_assets) { AssetFile asset_file; - any_asset.UnpackTo(&asset_file); + if (!any_asset.Is<AssetFile>()) { + return errors::FailedPrecondition( + "Expected asset Any type_url for: ", + asset_file.descriptor()->full_name(), ". Got: ", + string(any_asset.type_url().data(), any_asset.type_url().size()), + "."); + } + if (!any_asset.UnpackTo(&asset_file)) { + return errors::FailedPrecondition("Failed to unpack: ", + any_asset.DebugString()); + } asset_files.push_back(asset_file); } } diff --git a/tensorflow/contrib/session_bundle/session_bundle_test.cc b/tensorflow/contrib/session_bundle/session_bundle_test.cc index 5aeb344524..5ecbb32f88 100644 --- a/tensorflow/contrib/session_bundle/session_bundle_test.cc +++ b/tensorflow/contrib/session_bundle/session_bundle_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include <vector> #include "google/protobuf/any.pb.h" +#include "tensorflow/contrib/session_bundle/signature.h" #include "tensorflow/contrib/session_bundle/test_util.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -36,9 +37,48 @@ namespace tensorflow { namespace serving { namespace { -TEST(LoadSessionBundleFromPath, Basic) { - const string export_path = test_util::TestSrcDirPath( +// Constants for the export file names. +const char kVariablesFilename[] = "export-00000-of-00001"; +const char kMetaGraphDefFilename[] = "export.meta"; + +// Function used to rewrite a MetaGraphDef. +using MetaGraphDefTwiddler = std::function<void(MetaGraphDef*)>; + +// Copy the base half_plus_two to `export_path`. +// Outputs the files using the passed names (typically the constants above). +// The Twiddler can be used to update the MetaGraphDef before output. +Status CopyExport(const string& export_path, const string& variables_filename, + const string& meta_graph_def_filename, + const MetaGraphDefTwiddler& twiddler) { + TF_RETURN_IF_ERROR(Env::Default()->CreateDir(export_path)); + const string orig_path = test_util::TestSrcDirPath( "session_bundle/example/half_plus_two/00000123"); + { + const string source = + tensorflow::io::JoinPath(orig_path, kVariablesFilename); + const string sink = + tensorflow::io::JoinPath(export_path, variables_filename); + + string data; + TF_RETURN_IF_ERROR(ReadFileToString(Env::Default(), source, &data)); + TF_RETURN_IF_ERROR(WriteStringToFile(Env::Default(), sink, data)); + } + { + const string source = + tensorflow::io::JoinPath(orig_path, kMetaGraphDefFilename); + const string sink = + tensorflow::io::JoinPath(export_path, meta_graph_def_filename); + + tensorflow::MetaGraphDef graph_def; + TF_RETURN_IF_ERROR(ReadBinaryProto(Env::Default(), source, &graph_def)); + twiddler(&graph_def); + TF_RETURN_IF_ERROR( + WriteStringToFile(Env::Default(), sink, graph_def.SerializeAsString())); + } + return Status::OK(); +} + +void BasicTest(const string& export_path) { tensorflow::SessionOptions options; SessionBundle bundle; TF_ASSERT_OK(LoadSessionBundleFromPath(options, export_path, &bundle)); @@ -67,10 +107,8 @@ TEST(LoadSessionBundleFromPath, Basic) { Tensor input = test::AsTensor<float>({0, 1, 2, 3}, TensorShape({4, 1})); // Recover the Tensor names of our inputs and outputs. - auto collection_def = bundle.meta_graph_def.collection_def(); Signatures signatures; - ASSERT_EQ(1, collection_def[kSignaturesKey].any_list().value_size()); - collection_def[kSignaturesKey].any_list().value(0).UnpackTo(&signatures); + TF_ASSERT_OK(GetSignatures(bundle.meta_graph_def, &signatures)); ASSERT_TRUE(signatures.default_signature().has_regression_signature()); const tensorflow::serving::RegressionSignature regression_signature = signatures.default_signature().regression_signature(); @@ -86,6 +124,12 @@ TEST(LoadSessionBundleFromPath, Basic) { outputs[0], test::AsTensor<float>({2, 2.5, 3, 3.5}, TensorShape({4, 1}))); } +TEST(LoadSessionBundleFromPath, BasicTensorflowContrib) { + const string export_path = test_util::TestSrcDirPath( + "session_bundle/example/half_plus_two/00000123"); + BasicTest(export_path); +} + TEST(LoadSessionBundleFromPath, BadExportPath) { const string export_path = test_util::TestSrcDirPath("/tmp/bigfoot"); tensorflow::SessionOptions options; @@ -97,6 +141,142 @@ TEST(LoadSessionBundleFromPath, BadExportPath) { EXPECT_TRUE(msg.find("Not found") != std::string::npos) << msg; } +class SessionBundleTest : public ::testing::Test { + protected: + // Copy the half_plus_two graph and apply the twiddler to rewrite the + // MetaGraphDef. + // Returns the path of the export. + // ** Should only be called once per test ** + string SetupExport(MetaGraphDefTwiddler twiddler) { + return SetupExport(twiddler, kVariablesFilename, kMetaGraphDefFilename); + } + // SetupExport that allows for the variables and meta_graph_def filenames + // to be overridden. + string SetupExport(MetaGraphDefTwiddler twiddler, + const string& variables_filename, + const string& meta_graph_def_filename) { + // Construct a unique path name based on the test name. + const ::testing::TestInfo* const test_info = + ::testing::UnitTest::GetInstance()->current_test_info(); + const string export_path = tensorflow::io::JoinPath( + testing::TmpDir(), + strings::StrCat(test_info->test_case_name(), test_info->name())); + TF_CHECK_OK(CopyExport(export_path, variables_filename, + meta_graph_def_filename, twiddler)); + return export_path; + } + + tensorflow::SessionOptions options_; + SessionBundle bundle_; + Status status_; +}; + +TEST_F(SessionBundleTest, Basic) { + const string export_path = SetupExport([](MetaGraphDef*) {}); + BasicTest(export_path); +} + +TEST_F(SessionBundleTest, UnshardedVariableFile) { + // Test that we can properly read the variables when exported + // without sharding. + const string export_path = + SetupExport([](MetaGraphDef*) {}, "export", kMetaGraphDefFilename); + BasicTest(export_path); +} + +TEST_F(SessionBundleTest, ServingGraph_Empty) { + const string path = SetupExport([](MetaGraphDef* def) { + (*def->mutable_collection_def())[kGraphKey].clear_any_list(); + }); + status_ = LoadSessionBundleFromPath(options_, path, &bundle_); + EXPECT_FALSE(status_.ok()); + EXPECT_TRUE(StringPiece(status_.error_message()) + .contains("Expected exactly one serving GraphDef")) + << status_.error_message(); +} + +TEST_F(SessionBundleTest, ServingGraphAny_IncorrectType) { + const string path = SetupExport([](MetaGraphDef* def) { + // Pack an unexpected type in the GraphDef Any. + (*def->mutable_collection_def())[kGraphKey].clear_any_list(); + auto* any = (*def->mutable_collection_def())[kGraphKey] + .mutable_any_list() + ->add_value(); + any->PackFrom(AssetFile()); + }); + status_ = LoadSessionBundleFromPath(options_, path, &bundle_); + EXPECT_FALSE(status_.ok()); + EXPECT_TRUE(StringPiece(status_.error_message()) + .contains("Expected Any type_url for: tensorflow.GraphDef")) + << status_.error_message(); +} + +TEST_F(SessionBundleTest, ServingGraphAnyValue_Corrupted) { + const string path = SetupExport([](MetaGraphDef* def) { + // Pack an unexpected type in the GraphDef Any. + (*def->mutable_collection_def())[kGraphKey].clear_any_list(); + auto* any = (*def->mutable_collection_def())[kGraphKey] + .mutable_any_list() + ->add_value(); + any->PackFrom(GraphDef()); + any->set_value("junk junk"); + }); + status_ = LoadSessionBundleFromPath(options_, path, &bundle_); + EXPECT_FALSE(status_.ok()); + EXPECT_TRUE(StringPiece(status_.error_message()).contains("Failed to unpack")) + << status_.error_message(); +} + +TEST_F(SessionBundleTest, AssetFileAny_IncorrectType) { + const string path = SetupExport([](MetaGraphDef* def) { + // Pack an unexpected type in the AssetFile Any. + (*def->mutable_collection_def())[kAssetsKey].clear_any_list(); + auto* any = (*def->mutable_collection_def())[kAssetsKey] + .mutable_any_list() + ->add_value(); + any->PackFrom(GraphDef()); + }); + status_ = LoadSessionBundleFromPath(options_, path, &bundle_); + EXPECT_FALSE(status_.ok()); + EXPECT_TRUE( + StringPiece(status_.error_message()) + .contains( + "Expected asset Any type_url for: tensorflow.serving.AssetFile")) + << status_.error_message(); +} + +TEST_F(SessionBundleTest, AssetFileAny_ValueCorrupted) { + const string path = SetupExport([](MetaGraphDef* def) { + // Pack an unexpected type in the AssetFile Any. + (*def->mutable_collection_def())[kAssetsKey].clear_any_list(); + auto* any = (*def->mutable_collection_def())[kAssetsKey] + .mutable_any_list() + ->add_value(); + any->PackFrom(AssetFile()); + any->set_value("junk junk"); + }); + status_ = LoadSessionBundleFromPath(options_, path, &bundle_); + EXPECT_FALSE(status_.ok()); + EXPECT_TRUE(StringPiece(status_.error_message()).contains("Failed to unpack")) + << status_.error_message(); +} + +TEST_F(SessionBundleTest, InitOp_TooManyValues) { + const string path = SetupExport([](MetaGraphDef* def) { + // Pack multiple init ops in to the collection. + (*def->mutable_collection_def())[kInitOpKey].clear_node_list(); + auto* node_list = + (*def->mutable_collection_def())[kInitOpKey].mutable_node_list(); + node_list->add_value("foo"); + node_list->add_value("bar"); + }); + status_ = LoadSessionBundleFromPath(options_, path, &bundle_); + EXPECT_FALSE(status_.ok()); + EXPECT_TRUE(StringPiece(status_.error_message()) + .contains("Expected exactly one serving init op")) + << status_.error_message(); +} + } // namespace } // namespace serving } // namespace tensorflow diff --git a/tensorflow/contrib/session_bundle/signature.cc b/tensorflow/contrib/session_bundle/signature.cc index 50fea99ef2..04fb479523 100644 --- a/tensorflow/contrib/session_bundle/signature.cc +++ b/tensorflow/contrib/session_bundle/signature.cc @@ -48,14 +48,23 @@ Status BatchSizesMatch(const Tensor& input, const Tensor& output) { Status GetSignatures(const tensorflow::MetaGraphDef& meta_graph_def, Signatures* signatures) { - auto collection_def = meta_graph_def.collection_def(); - auto any_list = collection_def[kSignaturesKey].any_list(); - if (any_list.value_size() != 1) { + const auto& collection_def = meta_graph_def.collection_def(); + const auto it = collection_def.find(kSignaturesKey); + if (it == collection_def.end() || it->second.any_list().value_size() != 1) { return errors::FailedPrecondition( strings::StrCat("Expected exactly one signatures proto in : ", meta_graph_def.DebugString())); } - any_list.value(0).UnpackTo(signatures); + const auto& any = it->second.any_list().value(0); + if (!any.Is<Signatures>()) { + return errors::FailedPrecondition( + "Expected signature Any type_url for: ", + signatures->descriptor()->full_name(), ". Got: ", + string(any.type_url().data(), any.type_url().size()), "."); + } + if (!any.UnpackTo(signatures)) { + return errors::FailedPrecondition("Failed to unpack: ", any.DebugString()); + } return Status::OK(); } diff --git a/tensorflow/contrib/session_bundle/signature_test.cc b/tensorflow/contrib/session_bundle/signature_test.cc index c890551217..10908cc649 100644 --- a/tensorflow/contrib/session_bundle/signature_test.cc +++ b/tensorflow/contrib/session_bundle/signature_test.cc @@ -478,6 +478,47 @@ TEST(SetAndGetSignatures, RoundTrip) { .tensor_name()); } +TEST(GetSignatures, MissingSignature) { + tensorflow::MetaGraphDef meta_graph_def; + Signatures read_signatures; + const auto status = GetSignatures(meta_graph_def, &read_signatures); + EXPECT_FALSE(status.ok()); + EXPECT_TRUE( + StringPiece(status.error_message()).contains("Expected exactly one")) + << status.error_message(); +} + +TEST(GetSignatures, WrongProtoInAny) { + tensorflow::MetaGraphDef meta_graph_def; + auto& collection_def = *(meta_graph_def.mutable_collection_def()); + auto* any = + collection_def[kSignaturesKey].mutable_any_list()->mutable_value()->Add(); + // Put an unexpected type into the Signatures Any. + any->PackFrom(TensorBinding()); + Signatures read_signatures; + const auto status = GetSignatures(meta_graph_def, &read_signatures); + EXPECT_FALSE(status.ok()); + EXPECT_TRUE(StringPiece(status.error_message()) + .contains("Expected signature Any type_url for: " + "tensorflow.serving.Signatures")) + << status.error_message(); +} + +TEST(GetSignatures, JunkInAny) { + tensorflow::MetaGraphDef meta_graph_def; + auto& collection_def = *(meta_graph_def.mutable_collection_def()); + auto* any = + collection_def[kSignaturesKey].mutable_any_list()->mutable_value()->Add(); + // Create a valid Any then corrupt it. + any->PackFrom(Signatures()); + any->set_value("junk junk"); + Signatures read_signatures; + const auto status = GetSignatures(meta_graph_def, &read_signatures); + EXPECT_FALSE(status.ok()); + EXPECT_TRUE(StringPiece(status.error_message()).contains("Failed to unpack")) + << status.error_message(); +} + // GenericSignature test fixture that contains a signature initialized with two // bound Tensors. class GenericSignatureTest : public ::testing::Test { |