aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-01 17:05:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-01 17:09:56 -0700
commit80f8931682aeaae89786f0940892a6557b4cfd67 (patch)
tree63a716d5d72b3d423f3a0a286c3e2744e9cc1b27 /tensorflow/python/training
parentb72265dc002e712fc3d0f33434f13c7a36a484b2 (diff)
Mark bfloat16 as supported for ExponentialMovingAverage.
PiperOrigin-RevId: 215307701
Diffstat (limited to 'tensorflow/python/training')
-rw-r--r--tensorflow/python/training/moving_averages.py9
-rw-r--r--tensorflow/python/training/moving_averages_test.py27
2 files changed, 32 insertions, 4 deletions
diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py
index 177a7ddfa5..041266da3e 100644
--- a/tensorflow/python/training/moving_averages.py
+++ b/tensorflow/python/training/moving_averages.py
@@ -372,13 +372,13 @@ class ExponentialMovingAverage(object):
Args:
var_list: A list of Variable or Tensor objects. The variables
- and Tensors must be of types float16, float32, or float64.
+ and Tensors must be of types bfloat16, float16, float32, or float64.
Returns:
An Operation that updates the moving averages.
Raises:
- TypeError: If the arguments are not all float16, float32, or float64.
+ TypeError: If the arguments are not an allowed type.
ValueError: If the moving average of one of the variables is already
being computed.
"""
@@ -387,8 +387,9 @@ class ExponentialMovingAverage(object):
var_list = variables.trainable_variables()
zero_debias_true = set() # set of vars to set `zero_debias=True`
for var in var_list:
- if var.dtype.base_dtype not in [dtypes.float16, dtypes.float32,
- dtypes.float64]:
+ if var.dtype.base_dtype not in [
+ dtypes.bfloat16, dtypes.float16, dtypes.float32, dtypes.float64
+ ]:
raise TypeError("The variables must be half, float, or double: %s" %
var.name)
diff --git a/tensorflow/python/training/moving_averages_test.py b/tensorflow/python/training/moving_averages_test.py
index 93991d0e14..bb2fca66e3 100644
--- a/tensorflow/python/training/moving_averages_test.py
+++ b/tensorflow/python/training/moving_averages_test.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -110,6 +111,32 @@ class MovingAveragesTest(test.TestCase):
denominator_2 = denominator_1 * decay + weight_2 * (1.0 - decay)
self.assertAllClose(numerator_2 / denominator_2, wma_array)
+ def testWeightedMovingAverageBfloat16(self):
+ bfloat16 = pywrap_tensorflow.TF_bfloat16_type()
+ with self.cached_session() as sess:
+ decay = 0.5
+ weight = array_ops.placeholder(dtypes.bfloat16, [])
+ val = array_ops.placeholder(dtypes.bfloat16, [])
+
+ wma = moving_averages.weighted_moving_average(val, decay, weight)
+ variables.global_variables_initializer().run()
+
+ # Get the first weighted moving average.
+ val_1 = 3.0
+ weight_1 = 4.0
+ wma_array = sess.run(wma, feed_dict={val: val_1, weight: weight_1})
+ numerator_1 = val_1 * weight_1 * (1.0 - decay)
+ denominator_1 = weight_1 * (1.0 - decay)
+ self.assertAllClose(numerator_1 / denominator_1, wma_array)
+
+ # Get the second weighted moving average.
+ val_2 = 11.0
+ weight_2 = 22.0
+ wma_array = sess.run(wma, feed_dict={val: val_2, weight: weight_2})
+ numerator_2 = numerator_1 * decay + val_2 * weight_2 * (1.0 - decay)
+ denominator_2 = denominator_1 * decay + weight_2 * (1.0 - decay)
+ self.assertAllClose(bfloat16(numerator_2 / denominator_2), wma_array)
+
def _Repeat(value, dim):
if dim == 1: