diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-02-01 20:35:29 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-02-01 21:47:30 -0800 |
commit | d7fcf5a865570073569817fffafc07c8c74ec66d (patch) | |
tree | 0b75fadb4ca7c5edf1324ee2b43fe7aecfc0526e /tensorflow/cc | |
parent | c39d141948174b94213848d9b95541ee09af5e53 (diff) |
Automated g4 rollback of changelist 183874527
PiperOrigin-RevId: 184236409
Diffstat (limited to 'tensorflow/cc')
-rw-r--r-- | tensorflow/cc/saved_model/loader.cc | 4 | ||||
-rw-r--r-- | tensorflow/cc/saved_model/loader_test.cc | 18 |
2 files changed, 21 insertions, 1 deletions
diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index acef098c7d..faa1e378d0 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -96,7 +96,9 @@ Status FindMetaGraphDefToLoad(const SavedModel& saved_model_proto, Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def, const SessionOptions& session_options, std::unique_ptr<Session>* session) { - session->reset(NewSession(session_options)); + Session* session_p = nullptr; + TF_RETURN_IF_ERROR(NewSession(session_options, &session_p)); + session->reset(session_p); return (*session)->Create(meta_graph_def.graph_def()); } diff --git a/tensorflow/cc/saved_model/loader_test.cc b/tensorflow/cc/saved_model/loader_test.cc index 0ad6b33bba..4c64d2cfe3 100644 --- a/tensorflow/cc/saved_model/loader_test.cc +++ b/tensorflow/cc/saved_model/loader_test.cc @@ -155,6 +155,24 @@ TEST_F(LoaderTest, NoTagMatchMultiple) { << st.error_message(); } +TEST_F(LoaderTest, SessionCreationFailure) { + SavedModelBundle bundle; + // Use invalid SessionOptions to cause session creation to fail. Default + // options work, so provide an invalid value for the target field. + SessionOptions session_options; + constexpr char kInvalidTarget[] = "invalid target"; + session_options.target = kInvalidTarget; + RunOptions run_options; + + const string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded); + Status st = LoadSavedModel(session_options, run_options, export_dir, + {kSavedModelTagServe}, &bundle); + EXPECT_FALSE(st.ok()); + EXPECT_TRUE(StringPiece(st.error_message()).contains(kInvalidTarget)) + << st.error_message(); +} + TEST_F(LoaderTest, PbtxtFormat) { SavedModelBundle bundle; SessionOptions session_options; |