From 88eb6c61ef7659c2b5bb1ec6586c7d3cca5e4e9c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 30 Jan 2018 09:57:16 -0800 Subject: TensorFlow SavedModel loader: avoid segmentation fault when NewSession returns null PiperOrigin-RevId: 183846994 --- tensorflow/cc/saved_model/loader.cc | 4 +++- tensorflow/cc/saved_model/loader_test.cc | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) (limited to 'tensorflow/cc/saved_model') diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index acef098c7d..faa1e378d0 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -96,7 +96,9 @@ Status FindMetaGraphDefToLoad(const SavedModel& saved_model_proto, Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def, const SessionOptions& session_options, std::unique_ptr* session) { - session->reset(NewSession(session_options)); + Session* session_p = nullptr; + TF_RETURN_IF_ERROR(NewSession(session_options, &session_p)); + session->reset(session_p); return (*session)->Create(meta_graph_def.graph_def()); } diff --git a/tensorflow/cc/saved_model/loader_test.cc b/tensorflow/cc/saved_model/loader_test.cc index 0ad6b33bba..4c64d2cfe3 100644 --- a/tensorflow/cc/saved_model/loader_test.cc +++ b/tensorflow/cc/saved_model/loader_test.cc @@ -155,6 +155,24 @@ TEST_F(LoaderTest, NoTagMatchMultiple) { << st.error_message(); } +TEST_F(LoaderTest, SessionCreationFailure) { + SavedModelBundle bundle; + // Use invalid SessionOptions to cause session creation to fail. Default + // options work, so provide an invalid value for the target field. + SessionOptions session_options; + constexpr char kInvalidTarget[] = "invalid target"; + session_options.target = kInvalidTarget; + RunOptions run_options; + + const string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded); + Status st = LoadSavedModel(session_options, run_options, export_dir, + {kSavedModelTagServe}, &bundle); + EXPECT_FALSE(st.ok()); + EXPECT_TRUE(StringPiece(st.error_message()).contains(kInvalidTarget)) + << st.error_message(); +} + TEST_F(LoaderTest, PbtxtFormat) { SavedModelBundle bundle; SessionOptions session_options; -- cgit v1.2.3