diff options
Diffstat (limited to 'tensorflow/core/kernels/fake_quant_ops_functor.h')
-rw-r--r-- | tensorflow/core/kernels/fake_quant_ops_functor.h | 15 |
1 files changed, 6 insertions, 9 deletions
diff --git a/tensorflow/core/kernels/fake_quant_ops_functor.h b/tensorflow/core/kernels/fake_quant_ops_functor.h index 7aaad6e6c7..b41b22d634 100644 --- a/tensorflow/core/kernels/fake_quant_ops_functor.h +++ b/tensorflow/core/kernels/fake_quant_ops_functor.h @@ -132,7 +132,7 @@ struct FakeQuantWithMinMaxVarsFunctor { const float max_val = max(); // If min and max are both zero, we should just return zero. if (min_val == 0.0f && max_val == 0.0f) { - outputs.device(d) = outputs.constant(0.0f); + outputs.setZero(); return; } float nudged_min, nudged_max, nudged_scale; @@ -163,8 +163,8 @@ struct FakeQuantWithMinMaxVarsGradientFunctor { // If min and max are both zero, we propagate everything to inputs. if (min_val == 0.0f && max_val == 0.0f) { backprops_wrt_input.device(d) = gradients; - backprop_wrt_min.device(d) = backprop_wrt_min.constant(0.0f); - backprop_wrt_max.device(d) = backprop_wrt_max.constant(0.0f); + backprop_wrt_min.setZero(); + backprop_wrt_max.setZero(); return; } float nudged_min, nudged_max, nudged_scale; @@ -205,8 +205,7 @@ struct FakeQuantWithMinMaxVarsPerChannelFunctor { const float max_val = max(i); // If min and max are both zero, we should just return zero. if (min_val == 0.0f && max_val == 0.0f) { - auto chip = outputs.chip<1>(i); - chip.device(d) = chip.constant(0.0f); + outputs.chip<1>(i).setZero(); continue; } float nudged_min, nudged_max, nudged_scale; @@ -243,10 +242,8 @@ struct FakeQuantWithMinMaxVarsPerChannelGradientFunctor { // If min and max are both zero, we propagate everything to inputs. if (min_val == 0.0f && max_val == 0.0f) { backprops_wrt_input.chip<1>(i).device(d) = gradients_chip; - auto min_chip = backprop_wrt_min.chip<0>(i); - auto max_chip = backprop_wrt_max.chip<0>(i); - min_chip.device(d) = min_chip.constant(0.0f); - max_chip.device(d) = max_chip.constant(0.0f); + backprop_wrt_min.chip<0>(i).setZero(); + backprop_wrt_max.chip<0>(i).setZero(); continue; } float nudged_min, nudged_max, nudged_scale; |