aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/layers/normalization_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/layers/normalization_test.py')
-rw-r--r--tensorflow/python/keras/layers/normalization_test.py18
1 files changed, 18 insertions, 0 deletions
diff --git a/tensorflow/python/keras/layers/normalization_test.py b/tensorflow/python/keras/layers/normalization_test.py
index b22f3bd152..a97b4cac46 100644
--- a/tensorflow/python/keras/layers/normalization_test.py
+++ b/tensorflow/python/keras/layers/normalization_test.py
@@ -95,6 +95,24 @@ class NormalizationLayersTest(test.TestCase):
np.testing.assert_allclose(out.mean(), 0.0, atol=1e-1)
np.testing.assert_allclose(out.std(), 1.0, atol=1e-1)
+ def test_batchnorm_mixed_precision(self):
+ with self.test_session():
+ model = keras.models.Sequential()
+ norm = keras.layers.BatchNormalization(input_shape=(10,), momentum=0.8)
+ model.add(norm)
+ model.compile(loss='mse', optimizer='sgd')
+
+ # centered on 5.0, variance 10.0
+ x = np.random.normal(
+ loc=5.0, scale=10.0, size=(1000, 10)).astype(np.float16)
+ model.fit(x, x, epochs=4, verbose=0)
+ out = model.predict(x)
+ out -= keras.backend.eval(norm.beta)
+ out /= keras.backend.eval(norm.gamma)
+
+ np.testing.assert_allclose(out.mean(), 0.0, atol=1e-1)
+ np.testing.assert_allclose(out.std(), 1.0, atol=1e-1)
+
def test_batchnorm_convnet(self):
if test.is_gpu_available(cuda_only=True):
with self.test_session(use_gpu=True):