diff options
author | 2016-12-21 20:50:31 -0800 | |
---|---|---|
committer | 2016-12-21 21:05:54 -0800 | |
commit | 0f0e29e7ba06c50fe4a1a7718e63731b96563a8d (patch) | |
tree | bfbb7dcc4dba793bfef84d0b1681f20ad4eeff2f /tensorflow/tools/graph_transforms/remove_device_test.cc | |
parent | be60473c88175dbc9359c9d1bbb384518757ee81 (diff) |
Create Graph Transform Tool for rewriting model files.
Change: 142729497
Diffstat (limited to 'tensorflow/tools/graph_transforms/remove_device_test.cc')
-rw-r--r-- | tensorflow/tools/graph_transforms/remove_device_test.cc | 95 |
1 files changed, 95 insertions, 0 deletions
diff --git a/tensorflow/tools/graph_transforms/remove_device_test.cc b/tensorflow/tools/graph_transforms/remove_device_test.cc new file mode 100644 index 0000000000..554c5e3595 --- /dev/null +++ b/tensorflow/tools/graph_transforms/remove_device_test.cc @@ -0,0 +1,95 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/image_ops.h" +#include "tensorflow/cc/ops/nn_ops.h" +#include "tensorflow/cc/ops/sendrecv_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +// Declare here, so we don't need a public header. +Status RemoveDevice(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); + +class RemoveDeviceTest : public ::testing::Test { + protected: + void TestRemoveDevice() { + GraphDef graph_def; + + NodeDef* mul_node1 = graph_def.add_node(); + mul_node1->set_name("mul_node1"); + mul_node1->set_op("Mul"); + mul_node1->set_device("//cpu:0"); + mul_node1->add_input("add_node2"); + mul_node1->add_input("add_node3"); + + NodeDef* add_node2 = graph_def.add_node(); + add_node2->set_name("add_node2"); + add_node2->set_op("Add"); + add_node2->add_input("const_node1"); + add_node2->add_input("const_node2"); + add_node2->set_device("//gpu:1"); + + NodeDef* add_node3 = graph_def.add_node(); + add_node3->set_name("add_node3"); + add_node3->set_op("Add"); + add_node3->add_input("const_node1"); + add_node3->add_input("const_node3"); + + NodeDef* const_node1 = graph_def.add_node(); + const_node1->set_name("const_node1"); + const_node1->set_op("Const"); + + NodeDef* const_node2 = graph_def.add_node(); + const_node2->set_name("const_node2"); + const_node2->set_op("Const"); + + NodeDef* const_node3 = graph_def.add_node(); + const_node3->set_name("const_node3"); + const_node3->set_op("Const"); + + NodeDef* add_node4 = graph_def.add_node(); + add_node4->set_name("add_node4"); + add_node4->set_op("Add"); + add_node4->add_input("add_node2"); + add_node4->add_input("add_node3"); + + GraphDef result; + TransformFuncContext context; + context.input_names = {}; + context.output_names = {"mul_node1"}; + TF_ASSERT_OK(RemoveDevice(graph_def, context, &result)); + + std::map<string, const NodeDef*> node_lookup; + MapNamesToNodes(result, &node_lookup); + EXPECT_EQ("", node_lookup.at("mul_node1")->device()); + EXPECT_EQ("", node_lookup.at("add_node2")->device()); + } +}; + +TEST_F(RemoveDeviceTest, TestRemoveDevice) { TestRemoveDevice(); } + +} // namespace graph_transforms +} // namespace tensorflow |