aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/quantize
diff options
context:
space:
mode:
authorGravatar Yifei Feng <yifeif@google.com>2018-07-02 17:07:06 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-02 17:10:57 -0700
commit73e38c29c74d9d9bf7128bf4737a410ff005611e (patch)
treef84c84429850d1b38cb4c0f0df24aadfefc7db8e /tensorflow/contrib/quantize
parenteacdfdf6c0353ac0578afbd962dbbafa6121c28f (diff)
Merge changes from github.
PiperOrigin-RevId: 203037623
Diffstat (limited to 'tensorflow/contrib/quantize')
-rw-r--r--tensorflow/contrib/quantize/python/fold_batch_norms.py14
-rw-r--r--tensorflow/contrib/quantize/python/fold_batch_norms_test.py6
2 files changed, 10 insertions, 10 deletions
diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py
index 804cd8d72d..e3c4899830 100644
--- a/tensorflow/contrib/quantize/python/fold_batch_norms.py
+++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py
@@ -506,7 +506,7 @@ def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
def _IsValidUnfusedBatchNorm(graph, context):
"""Checks that the output of the unfused batch norm has consumers."""
add_shift = graph.get_operation_by_name(
- context + '/BatchNorm/batchnorm/add_1')
+ context + '/BatchNorm/batchnorm_1/add_1')
# Ensure that the output tensor of batch norm has consumers, otherwise this
# is a dangling node and not a match.
return bool(add_shift.outputs[0].consumers())
@@ -599,7 +599,7 @@ def _GetBatchNormParams(graph, context, has_scaling):
op_suffix_mean = '/BatchNorm/moments/Squeeze'
op_suffix_variance = '/BatchNorm/moments/Squeeze_1'
- op_suffix_epsilon = '/BatchNorm/batchnorm/add/y'
+ op_suffix_epsilon = '/BatchNorm/batchnorm_1/add/y'
op_suffix_bn_decay_mean = '/BatchNorm/AssignMovingAvg/decay'
op_suffix_bn_decay_var = '/BatchNorm/AssignMovingAvg_1/decay'
@@ -675,12 +675,12 @@ def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay,
Returns:
A pair of Operations, the first is the original consumer node of the batch
- norm (../BatchNorm/batchnorm/add_1), the second is the consumer node of
+ norm (../BatchNorm/batchnorm_1/add_1), the second is the consumer node of
the folded graph (add_fold).
"""
mul_scale_name = 'mul_1' if has_scaling else 'mul'
mul_scale = graph.get_operation_by_name(context +
- '/BatchNorm/batchnorm/' +
+ '/BatchNorm/batchnorm_1/' +
mul_scale_name)
op_below = mul_scale.inputs[0].op
# Skip over the BatchToSpace operation in the case of atrous convolutions.
@@ -707,7 +707,7 @@ def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay,
]
scale_name = 'mul' if has_scaling else 'Rsqrt'
scale = graph.get_operation_by_name(
- context + '/BatchNorm/batchnorm/' + scale_name)
+ context + '/BatchNorm/batchnorm_1/' + scale_name)
scale = array_ops.reshape(scale.outputs[0], new_shape,
context + '/scale_reshape')
@@ -735,7 +735,7 @@ def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay,
[(1, mul_fold.outputs[0])])
add_shift = graph.get_operation_by_name(
- context + '/BatchNorm/batchnorm/add_1')
+ context + '/BatchNorm/batchnorm_1/add_1')
corrected_output = conv_or_fc_folded.outputs[0]
# Copy the batch to space operation if we have a atrous convolution.
@@ -930,7 +930,7 @@ def _HasScaling(graph, input_to_ops_map, bn):
Returns:
A boolean indicating whether this batch norm layer has scaling enabled.
"""
- rsqrt_op = graph.get_operation_by_name(bn + '/BatchNorm/batchnorm/Rsqrt')
+ rsqrt_op = graph.get_operation_by_name(bn + '/BatchNorm/batchnorm_1/Rsqrt')
rsqrt_consumers = input_to_ops_map.ConsumerOperations(rsqrt_op)
return sum(1 for op in rsqrt_consumers if op.type == 'Mul') == 1
diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py
index 272afcdf07..7c907ffd92 100644
--- a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py
+++ b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py
@@ -600,13 +600,13 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
if has_scaling:
if fused:
return scope + '/BatchNorm_Fold/mul'
- return scope + '/BatchNorm/batchnorm/mul'
- return scope + '/BatchNorm/batchnorm/Rsqrt'
+ return scope + '/BatchNorm/batchnorm_1/mul'
+ return scope + '/BatchNorm/batchnorm_1/Rsqrt'
def _BathNormBiasName(self, scope, fused):
if fused:
return scope + '/BatchNorm_Fold/bias'
- return scope + '/BatchNorm/batchnorm/sub'
+ return scope + '/BatchNorm/batchnorm_1/sub'
def _WeightInit(self, stddev):
"""Returns a truncated normal variable initializer.