diff options
author | josephyearsley <joggino23@gmail.com> | 2018-03-18 19:24:10 +0000 |
---|---|---|
committer | Joe Yearsley <josephelliotyearsley@gmail.com> | 2018-09-29 17:31:57 +0100 |
commit | dd928d5ae31dd0484e5e4a96c6322adecc4e511b (patch) | |
tree | f05d788d02484f87360ee306713a63e23eb2808b /tensorflow/python/layers | |
parent | eb6c1bdcbf6093888f2b443fdb49f836f3352316 (diff) |
Added Flatten Test
Diffstat (limited to 'tensorflow/python/layers')
-rw-r--r-- | tensorflow/python/layers/core_test.py | 16 |
1 files changed, 16 insertions, 0 deletions
diff --git a/tensorflow/python/layers/core_test.py b/tensorflow/python/layers/core_test.py index d26f3f4789..0d019897aa 100644 --- a/tensorflow/python/layers/core_test.py +++ b/tensorflow/python/layers/core_test.py @@ -476,6 +476,22 @@ class FlattenTest(test.TestCase): shape = core_layers.Flatten().compute_output_shape((None, 3, None)) self.assertEqual(shape.as_list(), [None, None]) + def testDataFormat(self): + np_input_channels_last = np.arange(3, 7).reshape([1, 2, 3, 2]) + + with self.test_session() as sess: + x = array_ops.placeholder(shape=(1, 2, 3, 2), dtype='float32') + y = core_layers.Flatten(data_format='channels_last')(x) + np_output_cl = sess.run(y, feed_dict={x: np_input_channels_last}) + + x = array_ops.placeholder(shape=(1, 2, 3, 2), dtype='float32') + y = core_layers.Flatten(data_format='channels_first')(x) + np_input_channels_first = np.transpose(np_input_channels_last, + [0, 3, 1, 2]) + np_output_cf = sess.run(y, feed_dict={x: np_input_channels_first}) + + self.assertEqual(np_output_cl, np_output_cf) + def testFunctionalFlatten(self): x = array_ops.placeholder(shape=(None, 2, 3), dtype='float32') y = core_layers.flatten(x, name='flatten') |