aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/layers/python/layers/layers_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/layers/python/layers/layers_test.py')
-rw-r--r--tensorflow/contrib/layers/python/layers/layers_test.py6
1 files changed, 6 insertions, 0 deletions
diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py
index ae64b75d93..1150328b7a 100644
--- a/tensorflow/contrib/layers/python/layers/layers_test.py
+++ b/tensorflow/contrib/layers/python/layers/layers_test.py
@@ -1747,6 +1747,12 @@ class BatchNormTest(test.TestCase):
expected_var *= correction_factor
return expected_var, correction_factor
+ def testBatchNormCenterFalse(self):
+ a = array_ops.placeholder(dtype=dtypes.float32, shape=(10, 10, 10, 10))
+ # Test that center=False builds a valid graph.
+ _layers.batch_norm(a, center=False, data_format='NCHW',
+ zero_debias_moving_mean=True)
+
def testUnknownShape(self):
with ops.Graph().as_default() as g, self.test_session(g):
inputs = array_ops.placeholder(dtype=dtypes.float32)