aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/layers/python/layers/layers_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/layers/python/layers/layers_test.py')
-rw-r--r--tensorflow/contrib/layers/python/layers/layers_test.py73
1 files changed, 68 insertions, 5 deletions
diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py
index 2837a3172d..7ccd9d8868 100644
--- a/tensorflow/contrib/layers/python/layers/layers_test.py
+++ b/tensorflow/contrib/layers/python/layers/layers_test.py
@@ -1766,10 +1766,12 @@ class BatchNormTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, 'undefined'):
_layers.batch_norm(inputs, data_format='NCHW')
- def _testCreateOp(self, fused):
+ def _testCreateOp(self, fused, dtype=None):
+ if dtype is None:
+ dtype = dtypes.float32
height, width = 3, 3
with self.test_session():
- images = np.random.uniform(size=(5, height, width, 3)).astype('f')
+ images = np.random.uniform(size=(5, height, width, 3)).astype(dtype.as_numpy_dtype)
output = _layers.batch_norm(images, fused=fused)
expected_name = ('BatchNorm/FusedBatchNorm' if fused else
'BatchNorm/batchnorm')
@@ -1784,6 +1786,9 @@ class BatchNormTest(test.TestCase):
def testCreateOpFused(self):
self._testCreateOp(True)
+ def testCreateOpFusedFloat16(self):
+ self._testCreateOp(True, dtypes.float16)
+
def _testCreateOpBetaRegularizer(self, fused=True):
height, width = 3, 3
with self.test_session():
@@ -2651,10 +2656,68 @@ class BatchNormTest(test.TestCase):
def testBatchNormBeta(self):
# Test case for 11673
with self.test_session() as sess:
- a = array_ops.placeholder(dtypes.float32, shape=(10, 10, 10, 10))
- b = _layers.batch_norm(a, center=False, data_format='NCHW',
- zero_debias_moving_mean=True)
+ a_32 = array_ops.placeholder(dtypes.float32, shape=(10, 10, 10, 10))
+ b_32 = _layers.batch_norm(a_32, center=False, data_format='NCHW',
+ zero_debias_moving_mean=True)
+ a_16 = array_ops.placeholder(dtypes.float16, shape=(10, 10, 10, 10))
+ b_16 = _layers.batch_norm(a_16, center=False, data_format='NCHW',
+ zero_debias_moving_mean=True)
+ sess.run(variables_lib.global_variables_initializer())
+
+ def testVariablesAreFloat32(self):
+ height, width = 3, 3
+ with self.test_session():
+ images = random_ops.random_uniform((5, height, width, 3),
+ seed=1, dtype=dtypes.float16)
+ _layers.batch_norm(images, scale=True)
+ beta = variables.get_variables_by_name('beta')[0]
+ gamma = variables.get_variables_by_name('gamma')[0]
+ self.assertEqual(beta.dtype, dtypes.float32_ref)
+ self.assertEqual(gamma.dtype, dtypes.float32_ref)
+ moving_mean = variables.get_variables_by_name('moving_mean')[0]
+ moving_variance = variables.get_variables_by_name('moving_variance')[0]
+ self.assertEqual(moving_mean.dtype, dtypes.float32_ref)
+ self.assertEqual(moving_variance.dtype, dtypes.float32_ref)
+
+ def _runFusedBatchNorm(self, shape, dtype):
+ channels = shape[1]
+ images = np.arange(np.product(shape), dtype=dtype).reshape(shape)
+ beta = init_ops.constant_initializer(
+ np.arange(
+ 2, channels + 2, dtype=np.float32))
+ gamma = init_ops.constant_initializer(
+ np.arange(
+ 10, channels + 10, dtype=np.float32) * 2.0)
+ mean = init_ops.constant_initializer(
+ np.arange(
+ 3, channels + 3, dtype=np.float32) * 5.0)
+ variance = init_ops.constant_initializer(
+ np.arange(
+ 1, channels + 1, dtype=np.float32) * 4.0)
+ output = _layers.batch_norm(
+ images,
+ fused=True,
+ is_training=True,
+ scale=True,
+ epsilon=0.5,
+ param_initializers={
+ 'beta': beta,
+ 'gamma': gamma,
+ 'moving_mean': mean,
+ 'moving_variance': variance,
+ },
+ data_format='NCHW')
+ with self.test_session(use_gpu=True) as sess:
sess.run(variables_lib.global_variables_initializer())
+ return sess.run(output)
+
+ def testFusedBatchNormFloat16MatchesFloat32(self):
+ if test.is_gpu_available(cuda_only=True):
+ shape = [5, 4, 2, 3]
+ res_32 = self._runFusedBatchNorm(shape, np.float32)
+ res_16 = self._runFusedBatchNorm(shape, np.float16)
+ self.assertAllClose(res_32, res_16, rtol=1e-3)
+
def testAdjustmentCreated(self):
# Tests that the adjustment is appropriately passed to and used by the core