diff options
Diffstat (limited to 'tensorflow/python/keras/backend_test.py')
-rw-r--r-- | tensorflow/python/keras/backend_test.py | 44 |
1 files changed, 38 insertions, 6 deletions
diff --git a/tensorflow/python/keras/backend_test.py b/tensorflow/python/keras/backend_test.py index ab71589940..0834448699 100644 --- a/tensorflow/python/keras/backend_test.py +++ b/tensorflow/python/keras/backend_test.py @@ -26,6 +26,7 @@ from tensorflow.python import keras from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import nn from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.util import tf_inspect @@ -1381,6 +1382,36 @@ class BackendNNOpsTest(test.TestCase, parameterized.TestCase): self.assertEqual(mean.get_shape().as_list(), [3,]) self.assertEqual(var.get_shape().as_list(), [3,]) + def test_batch_normalization(self): + g_val = np.random.random((3,)) + b_val = np.random.random((3,)) + gamma = keras.backend.variable(g_val) + beta = keras.backend.variable(b_val) + + # 3D NHC case + val = np.random.random((10, 5, 3)) + x = keras.backend.variable(val) + mean, var = nn.moments(x, (0, 1), None, None, False) + normed = keras.backend.batch_normalization( + x, mean, var, beta, gamma, axis=-1, epsilon=1e-3) + self.assertEqual(normed.shape.as_list(), [10, 5, 3]) + + # 4D NHWC case + val = np.random.random((10, 5, 5, 3)) + x = keras.backend.variable(val) + mean, var = nn.moments(x, (0, 1, 2), None, None, False) + normed = keras.backend.batch_normalization( + x, mean, var, beta, gamma, axis=-1, epsilon=1e-3) + self.assertEqual(normed.shape.as_list(), [10, 5, 5, 3]) + + # 4D NCHW case + val = np.random.random((10, 3, 5, 5)) + x = keras.backend.variable(val) + mean, var = nn.moments(x, (0, 2, 3), None, None, False) + normed = keras.backend.batch_normalization( + x, mean, var, beta, gamma, axis=1, epsilon=1e-3) + self.assertEqual(normed.shape.as_list(), [10, 3, 5, 5]) + class TestCTC(test.TestCase): @@ -1506,12 +1537,13 @@ class TestRandomOps(test.TestCase): self.assertAllClose(np.min(y), -2., atol=0.1) def test_string_input(self): - seq = keras.Sequential([ - keras.layers.InputLayer(input_shape=(1,), dtype=dtypes.string), - keras.layers.Lambda(lambda x: x[0]) - ]) - preds = seq.predict([['tensorflow eager']]) - self.assertEqual(preds.shape, (1,)) + with self.cached_session(): + seq = keras.Sequential([ + keras.layers.InputLayer(input_shape=(1,), dtype=dtypes.string), + keras.layers.Lambda(lambda x: x[0]) + ]) + preds = seq.predict([['tensorflow eager']]) + self.assertEqual(preds.shape, (1,)) if __name__ == '__main__': test.main() |