aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/layers/normalization_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/layers/normalization_test.py')
-rw-r--r--tensorflow/python/layers/normalization_test.py98
1 files changed, 4 insertions, 94 deletions
diff --git a/tensorflow/python/layers/normalization_test.py b/tensorflow/python/layers/normalization_test.py
index b2876c58c2..90ebdc8c86 100644
--- a/tensorflow/python/layers/normalization_test.py
+++ b/tensorflow/python/layers/normalization_test.py
@@ -68,12 +68,11 @@ class BNTest(test.TestCase):
use_gpu,
is_fused,
restore=False,
- freeze_mode=False,
- dtype=dtypes.float32):
+ freeze_mode=False):
ops.reset_default_graph()
graph = ops.get_default_graph()
with self.test_session(graph=graph, use_gpu=use_gpu) as sess:
- image = array_ops.placeholder(dtype=dtype, shape=shape)
+ image = array_ops.placeholder(dtype='float32', shape=shape)
loss, train_op, saver = self._simple_model(image, is_fused, freeze_mode)
if restore:
saver.restore(sess, checkpoint_path)
@@ -81,7 +80,7 @@ class BNTest(test.TestCase):
sess.run(variables.global_variables_initializer())
np.random.seed(0)
for _ in range(2):
- image_val = np.random.rand(*shape).astype(dtype.as_numpy_dtype)
+ image_val = np.random.rand(*shape).astype(np.float32)
sess.run([loss, train_op], feed_dict={image: image_val})
if restore:
all_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
@@ -91,74 +90,15 @@ class BNTest(test.TestCase):
saver.save(sess, checkpoint_path)
def _infer(self, checkpoint_path, image_val, shape, use_gpu, is_fused):
- dtype = image_val.dtype
ops.reset_default_graph()
graph = ops.get_default_graph()
with self.test_session(graph=graph, use_gpu=use_gpu) as sess:
- image = array_ops.placeholder(dtype=dtype, shape=shape)
+ image = array_ops.placeholder(dtype='float32', shape=shape)
loss, _, saver = self._simple_model(image, is_fused, True)
saver.restore(sess, checkpoint_path)
loss_val = sess.run(loss, feed_dict={image: image_val})
return loss_val
- def _trainEvalSequence(self,
- dtype,
- train1_use_gpu,
- train2_use_gpu,
- infer_use_gpu):
- batch, height, width, input_channels = 2, 4, 5, 3
- shape = [batch, height, width, input_channels]
- checkpoint = os.path.join(self.get_temp_dir(), 'cp_%s_%s_%s_%s' %
- (dtype, train1_use_gpu, train2_use_gpu, infer_use_gpu))
-
- self._train(
- checkpoint,
- shape,
- use_gpu=train1_use_gpu,
- is_fused=True,
- restore=False,
- freeze_mode=False,
- dtype=dtype)
-
- train_vars = self._train(
- checkpoint,
- shape,
- use_gpu=train2_use_gpu,
- is_fused=True,
- restore=True,
- freeze_mode=False,
- dtype=dtype)
-
- np.random.seed(0)
- image_val = np.random.rand(batch,
- height,
- width,
- input_channels).astype(dtype.as_numpy_dtype)
- loss_val = self._infer(checkpoint, image_val, shape,
- use_gpu=infer_use_gpu, is_fused=True)
-
- return train_vars, loss_val
-
- def testHalfPrecision(self):
- ref_vars, ref_loss = self._trainEvalSequence(dtype=dtypes.float32,
- train1_use_gpu=True,
- train2_use_gpu=True,
- infer_use_gpu=True)
-
- self.assertEqual(len(ref_vars), 5)
-
- for train1_use_gpu in [True, False]:
- for train2_use_gpu in [True, False]:
- for infer_use_gpu in [True, False]:
- test_vars, test_loss = self._trainEvalSequence(dtypes.float16,
- train1_use_gpu,
- train2_use_gpu,
- infer_use_gpu)
- self.assertEqual(len(test_vars), 5)
- for test_var, ref_var in zip(test_vars, ref_vars):
- self.assertAllClose(test_var, ref_var, rtol=1.e-3, atol=1.e-3)
- self.assertAllClose(test_loss, ref_loss, rtol=1.e-3, atol=1.e-3)
-
def _testCheckpoint(self, is_fused_checkpoint_a, is_fused_checkpoint_b,
use_gpu_checkpoint_a, use_gpu_checkpoint_b,
use_gpu_test_a, use_gpu_test_b, freeze_mode):
@@ -278,36 +218,6 @@ class BNTest(test.TestCase):
ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES),
bn.trainable_variables)
- def testCreateFusedBNFloat16(self):
- # Call layer.
- bn = normalization_layers.BatchNormalization(axis=1, fused=True)
- inputs = random_ops.random_uniform((5, 4, 3, 3),
- seed=1,
- dtype=dtypes.float16)
- training = array_ops.placeholder(dtype='bool')
- outputs = bn.apply(inputs, training=training)
-
- # Verify shape.
- self.assertListEqual(outputs.get_shape().as_list(), [5, 4, 3, 3])
-
- # Verify layer attributes.
- self.assertEqual(len(bn.updates), 2)
- self.assertEqual(len(bn.variables), 4)
- self.assertEqual(len(bn.trainable_variables), 2)
- self.assertEqual(len(bn.non_trainable_variables), 2)
- for var in bn.variables:
- self.assertEqual(var.dtype, dtypes.float32_ref)
-
- # Test that updates were created and added to UPDATE_OPS.
- self.assertEqual(len(bn.updates), 2)
- self.assertListEqual(
- ops.get_collection(ops.GraphKeys.UPDATE_OPS), bn.updates)
-
- # Test that weights were created and added to TRAINABLE_VARIABLES.
- self.assertListEqual(
- ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES),
- bn.trainable_variables)
-
def test3DInputAxis1(self):
epsilon = 1e-3
bn = normalization_layers.BatchNormalization(