aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/tools/graph_transforms/fold_old_batch_norms.cc')
-rw-r--r--tensorflow/tools/graph_transforms/fold_old_batch_norms.cc67
1 files changed, 67 insertions, 0 deletions
diff --git a/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc b/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc
index d89afe85c7..d86f65325b 100644
--- a/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc
+++ b/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc
@@ -182,6 +182,36 @@ Status FuseBatchNormWithConv(const NodeMatch& match,
return Status::OK();
}
+Status FuseBatchNormWithBatchToSpace(const NodeMatch& match,
+ std::vector<NodeDef>* new_nodes) {
+ // Calculate the scale and offset values to apply.
+ std::vector<float> scale_values;
+ std::vector<float> offset_values;
+ TF_RETURN_IF_ERROR(
+ GetScaleAndOffsetValues(match, &scale_values, &offset_values));
+
+ // Fuse conv weights, and set the final output node name as batch_norm_node.
+ const NodeDef& batch_norm_node = match.node;
+ const NodeMatch& batch_to_space_node_match = match.inputs[0];
+ const NodeMatch& conv_node_match = batch_to_space_node_match.inputs[0];
+ const NodeDef& batch_to_space_node = batch_to_space_node_match.node;
+ const NodeDef& conv_node = conv_node_match.node;
+
+ string biasadd_name = conv_node.name() + "/biasadd";
+ TF_RETURN_IF_ERROR(
+ FuseScaleOffsetToConvWeights(scale_values, offset_values, conv_node_match,
+ biasadd_name , new_nodes));
+
+ NodeDef new_batch_to_space_node = batch_to_space_node;
+ // reuse batch_norm node name
+ new_batch_to_space_node.set_name(batch_norm_node.name());
+ new_batch_to_space_node.set_input(0, biasadd_name);
+ new_nodes->push_back(batch_to_space_node_match.inputs[1].node);
+ new_nodes->push_back(batch_to_space_node_match.inputs[2].node);
+ new_nodes->push_back(new_batch_to_space_node);
+ return Status::OK();
+}
+
Status FuseBatchNormWithConvConcat(const NodeMatch& match,
std::vector<NodeDef>* new_nodes) {
// Calculate the scale and offset values to apply.
@@ -287,6 +317,43 @@ Status FoldOldBatchNorms(const GraphDef& input_graph_def,
do {
did_graph_change = false;
GraphDef replaced_graph_def;
+ TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
+ current_graph_def, // clang-format off
+ {"BatchNormWithGlobalNormalization|FusedBatchNorm", // batch_norm_node
+ {
+ {"BatchToSpaceND", // batch_to_space_node
+ {
+ {"Conv2D", // conv_node
+ {
+ {"*"}, // input_node
+ {"Const"}, // weights_node
+ }
+ },
+ {"Const"}, // block_shape
+ {"Const"}, // crops
+ }
+ },
+ {"Const"}, // mean_node
+ {"Const"}, // variance_node
+ {"Const"}, // beta_node
+ {"Const"}, // gamma_node
+ }
+ }, // clang-format on
+ [&did_graph_change](const NodeMatch& match,
+ const std::set<string>& input_nodes,
+ const std::set<string>& output_nodes,
+ std::vector<NodeDef>* new_nodes) {
+ TF_RETURN_IF_ERROR(FuseBatchNormWithBatchToSpace(match, new_nodes));
+ did_graph_change = true;
+ return Status::OK();
+ },
+ {}, &replaced_graph_def));
+ current_graph_def = replaced_graph_def;
+ } while (did_graph_change);
+
+ do {
+ did_graph_change = false;
+ GraphDef replaced_graph_def;
// Replace BatchNorm with concat as input.
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
current_graph_def, // clang-format off