aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/batch_norm_op.h
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2016-05-26 15:14:00 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-05-26 16:18:36 -0700
commit36357e7e1127873165694a38e3a989df4e0b6ffe (patch)
treeeaf00810dfaefdc7f308a088829a3798af23e7bc /tensorflow/core/kernels/batch_norm_op.h
parente1b4934bb59904ee4dd243a34cc8356ff6bd266d (diff)
Added support for half floats to the batch normalization op
Change: 123368006
Diffstat (limited to 'tensorflow/core/kernels/batch_norm_op.h')
-rw-r--r--tensorflow/core/kernels/batch_norm_op.h4
1 files changed, 2 insertions, 2 deletions
diff --git a/tensorflow/core/kernels/batch_norm_op.h b/tensorflow/core/kernels/batch_norm_op.h
index baef68125e..94707e9be9 100644
--- a/tensorflow/core/kernels/batch_norm_op.h
+++ b/tensorflow/core/kernels/batch_norm_op.h
@@ -29,7 +29,7 @@ struct BatchNorm {
typename TTypes<T>::ConstVec mean,
typename TTypes<T>::ConstVec var,
typename TTypes<T>::ConstVec beta,
- typename TTypes<T>::ConstVec gamma, float variance_epsilon,
+ typename TTypes<T>::ConstVec gamma, T variance_epsilon,
bool scale_after_normalization,
typename TTypes<T, 4>::Tensor output) {
const int depth = mean.dimension(0);
@@ -77,7 +77,7 @@ struct BatchNormGrad {
typename TTypes<T>::ConstVec var,
typename TTypes<T>::ConstVec gamma,
typename TTypes<T, 4>::ConstTensor out_backprop,
- float variance_epsilon, bool scale_after_normalization,
+ T variance_epsilon, bool scale_after_normalization,
typename TTypes<T, 4>::Tensor dx, typename TTypes<T>::Vec dm,
typename TTypes<T>::Vec dv, typename TTypes<T>::Vec db,
typename TTypes<T>::Vec dg, typename TTypes<T>::Vec scratch1,