diff options
Diffstat (limited to 'tensorflow/python/layers/normalization_test.py')
-rw-r--r-- | tensorflow/python/layers/normalization_test.py | 98 |
1 files changed, 4 insertions, 94 deletions
diff --git a/tensorflow/python/layers/normalization_test.py b/tensorflow/python/layers/normalization_test.py index b2876c58c2..90ebdc8c86 100644 --- a/tensorflow/python/layers/normalization_test.py +++ b/tensorflow/python/layers/normalization_test.py @@ -68,12 +68,11 @@ class BNTest(test.TestCase): use_gpu, is_fused, restore=False, - freeze_mode=False, - dtype=dtypes.float32): + freeze_mode=False): ops.reset_default_graph() graph = ops.get_default_graph() with self.test_session(graph=graph, use_gpu=use_gpu) as sess: - image = array_ops.placeholder(dtype=dtype, shape=shape) + image = array_ops.placeholder(dtype='float32', shape=shape) loss, train_op, saver = self._simple_model(image, is_fused, freeze_mode) if restore: saver.restore(sess, checkpoint_path) @@ -81,7 +80,7 @@ class BNTest(test.TestCase): sess.run(variables.global_variables_initializer()) np.random.seed(0) for _ in range(2): - image_val = np.random.rand(*shape).astype(dtype.as_numpy_dtype) + image_val = np.random.rand(*shape).astype(np.float32) sess.run([loss, train_op], feed_dict={image: image_val}) if restore: all_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) @@ -91,74 +90,15 @@ class BNTest(test.TestCase): saver.save(sess, checkpoint_path) def _infer(self, checkpoint_path, image_val, shape, use_gpu, is_fused): - dtype = image_val.dtype ops.reset_default_graph() graph = ops.get_default_graph() with self.test_session(graph=graph, use_gpu=use_gpu) as sess: - image = array_ops.placeholder(dtype=dtype, shape=shape) + image = array_ops.placeholder(dtype='float32', shape=shape) loss, _, saver = self._simple_model(image, is_fused, True) saver.restore(sess, checkpoint_path) loss_val = sess.run(loss, feed_dict={image: image_val}) return loss_val - def _trainEvalSequence(self, - dtype, - train1_use_gpu, - train2_use_gpu, - infer_use_gpu): - batch, height, width, input_channels = 2, 4, 5, 3 - shape = [batch, height, width, input_channels] - checkpoint = os.path.join(self.get_temp_dir(), 'cp_%s_%s_%s_%s' % - (dtype, train1_use_gpu, train2_use_gpu, infer_use_gpu)) - - self._train( - checkpoint, - shape, - use_gpu=train1_use_gpu, - is_fused=True, - restore=False, - freeze_mode=False, - dtype=dtype) - - train_vars = self._train( - checkpoint, - shape, - use_gpu=train2_use_gpu, - is_fused=True, - restore=True, - freeze_mode=False, - dtype=dtype) - - np.random.seed(0) - image_val = np.random.rand(batch, - height, - width, - input_channels).astype(dtype.as_numpy_dtype) - loss_val = self._infer(checkpoint, image_val, shape, - use_gpu=infer_use_gpu, is_fused=True) - - return train_vars, loss_val - - def testHalfPrecision(self): - ref_vars, ref_loss = self._trainEvalSequence(dtype=dtypes.float32, - train1_use_gpu=True, - train2_use_gpu=True, - infer_use_gpu=True) - - self.assertEqual(len(ref_vars), 5) - - for train1_use_gpu in [True, False]: - for train2_use_gpu in [True, False]: - for infer_use_gpu in [True, False]: - test_vars, test_loss = self._trainEvalSequence(dtypes.float16, - train1_use_gpu, - train2_use_gpu, - infer_use_gpu) - self.assertEqual(len(test_vars), 5) - for test_var, ref_var in zip(test_vars, ref_vars): - self.assertAllClose(test_var, ref_var, rtol=1.e-3, atol=1.e-3) - self.assertAllClose(test_loss, ref_loss, rtol=1.e-3, atol=1.e-3) - def _testCheckpoint(self, is_fused_checkpoint_a, is_fused_checkpoint_b, use_gpu_checkpoint_a, use_gpu_checkpoint_b, use_gpu_test_a, use_gpu_test_b, freeze_mode): @@ -278,36 +218,6 @@ class BNTest(test.TestCase): ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES), bn.trainable_variables) - def testCreateFusedBNFloat16(self): - # Call layer. - bn = normalization_layers.BatchNormalization(axis=1, fused=True) - inputs = random_ops.random_uniform((5, 4, 3, 3), - seed=1, - dtype=dtypes.float16) - training = array_ops.placeholder(dtype='bool') - outputs = bn.apply(inputs, training=training) - - # Verify shape. - self.assertListEqual(outputs.get_shape().as_list(), [5, 4, 3, 3]) - - # Verify layer attributes. - self.assertEqual(len(bn.updates), 2) - self.assertEqual(len(bn.variables), 4) - self.assertEqual(len(bn.trainable_variables), 2) - self.assertEqual(len(bn.non_trainable_variables), 2) - for var in bn.variables: - self.assertEqual(var.dtype, dtypes.float32_ref) - - # Test that updates were created and added to UPDATE_OPS. - self.assertEqual(len(bn.updates), 2) - self.assertListEqual( - ops.get_collection(ops.GraphKeys.UPDATE_OPS), bn.updates) - - # Test that weights were created and added to TRAINABLE_VARIABLES. - self.assertListEqual( - ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES), - bn.trainable_variables) - def test3DInputAxis1(self): epsilon = 1e-3 bn = normalization_layers.BatchNormalization( |