diff options
Diffstat (limited to 'tensorflow/python/keras/layers/normalization_test.py')
-rw-r--r-- | tensorflow/python/keras/layers/normalization_test.py | 18 |
1 files changed, 18 insertions, 0 deletions
diff --git a/tensorflow/python/keras/layers/normalization_test.py b/tensorflow/python/keras/layers/normalization_test.py index b22f3bd152..a97b4cac46 100644 --- a/tensorflow/python/keras/layers/normalization_test.py +++ b/tensorflow/python/keras/layers/normalization_test.py @@ -95,6 +95,24 @@ class NormalizationLayersTest(test.TestCase): np.testing.assert_allclose(out.mean(), 0.0, atol=1e-1) np.testing.assert_allclose(out.std(), 1.0, atol=1e-1) + def test_batchnorm_mixed_precision(self): + with self.test_session(): + model = keras.models.Sequential() + norm = keras.layers.BatchNormalization(input_shape=(10,), momentum=0.8) + model.add(norm) + model.compile(loss='mse', optimizer='sgd') + + # centered on 5.0, variance 10.0 + x = np.random.normal( + loc=5.0, scale=10.0, size=(1000, 10)).astype(np.float16) + model.fit(x, x, epochs=4, verbose=0) + out = model.predict(x) + out -= keras.backend.eval(norm.beta) + out /= keras.backend.eval(norm.gamma) + + np.testing.assert_allclose(out.mean(), 0.0, atol=1e-1) + np.testing.assert_allclose(out.std(), 1.0, atol=1e-1) + def test_batchnorm_convnet(self): if test.is_gpu_available(cuda_only=True): with self.test_session(use_gpu=True): |