aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/quantize
diff options
context:
space:
mode:
authorGravatar Suharsh Sivakumar <suharshs@google.com>2018-05-03 00:16:09 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-03 00:18:26 -0700
commit283e8fe7e191f8e0e2ca6ea62b8b4553c30a6286 (patch)
tree4a871b5ec92f529a63c608793568438755a1345d /tensorflow/contrib/quantize
parent985351dc1ab33cedbfd7790dd9cccc36d2d4b150 (diff)
Use tensorflow size to determine number of elements instead of the static shape, which can sometimes be missing.
PiperOrigin-RevId: 195209826
Diffstat (limited to 'tensorflow/contrib/quantize')
-rw-r--r--tensorflow/contrib/quantize/python/fold_batch_norms.py3
1 files changed, 2 insertions, 1 deletions
diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py
index 1f286bc39a..76f695dce0 100644
--- a/tensorflow/contrib/quantize/python/fold_batch_norms.py
+++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py
@@ -414,7 +414,8 @@ def _CloneWithNewOperands(layer_op, input_tensor, weight_tensor):
def _FoldFusedBatchNormGrad(op, unused_grad_y, grad_mean, grad_var, unused_1,
unused_2):
x = op.inputs[0]
- n = x.get_shape().num_elements() / grad_mean.get_shape().num_elements()
+ n = math_ops.cast(
+ array_ops.size(x) / array_ops.size(grad_mean), dtypes.float32)
dmean_dx = grad_mean / n
dvar_dx = 2 * grad_var * (x - op.outputs[1]) / (n - 1)
return (dmean_dx + dvar_dx), None, None, None, None