aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/backend_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/backend_test.py')
-rw-r--r--tensorflow/python/keras/backend_test.py44
1 files changed, 38 insertions, 6 deletions
diff --git a/tensorflow/python/keras/backend_test.py b/tensorflow/python/keras/backend_test.py
index ab71589940..0834448699 100644
--- a/tensorflow/python/keras/backend_test.py
+++ b/tensorflow/python/keras/backend_test.py
@@ -26,6 +26,7 @@ from tensorflow.python import keras
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import nn
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.util import tf_inspect
@@ -1381,6 +1382,36 @@ class BackendNNOpsTest(test.TestCase, parameterized.TestCase):
self.assertEqual(mean.get_shape().as_list(), [3,])
self.assertEqual(var.get_shape().as_list(), [3,])
+ def test_batch_normalization(self):
+ g_val = np.random.random((3,))
+ b_val = np.random.random((3,))
+ gamma = keras.backend.variable(g_val)
+ beta = keras.backend.variable(b_val)
+
+ # 3D NHC case
+ val = np.random.random((10, 5, 3))
+ x = keras.backend.variable(val)
+ mean, var = nn.moments(x, (0, 1), None, None, False)
+ normed = keras.backend.batch_normalization(
+ x, mean, var, beta, gamma, axis=-1, epsilon=1e-3)
+ self.assertEqual(normed.shape.as_list(), [10, 5, 3])
+
+ # 4D NHWC case
+ val = np.random.random((10, 5, 5, 3))
+ x = keras.backend.variable(val)
+ mean, var = nn.moments(x, (0, 1, 2), None, None, False)
+ normed = keras.backend.batch_normalization(
+ x, mean, var, beta, gamma, axis=-1, epsilon=1e-3)
+ self.assertEqual(normed.shape.as_list(), [10, 5, 5, 3])
+
+ # 4D NCHW case
+ val = np.random.random((10, 3, 5, 5))
+ x = keras.backend.variable(val)
+ mean, var = nn.moments(x, (0, 2, 3), None, None, False)
+ normed = keras.backend.batch_normalization(
+ x, mean, var, beta, gamma, axis=1, epsilon=1e-3)
+ self.assertEqual(normed.shape.as_list(), [10, 3, 5, 5])
+
class TestCTC(test.TestCase):
@@ -1506,12 +1537,13 @@ class TestRandomOps(test.TestCase):
self.assertAllClose(np.min(y), -2., atol=0.1)
def test_string_input(self):
- seq = keras.Sequential([
- keras.layers.InputLayer(input_shape=(1,), dtype=dtypes.string),
- keras.layers.Lambda(lambda x: x[0])
- ])
- preds = seq.predict([['tensorflow eager']])
- self.assertEqual(preds.shape, (1,))
+ with self.cached_session():
+ seq = keras.Sequential([
+ keras.layers.InputLayer(input_shape=(1,), dtype=dtypes.string),
+ keras.layers.Lambda(lambda x: x[0])
+ ])
+ preds = seq.predict([['tensorflow eager']])
+ self.assertEqual(preds.shape, (1,))
if __name__ == '__main__':
test.main()