aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/graph_transforms
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-09 13:30:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-09 13:30:50 -0700
commitee38f86972b13f3eb90032e93b305e822152bf62 (patch)
tree76830b9ae48559dfcf7fe119521e529424d904ea /tensorflow/tools/graph_transforms
parentb7b8ab67bc4c86be8e52a51e2f85d53212ccdf64 (diff)
parent049c529e567a557461d9ffe8614e6c06a7b3a39f (diff)
Merge pull request #21023 from zhaoyongke:master
PiperOrigin-RevId: 208100633
Diffstat (limited to 'tensorflow/tools/graph_transforms')
-rw-r--r--tensorflow/tools/graph_transforms/fold_batch_norms.cc20
-rw-r--r--tensorflow/tools/graph_transforms/fold_old_batch_norms.cc67
2 files changed, 44 insertions, 43 deletions
diff --git a/tensorflow/tools/graph_transforms/fold_batch_norms.cc b/tensorflow/tools/graph_transforms/fold_batch_norms.cc
index 975b17380f..39f682e8b0 100644
--- a/tensorflow/tools/graph_transforms/fold_batch_norms.cc
+++ b/tensorflow/tools/graph_transforms/fold_batch_norms.cc
@@ -38,7 +38,7 @@ Status FoldBatchNorms(const GraphDef& input_graph_def,
input_graph_def, // clang-format off
{"Mul", // mul_node
{
- {"Conv2D|MatMul", // conv_node
+ {"Conv2D|MatMul|DepthwiseConv2dNative", // conv_node
{
{"*"}, // input_node
{"Const"}, // weights_node
@@ -73,7 +73,10 @@ Status FoldBatchNorms(const GraphDef& input_graph_def,
// Make sure all the inputs really are vectors, with as many entries as
// there are columns in the weights.
- const int weights_cols_index = conv_node.op() == "Conv2D" ? 3 : 1;
+ const int weights_cols_index =
+ conv_node.op() == "Conv2D"
+ ? 3
+ : (conv_node.op() == "DepthwiseConv2dNative" ? 2 : 1);
const int64 weights_cols = weights.shape().dim_size(weights_cols_index);
if ((mul_values.shape().dims() != 1) ||
(mul_values.shape().dim_size(0) != weights_cols)) {
@@ -83,14 +86,13 @@ Status FoldBatchNorms(const GraphDef& input_graph_def,
}
// Multiply the original weights by the scale vector.
- auto weights_matrix = weights.flat_inner_dims<float>();
+ auto weights_vector = weights.flat<float>();
Tensor scaled_weights(DT_FLOAT, weights.shape());
- auto scaled_weights_matrix = scaled_weights.flat_inner_dims<float>();
- for (int64 row = 0; row < weights_matrix.dimension(0); ++row) {
- for (int64 col = 0; col < weights_cols; ++col) {
- scaled_weights_matrix(row, col) =
- weights_matrix(row, col) * mul_values.flat<float>()(col);
- }
+ auto scaled_weights_vector = scaled_weights.flat<float>();
+ for (int64 row = 0; row < weights_vector.dimension(0); ++row) {
+ scaled_weights_vector(row) =
+ weights_vector(row) *
+ mul_values.flat<float>()(row % weights_cols);
}
// Construct the new nodes.
diff --git a/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc b/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc
index 156636ab82..a35d64b789 100644
--- a/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc
+++ b/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc
@@ -110,24 +110,23 @@ Status FuseScaleOffsetToConvWeights(const std::vector<float>& scale_values,
const string& conv_output_name,
std::vector<NodeDef>* new_nodes) {
const NodeDef& conv_node = conv_node_match.node;
- CHECK_EQ("Conv2D", conv_node.op());
+ // CHECK_EQ("Conv2D", conv_node.op());
const NodeDef& input_node = conv_node_match.inputs[0].node;
const NodeDef& weights_node = conv_node_match.inputs[1].node;
CHECK_EQ("Const", weights_node.op());
Tensor weights = GetNodeTensorAttr(weights_node, "value");
- const int64 weights_cols = weights.shape().dim_size(3);
+ const int weights_cols_idx = conv_node.op() == "Conv2D" ? 3 : 2;
+ const int64 weights_cols = weights.shape().dim_size(weights_cols_idx);
CHECK_EQ(weights_cols, scale_values.size());
// Multiply the original weights by the scale vector.
- auto weights_matrix = weights.flat_inner_dims<float>();
+ auto weights_vector = weights.flat<float>();
Tensor scaled_weights(DT_FLOAT, weights.shape());
- auto scaled_weights_matrix = scaled_weights.flat_inner_dims<float>();
- for (int64 row = 0; row < weights_matrix.dimension(0); ++row) {
- for (int64 col = 0; col < weights_cols; ++col) {
- scaled_weights_matrix(row, col) =
- weights_matrix(row, col) * scale_values[col];
- }
+ auto scaled_weights_vector = scaled_weights.flat<float>();
+ for (int64 row = 0; row < weights_vector.dimension(0); ++row) {
+ scaled_weights_vector(row) =
+ weights_vector(row) * scale_values[row % weights_cols];
}
// Figure out the remaining bias to add on.
Tensor bias_offset(DT_FLOAT, {weights_cols});
@@ -293,7 +292,7 @@ Status FoldOldBatchNorms(const GraphDef& input_graph_def,
current_graph_def, // clang-format off
{"BatchNormWithGlobalNormalization|FusedBatchNorm", // batch_norm_node
{
- {"Conv2D", // conv_node
+ {"Conv2D|DepthwiseConv2dNative", // conv_node
{
{"*"}, // input_node
{"Const"}, // weights_node
@@ -322,24 +321,24 @@ Status FoldOldBatchNorms(const GraphDef& input_graph_def,
GraphDef replaced_graph_def;
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
current_graph_def, // clang-format off
- {"BatchNormWithGlobalNormalization|FusedBatchNorm", // batch_norm_node
+ {"BatchNormWithGlobalNormalization|FusedBatchNorm", // batch_norm_node
{
- {"BatchToSpaceND", // batch_to_space_node
+ {"BatchToSpaceND", // batch_to_space_node
{
- {"Conv2D", // conv_node
+ {"Conv2D|DepthwiseConv2dNative", // conv_node
{
- {"*"}, // input_node
- {"Const"}, // weights_node
+ {"*"}, // input_node
+ {"Const"}, // weights_node
}
},
- {"Const"}, // block_shape
- {"Const"}, // crops
+ {"Const"}, // block_shape
+ {"Const"}, // crops
}
},
- {"Const"}, // mean_node
- {"Const"}, // variance_node
- {"Const"}, // beta_node
- {"Const"}, // gamma_node
+ {"Const"}, // mean_node
+ {"Const"}, // variance_node
+ {"Const"}, // beta_node
+ {"Const"}, // gamma_node
}
}, // clang-format on
[&did_graph_change](const NodeMatch& match,
@@ -360,29 +359,29 @@ Status FoldOldBatchNorms(const GraphDef& input_graph_def,
// Replace BatchNorm with concat as input.
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
current_graph_def, // clang-format off
- {"BatchNormWithGlobalNormalization|FusedBatchNorm", // batch_norm_node
+ {"BatchNormWithGlobalNormalization|FusedBatchNorm", // batch_norm_node
{
- {"ConcatV2|Concat", // concat two conv2d.
+ {"ConcatV2|Concat", // concat two conv2d.
{
- {"Conv2D", // conv_node
+ {"Conv2D|DepthwiseConv2dNative", // conv_node
{
- {"*"}, // input_node
- {"Const"}, // weights_node
+ {"*"}, // input_node
+ {"Const"}, // weights_node
}
},
- {"Conv2D", // conv_node
+ {"Conv2D|DepthwiseConv2dNative", // conv_node
{
- {"*"}, // input_node
- {"Const"}, // weights_node
+ {"*"}, // input_node
+ {"Const"}, // weights_node
}
},
- {"Const"}, // axis
+ {"Const"}, // axis
},
},
- {"Const"}, // mean_node
- {"Const"}, // variance_node
- {"Const"}, // beta_node
- {"Const"}, // gamma_node
+ {"Const"}, // mean_node
+ {"Const"}, // variance_node
+ {"Const"}, // beta_node
+ {"Const"}, // gamma_node
}
}, // clang-format on
[&did_graph_change](const NodeMatch& match,