aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/layers
diff options
context:
space:
mode:
authorGravatar josephyearsley <joggino23@gmail.com>2018-03-18 19:24:10 +0000
committerGravatar Joe Yearsley <josephelliotyearsley@gmail.com>2018-09-29 17:31:57 +0100
commitdd928d5ae31dd0484e5e4a96c6322adecc4e511b (patch)
treef05d788d02484f87360ee306713a63e23eb2808b /tensorflow/python/layers
parenteb6c1bdcbf6093888f2b443fdb49f836f3352316 (diff)
Added Flatten Test
Diffstat (limited to 'tensorflow/python/layers')
-rw-r--r--tensorflow/python/layers/core_test.py16
1 files changed, 16 insertions, 0 deletions
diff --git a/tensorflow/python/layers/core_test.py b/tensorflow/python/layers/core_test.py
index d26f3f4789..0d019897aa 100644
--- a/tensorflow/python/layers/core_test.py
+++ b/tensorflow/python/layers/core_test.py
@@ -476,6 +476,22 @@ class FlattenTest(test.TestCase):
shape = core_layers.Flatten().compute_output_shape((None, 3, None))
self.assertEqual(shape.as_list(), [None, None])
+ def testDataFormat(self):
+ np_input_channels_last = np.arange(3, 7).reshape([1, 2, 3, 2])
+
+ with self.test_session() as sess:
+ x = array_ops.placeholder(shape=(1, 2, 3, 2), dtype='float32')
+ y = core_layers.Flatten(data_format='channels_last')(x)
+ np_output_cl = sess.run(y, feed_dict={x: np_input_channels_last})
+
+ x = array_ops.placeholder(shape=(1, 2, 3, 2), dtype='float32')
+ y = core_layers.Flatten(data_format='channels_first')(x)
+ np_input_channels_first = np.transpose(np_input_channels_last,
+ [0, 3, 1, 2])
+ np_output_cf = sess.run(y, feed_dict={x: np_input_channels_first})
+
+ self.assertEqual(np_output_cl, np_output_cf)
+
def testFunctionalFlatten(self):
x = array_ops.placeholder(shape=(None, 2, 3), dtype='float32')
y = core_layers.flatten(x, name='flatten')