aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/session_bundle/session_bundle_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/session_bundle/session_bundle_test.cc')
-rw-r--r--tensorflow/contrib/session_bundle/session_bundle_test.cc102
1 files changed, 102 insertions, 0 deletions
diff --git a/tensorflow/contrib/session_bundle/session_bundle_test.cc b/tensorflow/contrib/session_bundle/session_bundle_test.cc
new file mode 100644
index 0000000000..2bf09ad438
--- /dev/null
+++ b/tensorflow/contrib/session_bundle/session_bundle_test.cc
@@ -0,0 +1,102 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/session_bundle/session_bundle.h"
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "google/protobuf/any.pb.h"
+#include "tensorflow/contrib/session_bundle/test_util.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/public/session.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+namespace contrib {
+namespace {
+
+TEST(LoadSessionBundleFromPath, Basic) {
+ const string export_path = test_util::TestSrcDirPath(
+ "session_bundle/example/half_plus_two/00000123");
+ tensorflow::SessionOptions options;
+ SessionBundle bundle;
+ TF_ASSERT_OK(LoadSessionBundleFromPath(options, export_path, &bundle));
+
+ const string asset_path =
+ tensorflow::io::JoinPath(export_path, kAssetsDirectory);
+ // Validate the assets behavior.
+ std::vector<Tensor> path_outputs;
+ TF_ASSERT_OK(bundle.session->Run({}, {"filename1:0", "filename2:0"}, {},
+ &path_outputs));
+ ASSERT_EQ(2, path_outputs.size());
+ // Validate the two asset file tensors are set by the init_op and include the
+ // base_path and asset directory.
+ test::ExpectTensorEqual<string>(
+ test::AsTensor<string>(
+ {tensorflow::io::JoinPath(asset_path, "hello1.txt")},
+ TensorShape({})),
+ path_outputs[0]);
+ test::ExpectTensorEqual<string>(
+ test::AsTensor<string>(
+ {tensorflow::io::JoinPath(asset_path, "hello2.txt")},
+ TensorShape({})),
+ path_outputs[1]);
+
+ // Validate the half plus two behavior.
+ Tensor input = test::AsTensor<float>({0, 1, 2, 3}, TensorShape({4, 1}));
+
+ // Recover the Tensor names of our inputs and outputs.
+ auto collection_def = bundle.meta_graph_def.collection_def();
+ Signatures signatures;
+ ASSERT_EQ(1, collection_def[kSignaturesKey].any_list().value_size());
+ collection_def[kSignaturesKey].any_list().value(0).UnpackTo(&signatures);
+ ASSERT_TRUE(signatures.default_signature().has_regression_signature());
+ const tensorflow::contrib::RegressionSignature regression_signature =
+ signatures.default_signature().regression_signature();
+
+ const string input_name = regression_signature.input().tensor_name();
+ const string output_name = regression_signature.output().tensor_name();
+
+ std::vector<Tensor> outputs;
+ TF_ASSERT_OK(
+ bundle.session->Run({{input_name, input}}, {output_name}, {}, &outputs));
+ ASSERT_EQ(outputs.size(), 1);
+ test::ExpectTensorEqual<float>(
+ outputs[0], test::AsTensor<float>({2, 2.5, 3, 3.5}, TensorShape({4, 1})));
+}
+
+TEST(LoadSessionBundleFromPath, BadExportPath) {
+ const string export_path = test_util::TestSrcDirPath("/tmp/bigfoot");
+ tensorflow::SessionOptions options;
+ options.target = "local";
+ SessionBundle bundle;
+ const auto status = LoadSessionBundleFromPath(options, export_path, &bundle);
+ ASSERT_FALSE(status.ok());
+ const string msg = status.ToString();
+ EXPECT_TRUE(msg.find("Not found") != std::string::npos) << msg;
+}
+
+} // namespace
+} // namespace contrib
+} // namespace tensorflow