aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/backend.py
diff options
context:
space:
mode:
authorGravatar Yan Facai (颜发才) <facai.yan@gmail.com>2018-06-22 14:16:16 +0800
committerGravatar Yan Facai (颜发才) <facai.yan@gmail.com>2018-06-22 14:16:16 +0800
commit13241bb8c73746f0af81bfbf2c72f5b42e47f82b (patch)
treeba2ad8902dab51c9ab209286a09aefd7b22bf0db /tensorflow/python/keras/backend.py
parenta5840964e0eb422fbe73dd3738c8d14c1147276f (diff)
parent359f53686c87ee76e80353c32a3d22cfb1cf0989 (diff)
Merge remote-tracking branch 'upstream/master' into ENH/support_run_configs_for_keras_model
Diffstat (limited to 'tensorflow/python/keras/backend.py')
-rw-r--r--tensorflow/python/keras/backend.py246
1 files changed, 143 insertions, 103 deletions
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index 8abdd5238a..5f302d8e90 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -22,6 +22,7 @@ from __future__ import division
from __future__ import print_function
import collections
+import itertools
import json
import os
import weakref
@@ -2885,7 +2886,10 @@ class Function(object):
feed_arrays.append(tensor)
# We need to do array conversion and type casting at this level, since
# `callable_fn` only supports exact matches.
- array_vals.append(np.asarray(value, dtype=tensor.dtype.base_dtype.name))
+ tensor_type = dtypes_module.as_dtype(tensor.dtype)
+ array_vals.append(np.asarray(value,
+ dtype=tensor_type.as_numpy_dtype))
+
if self.feed_dict:
for key in sorted(self.feed_dict.keys()):
array_vals.append(
@@ -2978,30 +2982,29 @@ def rnn(step_function,
Arguments:
step_function: RNN step function.
- Parameters;
- input; tensor with shape `(samples, ...)` (no time dimension),
+ Args;
+ input; Tensor with shape `(samples, ...)` (no time dimension),
representing input for the batch of samples at a certain
time step.
- states; list of tensors.
+ states; List of tensors.
Returns;
- output; tensor with shape `(samples, output_dim)`
+ output; Tensor with shape `(samples, output_dim)`
(no time dimension).
- new_states; list of tensors, same length and shapes
+ new_states; List of tensors, same length and shapes
as 'states'. The first state in the list must be the
output tensor at the previous timestep.
- inputs: tensor of temporal data of shape `(samples, time, ...)`
+ inputs: Tensor of temporal data of shape `(samples, time, ...)`
(at least 3D).
- initial_states: tensor with shape (samples, output_dim)
+ initial_states: Tensor with shape `(samples, output_dim)`
(no time dimension),
containing the initial values for the states used in
the step function.
- go_backwards: boolean. If True, do the iteration over the time
+ go_backwards: Boolean. If True, do the iteration over the time
dimension in reverse order and return the reversed sequence.
- mask: binary tensor with shape `(samples, time, 1)`,
+ mask: Binary tensor with shape `(samples, time, 1)`,
with a zero for every element that is masked.
- constants: a list of constant values passed at each step.
- unroll: whether to unroll the RNN or to use a symbolic loop
- (`while_loop` or `scan` depending on backend).
+ constants: List of constant values passed at each step.
+ unroll: Whether to unroll the RNN or to use a symbolic `while_loop`.
input_length: If specified, assume time dimension is of this length.
Returns:
@@ -3642,12 +3645,12 @@ def _preprocess_conv1d_input(x, data_format):
Returns:
A tensor.
"""
- tf_data_format = 'NHWC' # to pass TF Conv2dNative operations
+ tf_data_format = 'NWC' # to pass TF Conv2dNative operations
if data_format == 'channels_first':
if not _has_nchw_support():
x = array_ops.transpose(x, (0, 2, 1)) # NCW -> NWC
else:
- tf_data_format = 'NCHW'
+ tf_data_format = 'NCW'
return x, tf_data_format
@@ -3746,10 +3749,8 @@ def conv1d(x,
x = temporal_padding(x, (left_pad, 0))
padding = 'valid'
padding = _preprocess_padding(padding)
- if data_format == 'channels_last':
- tf_data_format = 'NWC'
- else:
- tf_data_format = 'NCW'
+
+ x, tf_data_format = _preprocess_conv1d_input(x, data_format)
x = nn.convolution(
input=x,
filter=kernel,
@@ -3757,6 +3758,8 @@ def conv1d(x,
strides=(strides,),
padding=padding,
data_format=tf_data_format)
+ if data_format == 'channels_first' and tf_data_format == 'NWC':
+ x = array_ops.transpose(x, (0, 2, 1)) # NWC -> NCW
return x
@@ -3897,11 +3900,16 @@ def separable_conv1d(x,
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format: ' + str(data_format))
+ if isinstance(strides, int):
+ strides = (strides,)
+ if isinstance(dilation_rate, int):
+ dilation_rate = (dilation_rate,)
+
x, tf_data_format = _preprocess_conv1d_input(x, data_format)
padding = _preprocess_padding(padding)
if not isinstance(strides, tuple):
strides = tuple(strides)
- if tf_data_format == 'NHWC':
+ if tf_data_format == 'NWC':
spatial_start_dim = 1
strides = (1,) + strides * 2 + (1,)
else:
@@ -3923,7 +3931,7 @@ def separable_conv1d(x,
x = array_ops.squeeze(x, [spatial_start_dim])
- if data_format == 'channels_first' and tf_data_format == 'NHWC':
+ if data_format == 'channels_first' and tf_data_format == 'NWC':
x = array_ops.transpose(x, (0, 2, 1)) # NWC -> NCW
return x
@@ -4243,58 +4251,115 @@ def pool3d(x,
return x
-def local_conv1d(inputs, kernel, kernel_size, strides, data_format=None):
- """Apply 1D conv with un-shared weights.
-
- Arguments:
- inputs: 3D tensor with shape:
- (batch_size, steps, input_dim)
- if data_format is "channels_last" or
- (batch_size, input_dim, steps)
- if data_format is "channels_first".
- kernel: the unshared weight for convolution,
- with shape (output_length, feature_dim, filters)
- kernel_size: a tuple of a single integer,
- specifying the length of the 1D convolution window
- strides: a tuple of a single integer,
- specifying the stride length of the convolution
- data_format: the data format, channels_first or channels_last
-
- Returns:
- the tensor after 1d conv with un-shared weights, with shape (batch_size,
- output_length, filters)
+def local_conv(inputs,
+ kernel,
+ kernel_size,
+ strides,
+ output_shape,
+ data_format=None):
+ """Apply N-D convolution with un-shared weights.
+
+ Arguments:
+ inputs: (N+2)-D tensor with shape
+ (batch_size, channels_in, d_in1, ..., d_inN)
+ if data_format='channels_first', or
+ (batch_size, d_in1, ..., d_inN, channels_in)
+ if data_format='channels_last'.
+ kernel: the unshared weight for N-D convolution,
+ with shape (output_items, feature_dim, channels_out), where
+ feature_dim = np.prod(kernel_size) * channels_in,
+ output_items = np.prod(output_shape).
+ kernel_size: a tuple of N integers, specifying the
+ spatial dimensions of the N-D convolution window.
+ strides: a tuple of N integers, specifying the strides
+ of the convolution along the spatial dimensions.
+ output_shape: a tuple of (d_out1, ..., d_outN) specifying the spatial
+ dimensionality of the output.
+ data_format: string, "channels_first" or "channels_last".
+
+ Returns:
+ An (N+2)-D tensor with shape:
+ (batch_size, channels_out) + output_shape
+ if data_format='channels_first', or:
+ (batch_size,) + output_shape + (channels_out,)
+ if data_format='channels_last'.
Raises:
- ValueError: if `data_format` is neither `channels_last` or
- `channels_first`.
+ ValueError: if `data_format` is neither
+ `channels_last` nor `channels_first`.
"""
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format: ' + str(data_format))
- stride = strides[0]
kernel_shape = int_shape(kernel)
- output_length = kernel_shape[0]
feature_dim = kernel_shape[1]
+ channels_out = kernel_shape[-1]
+ ndims = len(output_shape)
+ spatial_dimensions = list(range(ndims))
xs = []
- for i in range(output_length):
- slice_length = slice(i * stride, i * stride + kernel_size[0])
+ output_axes_ticks = [range(axis_max) for axis_max in output_shape]
+ for position in itertools.product(*output_axes_ticks):
+ slices = [slice(None)]
+
if data_format == 'channels_first':
- xs.append(reshape(inputs[:, :, slice_length], (1, -1, feature_dim)))
- else:
- xs.append(reshape(inputs[:, slice_length, :], (1, -1, feature_dim)))
+ slices.append(slice(None))
+
+ slices.extend([slice(position[d] * strides[d],
+ position[d] * strides[d] + kernel_size[d])
+ for d in spatial_dimensions])
+
+ if data_format == 'channels_last':
+ slices.append(slice(None))
+
+ xs.append(reshape(inputs[slices], (1, -1, feature_dim)))
x_aggregate = concatenate(xs, axis=0)
- # Shape: `(output_length, batch_size, filters)`.
output = batch_dot(x_aggregate, kernel)
+ output = reshape(output, output_shape + (-1, channels_out))
if data_format == 'channels_first':
- output = permute_dimensions(output, (1, 2, 0))
+ permutation = [ndims, ndims + 1] + spatial_dimensions
else:
- output = permute_dimensions(output, (1, 0, 2))
- return output
+ permutation = [ndims] + spatial_dimensions + [ndims + 1]
+
+ return permute_dimensions(output, permutation)
+
+
+def local_conv1d(inputs, kernel, kernel_size, strides, data_format=None):
+ """Apply 1D conv with un-shared weights.
+
+ Arguments:
+ inputs: 3D tensor with shape:
+ (batch_size, steps, input_dim)
+ if data_format is "channels_last" or
+ (batch_size, input_dim, steps)
+ if data_format is "channels_first".
+ kernel: the unshared weight for convolution,
+ with shape (output_length, feature_dim, filters).
+ kernel_size: a tuple of a single integer,
+ specifying the length of the 1D convolution window.
+ strides: a tuple of a single integer,
+ specifying the stride length of the convolution.
+ data_format: the data format, channels_first or channels_last.
+
+ Returns:
+ A 3d tensor with shape:
+ (batch_size, output_length, filters)
+ if data_format='channels_first'
+ or 3D tensor with shape:
+ (batch_size, filters, output_length)
+ if data_format='channels_last'.
+ """
+ output_shape = (kernel.shape[0],)
+ return local_conv(inputs,
+ kernel,
+ kernel_size,
+ strides,
+ output_shape,
+ data_format)
def local_conv2d(inputs,
@@ -4307,64 +4372,34 @@ def local_conv2d(inputs,
Arguments:
inputs: 4D tensor with shape:
- (batch_size, filters, new_rows, new_cols)
- if data_format='channels_first'
- or 4D tensor with shape:
- (batch_size, new_rows, new_cols, filters)
- if data_format='channels_last'.
+ (batch_size, filters, new_rows, new_cols)
+ if data_format='channels_first'
+ or 4D tensor with shape:
+ (batch_size, new_rows, new_cols, filters)
+ if data_format='channels_last'.
kernel: the unshared weight for convolution,
- with shape (output_items, feature_dim, filters)
+ with shape (output_items, feature_dim, filters).
kernel_size: a tuple of 2 integers, specifying the
- width and height of the 2D convolution window.
+ width and height of the 2D convolution window.
strides: a tuple of 2 integers, specifying the strides
- of the convolution along the width and height.
- output_shape: a tuple with (output_row, output_col)
- data_format: the data format, channels_first or channels_last
+ of the convolution along the width and height.
+ output_shape: a tuple with (output_row, output_col).
+ data_format: the data format, channels_first or channels_last.
Returns:
- A 4d tensor with shape:
+ A 4D tensor with shape:
(batch_size, filters, new_rows, new_cols)
if data_format='channels_first'
or 4D tensor with shape:
(batch_size, new_rows, new_cols, filters)
if data_format='channels_last'.
-
- Raises:
- ValueError: if `data_format` is neither
- `channels_last` or `channels_first`.
"""
- if data_format is None:
- data_format = image_data_format()
- if data_format not in {'channels_first', 'channels_last'}:
- raise ValueError('Unknown data_format: ' + str(data_format))
-
- stride_row, stride_col = strides
- output_row, output_col = output_shape
- kernel_shape = int_shape(kernel)
- feature_dim = kernel_shape[1]
- filters = kernel_shape[2]
-
- xs = []
- for i in range(output_row):
- for j in range(output_col):
- slice_row = slice(i * stride_row, i * stride_row + kernel_size[0])
- slice_col = slice(j * stride_col, j * stride_col + kernel_size[1])
- if data_format == 'channels_first':
- xs.append(
- reshape(inputs[:, :, slice_row, slice_col], (1, -1, feature_dim)))
- else:
- xs.append(
- reshape(inputs[:, slice_row, slice_col, :], (1, -1, feature_dim)))
-
- x_aggregate = concatenate(xs, axis=0)
- output = batch_dot(x_aggregate, kernel)
- output = reshape(output, (output_row, output_col, -1, filters))
-
- if data_format == 'channels_first':
- output = permute_dimensions(output, (2, 3, 0, 1))
- else:
- output = permute_dimensions(output, (2, 0, 1, 3))
- return output
+ return local_conv(inputs,
+ kernel,
+ kernel_size,
+ strides,
+ output_shape,
+ data_format)
@tf_export('keras.backend.bias_add')
@@ -4722,8 +4757,13 @@ def foldr(fn, elems, initializer=None, name=None):
# Load Keras default configuration from config file if present.
-_keras_base_dir = os.path.expanduser('~')
-_keras_dir = os.path.join(_keras_base_dir, '.keras')
+# Set Keras base dir path given KERAS_HOME env variable, if applicable.
+# Otherwise either ~/.keras or /tmp.
+if 'KERAS_HOME' in os.environ:
+ _keras_dir = os.environ.get('KERAS_HOME')
+else:
+ _keras_base_dir = os.path.expanduser('~')
+ _keras_dir = os.path.join(_keras_base_dir, '.keras')
_config_path = os.path.expanduser(os.path.join(_keras_dir, 'keras.json'))
if os.path.exists(_config_path):
try: