diff options
author | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-09 13:30:50 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-09 13:30:50 -0700 |
commit | ee38f86972b13f3eb90032e93b305e822152bf62 (patch) | |
tree | 76830b9ae48559dfcf7fe119521e529424d904ea /tensorflow/tools/graph_transforms | |
parent | b7b8ab67bc4c86be8e52a51e2f85d53212ccdf64 (diff) | |
parent | 049c529e567a557461d9ffe8614e6c06a7b3a39f (diff) |
Merge pull request #21023 from zhaoyongke:master
PiperOrigin-RevId: 208100633
Diffstat (limited to 'tensorflow/tools/graph_transforms')
-rw-r--r-- | tensorflow/tools/graph_transforms/fold_batch_norms.cc | 20 | ||||
-rw-r--r-- | tensorflow/tools/graph_transforms/fold_old_batch_norms.cc | 67 |
2 files changed, 44 insertions, 43 deletions
diff --git a/tensorflow/tools/graph_transforms/fold_batch_norms.cc b/tensorflow/tools/graph_transforms/fold_batch_norms.cc index 975b17380f..39f682e8b0 100644 --- a/tensorflow/tools/graph_transforms/fold_batch_norms.cc +++ b/tensorflow/tools/graph_transforms/fold_batch_norms.cc @@ -38,7 +38,7 @@ Status FoldBatchNorms(const GraphDef& input_graph_def, input_graph_def, // clang-format off {"Mul", // mul_node { - {"Conv2D|MatMul", // conv_node + {"Conv2D|MatMul|DepthwiseConv2dNative", // conv_node { {"*"}, // input_node {"Const"}, // weights_node @@ -73,7 +73,10 @@ Status FoldBatchNorms(const GraphDef& input_graph_def, // Make sure all the inputs really are vectors, with as many entries as // there are columns in the weights. - const int weights_cols_index = conv_node.op() == "Conv2D" ? 3 : 1; + const int weights_cols_index = + conv_node.op() == "Conv2D" + ? 3 + : (conv_node.op() == "DepthwiseConv2dNative" ? 2 : 1); const int64 weights_cols = weights.shape().dim_size(weights_cols_index); if ((mul_values.shape().dims() != 1) || (mul_values.shape().dim_size(0) != weights_cols)) { @@ -83,14 +86,13 @@ Status FoldBatchNorms(const GraphDef& input_graph_def, } // Multiply the original weights by the scale vector. - auto weights_matrix = weights.flat_inner_dims<float>(); + auto weights_vector = weights.flat<float>(); Tensor scaled_weights(DT_FLOAT, weights.shape()); - auto scaled_weights_matrix = scaled_weights.flat_inner_dims<float>(); - for (int64 row = 0; row < weights_matrix.dimension(0); ++row) { - for (int64 col = 0; col < weights_cols; ++col) { - scaled_weights_matrix(row, col) = - weights_matrix(row, col) * mul_values.flat<float>()(col); - } + auto scaled_weights_vector = scaled_weights.flat<float>(); + for (int64 row = 0; row < weights_vector.dimension(0); ++row) { + scaled_weights_vector(row) = + weights_vector(row) * + mul_values.flat<float>()(row % weights_cols); } // Construct the new nodes. diff --git a/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc b/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc index 156636ab82..a35d64b789 100644 --- a/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc +++ b/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc @@ -110,24 +110,23 @@ Status FuseScaleOffsetToConvWeights(const std::vector<float>& scale_values, const string& conv_output_name, std::vector<NodeDef>* new_nodes) { const NodeDef& conv_node = conv_node_match.node; - CHECK_EQ("Conv2D", conv_node.op()); + // CHECK_EQ("Conv2D", conv_node.op()); const NodeDef& input_node = conv_node_match.inputs[0].node; const NodeDef& weights_node = conv_node_match.inputs[1].node; CHECK_EQ("Const", weights_node.op()); Tensor weights = GetNodeTensorAttr(weights_node, "value"); - const int64 weights_cols = weights.shape().dim_size(3); + const int weights_cols_idx = conv_node.op() == "Conv2D" ? 3 : 2; + const int64 weights_cols = weights.shape().dim_size(weights_cols_idx); CHECK_EQ(weights_cols, scale_values.size()); // Multiply the original weights by the scale vector. - auto weights_matrix = weights.flat_inner_dims<float>(); + auto weights_vector = weights.flat<float>(); Tensor scaled_weights(DT_FLOAT, weights.shape()); - auto scaled_weights_matrix = scaled_weights.flat_inner_dims<float>(); - for (int64 row = 0; row < weights_matrix.dimension(0); ++row) { - for (int64 col = 0; col < weights_cols; ++col) { - scaled_weights_matrix(row, col) = - weights_matrix(row, col) * scale_values[col]; - } + auto scaled_weights_vector = scaled_weights.flat<float>(); + for (int64 row = 0; row < weights_vector.dimension(0); ++row) { + scaled_weights_vector(row) = + weights_vector(row) * scale_values[row % weights_cols]; } // Figure out the remaining bias to add on. Tensor bias_offset(DT_FLOAT, {weights_cols}); @@ -293,7 +292,7 @@ Status FoldOldBatchNorms(const GraphDef& input_graph_def, current_graph_def, // clang-format off {"BatchNormWithGlobalNormalization|FusedBatchNorm", // batch_norm_node { - {"Conv2D", // conv_node + {"Conv2D|DepthwiseConv2dNative", // conv_node { {"*"}, // input_node {"Const"}, // weights_node @@ -322,24 +321,24 @@ Status FoldOldBatchNorms(const GraphDef& input_graph_def, GraphDef replaced_graph_def; TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( current_graph_def, // clang-format off - {"BatchNormWithGlobalNormalization|FusedBatchNorm", // batch_norm_node + {"BatchNormWithGlobalNormalization|FusedBatchNorm", // batch_norm_node { - {"BatchToSpaceND", // batch_to_space_node + {"BatchToSpaceND", // batch_to_space_node { - {"Conv2D", // conv_node + {"Conv2D|DepthwiseConv2dNative", // conv_node { - {"*"}, // input_node - {"Const"}, // weights_node + {"*"}, // input_node + {"Const"}, // weights_node } }, - {"Const"}, // block_shape - {"Const"}, // crops + {"Const"}, // block_shape + {"Const"}, // crops } }, - {"Const"}, // mean_node - {"Const"}, // variance_node - {"Const"}, // beta_node - {"Const"}, // gamma_node + {"Const"}, // mean_node + {"Const"}, // variance_node + {"Const"}, // beta_node + {"Const"}, // gamma_node } }, // clang-format on [&did_graph_change](const NodeMatch& match, @@ -360,29 +359,29 @@ Status FoldOldBatchNorms(const GraphDef& input_graph_def, // Replace BatchNorm with concat as input. TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( current_graph_def, // clang-format off - {"BatchNormWithGlobalNormalization|FusedBatchNorm", // batch_norm_node + {"BatchNormWithGlobalNormalization|FusedBatchNorm", // batch_norm_node { - {"ConcatV2|Concat", // concat two conv2d. + {"ConcatV2|Concat", // concat two conv2d. { - {"Conv2D", // conv_node + {"Conv2D|DepthwiseConv2dNative", // conv_node { - {"*"}, // input_node - {"Const"}, // weights_node + {"*"}, // input_node + {"Const"}, // weights_node } }, - {"Conv2D", // conv_node + {"Conv2D|DepthwiseConv2dNative", // conv_node { - {"*"}, // input_node - {"Const"}, // weights_node + {"*"}, // input_node + {"Const"}, // weights_node } }, - {"Const"}, // axis + {"Const"}, // axis }, }, - {"Const"}, // mean_node - {"Const"}, // variance_node - {"Const"}, // beta_node - {"Const"}, // gamma_node + {"Const"}, // mean_node + {"Const"}, // variance_node + {"Const"}, // beta_node + {"Const"}, // gamma_node } }, // clang-format on [&did_graph_change](const NodeMatch& match, |