diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-10-01 17:05:45 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-01 17:09:56 -0700 |
commit | 80f8931682aeaae89786f0940892a6557b4cfd67 (patch) | |
tree | 63a716d5d72b3d423f3a0a286c3e2744e9cc1b27 /tensorflow/python/training | |
parent | b72265dc002e712fc3d0f33434f13c7a36a484b2 (diff) |
Mark bfloat16 as supported for ExponentialMovingAverage.
PiperOrigin-RevId: 215307701
Diffstat (limited to 'tensorflow/python/training')
-rw-r--r-- | tensorflow/python/training/moving_averages.py | 9 | ||||
-rw-r--r-- | tensorflow/python/training/moving_averages_test.py | 27 |
2 files changed, 32 insertions, 4 deletions
diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py index 177a7ddfa5..041266da3e 100644 --- a/tensorflow/python/training/moving_averages.py +++ b/tensorflow/python/training/moving_averages.py @@ -372,13 +372,13 @@ class ExponentialMovingAverage(object): Args: var_list: A list of Variable or Tensor objects. The variables - and Tensors must be of types float16, float32, or float64. + and Tensors must be of types bfloat16, float16, float32, or float64. Returns: An Operation that updates the moving averages. Raises: - TypeError: If the arguments are not all float16, float32, or float64. + TypeError: If the arguments are not an allowed type. ValueError: If the moving average of one of the variables is already being computed. """ @@ -387,8 +387,9 @@ class ExponentialMovingAverage(object): var_list = variables.trainable_variables() zero_debias_true = set() # set of vars to set `zero_debias=True` for var in var_list: - if var.dtype.base_dtype not in [dtypes.float16, dtypes.float32, - dtypes.float64]: + if var.dtype.base_dtype not in [ + dtypes.bfloat16, dtypes.float16, dtypes.float32, dtypes.float64 + ]: raise TypeError("The variables must be half, float, or double: %s" % var.name) diff --git a/tensorflow/python/training/moving_averages_test.py b/tensorflow/python/training/moving_averages_test.py index 93991d0e14..bb2fca66e3 100644 --- a/tensorflow/python/training/moving_averages_test.py +++ b/tensorflow/python/training/moving_averages_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -110,6 +111,32 @@ class MovingAveragesTest(test.TestCase): denominator_2 = denominator_1 * decay + weight_2 * (1.0 - decay) self.assertAllClose(numerator_2 / denominator_2, wma_array) + def testWeightedMovingAverageBfloat16(self): + bfloat16 = pywrap_tensorflow.TF_bfloat16_type() + with self.cached_session() as sess: + decay = 0.5 + weight = array_ops.placeholder(dtypes.bfloat16, []) + val = array_ops.placeholder(dtypes.bfloat16, []) + + wma = moving_averages.weighted_moving_average(val, decay, weight) + variables.global_variables_initializer().run() + + # Get the first weighted moving average. + val_1 = 3.0 + weight_1 = 4.0 + wma_array = sess.run(wma, feed_dict={val: val_1, weight: weight_1}) + numerator_1 = val_1 * weight_1 * (1.0 - decay) + denominator_1 = weight_1 * (1.0 - decay) + self.assertAllClose(numerator_1 / denominator_1, wma_array) + + # Get the second weighted moving average. + val_2 = 11.0 + weight_2 = 22.0 + wma_array = sess.run(wma, feed_dict={val: val_2, weight: weight_2}) + numerator_2 = numerator_1 * decay + val_2 * weight_2 * (1.0 - decay) + denominator_2 = denominator_1 * decay + weight_2 * (1.0 - decay) + self.assertAllClose(bfloat16(numerator_2 / denominator_2), wma_array) + def _Repeat(value, dim): if dim == 1: |