aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/saved_model
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-02-01 20:35:29 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-01 21:47:30 -0800
commitd7fcf5a865570073569817fffafc07c8c74ec66d (patch)
tree0b75fadb4ca7c5edf1324ee2b43fe7aecfc0526e /tensorflow/cc/saved_model
parentc39d141948174b94213848d9b95541ee09af5e53 (diff)
Automated g4 rollback of changelist 183874527
PiperOrigin-RevId: 184236409
Diffstat (limited to 'tensorflow/cc/saved_model')
-rw-r--r--tensorflow/cc/saved_model/loader.cc4
-rw-r--r--tensorflow/cc/saved_model/loader_test.cc18
2 files changed, 21 insertions, 1 deletions
diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc
index acef098c7d..faa1e378d0 100644
--- a/tensorflow/cc/saved_model/loader.cc
+++ b/tensorflow/cc/saved_model/loader.cc
@@ -96,7 +96,9 @@ Status FindMetaGraphDefToLoad(const SavedModel& saved_model_proto,
Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def,
const SessionOptions& session_options,
std::unique_ptr<Session>* session) {
- session->reset(NewSession(session_options));
+ Session* session_p = nullptr;
+ TF_RETURN_IF_ERROR(NewSession(session_options, &session_p));
+ session->reset(session_p);
return (*session)->Create(meta_graph_def.graph_def());
}
diff --git a/tensorflow/cc/saved_model/loader_test.cc b/tensorflow/cc/saved_model/loader_test.cc
index 0ad6b33bba..4c64d2cfe3 100644
--- a/tensorflow/cc/saved_model/loader_test.cc
+++ b/tensorflow/cc/saved_model/loader_test.cc
@@ -155,6 +155,24 @@ TEST_F(LoaderTest, NoTagMatchMultiple) {
<< st.error_message();
}
+TEST_F(LoaderTest, SessionCreationFailure) {
+ SavedModelBundle bundle;
+ // Use invalid SessionOptions to cause session creation to fail. Default
+ // options work, so provide an invalid value for the target field.
+ SessionOptions session_options;
+ constexpr char kInvalidTarget[] = "invalid target";
+ session_options.target = kInvalidTarget;
+ RunOptions run_options;
+
+ const string export_dir =
+ io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
+ Status st = LoadSavedModel(session_options, run_options, export_dir,
+ {kSavedModelTagServe}, &bundle);
+ EXPECT_FALSE(st.ok());
+ EXPECT_TRUE(StringPiece(st.error_message()).contains(kInvalidTarget))
+ << st.error_message();
+}
+
TEST_F(LoaderTest, PbtxtFormat) {
SavedModelBundle bundle;
SessionOptions session_options;