aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Suharsh Sivakumar <suharshs@google.com>2017-11-07 16:52:51 -0800
committerGravatar Andrew Selle <aselle@andyselle.com>2017-11-10 16:14:34 -0800
commit969eae0da7aed343b382d12d6e65dcf1d3bbcfad (patch)
tree1e16a252a4115ac3fd3b85206bcc57ffd8882870
parent7fdcbc508c6d94a785111bc9468b221335345ce7 (diff)
Fix FakeQuant to correctly set zero on CPU.
PiperOrigin-RevId: 174935134
-rw-r--r--tensorflow/core/kernels/fake_quant_ops_functor.h15
1 files changed, 9 insertions, 6 deletions
diff --git a/tensorflow/core/kernels/fake_quant_ops_functor.h b/tensorflow/core/kernels/fake_quant_ops_functor.h
index b41b22d634..7aaad6e6c7 100644
--- a/tensorflow/core/kernels/fake_quant_ops_functor.h
+++ b/tensorflow/core/kernels/fake_quant_ops_functor.h
@@ -132,7 +132,7 @@ struct FakeQuantWithMinMaxVarsFunctor {
const float max_val = max();
// If min and max are both zero, we should just return zero.
if (min_val == 0.0f && max_val == 0.0f) {
- outputs.setZero();
+ outputs.device(d) = outputs.constant(0.0f);
return;
}
float nudged_min, nudged_max, nudged_scale;
@@ -163,8 +163,8 @@ struct FakeQuantWithMinMaxVarsGradientFunctor {
// If min and max are both zero, we propagate everything to inputs.
if (min_val == 0.0f && max_val == 0.0f) {
backprops_wrt_input.device(d) = gradients;
- backprop_wrt_min.setZero();
- backprop_wrt_max.setZero();
+ backprop_wrt_min.device(d) = backprop_wrt_min.constant(0.0f);
+ backprop_wrt_max.device(d) = backprop_wrt_max.constant(0.0f);
return;
}
float nudged_min, nudged_max, nudged_scale;
@@ -205,7 +205,8 @@ struct FakeQuantWithMinMaxVarsPerChannelFunctor {
const float max_val = max(i);
// If min and max are both zero, we should just return zero.
if (min_val == 0.0f && max_val == 0.0f) {
- outputs.chip<1>(i).setZero();
+ auto chip = outputs.chip<1>(i);
+ chip.device(d) = chip.constant(0.0f);
continue;
}
float nudged_min, nudged_max, nudged_scale;
@@ -242,8 +243,10 @@ struct FakeQuantWithMinMaxVarsPerChannelGradientFunctor {
// If min and max are both zero, we propagate everything to inputs.
if (min_val == 0.0f && max_val == 0.0f) {
backprops_wrt_input.chip<1>(i).device(d) = gradients_chip;
- backprop_wrt_min.chip<0>(i).setZero();
- backprop_wrt_max.chip<0>(i).setZero();
+ auto min_chip = backprop_wrt_min.chip<0>(i);
+ auto max_chip = backprop_wrt_max.chip<0>(i);
+ min_chip.device(d) = min_chip.constant(0.0f);
+ max_chip.device(d) = max_chip.constant(0.0f);
continue;
}
float nudged_min, nudged_max, nudged_scale;