diff options
-rw-r--r-- | tensorflow/python/layers/core.py | 16 | ||||
-rw-r--r-- | tensorflow/python/layers/core_test.py | 34 | ||||
-rw-r--r-- | tensorflow/tools/api/golden/v1/tensorflow.layers.pbtxt | 2 | ||||
-rw-r--r-- | tensorflow/tools/api/golden/v2/tensorflow.layers.pbtxt | 2 |
4 files changed, 50 insertions, 4 deletions
diff --git a/tensorflow/python/layers/core.py b/tensorflow/python/layers/core.py index 9879e5020f..e06e9aba4a 100644 --- a/tensorflow/python/layers/core.py +++ b/tensorflow/python/layers/core.py @@ -269,6 +269,13 @@ def dropout(inputs, class Flatten(keras_layers.Flatten, base.Layer): """Flattens an input tensor while preserving the batch axis (axis 0). + Arguments: + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, ..., channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, ...)`. + Examples: ``` @@ -285,12 +292,17 @@ class Flatten(keras_layers.Flatten, base.Layer): @tf_export('layers.flatten') -def flatten(inputs, name=None): +def flatten(inputs, name=None, data_format='channels_last'): """Flattens an input tensor while preserving the batch axis (axis 0). Arguments: inputs: Tensor input. name: The name of the layer (string). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, height, width, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, height, width)`. Returns: Reshaped tensor. @@ -307,7 +319,7 @@ def flatten(inputs, name=None): # now `y` has shape `(None, None)` ``` """ - layer = Flatten(name=name) + layer = Flatten(name=name, data_format=data_format) return layer.apply(inputs) diff --git a/tensorflow/python/layers/core_test.py b/tensorflow/python/layers/core_test.py index d26f3f4789..0343bfa8bd 100644 --- a/tensorflow/python/layers/core_test.py +++ b/tensorflow/python/layers/core_test.py @@ -476,6 +476,40 @@ class FlattenTest(test.TestCase): shape = core_layers.Flatten().compute_output_shape((None, 3, None)) self.assertEqual(shape.as_list(), [None, None]) + def testDataFormat5d(self): + np_input_channels_last = np.arange( + 120, dtype='float32').reshape([1, 5, 4, 3, 2]) + + with self.test_session() as sess: + x = array_ops.placeholder(shape=(1, 5, 4, 3, 2), dtype='float32') + y = core_layers.Flatten(data_format='channels_last')(x) + np_output_cl = sess.run(y, feed_dict={x: np_input_channels_last}) + + x = array_ops.placeholder(shape=(1, 2, 5, 4, 3), dtype='float32') + y = core_layers.Flatten(data_format='channels_first')(x) + np_input_channels_first = np.transpose(np_input_channels_last, + [0, 4, 1, 2, 3]) + np_output_cf = sess.run(y, feed_dict={x: np_input_channels_first}) + + self.assertAllEqual(np_output_cl, np_output_cf) + + def testDataFormat4d(self): + np_input_channels_last = np.arange( + 24, dtype='float32').reshape([1, 4, 3, 2]) + + with self.test_session() as sess: + x = array_ops.placeholder(shape=(1, 4, 3, 2), dtype='float32') + y = core_layers.Flatten(data_format='channels_last')(x) + np_output_cl = sess.run(y, feed_dict={x: np_input_channels_last}) + + x = array_ops.placeholder(shape=(1, 2, 4, 3), dtype='float32') + y = core_layers.Flatten(data_format='channels_first')(x) + np_input_channels_first = np.transpose(np_input_channels_last, + [0, 3, 1, 2]) + np_output_cf = sess.run(y, feed_dict={x: np_input_channels_first}) + + self.assertAllEqual(np_output_cl, np_output_cf) + def testFunctionalFlatten(self): x = array_ops.placeholder(shape=(None, 2, 3), dtype='float32') y = core_layers.flatten(x, name='flatten') diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.pbtxt index df74c32e1f..0c24e9c7dd 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.pbtxt @@ -122,7 +122,7 @@ tf_module { } member_method { name: "flatten" - argspec: "args=[\'inputs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'inputs\', \'name\', \'data_format\'], varargs=None, keywords=None, defaults=[\'None\', \'channels_last\'], " } member_method { name: "max_pooling1d" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.layers.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.layers.pbtxt index df74c32e1f..0c24e9c7dd 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.layers.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.layers.pbtxt @@ -122,7 +122,7 @@ tf_module { } member_method { name: "flatten" - argspec: "args=[\'inputs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'inputs\', \'name\', \'data_format\'], varargs=None, keywords=None, defaults=[\'None\', \'channels_last\'], " } member_method { name: "max_pooling1d" |