aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/layers/normalization_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/layers/normalization_test.py')
-rw-r--r--tensorflow/python/layers/normalization_test.py98
1 files changed, 94 insertions, 4 deletions
diff --git a/tensorflow/python/layers/normalization_test.py b/tensorflow/python/layers/normalization_test.py
index 90ebdc8c86..b2876c58c2 100644
--- a/tensorflow/python/layers/normalization_test.py
+++ b/tensorflow/python/layers/normalization_test.py
@@ -68,11 +68,12 @@ class BNTest(test.TestCase):
use_gpu,
is_fused,
restore=False,
- freeze_mode=False):
+ freeze_mode=False,
+ dtype=dtypes.float32):
ops.reset_default_graph()
graph = ops.get_default_graph()
with self.test_session(graph=graph, use_gpu=use_gpu) as sess:
- image = array_ops.placeholder(dtype='float32', shape=shape)
+ image = array_ops.placeholder(dtype=dtype, shape=shape)
loss, train_op, saver = self._simple_model(image, is_fused, freeze_mode)
if restore:
saver.restore(sess, checkpoint_path)
@@ -80,7 +81,7 @@ class BNTest(test.TestCase):
sess.run(variables.global_variables_initializer())
np.random.seed(0)
for _ in range(2):
- image_val = np.random.rand(*shape).astype(np.float32)
+ image_val = np.random.rand(*shape).astype(dtype.as_numpy_dtype)
sess.run([loss, train_op], feed_dict={image: image_val})
if restore:
all_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
@@ -90,15 +91,74 @@ class BNTest(test.TestCase):
saver.save(sess, checkpoint_path)
def _infer(self, checkpoint_path, image_val, shape, use_gpu, is_fused):
+ dtype = image_val.dtype
ops.reset_default_graph()
graph = ops.get_default_graph()
with self.test_session(graph=graph, use_gpu=use_gpu) as sess:
- image = array_ops.placeholder(dtype='float32', shape=shape)
+ image = array_ops.placeholder(dtype=dtype, shape=shape)
loss, _, saver = self._simple_model(image, is_fused, True)
saver.restore(sess, checkpoint_path)
loss_val = sess.run(loss, feed_dict={image: image_val})
return loss_val
+ def _trainEvalSequence(self,
+ dtype,
+ train1_use_gpu,
+ train2_use_gpu,
+ infer_use_gpu):
+ batch, height, width, input_channels = 2, 4, 5, 3
+ shape = [batch, height, width, input_channels]
+ checkpoint = os.path.join(self.get_temp_dir(), 'cp_%s_%s_%s_%s' %
+ (dtype, train1_use_gpu, train2_use_gpu, infer_use_gpu))
+
+ self._train(
+ checkpoint,
+ shape,
+ use_gpu=train1_use_gpu,
+ is_fused=True,
+ restore=False,
+ freeze_mode=False,
+ dtype=dtype)
+
+ train_vars = self._train(
+ checkpoint,
+ shape,
+ use_gpu=train2_use_gpu,
+ is_fused=True,
+ restore=True,
+ freeze_mode=False,
+ dtype=dtype)
+
+ np.random.seed(0)
+ image_val = np.random.rand(batch,
+ height,
+ width,
+ input_channels).astype(dtype.as_numpy_dtype)
+ loss_val = self._infer(checkpoint, image_val, shape,
+ use_gpu=infer_use_gpu, is_fused=True)
+
+ return train_vars, loss_val
+
+ def testHalfPrecision(self):
+ ref_vars, ref_loss = self._trainEvalSequence(dtype=dtypes.float32,
+ train1_use_gpu=True,
+ train2_use_gpu=True,
+ infer_use_gpu=True)
+
+ self.assertEqual(len(ref_vars), 5)
+
+ for train1_use_gpu in [True, False]:
+ for train2_use_gpu in [True, False]:
+ for infer_use_gpu in [True, False]:
+ test_vars, test_loss = self._trainEvalSequence(dtypes.float16,
+ train1_use_gpu,
+ train2_use_gpu,
+ infer_use_gpu)
+ self.assertEqual(len(test_vars), 5)
+ for test_var, ref_var in zip(test_vars, ref_vars):
+ self.assertAllClose(test_var, ref_var, rtol=1.e-3, atol=1.e-3)
+ self.assertAllClose(test_loss, ref_loss, rtol=1.e-3, atol=1.e-3)
+
def _testCheckpoint(self, is_fused_checkpoint_a, is_fused_checkpoint_b,
use_gpu_checkpoint_a, use_gpu_checkpoint_b,
use_gpu_test_a, use_gpu_test_b, freeze_mode):
@@ -218,6 +278,36 @@ class BNTest(test.TestCase):
ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES),
bn.trainable_variables)
+ def testCreateFusedBNFloat16(self):
+ # Call layer.
+ bn = normalization_layers.BatchNormalization(axis=1, fused=True)
+ inputs = random_ops.random_uniform((5, 4, 3, 3),
+ seed=1,
+ dtype=dtypes.float16)
+ training = array_ops.placeholder(dtype='bool')
+ outputs = bn.apply(inputs, training=training)
+
+ # Verify shape.
+ self.assertListEqual(outputs.get_shape().as_list(), [5, 4, 3, 3])
+
+ # Verify layer attributes.
+ self.assertEqual(len(bn.updates), 2)
+ self.assertEqual(len(bn.variables), 4)
+ self.assertEqual(len(bn.trainable_variables), 2)
+ self.assertEqual(len(bn.non_trainable_variables), 2)
+ for var in bn.variables:
+ self.assertEqual(var.dtype, dtypes.float32_ref)
+
+ # Test that updates were created and added to UPDATE_OPS.
+ self.assertEqual(len(bn.updates), 2)
+ self.assertListEqual(
+ ops.get_collection(ops.GraphKeys.UPDATE_OPS), bn.updates)
+
+ # Test that weights were created and added to TRAINABLE_VARIABLES.
+ self.assertListEqual(
+ ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES),
+ bn.trainable_variables)
+
def test3DInputAxis1(self):
epsilon = 1e-3
bn = normalization_layers.BatchNormalization(