aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/graph_transforms/transform_graph_test.cc
diff options
context:
space:
mode:
authorGravatar Pete Warden <petewarden@google.com>2016-12-21 20:50:31 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-21 21:05:54 -0800
commit0f0e29e7ba06c50fe4a1a7718e63731b96563a8d (patch)
treebfbb7dcc4dba793bfef84d0b1681f20ad4eeff2f /tensorflow/tools/graph_transforms/transform_graph_test.cc
parentbe60473c88175dbc9359c9d1bbb384518757ee81 (diff)
Create Graph Transform Tool for rewriting model files.
Change: 142729497
Diffstat (limited to 'tensorflow/tools/graph_transforms/transform_graph_test.cc')
-rw-r--r--tensorflow/tools/graph_transforms/transform_graph_test.cc228
1 files changed, 228 insertions, 0 deletions
diff --git a/tensorflow/tools/graph_transforms/transform_graph_test.cc b/tensorflow/tools/graph_transforms/transform_graph_test.cc
new file mode 100644
index 0000000000..12df4051fb
--- /dev/null
+++ b/tensorflow/tools/graph_transforms/transform_graph_test.cc
@@ -0,0 +1,228 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/tools/graph_transforms/transform_graph.h"
+#include "tensorflow/cc/ops/const_op.h"
+#include "tensorflow/cc/ops/image_ops.h"
+#include "tensorflow/cc/ops/nn_ops.h"
+#include "tensorflow/cc/ops/sendrecv_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/public/session.h"
+#include "tensorflow/tools/graph_transforms/transform_utils.h"
+
+namespace tensorflow {
+namespace graph_transforms {
+
+// Declared here so we don't have to expose it in the public header.
+Status ShouldIgnoreErrors(const TransformFuncParameters& transform_params,
+ bool* ignore_errors);
+
+namespace {
+Status test_empty_graph_transform(const GraphDef& graph_def,
+ const TransformFuncContext& context,
+ GraphDef* result) {
+ result->Clear();
+ return Status::OK();
+}
+} // namespace
+
+REGISTER_GRAPH_TRANSFORM("test_empty_graph_transform",
+ test_empty_graph_transform);
+
+class TransformGraphTest : public ::testing::Test {
+ protected:
+ void TestConstantFolding() {
+ auto root = tensorflow::Scope::NewRootScope();
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+
+ const int width = 100;
+
+ Tensor a_data(DT_FLOAT, TensorShape({width}));
+ test::FillIota<float>(&a_data, 1.0f);
+ Output a_const =
+ Const(root.WithOpName("a_expect_removed"), Input::Initializer(a_data));
+
+ Tensor b_data(DT_FLOAT, TensorShape({width}));
+ test::FillIota<float>(&b_data, 1.0f);
+ Output b_const =
+ Const(root.WithOpName("b_expect_removed"), Input::Initializer(b_data));
+
+ Output add = Add(root.WithOpName("add_expect_removed"), a_const, b_const);
+
+ Output placeholder =
+ Placeholder(root.WithOpName("placeholder_expect_remains"), DT_FLOAT);
+
+ Output mul =
+ Mul(root.WithOpName("output_expect_remains"), add, placeholder);
+
+ GraphDef graph_def;
+ TF_ASSERT_OK(root.ToGraphDef(&graph_def));
+ string graph_def_serialized;
+ graph_def.SerializeToString(&graph_def_serialized);
+ const string dir = testing::TmpDir();
+ const string in_filename_pb = io::JoinPath(dir, "in_graphdef.pb");
+ const string out_filename_pb = io::JoinPath(dir, "out_graphdef.pb");
+ TF_ASSERT_OK(WriteStringToFile(Env::Default(), in_filename_pb,
+ graph_def_serialized));
+
+ std::vector<string> args = {"some_binary",
+ "--in_graph=" + in_filename_pb,
+ "--out_graph=" + out_filename_pb,
+ "--inputs=placeholder_expect_remains",
+ "--outputs=output_expect_remains",
+ "--transforms=fold_constants"};
+ const int argc = 6;
+ EXPECT_EQ(argc, args.size());
+ char* argv[argc];
+ std::vector<char*> char_strings;
+ for (int i = 0; i < argc; ++i) {
+ string arg = args[i];
+ char* char_string = new char[arg.size() + 1];
+ std::copy_n(arg.c_str(), arg.size() + 1, char_string);
+ argv[i] = char_string;
+ char_strings.push_back(char_string);
+ }
+ ParseFlagsAndTransformGraph(argc, argv, false);
+ for (char* char_string : char_strings) {
+ delete[] char_string;
+ }
+
+ GraphDef out_graph_def;
+ TF_EXPECT_OK(
+ ReadBinaryProto(Env::Default(), out_filename_pb, &out_graph_def));
+
+ std::map<string, const NodeDef*> out_node_map;
+ graph_transforms::MapNamesToNodes(out_graph_def, &out_node_map);
+
+ for (const NodeDef& node : out_graph_def.node()) {
+ const StringPiece name(node.name());
+ const int occurrence_count = out_node_map.count(node.name());
+ if (name.ends_with("expect_removed")) {
+ EXPECT_EQ(0, occurrence_count) << "node.name()=" << node.name();
+ }
+ if (name.ends_with("expect_remains")) {
+ EXPECT_EQ(1, occurrence_count) << "node.name()=" << node.name();
+ }
+ }
+ }
+
+ void TestTransformRegistration() {
+ auto root = tensorflow::Scope::NewRootScope();
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+ Output placeholder =
+ Placeholder(root.WithOpName("placeholder_expect_remains"), DT_FLOAT);
+ GraphDef graph_def;
+ TF_ASSERT_OK(root.ToGraphDef(&graph_def));
+ EXPECT_EQ(1, graph_def.node().size());
+ TF_ASSERT_OK(TransformGraph({}, {}, {{"test_empty_graph_transform", {}}},
+ &graph_def));
+ EXPECT_EQ(0, graph_def.node().size());
+
+ TF_ASSERT_OK(root.ToGraphDef(&graph_def));
+ Status no_such_status =
+ TransformGraph({}, {}, {{"test_no_such_transform", {}}}, &graph_def);
+ EXPECT_TRUE(
+ StringPiece(no_such_status.ToString()).contains("not recognized"));
+ }
+
+ void TestParseTransformParameters() {
+ TransformParameters params_list;
+
+ ParseTransformParameters("foo", &params_list);
+ EXPECT_EQ(1, params_list.size());
+ EXPECT_EQ("foo", params_list[0].first);
+ EXPECT_TRUE(params_list[0].second.empty());
+
+ ParseTransformParameters("foo bar", &params_list);
+ EXPECT_EQ(2, params_list.size());
+ EXPECT_EQ("foo", params_list[0].first);
+ EXPECT_TRUE(params_list[0].second.empty());
+ EXPECT_EQ("bar", params_list[1].first);
+ EXPECT_TRUE(params_list[1].second.empty());
+
+ ParseTransformParameters("foo() bar()", &params_list);
+ EXPECT_EQ(2, params_list.size());
+ EXPECT_EQ("foo", params_list[0].first);
+ EXPECT_TRUE(params_list[0].second.empty());
+ EXPECT_EQ("bar", params_list[1].first);
+ EXPECT_TRUE(params_list[1].second.empty());
+
+ ParseTransformParameters("foo(bob_something=sue)", &params_list);
+ EXPECT_EQ(1, params_list.size());
+ EXPECT_EQ("foo", params_list[0].first);
+ EXPECT_EQ(1, params_list[0].second.count("bob_something"));
+ EXPECT_EQ(1, params_list[0].second["bob_something"].size());
+ EXPECT_EQ("sue", params_list[0].second["bob_something"][0]);
+
+ ParseTransformParameters("bar(a=1, b=2, a=3)", &params_list);
+ EXPECT_EQ(1, params_list.size());
+ EXPECT_EQ("bar", params_list[0].first);
+ EXPECT_EQ(1, params_list[0].second.count("a"));
+ EXPECT_EQ(2, params_list[0].second["a"].size());
+ EXPECT_EQ("1", params_list[0].second["a"][0]);
+ EXPECT_EQ("3", params_list[0].second["a"][1]);
+ EXPECT_EQ(1, params_list[0].second.count("b"));
+ EXPECT_EQ(1, params_list[0].second["b"].size());
+ EXPECT_EQ("2", params_list[0].second["b"][0]);
+
+ ParseTransformParameters("bar(a=\"1\", b=\"1,2,3\", a=3)", &params_list);
+ EXPECT_EQ(1, params_list.size());
+ EXPECT_EQ("bar", params_list[0].first);
+ EXPECT_EQ(1, params_list[0].second.count("a"));
+ EXPECT_EQ(2, params_list[0].second["a"].size());
+ EXPECT_EQ("1", params_list[0].second["a"][0]);
+ EXPECT_EQ("3", params_list[0].second["a"][1]);
+ EXPECT_EQ(1, params_list[0].second.count("b"));
+ EXPECT_EQ(1, params_list[0].second["b"].size());
+ EXPECT_EQ("1,2,3", params_list[0].second["b"][0]);
+ }
+
+ void TestShouldIgnoreErrors() {
+ bool ignore_errors;
+ TF_EXPECT_OK(
+ ShouldIgnoreErrors({{"ignore_errors", {"true"}}}, &ignore_errors));
+ EXPECT_TRUE(ignore_errors);
+
+ TF_EXPECT_OK(
+ ShouldIgnoreErrors({{"ignore_errors", {"false"}}}, &ignore_errors));
+ EXPECT_FALSE(ignore_errors);
+
+ TF_EXPECT_OK(ShouldIgnoreErrors({}, &ignore_errors));
+ EXPECT_FALSE(ignore_errors);
+
+ EXPECT_FALSE(
+ ShouldIgnoreErrors({{"ignore_errors", {"foo"}}}, &ignore_errors).ok());
+ }
+};
+
+TEST_F(TransformGraphTest, TestConstantFolding) { TestConstantFolding(); }
+
+TEST_F(TransformGraphTest, TestTransformRegistration) {
+ TestTransformRegistration();
+}
+
+TEST_F(TransformGraphTest, TestParseTransformParameters) {
+ TestParseTransformParameters();
+}
+
+TEST_F(TransformGraphTest, TestShouldIgnoreErrors) { TestShouldIgnoreErrors(); }
+
+} // namespace graph_transforms
+} // namespace tensorflow