aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/layers/python/layers/layers_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/layers/python/layers/layers_test.py')
-rw-r--r--tensorflow/contrib/layers/python/layers/layers_test.py73
1 files changed, 5 insertions, 68 deletions
diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py
index 5aa2253516..ff7f0e4462 100644
--- a/tensorflow/contrib/layers/python/layers/layers_test.py
+++ b/tensorflow/contrib/layers/python/layers/layers_test.py
@@ -1774,12 +1774,10 @@ class BatchNormTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, 'undefined'):
_layers.batch_norm(inputs, data_format='NCHW')
- def _testCreateOp(self, fused, dtype=None):
- if dtype is None:
- dtype = dtypes.float32
+ def _testCreateOp(self, fused):
height, width = 3, 3
with self.test_session():
- images = np.random.uniform(size=(5, height, width, 3)).astype(dtype.as_numpy_dtype)
+ images = np.random.uniform(size=(5, height, width, 3)).astype('f')
output = _layers.batch_norm(images, fused=fused)
expected_name = ('BatchNorm/FusedBatchNorm' if fused else
'BatchNorm/batchnorm')
@@ -1794,9 +1792,6 @@ class BatchNormTest(test.TestCase):
def testCreateOpFused(self):
self._testCreateOp(True)
- def testCreateOpFusedFloat16(self):
- self._testCreateOp(True, dtypes.float16)
-
def _testCreateOpBetaRegularizer(self, fused=True):
height, width = 3, 3
with self.test_session():
@@ -2664,68 +2659,10 @@ class BatchNormTest(test.TestCase):
def testBatchNormBeta(self):
# Test case for 11673
with self.test_session() as sess:
- a_32 = array_ops.placeholder(dtypes.float32, shape=(10, 10, 10, 10))
- b_32 = _layers.batch_norm(a_32, center=False, data_format='NCHW',
- zero_debias_moving_mean=True)
- a_16 = array_ops.placeholder(dtypes.float16, shape=(10, 10, 10, 10))
- b_16 = _layers.batch_norm(a_16, center=False, data_format='NCHW',
- zero_debias_moving_mean=True)
- sess.run(variables_lib.global_variables_initializer())
-
- def testVariablesAreFloat32(self):
- height, width = 3, 3
- with self.test_session():
- images = random_ops.random_uniform((5, height, width, 3),
- seed=1, dtype=dtypes.float16)
- _layers.batch_norm(images, scale=True)
- beta = variables.get_variables_by_name('beta')[0]
- gamma = variables.get_variables_by_name('gamma')[0]
- self.assertEqual(beta.dtype, dtypes.float32_ref)
- self.assertEqual(gamma.dtype, dtypes.float32_ref)
- moving_mean = variables.get_variables_by_name('moving_mean')[0]
- moving_variance = variables.get_variables_by_name('moving_variance')[0]
- self.assertEqual(moving_mean.dtype, dtypes.float32_ref)
- self.assertEqual(moving_variance.dtype, dtypes.float32_ref)
-
- def _runFusedBatchNorm(self, shape, dtype):
- channels = shape[1]
- images = np.arange(np.product(shape), dtype=dtype).reshape(shape)
- beta = init_ops.constant_initializer(
- np.arange(
- 2, channels + 2, dtype=np.float32))
- gamma = init_ops.constant_initializer(
- np.arange(
- 10, channels + 10, dtype=np.float32) * 2.0)
- mean = init_ops.constant_initializer(
- np.arange(
- 3, channels + 3, dtype=np.float32) * 5.0)
- variance = init_ops.constant_initializer(
- np.arange(
- 1, channels + 1, dtype=np.float32) * 4.0)
- output = _layers.batch_norm(
- images,
- fused=True,
- is_training=True,
- scale=True,
- epsilon=0.5,
- param_initializers={
- 'beta': beta,
- 'gamma': gamma,
- 'moving_mean': mean,
- 'moving_variance': variance,
- },
- data_format='NCHW')
- with self.test_session(use_gpu=True) as sess:
+ a = array_ops.placeholder(dtypes.float32, shape=(10, 10, 10, 10))
+ b = _layers.batch_norm(a, center=False, data_format='NCHW',
+ zero_debias_moving_mean=True)
sess.run(variables_lib.global_variables_initializer())
- return sess.run(output)
-
- def testFusedBatchNormFloat16MatchesFloat32(self):
- if test.is_gpu_available(cuda_only=True):
- shape = [5, 4, 2, 3]
- res_32 = self._runFusedBatchNorm(shape, np.float32)
- res_16 = self._runFusedBatchNorm(shape, np.float16)
- self.assertAllClose(res_32, res_16, rtol=1e-3)
-
def testAdjustmentCreated(self):
# Tests that the adjustment is appropriately passed to and used by the core