diff options
author | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-02 13:26:53 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-02 13:27:15 -0700 |
commit | 3a1ae8f4ab49e8aa6ffda48273b56e4c68157a29 (patch) | |
tree | 1de3493e1281695cf0f964a478b13570ac8f0803 /tensorflow/python/layers | |
parent | 8d12c635cc48e896da0bcac1cd568bd6381ca64e (diff) | |
parent | 1a56a3299e904d5a3352a3a15e4cf7401f72bbc3 (diff) |
Merge pull request #17672 from joeyearsley:patch-3
PiperOrigin-RevId: 215447391
Diffstat (limited to 'tensorflow/python/layers')
-rw-r--r-- | tensorflow/python/layers/core.py | 16 | ||||
-rw-r--r-- | tensorflow/python/layers/core_test.py | 34 |
2 files changed, 48 insertions, 2 deletions
diff --git a/tensorflow/python/layers/core.py b/tensorflow/python/layers/core.py index 9879e5020f..e06e9aba4a 100644 --- a/tensorflow/python/layers/core.py +++ b/tensorflow/python/layers/core.py @@ -269,6 +269,13 @@ def dropout(inputs, class Flatten(keras_layers.Flatten, base.Layer): """Flattens an input tensor while preserving the batch axis (axis 0). + Arguments: + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, ..., channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, ...)`. + Examples: ``` @@ -285,12 +292,17 @@ class Flatten(keras_layers.Flatten, base.Layer): @tf_export('layers.flatten') -def flatten(inputs, name=None): +def flatten(inputs, name=None, data_format='channels_last'): """Flattens an input tensor while preserving the batch axis (axis 0). Arguments: inputs: Tensor input. name: The name of the layer (string). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, height, width, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, height, width)`. Returns: Reshaped tensor. @@ -307,7 +319,7 @@ def flatten(inputs, name=None): # now `y` has shape `(None, None)` ``` """ - layer = Flatten(name=name) + layer = Flatten(name=name, data_format=data_format) return layer.apply(inputs) diff --git a/tensorflow/python/layers/core_test.py b/tensorflow/python/layers/core_test.py index d26f3f4789..0343bfa8bd 100644 --- a/tensorflow/python/layers/core_test.py +++ b/tensorflow/python/layers/core_test.py @@ -476,6 +476,40 @@ class FlattenTest(test.TestCase): shape = core_layers.Flatten().compute_output_shape((None, 3, None)) self.assertEqual(shape.as_list(), [None, None]) + def testDataFormat5d(self): + np_input_channels_last = np.arange( + 120, dtype='float32').reshape([1, 5, 4, 3, 2]) + + with self.test_session() as sess: + x = array_ops.placeholder(shape=(1, 5, 4, 3, 2), dtype='float32') + y = core_layers.Flatten(data_format='channels_last')(x) + np_output_cl = sess.run(y, feed_dict={x: np_input_channels_last}) + + x = array_ops.placeholder(shape=(1, 2, 5, 4, 3), dtype='float32') + y = core_layers.Flatten(data_format='channels_first')(x) + np_input_channels_first = np.transpose(np_input_channels_last, + [0, 4, 1, 2, 3]) + np_output_cf = sess.run(y, feed_dict={x: np_input_channels_first}) + + self.assertAllEqual(np_output_cl, np_output_cf) + + def testDataFormat4d(self): + np_input_channels_last = np.arange( + 24, dtype='float32').reshape([1, 4, 3, 2]) + + with self.test_session() as sess: + x = array_ops.placeholder(shape=(1, 4, 3, 2), dtype='float32') + y = core_layers.Flatten(data_format='channels_last')(x) + np_output_cl = sess.run(y, feed_dict={x: np_input_channels_last}) + + x = array_ops.placeholder(shape=(1, 2, 4, 3), dtype='float32') + y = core_layers.Flatten(data_format='channels_first')(x) + np_input_channels_first = np.transpose(np_input_channels_last, + [0, 3, 1, 2]) + np_output_cf = sess.run(y, feed_dict={x: np_input_channels_first}) + + self.assertAllEqual(np_output_cl, np_output_cf) + def testFunctionalFlatten(self): x = array_ops.placeholder(shape=(None, 2, 3), dtype='float32') y = core_layers.flatten(x, name='flatten') |