aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/graph_transforms
diff options
context:
space:
mode:
authorGravatar Jacques Pienaar <jpienaar@google.com>2018-03-15 12:58:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-15 13:03:19 -0700
commitccd8079e579604547f4b4d8a6b061cfdc6ce49bf (patch)
tree0d498e84ca32a101afcada0993a30a5e3b0452a2 /tensorflow/tools/graph_transforms
parent61032e9ca7bf9849cb65db9b646381d124080856 (diff)
Merge changes from github.
PiperOrigin-RevId: 189231636
Diffstat (limited to 'tensorflow/tools/graph_transforms')
-rw-r--r--tensorflow/tools/graph_transforms/BUILD1
-rw-r--r--tensorflow/tools/graph_transforms/fold_old_batch_norms.cc67
-rw-r--r--tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc95
3 files changed, 163 insertions, 0 deletions
diff --git a/tensorflow/tools/graph_transforms/BUILD b/tensorflow/tools/graph_transforms/BUILD
index b7d7fac315..6e21aa2846 100644
--- a/tensorflow/tools/graph_transforms/BUILD
+++ b/tensorflow/tools/graph_transforms/BUILD
@@ -178,6 +178,7 @@ tf_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
+ "//tensorflow/core/kernels:quantization_utils",
"//tensorflow/core/kernels:quantized_ops",
"//tensorflow/core/util/tensor_bundle",
],
diff --git a/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc b/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc
index d89afe85c7..d86f65325b 100644
--- a/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc
+++ b/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc
@@ -182,6 +182,36 @@ Status FuseBatchNormWithConv(const NodeMatch& match,
return Status::OK();
}
+Status FuseBatchNormWithBatchToSpace(const NodeMatch& match,
+ std::vector<NodeDef>* new_nodes) {
+ // Calculate the scale and offset values to apply.
+ std::vector<float> scale_values;
+ std::vector<float> offset_values;
+ TF_RETURN_IF_ERROR(
+ GetScaleAndOffsetValues(match, &scale_values, &offset_values));
+
+ // Fuse conv weights, and set the final output node name as batch_norm_node.
+ const NodeDef& batch_norm_node = match.node;
+ const NodeMatch& batch_to_space_node_match = match.inputs[0];
+ const NodeMatch& conv_node_match = batch_to_space_node_match.inputs[0];
+ const NodeDef& batch_to_space_node = batch_to_space_node_match.node;
+ const NodeDef& conv_node = conv_node_match.node;
+
+ string biasadd_name = conv_node.name() + "/biasadd";
+ TF_RETURN_IF_ERROR(
+ FuseScaleOffsetToConvWeights(scale_values, offset_values, conv_node_match,
+ biasadd_name , new_nodes));
+
+ NodeDef new_batch_to_space_node = batch_to_space_node;
+ // reuse batch_norm node name
+ new_batch_to_space_node.set_name(batch_norm_node.name());
+ new_batch_to_space_node.set_input(0, biasadd_name);
+ new_nodes->push_back(batch_to_space_node_match.inputs[1].node);
+ new_nodes->push_back(batch_to_space_node_match.inputs[2].node);
+ new_nodes->push_back(new_batch_to_space_node);
+ return Status::OK();
+}
+
Status FuseBatchNormWithConvConcat(const NodeMatch& match,
std::vector<NodeDef>* new_nodes) {
// Calculate the scale and offset values to apply.
@@ -287,6 +317,43 @@ Status FoldOldBatchNorms(const GraphDef& input_graph_def,
do {
did_graph_change = false;
GraphDef replaced_graph_def;
+ TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
+ current_graph_def, // clang-format off
+ {"BatchNormWithGlobalNormalization|FusedBatchNorm", // batch_norm_node
+ {
+ {"BatchToSpaceND", // batch_to_space_node
+ {
+ {"Conv2D", // conv_node
+ {
+ {"*"}, // input_node
+ {"Const"}, // weights_node
+ }
+ },
+ {"Const"}, // block_shape
+ {"Const"}, // crops
+ }
+ },
+ {"Const"}, // mean_node
+ {"Const"}, // variance_node
+ {"Const"}, // beta_node
+ {"Const"}, // gamma_node
+ }
+ }, // clang-format on
+ [&did_graph_change](const NodeMatch& match,
+ const std::set<string>& input_nodes,
+ const std::set<string>& output_nodes,
+ std::vector<NodeDef>* new_nodes) {
+ TF_RETURN_IF_ERROR(FuseBatchNormWithBatchToSpace(match, new_nodes));
+ did_graph_change = true;
+ return Status::OK();
+ },
+ {}, &replaced_graph_def));
+ current_graph_def = replaced_graph_def;
+ } while (did_graph_change);
+
+ do {
+ did_graph_change = false;
+ GraphDef replaced_graph_def;
// Replace BatchNorm with concat as input.
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
current_graph_def, // clang-format off
diff --git a/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc b/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc
index b30ba9ac8b..272410c693 100644
--- a/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc
+++ b/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/image_ops.h"
#include "tensorflow/cc/ops/nn_ops.h"
+#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/sendrecv_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor_testutil.h"
@@ -298,6 +299,96 @@ class FoldOldBatchNormsTest : public ::testing::Test {
}
};
+void TestFoldFusedBatchNormsWithBatchToSpace() {
+ auto root = tensorflow::Scope::NewRootScope();
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+
+ Tensor input_data(DT_FLOAT, TensorShape({2, 1, 3, 2}));
+ test::FillValues<float>(
+ &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
+ -5.0f, -3.0f, -6.0f});
+ Output input_op =
+ Const(root.WithOpName("input_op"), Input::Initializer(input_data));
+
+ Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 2}));
+ test::FillValues<float>(&weights_data,
+ {1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f});
+ Output weights_op =
+ Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
+
+ Output conv_op = Conv2D(root.WithOpName("conv_op"), input_op, weights_op,
+ {1, 1, 1, 1}, "VALID");
+
+ Tensor block_shape_data(DT_INT32, TensorShape({2}));
+ test::FillValues<int32>(&block_shape_data, {1, 2});
+ Output block_shape_op =
+ Const(root.WithOpName("block_shape_op"), Input::Initializer(block_shape_data));
+
+ Tensor crops_data(DT_INT32, TensorShape({2, 2}));
+ test::FillValues<int32>(&crops_data, {0, 0, 0, 1});
+ Output crops_op =
+ Const(root.WithOpName("crops_op"), Input::Initializer(crops_data));
+
+ Output batch_to_space_op = BatchToSpaceND(root.WithOpName("batch_to_space_op"),
+ conv_op, block_shape_op, crops_data);
+
+ Tensor mean_data(DT_FLOAT, TensorShape({2}));
+ test::FillValues<float>(&mean_data, {10.0f, 20.0f});
+ Output mean_op =
+ Const(root.WithOpName("mean_op"), Input::Initializer(mean_data));
+
+ Tensor variance_data(DT_FLOAT, TensorShape({2}));
+ test::FillValues<float>(&variance_data, {0.25f, 0.5f});
+ Output variance_op = Const(root.WithOpName("variance_op"),
+ Input::Initializer(variance_data));
+
+ Tensor beta_data(DT_FLOAT, TensorShape({2}));
+ test::FillValues<float>(&beta_data, {0.1f, 0.6f});
+ Output beta_op =
+ Const(root.WithOpName("beta_op"), Input::Initializer(beta_data));
+
+ Tensor gamma_data(DT_FLOAT, TensorShape({2}));
+ test::FillValues<float>(&gamma_data, {1.0f, 2.0f});
+ Output gamma_op =
+ Const(root.WithOpName("gamma_op"), Input::Initializer(gamma_data));
+
+ GraphDef original_graph_def;
+ TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
+
+ NodeDef batch_norm_node;
+ batch_norm_node.set_op("FusedBatchNorm");
+ batch_norm_node.set_name("output");
+ AddNodeInput("batch_to_space_op", &batch_norm_node);
+ AddNodeInput("gamma_op", &batch_norm_node);
+ AddNodeInput("beta_op", &batch_norm_node);
+ AddNodeInput("mean_op", &batch_norm_node);
+ AddNodeInput("variance_op", &batch_norm_node);
+ SetNodeAttr("T", DT_FLOAT, &batch_norm_node);
+ SetNodeAttr("epsilon", 0.00001f, &batch_norm_node);
+ SetNodeAttr("is_training", false, &batch_norm_node);
+ *(original_graph_def.mutable_node()->Add()) = batch_norm_node;
+
+ std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
+ TF_ASSERT_OK(original_session->Create(original_graph_def));
+ std::vector<Tensor> original_outputs;
+ TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
+
+ GraphDef fused_graph_def;
+ TF_ASSERT_OK(FoldOldBatchNorms(original_graph_def, {{}, {"output"}},
+ &fused_graph_def));
+
+ std::unique_ptr<Session> fused_session(NewSession(SessionOptions()));
+ TF_ASSERT_OK(fused_session->Create(fused_graph_def));
+ std::vector<Tensor> fused_outputs;
+ TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs));
+
+ test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 1e-5);
+
+ for (const NodeDef& node : fused_graph_def.node()) {
+ EXPECT_NE("FusedBatchNormWithBatchToSpace", node.op());
+ }
+}
+
TEST_F(FoldOldBatchNormsTest, TestFoldOldBatchNorms) {
TestFoldOldBatchNorms();
}
@@ -315,5 +406,9 @@ TEST_F(FoldOldBatchNormsTest, TestFoldFusedBatchNormsWithConcat) {
TestFoldFusedBatchNormsWithConcat(/*split=*/false);
}
+TEST_F(FoldOldBatchNormsTest, TestFoldFusedBatchNormsWithBatchToSpace) {
+ TestFoldFusedBatchNormsWithBatchToSpace();
+}
+
} // namespace graph_transforms
} // namespace tensorflow