diff options
author | Joe Yearsley <josephelliotyearsley@gmail.com> | 2018-03-13 07:23:18 +0000 |
---|---|---|
committer | Joe Yearsley <josephelliotyearsley@gmail.com> | 2018-09-29 17:31:57 +0100 |
commit | eb6c1bdcbf6093888f2b443fdb49f836f3352316 (patch) | |
tree | 4707374e9539b34f3e4e9aac21bd0843f3a76755 /tensorflow/python/layers | |
parent | d8db18b4201d9d82d1c93ed5453914ff16f1adf4 (diff) |
Update core.py
Added `data_format` to flatten to allow changing of it during inference time.
Diffstat (limited to 'tensorflow/python/layers')
-rw-r--r-- | tensorflow/python/layers/core.py | 18 |
1 files changed, 15 insertions, 3 deletions
diff --git a/tensorflow/python/layers/core.py b/tensorflow/python/layers/core.py index 9879e5020f..5f89e3c0c3 100644 --- a/tensorflow/python/layers/core.py +++ b/tensorflow/python/layers/core.py @@ -268,7 +268,14 @@ def dropout(inputs, @tf_export('layers.Flatten') class Flatten(keras_layers.Flatten, base.Layer): """Flattens an input tensor while preserving the batch axis (axis 0). - + + Arguments: + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, ..., channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, ...)`. + Examples: ``` @@ -285,11 +292,16 @@ class Flatten(keras_layers.Flatten, base.Layer): @tf_export('layers.flatten') -def flatten(inputs, name=None): +def flatten(inputs, data_format='channels_last', name=None): """Flattens an input tensor while preserving the batch axis (axis 0). Arguments: inputs: Tensor input. + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, height, width, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, height, width)`. name: The name of the layer (string). Returns: @@ -307,7 +319,7 @@ def flatten(inputs, name=None): # now `y` has shape `(None, None)` ``` """ - layer = Flatten(name=name) + layer = Flatten(data_format=data_format, name=name) return layer.apply(inputs) |