diff options
author | Yifei Feng <yifeif@google.com> | 2018-07-02 17:07:06 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-02 17:10:57 -0700 |
commit | 73e38c29c74d9d9bf7128bf4737a410ff005611e (patch) | |
tree | f84c84429850d1b38cb4c0f0df24aadfefc7db8e /tensorflow/contrib/quantize | |
parent | eacdfdf6c0353ac0578afbd962dbbafa6121c28f (diff) |
Merge changes from github.
PiperOrigin-RevId: 203037623
Diffstat (limited to 'tensorflow/contrib/quantize')
-rw-r--r-- | tensorflow/contrib/quantize/python/fold_batch_norms.py | 14 | ||||
-rw-r--r-- | tensorflow/contrib/quantize/python/fold_batch_norms_test.py | 6 |
2 files changed, 10 insertions, 10 deletions
diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py index 804cd8d72d..e3c4899830 100644 --- a/tensorflow/contrib/quantize/python/fold_batch_norms.py +++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py @@ -506,7 +506,7 @@ def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay): def _IsValidUnfusedBatchNorm(graph, context): """Checks that the output of the unfused batch norm has consumers.""" add_shift = graph.get_operation_by_name( - context + '/BatchNorm/batchnorm/add_1') + context + '/BatchNorm/batchnorm_1/add_1') # Ensure that the output tensor of batch norm has consumers, otherwise this # is a dangling node and not a match. return bool(add_shift.outputs[0].consumers()) @@ -599,7 +599,7 @@ def _GetBatchNormParams(graph, context, has_scaling): op_suffix_mean = '/BatchNorm/moments/Squeeze' op_suffix_variance = '/BatchNorm/moments/Squeeze_1' - op_suffix_epsilon = '/BatchNorm/batchnorm/add/y' + op_suffix_epsilon = '/BatchNorm/batchnorm_1/add/y' op_suffix_bn_decay_mean = '/BatchNorm/AssignMovingAvg/decay' op_suffix_bn_decay_var = '/BatchNorm/AssignMovingAvg_1/decay' @@ -675,12 +675,12 @@ def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay, Returns: A pair of Operations, the first is the original consumer node of the batch - norm (../BatchNorm/batchnorm/add_1), the second is the consumer node of + norm (../BatchNorm/batchnorm_1/add_1), the second is the consumer node of the folded graph (add_fold). """ mul_scale_name = 'mul_1' if has_scaling else 'mul' mul_scale = graph.get_operation_by_name(context + - '/BatchNorm/batchnorm/' + + '/BatchNorm/batchnorm_1/' + mul_scale_name) op_below = mul_scale.inputs[0].op # Skip over the BatchToSpace operation in the case of atrous convolutions. @@ -707,7 +707,7 @@ def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay, ] scale_name = 'mul' if has_scaling else 'Rsqrt' scale = graph.get_operation_by_name( - context + '/BatchNorm/batchnorm/' + scale_name) + context + '/BatchNorm/batchnorm_1/' + scale_name) scale = array_ops.reshape(scale.outputs[0], new_shape, context + '/scale_reshape') @@ -735,7 +735,7 @@ def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay, [(1, mul_fold.outputs[0])]) add_shift = graph.get_operation_by_name( - context + '/BatchNorm/batchnorm/add_1') + context + '/BatchNorm/batchnorm_1/add_1') corrected_output = conv_or_fc_folded.outputs[0] # Copy the batch to space operation if we have a atrous convolution. @@ -930,7 +930,7 @@ def _HasScaling(graph, input_to_ops_map, bn): Returns: A boolean indicating whether this batch norm layer has scaling enabled. """ - rsqrt_op = graph.get_operation_by_name(bn + '/BatchNorm/batchnorm/Rsqrt') + rsqrt_op = graph.get_operation_by_name(bn + '/BatchNorm/batchnorm_1/Rsqrt') rsqrt_consumers = input_to_ops_map.ConsumerOperations(rsqrt_op) return sum(1 for op in rsqrt_consumers if op.type == 'Mul') == 1 diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py index 272afcdf07..7c907ffd92 100644 --- a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py +++ b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py @@ -600,13 +600,13 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): if has_scaling: if fused: return scope + '/BatchNorm_Fold/mul' - return scope + '/BatchNorm/batchnorm/mul' - return scope + '/BatchNorm/batchnorm/Rsqrt' + return scope + '/BatchNorm/batchnorm_1/mul' + return scope + '/BatchNorm/batchnorm_1/Rsqrt' def _BathNormBiasName(self, scope, fused): if fused: return scope + '/BatchNorm_Fold/bias' - return scope + '/BatchNorm/batchnorm/sub' + return scope + '/BatchNorm/batchnorm_1/sub' def _WeightInit(self, stddev): """Returns a truncated normal variable initializer. |