aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/cc')
-rw-r--r--tensorflow/cc/gradients/nn_grad.cc75
-rw-r--r--tensorflow/cc/gradients/nn_grad_test.cc25
-rw-r--r--tensorflow/cc/saved_model/BUILD41
-rw-r--r--tensorflow/cc/saved_model/loader.cc70
-rw-r--r--tensorflow/cc/saved_model/reader.cc88
-rw-r--r--tensorflow/cc/saved_model/reader.h39
-rw-r--r--tensorflow/cc/saved_model/reader_test.cc108
7 files changed, 324 insertions, 122 deletions
diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc
index dc6477e59d..588e96cb19 100644
--- a/tensorflow/cc/gradients/nn_grad.cc
+++ b/tensorflow/cc/gradients/nn_grad.cc
@@ -47,37 +47,28 @@ Status SoftmaxGrad(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("Softmax", SoftmaxGrad);
-bool IsZero(const Scope& scope, Output grad) {
- std::array<std::string, 2> zero_op_type_names{{"ZerosLike", "Zeros"}};
+bool IsZero(const Scope& scope, const Output& grad) {
string op_type_name = grad.op().node()->type_string();
- for (auto& zero_op_type_name: zero_op_type_names) {
- if (op_type_name == zero_op_type_name) {
- return true;
- }
+ if (op_type_name == "ZerosLike" || op_type_name == "Zeros") {
+ return true;
}
- // the Operation we were provided is not named something obvious
+ // The Operation we were provided is not named something obvious so
// we need to actually look at its contents.
- // the original python code did this by calling a utility function called
- // tensor_util.constant_value. When you dig into tensor_tuil.constant_value
- // it is a large number of 'if' statements that measure certain edge cases
- // where it is possible to get the value of the tensor without actually
- // evaluating it. There are many kinds of tensors that can not have this
- // done.
+ // The original python code did this by calling a utility function called
+ // tensor_util.constant_value.
// There is no C++ equivalent to tensor_util.constant_value so we do nothing
// for the moment.
return false;
}
-Output BroadcastMul(const Scope& scope, Output vec, Output mat) {
- /* Multiply after broadcasting vec to match dimensions of mat.
- Args:
- vec: A 1-D tensor of dimension [D0]
- mat: A 2-D tensor of dimesnion [D0, D1]
-
- Returns:
- A tensor of dimension [D0, D1], the result fo vec * mat
- we use an element for element multiply here.
- */
+// Multiply after broadcasting vec to match dimensions of mat.
+// Args:
+// vec: A 1-D tensor of dimension [D0]
+// mat: A 2-D tensor of dimesnion [D0, D1]
+//
+// Returns:
+// A tensor of dimension [D0, D1], the result fo vec * mat.
+Output BroadcastMul(const Scope& scope, const Output& vec, const Output& mat) {
auto reshaped = ExpandDims(scope, vec, -1);
return Multiply(scope, reshaped, mat);
}
@@ -86,37 +77,37 @@ Status SoftmaxCrossEntropyWithLogitsGrad(const Scope& scope,
const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
- // Softmax gradient with cross entropy logits function
- // We multiply the backprop for cost with the gradients - op.output[1]
- // There is no gradient for labels
- auto logits =
- op.input(0); // the outputs of the network are at
- // input index 0. The "truth" labels are at index 1.
+ // Softmax gradient with cross entropy logits function.
+ // We multiply the backprop for cost with the gradients - op.output[1].
+ // There is no gradient for labels.
+
+ // The outputs of the network are at input index 0.
+ auto logits = op.input(0);
+ // The "truth" labels are at index 1.
auto softmax_grad = op.output(1);
- // The documentation for ops::SoftmaxCrossEntropyWithLogits says
- // loss is the output at index 0, and backprop is the output at index 1
+ // The loss is the output at index 0, and backprop is the output at index 1.
auto grad_loss = grad_inputs[0];
auto grad_grad = grad_inputs[1];
auto grad = BroadcastMul(scope, grad_loss, softmax_grad);
if (!IsZero(scope, grad_grad)) {
std::vector<int> axis;
- auto logitsSoftmax = Softmax(scope, logits);
+ auto logits_softmax = Softmax(scope, logits);
- auto grad_gradExpand = ExpandDims(scope, grad_grad, 1);
- auto logitsSoftMaxExpand = ExpandDims(scope, logitsSoftmax, 2);
- auto matMulResult =
- BatchMatMul(scope, grad_gradExpand, logitsSoftMaxExpand);
+ auto grad_grad_expand = ExpandDims(scope, grad_grad, 1);
+ auto logits_softmax_expand = ExpandDims(scope, logits_softmax, 2);
+ auto matmul_result =
+ BatchMatMul(scope, grad_grad_expand, logits_softmax_expand);
axis.push_back(1);
- auto squeezeResult = Squeeze(scope, matMulResult, Squeeze::Axis(axis));
- auto subtractionResult = Subtract(scope, grad_grad, squeezeResult);
- auto multiplyResult = Multiply(scope, subtractionResult, logitsSoftmax);
- grad = Add(scope, grad, multiplyResult);
+ auto squeeze_result = Squeeze(scope, matmul_result, Squeeze::Axis(axis));
+ auto subtraction_result = Subtract(scope, grad_grad, squeeze_result);
+ auto multiply_result = Multiply(scope, subtraction_result, logits_softmax);
+ grad = Add(scope, grad, multiply_result);
}
- auto minusLogSoftmax = Multiply(scope, LogSoftmax(scope, logits), -1.0f);
+ auto minus_log_softmax = Multiply(scope, LogSoftmax(scope, logits), -1.0f);
grad_outputs->push_back(grad);
- grad_outputs->push_back(BroadcastMul(scope, grad_loss, minusLogSoftmax));
+ grad_outputs->push_back(BroadcastMul(scope, grad_loss, minus_log_softmax));
return scope.status();
}
REGISTER_GRADIENT_OP("SoftmaxCrossEntropyWithLogits",
diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc
index f26a7e99e6..aa72cf7ba2 100644
--- a/tensorflow/cc/gradients/nn_grad_test.cc
+++ b/tensorflow/cc/gradients/nn_grad_test.cc
@@ -112,24 +112,17 @@ TEST_F(NNGradTest, SoftmaxGrad) {
}
TEST_F(NNGradTest, SoftmaxCrossEntropyWithLogitsGrad) {
- TensorShape logitsShape(
- {5, 3}); // batch size of 5,3 possible labels (classes),
- // logits is what is produced by a network
- // they are compared to labels which are the truth
- TensorShape lossShape(
- {5}); // batch size of 5, 1 value for each entry in the batch
- // loss is the difference between logits and labels
-
- auto logits = Placeholder(scope_, DT_FLOAT,
- Placeholder::Shape(logitsShape)); // estimation
- auto labels =
- Placeholder(scope_, DT_FLOAT, Placeholder::Shape(logitsShape)); // truth
+ TensorShape logits_shape({5, 3});
+ TensorShape loss_shape({5});
+
+ auto logits = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(logits_shape));
+ auto labels = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(logits_shape));
auto y =
tensorflow::ops::SoftmaxCrossEntropyWithLogits(scope_, logits, labels);
- // Please note the reversal of the backprop and loss orders. A separate issue
- // #18734 has been opened for this.
- RunTest({logits, labels}, {logitsShape, logitsShape}, {y.backprop, y.loss},
- {logitsShape, lossShape});
+ // Note the reversal of the backprop and loss orders. Issue #18734 has been
+ // opened for this.
+ RunTest({logits, labels}, {logits_shape, logits_shape}, {y.backprop, y.loss},
+ {logits_shape, loss_shape});
}
TEST_F(NNGradTest, LogSoftmaxGrad) {
diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD
index 06a3be18e0..3d3895c8fa 100644
--- a/tensorflow/cc/saved_model/BUILD
+++ b/tensorflow/cc/saved_model/BUILD
@@ -34,6 +34,46 @@ cc_library(
)
cc_library(
+ name = "reader",
+ srcs = ["reader.cc"],
+ hdrs = ["reader.h"],
+ deps = [
+ ":constants",
+ ] + if_not_mobile([
+ # TODO(b/111634734): :lib and :protos_all contain dependencies that
+ # cannot be built on mobile platforms. Instead, include the appropriate
+ # tf_lib depending on the build platform.
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ ]) + if_mobile([
+ # Mobile-friendly SavedModel proto. See go/portable-proto for more info.
+ "//tensorflow/core:saved_model_portable_proto",
+ ]) + if_android([
+ "//tensorflow/core:android_tensorflow_lib",
+ ]) + if_ios([
+ "//tensorflow/core:ios_tensorflow_lib",
+ ]),
+)
+
+tf_cc_test(
+ name = "reader_test",
+ srcs = ["reader_test.cc"],
+ data = [
+ ":saved_model_half_plus_two",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":constants",
+ ":reader",
+ ":tag_constants",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
+cc_library(
name = "loader",
hdrs = ["loader.h"],
deps = [
@@ -54,6 +94,7 @@ cc_library(
hdrs = ["loader.h"],
deps = [
":constants",
+ ":reader",
] + if_not_mobile([
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc
index faa1e378d0..07807ed2f3 100644
--- a/tensorflow/cc/saved_model/loader.cc
+++ b/tensorflow/cc/saved_model/loader.cc
@@ -18,8 +18,10 @@ limitations under the License.
#include <unordered_set>
#include "tensorflow/cc/saved_model/constants.h"
+#include "tensorflow/cc/saved_model/reader.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/monitoring/counter.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/protobuf_internal.h"
@@ -43,56 +45,6 @@ auto* load_latency = monitoring::Counter<1>::New(
constexpr char kLoadAttemptFail[] = "fail";
constexpr char kLoadAttemptSuccess[] = "success";
-Status ReadSavedModel(const string& export_dir, SavedModel* saved_model_proto) {
- const string saved_model_pb_path =
- io::JoinPath(export_dir, kSavedModelFilenamePb);
- if (Env::Default()->FileExists(saved_model_pb_path).ok()) {
- return ReadBinaryProto(Env::Default(), saved_model_pb_path,
- saved_model_proto);
- }
- const string saved_model_pbtxt_path =
- io::JoinPath(export_dir, kSavedModelFilenamePbTxt);
- if (Env::Default()->FileExists(saved_model_pbtxt_path).ok()) {
- return ReadTextProto(Env::Default(), saved_model_pbtxt_path,
- saved_model_proto);
- }
- return Status(error::Code::NOT_FOUND,
- "Could not find SavedModel .pb or .pbtxt at supplied export "
- "directory path: " +
- export_dir);
-}
-
-string GetTagsAsString(const std::unordered_set<string>& tags) {
- string tags_as_string = "{ ";
- for (const string& tag : tags) {
- tags_as_string = strings::StrCat(tags_as_string, tag, " ");
- }
- tags_as_string = strings::StrCat(tags_as_string, "}");
- return tags_as_string;
-}
-
-Status FindMetaGraphDefToLoad(const SavedModel& saved_model_proto,
- const std::unordered_set<string>& tags,
- MetaGraphDef* meta_graph_def_to_load) {
- for (const MetaGraphDef& meta_graph_def : saved_model_proto.meta_graphs()) {
- // Get tags from the meta_graph_def.
- std::unordered_set<string> graph_tags;
- for (const string& tag : meta_graph_def.meta_info_def().tags()) {
- graph_tags.insert(tag);
- }
- // Match with the set of tags provided.
- if (graph_tags == tags) {
- *meta_graph_def_to_load = meta_graph_def;
- return Status::OK();
- }
- }
- return Status(error::Code::NOT_FOUND,
- "Could not find meta graph def matching supplied tags: " +
- GetTagsAsString(tags) +
- ". To inspect available tag-sets in the SavedModel, please "
- "use the SavedModel CLI: `saved_model_cli`");
-}
-
Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def,
const SessionOptions& session_options,
std::unique_ptr<Session>* session) {
@@ -235,18 +187,8 @@ Status LoadSavedModelInternal(const SessionOptions& session_options,
const string& export_dir,
const std::unordered_set<string>& tags,
SavedModelBundle* const bundle) {
- if (!MaybeSavedModelDirectory(export_dir)) {
- return Status(error::Code::NOT_FOUND,
- "SavedModel not found in export directory: " + export_dir);
- }
- LOG(INFO) << "Loading SavedModel with tags: " << GetTagsAsString(tags)
- << "; from: " << export_dir;
-
- SavedModel saved_model_proto;
- TF_RETURN_IF_ERROR(ReadSavedModel(export_dir, &saved_model_proto));
-
- TF_RETURN_IF_ERROR(
- FindMetaGraphDefToLoad(saved_model_proto, tags, &bundle->meta_graph_def));
+ TF_RETURN_IF_ERROR(ReadMetaGraphDefFromSavedModel(export_dir, tags,
+ &bundle->meta_graph_def));
TF_RETURN_IF_ERROR(LoadMetaGraphIntoSession(
bundle->meta_graph_def, session_options, &bundle->session));
@@ -288,8 +230,8 @@ Status LoadSavedModel(const SessionOptions& session_options,
return end_microseconds - start_microseconds;
}();
auto log_and_count = [&](const string& status_str) {
- LOG(INFO) << "SavedModel load for tags " << GetTagsAsString(tags)
- << "; Status: " << status_str << ". Took "
+ LOG(INFO) << "SavedModel load for tags { " << str_util::Join(tags, " ")
+ << " }; Status: " << status_str << ". Took "
<< load_latency_microsecs << " microseconds.";
load_attempt_count->GetCell(export_dir, status_str)->IncrementBy(1);
};
diff --git a/tensorflow/cc/saved_model/reader.cc b/tensorflow/cc/saved_model/reader.cc
new file mode 100644
index 0000000000..2146c8a197
--- /dev/null
+++ b/tensorflow/cc/saved_model/reader.cc
@@ -0,0 +1,88 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/saved_model/reader.h"
+
+#include <unordered_set>
+
+#include "tensorflow/cc/saved_model/constants.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/protobuf/saved_model.pb.h"
+
+namespace tensorflow {
+namespace {
+
+Status ReadSavedModel(const string& export_dir, SavedModel* saved_model_proto) {
+ LOG(INFO) << "Reading SavedModel from: " << export_dir;
+
+ const string saved_model_pb_path =
+ io::JoinPath(export_dir, kSavedModelFilenamePb);
+ if (Env::Default()->FileExists(saved_model_pb_path).ok()) {
+ return ReadBinaryProto(Env::Default(), saved_model_pb_path,
+ saved_model_proto);
+ }
+ const string saved_model_pbtxt_path =
+ io::JoinPath(export_dir, kSavedModelFilenamePbTxt);
+ if (Env::Default()->FileExists(saved_model_pbtxt_path).ok()) {
+ return ReadTextProto(Env::Default(), saved_model_pbtxt_path,
+ saved_model_proto);
+ }
+ return Status(error::Code::NOT_FOUND,
+ "Could not find SavedModel .pb or .pbtxt at supplied export "
+ "directory path: " +
+ export_dir);
+}
+
+Status FindMetaGraphDef(const SavedModel& saved_model_proto,
+ const std::unordered_set<string>& tags,
+ MetaGraphDef* meta_graph_def) {
+ LOG(INFO) << "Reading meta graph with tags { " << str_util::Join(tags, " ")
+ << " }";
+ for (const MetaGraphDef& graph_def : saved_model_proto.meta_graphs()) {
+ // Get tags from the graph_def.
+ std::unordered_set<string> graph_tags;
+ for (const string& tag : graph_def.meta_info_def().tags()) {
+ graph_tags.insert(tag);
+ }
+ // Match with the set of tags provided.
+ if (graph_tags == tags) {
+ *meta_graph_def = graph_def;
+ return Status::OK();
+ }
+ }
+ return Status(
+ error::Code::NOT_FOUND,
+ strings::StrCat(
+ "Could not find meta graph def matching supplied tags: { ",
+ str_util::Join(tags, " "),
+ " }. To inspect available tag-sets in the SavedModel, please "
+ "use the SavedModel CLI: `saved_model_cli`"));
+}
+
+} // namespace
+
+Status ReadMetaGraphDefFromSavedModel(const string& export_dir,
+ const std::unordered_set<string>& tags,
+ MetaGraphDef* const meta_graph_def) {
+ SavedModel saved_model_proto;
+ TF_RETURN_IF_ERROR(ReadSavedModel(export_dir, &saved_model_proto));
+ TF_RETURN_IF_ERROR(FindMetaGraphDef(saved_model_proto, tags, meta_graph_def));
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/cc/saved_model/reader.h b/tensorflow/cc/saved_model/reader.h
new file mode 100644
index 0000000000..5815108df2
--- /dev/null
+++ b/tensorflow/cc/saved_model/reader.h
@@ -0,0 +1,39 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+/// Functions to read the SavedModel proto, or parts of it.
+
+#ifndef TENSORFLOW_CC_SAVED_MODEL_READER_H_
+#define TENSORFLOW_CC_SAVED_MODEL_READER_H_
+
+#include <string>
+#include <unordered_set>
+
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/protobuf/meta_graph.pb.h"
+
+namespace tensorflow {
+
+// Reads the SavedModel proto from saved_model.pb(txt) in the given directory,
+// finds the MetaGraphDef that matches the given set of tags and writes it to
+// the `meta_graph_def` parameter. Returns a failure status when the SavedModel
+// file does not exist or no MetaGraphDef matches the tags.
+Status ReadMetaGraphDefFromSavedModel(const string& export_dir,
+ const std::unordered_set<string>& tags,
+ MetaGraphDef* const meta_graph_def);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CC_SAVED_MODEL_READER_H_
diff --git a/tensorflow/cc/saved_model/reader_test.cc b/tensorflow/cc/saved_model/reader_test.cc
new file mode 100644
index 0000000000..620e9c2eec
--- /dev/null
+++ b/tensorflow/cc/saved_model/reader_test.cc
@@ -0,0 +1,108 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/saved_model/reader.h"
+
+#include "tensorflow/cc/saved_model/constants.h"
+#include "tensorflow/cc/saved_model/tag_constants.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+constexpr char kTestDataPbTxt[] =
+ "cc/saved_model/testdata/half_plus_two_pbtxt/00000123";
+constexpr char kTestDataSharded[] =
+ "cc/saved_model/testdata/half_plus_two/00000123";
+
+class ReaderTest : public ::testing::Test {
+ protected:
+ ReaderTest() {}
+
+ void CheckMetaGraphDef(const MetaGraphDef& meta_graph_def) {
+ const auto& tags = meta_graph_def.meta_info_def().tags();
+ EXPECT_TRUE(std::find(tags.begin(), tags.end(), kSavedModelTagServe) !=
+ tags.end());
+ EXPECT_NE(meta_graph_def.meta_info_def().tensorflow_version(), "");
+ EXPECT_EQ(
+ meta_graph_def.signature_def().at("serving_default").method_name(),
+ "tensorflow/serving/predict");
+ }
+};
+
+TEST_F(ReaderTest, TagMatch) {
+ MetaGraphDef meta_graph_def;
+
+ const string export_dir =
+ io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
+ TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
+ &meta_graph_def));
+ CheckMetaGraphDef(meta_graph_def);
+}
+
+TEST_F(ReaderTest, NoTagMatch) {
+ MetaGraphDef meta_graph_def;
+
+ const string export_dir =
+ io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
+ Status st = ReadMetaGraphDefFromSavedModel(export_dir, {"missing-tag"},
+ &meta_graph_def);
+ EXPECT_FALSE(st.ok());
+ EXPECT_TRUE(str_util::StrContains(
+ st.error_message(),
+ "Could not find meta graph def matching supplied tags: { missing-tag }"))
+ << st.error_message();
+}
+
+TEST_F(ReaderTest, NoTagMatchMultiple) {
+ MetaGraphDef meta_graph_def;
+
+ const string export_dir =
+ io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
+ Status st = ReadMetaGraphDefFromSavedModel(
+ export_dir, {kSavedModelTagServe, "missing-tag"}, &meta_graph_def);
+ EXPECT_FALSE(st.ok());
+ EXPECT_TRUE(str_util::StrContains(
+ st.error_message(),
+ "Could not find meta graph def matching supplied tags: "))
+ << st.error_message();
+}
+
+TEST_F(ReaderTest, PbtxtFormat) {
+ MetaGraphDef meta_graph_def;
+
+ const string export_dir =
+ io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPbTxt);
+ TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
+ &meta_graph_def));
+ CheckMetaGraphDef(meta_graph_def);
+}
+
+TEST_F(ReaderTest, InvalidExportPath) {
+ MetaGraphDef meta_graph_def;
+
+ const string export_dir =
+ io::JoinPath(testing::TensorFlowSrcRoot(), "missing-path");
+ Status st = ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
+ &meta_graph_def);
+ EXPECT_FALSE(st.ok());
+}
+
+} // namespace
+} // namespace tensorflow