aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/layers
diff options
context:
space:
mode:
authorGravatar Joe Yearsley <josephelliotyearsley@gmail.com>2018-03-13 07:23:18 +0000
committerGravatar Joe Yearsley <josephelliotyearsley@gmail.com>2018-09-29 17:31:57 +0100
commiteb6c1bdcbf6093888f2b443fdb49f836f3352316 (patch)
tree4707374e9539b34f3e4e9aac21bd0843f3a76755 /tensorflow/python/layers
parentd8db18b4201d9d82d1c93ed5453914ff16f1adf4 (diff)
Update core.py
Added `data_format` to flatten to allow changing of it during inference time.
Diffstat (limited to 'tensorflow/python/layers')
-rw-r--r--tensorflow/python/layers/core.py18
1 files changed, 15 insertions, 3 deletions
diff --git a/tensorflow/python/layers/core.py b/tensorflow/python/layers/core.py
index 9879e5020f..5f89e3c0c3 100644
--- a/tensorflow/python/layers/core.py
+++ b/tensorflow/python/layers/core.py
@@ -268,7 +268,14 @@ def dropout(inputs,
@tf_export('layers.Flatten')
class Flatten(keras_layers.Flatten, base.Layer):
"""Flattens an input tensor while preserving the batch axis (axis 0).
-
+
+ Arguments:
+ data_format: A string, one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, ..., channels)` while `channels_first` corresponds to
+ inputs with shape `(batch, channels, ...)`.
+
Examples:
```
@@ -285,11 +292,16 @@ class Flatten(keras_layers.Flatten, base.Layer):
@tf_export('layers.flatten')
-def flatten(inputs, name=None):
+def flatten(inputs, data_format='channels_last', name=None):
"""Flattens an input tensor while preserving the batch axis (axis 0).
Arguments:
inputs: Tensor input.
+ data_format: A string, one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, height, width, channels)` while `channels_first` corresponds to
+ inputs with shape `(batch, channels, height, width)`.
name: The name of the layer (string).
Returns:
@@ -307,7 +319,7 @@ def flatten(inputs, name=None):
# now `y` has shape `(None, None)`
```
"""
- layer = Flatten(name=name)
+ layer = Flatten(data_format=data_format, name=name)
return layer.apply(inputs)