aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/graph_transforms/remove_device_test.cc
diff options
context:
space:
mode:
authorGravatar Pete Warden <petewarden@google.com>2016-12-21 20:50:31 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-21 21:05:54 -0800
commit0f0e29e7ba06c50fe4a1a7718e63731b96563a8d (patch)
treebfbb7dcc4dba793bfef84d0b1681f20ad4eeff2f /tensorflow/tools/graph_transforms/remove_device_test.cc
parentbe60473c88175dbc9359c9d1bbb384518757ee81 (diff)
Create Graph Transform Tool for rewriting model files.
Change: 142729497
Diffstat (limited to 'tensorflow/tools/graph_transforms/remove_device_test.cc')
-rw-r--r--tensorflow/tools/graph_transforms/remove_device_test.cc95
1 files changed, 95 insertions, 0 deletions
diff --git a/tensorflow/tools/graph_transforms/remove_device_test.cc b/tensorflow/tools/graph_transforms/remove_device_test.cc
new file mode 100644
index 0000000000..554c5e3595
--- /dev/null
+++ b/tensorflow/tools/graph_transforms/remove_device_test.cc
@@ -0,0 +1,95 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/ops/const_op.h"
+#include "tensorflow/cc/ops/image_ops.h"
+#include "tensorflow/cc/ops/nn_ops.h"
+#include "tensorflow/cc/ops/sendrecv_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/public/session.h"
+#include "tensorflow/tools/graph_transforms/transform_utils.h"
+
+namespace tensorflow {
+namespace graph_transforms {
+
+// Declare here, so we don't need a public header.
+Status RemoveDevice(const GraphDef& input_graph_def,
+ const TransformFuncContext& context,
+ GraphDef* output_graph_def);
+
+class RemoveDeviceTest : public ::testing::Test {
+ protected:
+ void TestRemoveDevice() {
+ GraphDef graph_def;
+
+ NodeDef* mul_node1 = graph_def.add_node();
+ mul_node1->set_name("mul_node1");
+ mul_node1->set_op("Mul");
+ mul_node1->set_device("//cpu:0");
+ mul_node1->add_input("add_node2");
+ mul_node1->add_input("add_node3");
+
+ NodeDef* add_node2 = graph_def.add_node();
+ add_node2->set_name("add_node2");
+ add_node2->set_op("Add");
+ add_node2->add_input("const_node1");
+ add_node2->add_input("const_node2");
+ add_node2->set_device("//gpu:1");
+
+ NodeDef* add_node3 = graph_def.add_node();
+ add_node3->set_name("add_node3");
+ add_node3->set_op("Add");
+ add_node3->add_input("const_node1");
+ add_node3->add_input("const_node3");
+
+ NodeDef* const_node1 = graph_def.add_node();
+ const_node1->set_name("const_node1");
+ const_node1->set_op("Const");
+
+ NodeDef* const_node2 = graph_def.add_node();
+ const_node2->set_name("const_node2");
+ const_node2->set_op("Const");
+
+ NodeDef* const_node3 = graph_def.add_node();
+ const_node3->set_name("const_node3");
+ const_node3->set_op("Const");
+
+ NodeDef* add_node4 = graph_def.add_node();
+ add_node4->set_name("add_node4");
+ add_node4->set_op("Add");
+ add_node4->add_input("add_node2");
+ add_node4->add_input("add_node3");
+
+ GraphDef result;
+ TransformFuncContext context;
+ context.input_names = {};
+ context.output_names = {"mul_node1"};
+ TF_ASSERT_OK(RemoveDevice(graph_def, context, &result));
+
+ std::map<string, const NodeDef*> node_lookup;
+ MapNamesToNodes(result, &node_lookup);
+ EXPECT_EQ("", node_lookup.at("mul_node1")->device());
+ EXPECT_EQ("", node_lookup.at("add_node2")->device());
+ }
+};
+
+TEST_F(RemoveDeviceTest, TestRemoveDevice) { TestRemoveDevice(); }
+
+} // namespace graph_transforms
+} // namespace tensorflow