aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/layers/core.py16
-rw-r--r--tensorflow/python/layers/core_test.py34
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.layers.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.layers.pbtxt2
4 files changed, 50 insertions, 4 deletions
diff --git a/tensorflow/python/layers/core.py b/tensorflow/python/layers/core.py
index 9879e5020f..e06e9aba4a 100644
--- a/tensorflow/python/layers/core.py
+++ b/tensorflow/python/layers/core.py
@@ -269,6 +269,13 @@ def dropout(inputs,
class Flatten(keras_layers.Flatten, base.Layer):
"""Flattens an input tensor while preserving the batch axis (axis 0).
+ Arguments:
+ data_format: A string, one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, ..., channels)` while `channels_first` corresponds to
+ inputs with shape `(batch, channels, ...)`.
+
Examples:
```
@@ -285,12 +292,17 @@ class Flatten(keras_layers.Flatten, base.Layer):
@tf_export('layers.flatten')
-def flatten(inputs, name=None):
+def flatten(inputs, name=None, data_format='channels_last'):
"""Flattens an input tensor while preserving the batch axis (axis 0).
Arguments:
inputs: Tensor input.
name: The name of the layer (string).
+ data_format: A string, one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, height, width, channels)` while `channels_first` corresponds to
+ inputs with shape `(batch, channels, height, width)`.
Returns:
Reshaped tensor.
@@ -307,7 +319,7 @@ def flatten(inputs, name=None):
# now `y` has shape `(None, None)`
```
"""
- layer = Flatten(name=name)
+ layer = Flatten(name=name, data_format=data_format)
return layer.apply(inputs)
diff --git a/tensorflow/python/layers/core_test.py b/tensorflow/python/layers/core_test.py
index d26f3f4789..0343bfa8bd 100644
--- a/tensorflow/python/layers/core_test.py
+++ b/tensorflow/python/layers/core_test.py
@@ -476,6 +476,40 @@ class FlattenTest(test.TestCase):
shape = core_layers.Flatten().compute_output_shape((None, 3, None))
self.assertEqual(shape.as_list(), [None, None])
+ def testDataFormat5d(self):
+ np_input_channels_last = np.arange(
+ 120, dtype='float32').reshape([1, 5, 4, 3, 2])
+
+ with self.test_session() as sess:
+ x = array_ops.placeholder(shape=(1, 5, 4, 3, 2), dtype='float32')
+ y = core_layers.Flatten(data_format='channels_last')(x)
+ np_output_cl = sess.run(y, feed_dict={x: np_input_channels_last})
+
+ x = array_ops.placeholder(shape=(1, 2, 5, 4, 3), dtype='float32')
+ y = core_layers.Flatten(data_format='channels_first')(x)
+ np_input_channels_first = np.transpose(np_input_channels_last,
+ [0, 4, 1, 2, 3])
+ np_output_cf = sess.run(y, feed_dict={x: np_input_channels_first})
+
+ self.assertAllEqual(np_output_cl, np_output_cf)
+
+ def testDataFormat4d(self):
+ np_input_channels_last = np.arange(
+ 24, dtype='float32').reshape([1, 4, 3, 2])
+
+ with self.test_session() as sess:
+ x = array_ops.placeholder(shape=(1, 4, 3, 2), dtype='float32')
+ y = core_layers.Flatten(data_format='channels_last')(x)
+ np_output_cl = sess.run(y, feed_dict={x: np_input_channels_last})
+
+ x = array_ops.placeholder(shape=(1, 2, 4, 3), dtype='float32')
+ y = core_layers.Flatten(data_format='channels_first')(x)
+ np_input_channels_first = np.transpose(np_input_channels_last,
+ [0, 3, 1, 2])
+ np_output_cf = sess.run(y, feed_dict={x: np_input_channels_first})
+
+ self.assertAllEqual(np_output_cl, np_output_cf)
+
def testFunctionalFlatten(self):
x = array_ops.placeholder(shape=(None, 2, 3), dtype='float32')
y = core_layers.flatten(x, name='flatten')
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.pbtxt
index df74c32e1f..0c24e9c7dd 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.layers.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.pbtxt
@@ -122,7 +122,7 @@ tf_module {
}
member_method {
name: "flatten"
- argspec: "args=[\'inputs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'inputs\', \'name\', \'data_format\'], varargs=None, keywords=None, defaults=[\'None\', \'channels_last\'], "
}
member_method {
name: "max_pooling1d"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.layers.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.layers.pbtxt
index df74c32e1f..0c24e9c7dd 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.layers.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.layers.pbtxt
@@ -122,7 +122,7 @@ tf_module {
}
member_method {
name: "flatten"
- argspec: "args=[\'inputs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'inputs\', \'name\', \'data_format\'], varargs=None, keywords=None, defaults=[\'None\', \'channels_last\'], "
}
member_method {
name: "max_pooling1d"