diff options
author | Suharsh Sivakumar <suharshs@google.com> | 2017-11-07 16:52:51 -0800 |
---|---|---|
committer | Andrew Selle <aselle@andyselle.com> | 2017-11-10 16:14:34 -0800 |
commit | 969eae0da7aed343b382d12d6e65dcf1d3bbcfad (patch) | |
tree | 1e16a252a4115ac3fd3b85206bcc57ffd8882870 | |
parent | 7fdcbc508c6d94a785111bc9468b221335345ce7 (diff) |
Fix FakeQuant to correctly set zero on CPU.
PiperOrigin-RevId: 174935134
-rw-r--r-- | tensorflow/core/kernels/fake_quant_ops_functor.h | 15 |
1 files changed, 9 insertions, 6 deletions
diff --git a/tensorflow/core/kernels/fake_quant_ops_functor.h b/tensorflow/core/kernels/fake_quant_ops_functor.h index b41b22d634..7aaad6e6c7 100644 --- a/tensorflow/core/kernels/fake_quant_ops_functor.h +++ b/tensorflow/core/kernels/fake_quant_ops_functor.h @@ -132,7 +132,7 @@ struct FakeQuantWithMinMaxVarsFunctor { const float max_val = max(); // If min and max are both zero, we should just return zero. if (min_val == 0.0f && max_val == 0.0f) { - outputs.setZero(); + outputs.device(d) = outputs.constant(0.0f); return; } float nudged_min, nudged_max, nudged_scale; @@ -163,8 +163,8 @@ struct FakeQuantWithMinMaxVarsGradientFunctor { // If min and max are both zero, we propagate everything to inputs. if (min_val == 0.0f && max_val == 0.0f) { backprops_wrt_input.device(d) = gradients; - backprop_wrt_min.setZero(); - backprop_wrt_max.setZero(); + backprop_wrt_min.device(d) = backprop_wrt_min.constant(0.0f); + backprop_wrt_max.device(d) = backprop_wrt_max.constant(0.0f); return; } float nudged_min, nudged_max, nudged_scale; @@ -205,7 +205,8 @@ struct FakeQuantWithMinMaxVarsPerChannelFunctor { const float max_val = max(i); // If min and max are both zero, we should just return zero. if (min_val == 0.0f && max_val == 0.0f) { - outputs.chip<1>(i).setZero(); + auto chip = outputs.chip<1>(i); + chip.device(d) = chip.constant(0.0f); continue; } float nudged_min, nudged_max, nudged_scale; @@ -242,8 +243,10 @@ struct FakeQuantWithMinMaxVarsPerChannelGradientFunctor { // If min and max are both zero, we propagate everything to inputs. if (min_val == 0.0f && max_val == 0.0f) { backprops_wrt_input.chip<1>(i).device(d) = gradients_chip; - backprop_wrt_min.chip<0>(i).setZero(); - backprop_wrt_max.chip<0>(i).setZero(); + auto min_chip = backprop_wrt_min.chip<0>(i); + auto max_chip = backprop_wrt_max.chip<0>(i); + min_chip.device(d) = min_chip.constant(0.0f); + max_chip.device(d) = max_chip.constant(0.0f); continue; } float nudged_min, nudged_max, nudged_scale; |