aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/graph_transforms
diff options
context:
space:
mode:
authorGravatar zhaoyongke <zhaoyongke@yeah.net>2018-08-03 10:58:06 +0800
committerGravatar zhaoyongke <zhaoyongke@yeah.net>2018-08-03 10:58:06 +0800
commit049c529e567a557461d9ffe8614e6c06a7b3a39f (patch)
tree483c63a792a38d388b04128acac90475cf1ce7cb /tensorflow/tools/graph_transforms
parent7a43aad35a4f806cec9715fa394c48dae3abd42a (diff)
Fold bn with depthwise conv, minor typo
Diffstat (limited to 'tensorflow/tools/graph_transforms')
-rw-r--r--tensorflow/tools/graph_transforms/fold_batch_norms.cc3
1 files changed, 2 insertions, 1 deletions
diff --git a/tensorflow/tools/graph_transforms/fold_batch_norms.cc b/tensorflow/tools/graph_transforms/fold_batch_norms.cc
index 42eebd98c9..cb4230dd82 100644
--- a/tensorflow/tools/graph_transforms/fold_batch_norms.cc
+++ b/tensorflow/tools/graph_transforms/fold_batch_norms.cc
@@ -73,7 +73,8 @@ Status FoldBatchNorms(const GraphDef& input_graph_def,
// Make sure all the inputs really are vectors, with as many entries as
// there are columns in the weights.
- const int weights_cols_index = conv_node.op() == "Conv2D" ? 3 : (conv_node.op() == "DepthwiseConv2dNative" ? 2 : 1);
+ const int weights_cols_index = conv_node.op() == "Conv2D" ? 3 : \
+ (conv_node.op() == "DepthwiseConv2dNative" ? 2 : 1);
const int64 weights_cols = weights.shape().dim_size(weights_cols_index);
if ((mul_values.shape().dims() != 1) ||
(mul_values.shape().dim_size(0) != weights_cols)) {