aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/layers
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-02 13:26:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-02 13:27:15 -0700
commit3a1ae8f4ab49e8aa6ffda48273b56e4c68157a29 (patch)
tree1de3493e1281695cf0f964a478b13570ac8f0803 /tensorflow/python/layers
parent8d12c635cc48e896da0bcac1cd568bd6381ca64e (diff)
parent1a56a3299e904d5a3352a3a15e4cf7401f72bbc3 (diff)
Merge pull request #17672 from joeyearsley:patch-3
PiperOrigin-RevId: 215447391
Diffstat (limited to 'tensorflow/python/layers')
-rw-r--r--tensorflow/python/layers/core.py16
-rw-r--r--tensorflow/python/layers/core_test.py34
2 files changed, 48 insertions, 2 deletions
diff --git a/tensorflow/python/layers/core.py b/tensorflow/python/layers/core.py
index 9879e5020f..e06e9aba4a 100644
--- a/tensorflow/python/layers/core.py
+++ b/tensorflow/python/layers/core.py
@@ -269,6 +269,13 @@ def dropout(inputs,
class Flatten(keras_layers.Flatten, base.Layer):
"""Flattens an input tensor while preserving the batch axis (axis 0).
+ Arguments:
+ data_format: A string, one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, ..., channels)` while `channels_first` corresponds to
+ inputs with shape `(batch, channels, ...)`.
+
Examples:
```
@@ -285,12 +292,17 @@ class Flatten(keras_layers.Flatten, base.Layer):
@tf_export('layers.flatten')
-def flatten(inputs, name=None):
+def flatten(inputs, name=None, data_format='channels_last'):
"""Flattens an input tensor while preserving the batch axis (axis 0).
Arguments:
inputs: Tensor input.
name: The name of the layer (string).
+ data_format: A string, one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, height, width, channels)` while `channels_first` corresponds to
+ inputs with shape `(batch, channels, height, width)`.
Returns:
Reshaped tensor.
@@ -307,7 +319,7 @@ def flatten(inputs, name=None):
# now `y` has shape `(None, None)`
```
"""
- layer = Flatten(name=name)
+ layer = Flatten(name=name, data_format=data_format)
return layer.apply(inputs)
diff --git a/tensorflow/python/layers/core_test.py b/tensorflow/python/layers/core_test.py
index d26f3f4789..0343bfa8bd 100644
--- a/tensorflow/python/layers/core_test.py
+++ b/tensorflow/python/layers/core_test.py
@@ -476,6 +476,40 @@ class FlattenTest(test.TestCase):
shape = core_layers.Flatten().compute_output_shape((None, 3, None))
self.assertEqual(shape.as_list(), [None, None])
+ def testDataFormat5d(self):
+ np_input_channels_last = np.arange(
+ 120, dtype='float32').reshape([1, 5, 4, 3, 2])
+
+ with self.test_session() as sess:
+ x = array_ops.placeholder(shape=(1, 5, 4, 3, 2), dtype='float32')
+ y = core_layers.Flatten(data_format='channels_last')(x)
+ np_output_cl = sess.run(y, feed_dict={x: np_input_channels_last})
+
+ x = array_ops.placeholder(shape=(1, 2, 5, 4, 3), dtype='float32')
+ y = core_layers.Flatten(data_format='channels_first')(x)
+ np_input_channels_first = np.transpose(np_input_channels_last,
+ [0, 4, 1, 2, 3])
+ np_output_cf = sess.run(y, feed_dict={x: np_input_channels_first})
+
+ self.assertAllEqual(np_output_cl, np_output_cf)
+
+ def testDataFormat4d(self):
+ np_input_channels_last = np.arange(
+ 24, dtype='float32').reshape([1, 4, 3, 2])
+
+ with self.test_session() as sess:
+ x = array_ops.placeholder(shape=(1, 4, 3, 2), dtype='float32')
+ y = core_layers.Flatten(data_format='channels_last')(x)
+ np_output_cl = sess.run(y, feed_dict={x: np_input_channels_last})
+
+ x = array_ops.placeholder(shape=(1, 2, 4, 3), dtype='float32')
+ y = core_layers.Flatten(data_format='channels_first')(x)
+ np_input_channels_first = np.transpose(np_input_channels_last,
+ [0, 3, 1, 2])
+ np_output_cf = sess.run(y, feed_dict={x: np_input_channels_first})
+
+ self.assertAllEqual(np_output_cl, np_output_cf)
+
def testFunctionalFlatten(self):
x = array_ops.placeholder(shape=(None, 2, 3), dtype='float32')
y = core_layers.flatten(x, name='flatten')