aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/graph_transforms
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-09-28 00:31:59 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-28 00:36:05 -0700
commitc9435befb1bd50ad550deaebfac272eb97da7780 (patch)
tree64559a21b4d2cb00d6ea26f2fbf20853adabc58b /tensorflow/tools/graph_transforms
parent35a162a8ee61b6d3fadc6c108ce97446bbb6afd8 (diff)
Don't fold batch norm calculations if weights are used somewhere else in the graph.
PiperOrigin-RevId: 170309345
Diffstat (limited to 'tensorflow/tools/graph_transforms')
-rw-r--r--tensorflow/tools/graph_transforms/fold_batch_norms.cc11
-rw-r--r--tensorflow/tools/graph_transforms/fold_batch_norms_test.cc58
2 files changed, 69 insertions, 0 deletions
diff --git a/tensorflow/tools/graph_transforms/fold_batch_norms.cc b/tensorflow/tools/graph_transforms/fold_batch_norms.cc
index 2ff3bb641e..975b17380f 100644
--- a/tensorflow/tools/graph_transforms/fold_batch_norms.cc
+++ b/tensorflow/tools/graph_transforms/fold_batch_norms.cc
@@ -57,6 +57,17 @@ Status FoldBatchNorms(const GraphDef& input_graph_def,
const NodeDef& weights_node = match.inputs[0].inputs[1].node;
const NodeDef& mul_values_node = match.inputs[1].node;
+ // Check that nodes that we use are not used somewhere else.
+ for (const auto& node : {conv_node, weights_node, mul_values_node}) {
+ if (output_nodes.count(node.name())) {
+ // Return original nodes.
+ new_nodes->insert(new_nodes->end(),
+ {mul_node, conv_node, input_node, weights_node,
+ mul_values_node});
+ return Status::OK();
+ }
+ }
+
Tensor weights = GetNodeTensorAttr(weights_node, "value");
Tensor mul_values = GetNodeTensorAttr(mul_values_node, "value");
diff --git a/tensorflow/tools/graph_transforms/fold_batch_norms_test.cc b/tensorflow/tools/graph_transforms/fold_batch_norms_test.cc
index ed741f002c..a5d541feb6 100644
--- a/tensorflow/tools/graph_transforms/fold_batch_norms_test.cc
+++ b/tensorflow/tools/graph_transforms/fold_batch_norms_test.cc
@@ -87,6 +87,64 @@ class FoldBatchNormsTest : public ::testing::Test {
}
}
+ void TestFoldBatchNormsConv2DShared() {
+ auto root = tensorflow::Scope::NewRootScope();
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+
+ Tensor input_data(DT_FLOAT, TensorShape({1, 1, 6, 2}));
+ test::FillValues<float>(
+ &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
+ -5.0f, -3.0f, -6.0f});
+ Output input_op =
+ Const(root.WithOpName("input_op"), Input::Initializer(input_data));
+
+ Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 2}));
+ test::FillValues<float>(&weights_data,
+ {1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f});
+ Output weights_op =
+ Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
+
+ Output conv_op = Conv2D(root.WithOpName("conv_op"), input_op, weights_op,
+ {1, 1, 1, 1}, "VALID");
+
+ Tensor mul_values_data(DT_FLOAT, TensorShape({2}));
+ test::FillValues<float>(&mul_values_data, {2.0f, 3.0f});
+ Output mul_values_op = Const(root.WithOpName("mul_values"),
+ Input::Initializer(mul_values_data));
+
+ Output mul_op = Mul(root.WithOpName("output"), conv_op, mul_values_op);
+
+ Tensor mul_values_data_2(DT_FLOAT, TensorShape({2}));
+ test::FillValues<float>(&mul_values_data_2, {1.0f, 2.0f});
+ Output mul_values_op_2 = Const(root.WithOpName("mul_values_2"),
+ Input::Initializer(mul_values_data));
+
+ Output mul_op_2 =
+ Mul(root.WithOpName("output_2"), conv_op, mul_values_op_2);
+
+ GraphDef original_graph_def;
+ TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
+
+ std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
+ TF_ASSERT_OK(original_session->Create(original_graph_def));
+ std::vector<Tensor> original_outputs;
+ TF_ASSERT_OK(original_session->Run({}, {"output", "output_2"}, {},
+ &original_outputs));
+
+ GraphDef fused_graph_def;
+ TF_ASSERT_OK(FoldBatchNorms(
+ original_graph_def, {{}, {"output", "output_2"}}, &fused_graph_def));
+
+ std::unique_ptr<Session> fused_session(NewSession(SessionOptions()));
+ TF_ASSERT_OK(fused_session->Create(fused_graph_def));
+ std::vector<Tensor> fused_outputs;
+ TF_ASSERT_OK(
+ fused_session->Run({}, {"output", "output_2"}, {}, &fused_outputs));
+
+ test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 1e-5);
+ test::ExpectTensorNear<float>(original_outputs[1], fused_outputs[1], 1e-5);
+ }
+
void TestFoldBatchNormsMatMul() {
auto root = tensorflow::Scope::NewRootScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)