diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-01-30 09:57:16 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-01-30 10:11:06 -0800 |
commit | 88eb6c61ef7659c2b5bb1ec6586c7d3cca5e4e9c (patch) | |
tree | 9fe3980894fa336bf0c96c4e54c924af6dac3b9b /tensorflow/cc | |
parent | 39bc42ebcf0df005b378fa88a4650a5bebb1eb0c (diff) |
TensorFlow SavedModel loader: avoid segmentation fault when NewSession returns null
PiperOrigin-RevId: 183846994
Diffstat (limited to 'tensorflow/cc')
-rw-r--r-- | tensorflow/cc/saved_model/loader.cc | 4 | ||||
-rw-r--r-- | tensorflow/cc/saved_model/loader_test.cc | 18 |
2 files changed, 21 insertions, 1 deletions
diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index acef098c7d..faa1e378d0 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -96,7 +96,9 @@ Status FindMetaGraphDefToLoad(const SavedModel& saved_model_proto, Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def, const SessionOptions& session_options, std::unique_ptr<Session>* session) { - session->reset(NewSession(session_options)); + Session* session_p = nullptr; + TF_RETURN_IF_ERROR(NewSession(session_options, &session_p)); + session->reset(session_p); return (*session)->Create(meta_graph_def.graph_def()); } diff --git a/tensorflow/cc/saved_model/loader_test.cc b/tensorflow/cc/saved_model/loader_test.cc index 0ad6b33bba..4c64d2cfe3 100644 --- a/tensorflow/cc/saved_model/loader_test.cc +++ b/tensorflow/cc/saved_model/loader_test.cc @@ -155,6 +155,24 @@ TEST_F(LoaderTest, NoTagMatchMultiple) { << st.error_message(); } +TEST_F(LoaderTest, SessionCreationFailure) { + SavedModelBundle bundle; + // Use invalid SessionOptions to cause session creation to fail. Default + // options work, so provide an invalid value for the target field. + SessionOptions session_options; + constexpr char kInvalidTarget[] = "invalid target"; + session_options.target = kInvalidTarget; + RunOptions run_options; + + const string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded); + Status st = LoadSavedModel(session_options, run_options, export_dir, + {kSavedModelTagServe}, &bundle); + EXPECT_FALSE(st.ok()); + EXPECT_TRUE(StringPiece(st.error_message()).contains(kInvalidTarget)) + << st.error_message(); +} + TEST_F(LoaderTest, PbtxtFormat) { SavedModelBundle bundle; SessionOptions session_options; |