diff options
Diffstat (limited to 'tensorflow/python/layers/normalization_test.py')
-rw-r--r-- | tensorflow/python/layers/normalization_test.py | 98 |
1 files changed, 94 insertions, 4 deletions
diff --git a/tensorflow/python/layers/normalization_test.py b/tensorflow/python/layers/normalization_test.py index 90ebdc8c86..b2876c58c2 100644 --- a/tensorflow/python/layers/normalization_test.py +++ b/tensorflow/python/layers/normalization_test.py @@ -68,11 +68,12 @@ class BNTest(test.TestCase): use_gpu, is_fused, restore=False, - freeze_mode=False): + freeze_mode=False, + dtype=dtypes.float32): ops.reset_default_graph() graph = ops.get_default_graph() with self.test_session(graph=graph, use_gpu=use_gpu) as sess: - image = array_ops.placeholder(dtype='float32', shape=shape) + image = array_ops.placeholder(dtype=dtype, shape=shape) loss, train_op, saver = self._simple_model(image, is_fused, freeze_mode) if restore: saver.restore(sess, checkpoint_path) @@ -80,7 +81,7 @@ class BNTest(test.TestCase): sess.run(variables.global_variables_initializer()) np.random.seed(0) for _ in range(2): - image_val = np.random.rand(*shape).astype(np.float32) + image_val = np.random.rand(*shape).astype(dtype.as_numpy_dtype) sess.run([loss, train_op], feed_dict={image: image_val}) if restore: all_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) @@ -90,15 +91,74 @@ class BNTest(test.TestCase): saver.save(sess, checkpoint_path) def _infer(self, checkpoint_path, image_val, shape, use_gpu, is_fused): + dtype = image_val.dtype ops.reset_default_graph() graph = ops.get_default_graph() with self.test_session(graph=graph, use_gpu=use_gpu) as sess: - image = array_ops.placeholder(dtype='float32', shape=shape) + image = array_ops.placeholder(dtype=dtype, shape=shape) loss, _, saver = self._simple_model(image, is_fused, True) saver.restore(sess, checkpoint_path) loss_val = sess.run(loss, feed_dict={image: image_val}) return loss_val + def _trainEvalSequence(self, + dtype, + train1_use_gpu, + train2_use_gpu, + infer_use_gpu): + batch, height, width, input_channels = 2, 4, 5, 3 + shape = [batch, height, width, input_channels] + checkpoint = os.path.join(self.get_temp_dir(), 'cp_%s_%s_%s_%s' % + (dtype, train1_use_gpu, train2_use_gpu, infer_use_gpu)) + + self._train( + checkpoint, + shape, + use_gpu=train1_use_gpu, + is_fused=True, + restore=False, + freeze_mode=False, + dtype=dtype) + + train_vars = self._train( + checkpoint, + shape, + use_gpu=train2_use_gpu, + is_fused=True, + restore=True, + freeze_mode=False, + dtype=dtype) + + np.random.seed(0) + image_val = np.random.rand(batch, + height, + width, + input_channels).astype(dtype.as_numpy_dtype) + loss_val = self._infer(checkpoint, image_val, shape, + use_gpu=infer_use_gpu, is_fused=True) + + return train_vars, loss_val + + def testHalfPrecision(self): + ref_vars, ref_loss = self._trainEvalSequence(dtype=dtypes.float32, + train1_use_gpu=True, + train2_use_gpu=True, + infer_use_gpu=True) + + self.assertEqual(len(ref_vars), 5) + + for train1_use_gpu in [True, False]: + for train2_use_gpu in [True, False]: + for infer_use_gpu in [True, False]: + test_vars, test_loss = self._trainEvalSequence(dtypes.float16, + train1_use_gpu, + train2_use_gpu, + infer_use_gpu) + self.assertEqual(len(test_vars), 5) + for test_var, ref_var in zip(test_vars, ref_vars): + self.assertAllClose(test_var, ref_var, rtol=1.e-3, atol=1.e-3) + self.assertAllClose(test_loss, ref_loss, rtol=1.e-3, atol=1.e-3) + def _testCheckpoint(self, is_fused_checkpoint_a, is_fused_checkpoint_b, use_gpu_checkpoint_a, use_gpu_checkpoint_b, use_gpu_test_a, use_gpu_test_b, freeze_mode): @@ -218,6 +278,36 @@ class BNTest(test.TestCase): ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES), bn.trainable_variables) + def testCreateFusedBNFloat16(self): + # Call layer. + bn = normalization_layers.BatchNormalization(axis=1, fused=True) + inputs = random_ops.random_uniform((5, 4, 3, 3), + seed=1, + dtype=dtypes.float16) + training = array_ops.placeholder(dtype='bool') + outputs = bn.apply(inputs, training=training) + + # Verify shape. + self.assertListEqual(outputs.get_shape().as_list(), [5, 4, 3, 3]) + + # Verify layer attributes. + self.assertEqual(len(bn.updates), 2) + self.assertEqual(len(bn.variables), 4) + self.assertEqual(len(bn.trainable_variables), 2) + self.assertEqual(len(bn.non_trainable_variables), 2) + for var in bn.variables: + self.assertEqual(var.dtype, dtypes.float32_ref) + + # Test that updates were created and added to UPDATE_OPS. + self.assertEqual(len(bn.updates), 2) + self.assertListEqual( + ops.get_collection(ops.GraphKeys.UPDATE_OPS), bn.updates) + + # Test that weights were created and added to TRAINABLE_VARIABLES. + self.assertListEqual( + ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES), + bn.trainable_variables) + def test3DInputAxis1(self): epsilon = 1e-3 bn = normalization_layers.BatchNormalization( |