aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Francois Chollet <fchollet@google.com>2017-04-04 15:51:13 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-04 17:13:08 -0700
commit9477900946f923cb43ed76ed215490d01474bfe7 (patch)
tree96761e0793fea27992ec7485807138bf72b86868
parent8f74d595ef8a7a617596529e129d2a05dd18bb90 (diff)
Backport fixes and improvements from external Keras.
Change: 152198296
-rw-r--r--tensorflow/contrib/keras/python/keras/__init__.py2
-rw-r--r--tensorflow/contrib/keras/python/keras/activations.py24
-rw-r--r--tensorflow/contrib/keras/python/keras/applications/resnet50.py4
-rw-r--r--tensorflow/contrib/keras/python/keras/backend.py94
-rw-r--r--tensorflow/contrib/keras/python/keras/engine/topology.py30
-rw-r--r--tensorflow/contrib/keras/python/keras/engine/training.py26
-rw-r--r--tensorflow/contrib/keras/python/keras/initializers.py9
-rw-r--r--tensorflow/contrib/keras/python/keras/layers/convolutional.py36
-rw-r--r--tensorflow/contrib/keras/python/keras/layers/convolutional_recurrent.py2
-rw-r--r--tensorflow/contrib/keras/python/keras/layers/core.py6
-rw-r--r--tensorflow/contrib/keras/python/keras/layers/local.py16
-rw-r--r--tensorflow/contrib/keras/python/keras/layers/merge.py156
-rw-r--r--tensorflow/contrib/keras/python/keras/layers/normalization.py2
-rw-r--r--tensorflow/contrib/keras/python/keras/layers/pooling.py16
-rw-r--r--tensorflow/contrib/keras/python/keras/layers/recurrent.py26
-rw-r--r--tensorflow/contrib/keras/python/keras/layers/wrappers.py50
-rw-r--r--tensorflow/contrib/keras/python/keras/metrics.py9
-rw-r--r--tensorflow/contrib/keras/python/keras/models.py23
-rw-r--r--tensorflow/contrib/keras/python/keras/preprocessing/image.py2
-rw-r--r--tensorflow/contrib/keras/python/keras/utils/generic_utils.py5
-rw-r--r--tensorflow/contrib/keras/python/keras/utils/layer_utils.py2
-rw-r--r--tensorflow/contrib/keras/python/keras/wrappers/scikit_learn.py34
22 files changed, 424 insertions, 150 deletions
diff --git a/tensorflow/contrib/keras/python/keras/__init__.py b/tensorflow/contrib/keras/python/keras/__init__.py
index cdfc40dff1..ec316253db 100644
--- a/tensorflow/contrib/keras/python/keras/__init__.py
+++ b/tensorflow/contrib/keras/python/keras/__init__.py
@@ -37,4 +37,4 @@ from tensorflow.contrib.keras.python.keras import utils
from tensorflow.contrib.keras.python.keras import wrappers
-__version__ = '2.0.0-tf'
+__version__ = '2.0.2-tf'
diff --git a/tensorflow/contrib/keras/python/keras/activations.py b/tensorflow/contrib/keras/python/keras/activations.py
index 1eac52dfad..67762c83ba 100644
--- a/tensorflow/contrib/keras/python/keras/activations.py
+++ b/tensorflow/contrib/keras/python/keras/activations.py
@@ -24,18 +24,28 @@ from tensorflow.contrib.keras.python.keras import backend as K
from tensorflow.contrib.keras.python.keras.utils.generic_utils import deserialize_keras_object
-def softmax(x):
+def softmax(x, axis=-1):
+ """Softmax activation function.
+
+ Arguments:
+ x : Tensor.
+ axis: Integer, axis along which the softmax normalization is applied.
+
+ Returns:
+ Tensor, output of softmax transformation.
+
+ Raises:
+ ValueError: In case `dim(x) == 1`.
+ """
ndim = K.ndim(x)
if ndim == 2:
return K.softmax(x)
- elif ndim == 3:
- e = K.exp(x - K.max(x, axis=-1, keepdims=True))
- s = K.sum(e, axis=-1, keepdims=True)
+ elif ndim > 2:
+ e = K.exp(x - K.max(x, axis=axis, keepdims=True))
+ s = K.sum(e, axis=axis, keepdims=True)
return e / s
else:
- raise ValueError('Cannot apply softmax to a tensor '
- 'that is not 2D or 3D. '
- 'Here, ndim=' + str(ndim))
+ raise ValueError('Cannot apply softmax to a tensor that is 1D')
def elu(x, alpha=1.0):
diff --git a/tensorflow/contrib/keras/python/keras/applications/resnet50.py b/tensorflow/contrib/keras/python/keras/applications/resnet50.py
index 546fcb9433..12f7ca424e 100644
--- a/tensorflow/contrib/keras/python/keras/applications/resnet50.py
+++ b/tensorflow/contrib/keras/python/keras/applications/resnet50.py
@@ -163,8 +163,8 @@ def ResNet50(include_top=True,
specified in your Keras config file.
Arguments:
- include_top: whether to include the 3 fully-connected
- layers at the top of the network.
+ include_top: whether to include the fully-connected
+ layer at the top of the network.
weights: one of `None` (random initialization)
or "imagenet" (pre-training on ImageNet).
input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
diff --git a/tensorflow/contrib/keras/python/keras/backend.py b/tensorflow/contrib/keras/python/keras/backend.py
index 9769bce3b0..d7c646c19a 100644
--- a/tensorflow/contrib/keras/python/keras/backend.py
+++ b/tensorflow/contrib/keras/python/keras/backend.py
@@ -22,7 +22,6 @@ from __future__ import division
from __future__ import print_function
from collections import defaultdict
-import errno
import json
import os
import warnings
@@ -270,6 +269,7 @@ def clear_session():
reset_uids()
_SESSION = None
phase = array_ops.placeholder(dtype='bool', name='keras_learning_phase')
+ _GRAPH_LEARNING_PHASES = {}
_GRAPH_LEARNING_PHASES[ops.get_default_graph()] = phase
@@ -1257,6 +1257,34 @@ def prod(x, axis=None, keepdims=False):
return math_ops.reduce_prod(x, reduction_indices=axis, keep_dims=keepdims)
+def cumsum(x, axis=0):
+ """Cumulative sum of the values in a tensor, alongside the specified axis.
+
+ Arguments:
+ x: A tensor or variable.
+ axis: An integer, the axis to compute the sum.
+
+ Returns:
+ A tensor of the cumulative sum of values of `x` along `axis`.
+ """
+ axis = _normalize_axis(axis, ndim(x))
+ return math_ops.cumsum(x, axis=axis)
+
+
+def cumprod(x, axis=0):
+ """Cumulative product of the values in a tensor, alongside the specified axis.
+
+ Arguments:
+ x: A tensor or variable.
+ axis: An integer, the axis to compute the product.
+
+ Returns:
+ A tensor of the cumulative product of values of `x` along `axis`.
+ """
+ axis = _normalize_axis(axis, ndim(x))
+ return math_ops.cumprod(x, axis=axis)
+
+
def var(x, axis=None, keepdims=False):
"""Variance of a tensor, alongside the specified axis.
@@ -1330,8 +1358,7 @@ def any(x, axis=None, keepdims=False):
"""
axis = _normalize_axis(axis, ndim(x))
x = math_ops.cast(x, dtypes_module.bool)
- x = math_ops.reduce_any(x, reduction_indices=axis, keep_dims=keepdims)
- return math_ops.cast(x, dtypes_module.uint8)
+ return math_ops.reduce_any(x, reduction_indices=axis, keep_dims=keepdims)
def all(x, axis=None, keepdims=False):
@@ -1347,8 +1374,7 @@ def all(x, axis=None, keepdims=False):
"""
axis = _normalize_axis(axis, ndim(x))
x = math_ops.cast(x, dtypes_module.bool)
- x = math_ops.reduce_all(x, reduction_indices=axis, keep_dims=keepdims)
- return math_ops.cast(x, dtypes_module.uint8)
+ return math_ops.reduce_all(x, reduction_indices=axis, keep_dims=keepdims)
def argmax(x, axis=-1):
@@ -1645,7 +1671,7 @@ def normalize_batch_in_training(x, gamma, beta, reduction_axes, epsilon=1e-3):
"""
mean, var = nn.moments(
x, reduction_axes, shift=None, name=None, keep_dims=False)
- if sorted(reduction_axes) == range(ndim(x))[:-1]:
+ if sorted(reduction_axes) == list(range(ndim(x)))[:-1]:
normed = nn.batch_normalization(x, mean, var, beta, gamma, epsilon)
else:
# need broadcasting
@@ -2324,8 +2350,8 @@ def rnn(step_function,
(no time dimension),
containing the initial values for the states used in
the step function.
- go_backwards: boolean. If True, do the iteration over
- the time dimension in reverse order.
+ go_backwards: boolean. If True, do the iteration over the time
+ dimension in reverse order and return the reversed sequence.
mask: binary tensor with shape `(samples, time, 1)`,
with a zero for every element that is masked.
constants: a list of constant values passed at each step.
@@ -2414,9 +2440,9 @@ def rnn(step_function,
states = return_states
successive_outputs.append(output)
successive_states.append(states)
- last_output = successive_outputs[-1]
- new_states = successive_states[-1]
- outputs = array_ops.stack(successive_outputs)
+ last_output = successive_outputs[-1]
+ new_states = successive_states[-1]
+ outputs = array_ops.stack(successive_outputs)
else:
for inp in input_list:
output, states = step_function(inp, states + constants)
@@ -3534,19 +3560,19 @@ def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1):
# HIGH ORDER FUNCTIONS
-def map_fn(fn, elems, name=None):
+def map_fn(fn, elems, name=None, dtype=None):
"""Map the function fn over the elements elems and return the outputs.
Arguments:
fn: Callable that will be called upon each element in elems
elems: tensor
name: A string name for the map node in the graph
+ dtype: Output data type.
Returns:
- Tensor with first dimension equal to the elems and second depending on
- fn
+ Tensor with dtype `dtype`.
"""
- return functional_ops.map_fn(fn, elems, name=name)
+ return functional_ops.map_fn(fn, elems, name=name, dtype=dtype)
def foldl(fn, elems, initializer=None, name=None):
@@ -3560,7 +3586,7 @@ def foldl(fn, elems, initializer=None, name=None):
name: A string name for the foldl node in the graph
Returns:
- Same type and shape as initializer
+ Tensor with same type and shape as `initializer`.
"""
return functional_ops.foldl(fn, elems, initializer=initializer, name=name)
@@ -3583,27 +3609,39 @@ def foldr(fn, elems, initializer=None, name=None):
# Load Keras default configuration from config file if present.
_keras_base_dir = os.path.expanduser('~')
-if not os.access(_keras_base_dir, os.W_OK):
- _keras_base_dir = '/tmp'
_keras_dir = os.path.join(_keras_base_dir, '.keras')
-if not os.path.exists(_keras_dir):
- try:
- os.makedirs(_keras_dir)
- except OSError as e:
- if e.errno == errno.EEXIST:
- pass
- else:
- raise
_config_path = os.path.expanduser(os.path.join(_keras_dir, 'keras.json'))
if os.path.exists(_config_path):
- _config = json.load(open(_config_path))
+ try:
+ _config = json.load(open(_config_path))
+ except json.decoder.JSONDecodeError:
+ _config = {}
_floatx = _config.get('floatx', floatx())
assert _floatx in {'float16', 'float32', 'float64'}
_epsilon = _config.get('epsilon', epsilon())
assert isinstance(_epsilon, float)
- _backend = backend()
_image_data_format = _config.get('image_data_format', image_data_format())
assert _image_data_format in {'channels_last', 'channels_first'}
set_floatx(_floatx)
set_epsilon(_epsilon)
set_image_data_format(_image_data_format)
+
+# Save config file.
+if os.access(_keras_base_dir, os.W_OK):
+ if not os.path.exists(_keras_dir):
+ try:
+ os.makedirs(_keras_dir)
+ except OSError:
+ # Except potential race conditions
+ # in multi-threaded environments.
+ pass
+
+ if not os.path.exists(_config_path):
+ _config = {
+ 'floatx': floatx(),
+ 'epsilon': epsilon(),
+ 'backend': 'tensorflow',
+ 'image_data_format': image_data_format()
+ }
+ with open(_config_path, 'w') as f:
+ f.write(json.dumps(_config, indent=4))
diff --git a/tensorflow/contrib/keras/python/keras/engine/topology.py b/tensorflow/contrib/keras/python/keras/engine/topology.py
index 0f506ff0a4..e33268235f 100644
--- a/tensorflow/contrib/keras/python/keras/engine/topology.py
+++ b/tensorflow/contrib/keras/python/keras/engine/topology.py
@@ -295,8 +295,14 @@ class Layer(object):
# are only applicable to input layers: do not pass these keywords
# to non-input layers.
allowed_kwargs = {
- 'input_shape', 'batch_input_shape', 'batch_size', 'dtype', 'name',
- 'trainable', 'weights'
+ 'input_shape',
+ 'batch_input_shape',
+ 'batch_size',
+ 'dtype',
+ 'name',
+ 'trainable',
+ 'weights',
+ 'input_dtype', # legacy
}
for kwarg in kwargs:
if kwarg not in allowed_kwargs:
@@ -320,8 +326,15 @@ class Layer(object):
batch_size = None
batch_input_shape = (batch_size,) + tuple(kwargs['input_shape'])
self.batch_input_shape = batch_input_shape
- dtype = kwargs.get('dtype', K.floatx())
+
+ # Set dtype.
+ dtype = kwargs.get('dtype')
+ if dtype is None:
+ dtype = kwargs.get('input_dtype')
+ if dtype is None:
+ dtype = K.floatx()
self.dtype = dtype
+
if 'weights' in kwargs:
self._initial_weights = kwargs['weights']
else:
@@ -485,11 +498,12 @@ class Layer(object):
': expected shape=' + str(spec.shape) +
', found shape=' + str(x_shape))
- def call(self, inputs):
+ def call(self, inputs, **kwargs): # pylint: disable=unused-argument
"""This is where the layer's logic lives.
Arguments:
- inputs: input tensor, or list/tuple of input tensors.
+ inputs: Input tensor, or list/tuple of input tensors.
+ **kwargs: Additional keyword arguments.
Returns:
A tensor or list/tuple of tensors.
@@ -518,6 +532,8 @@ class Layer(object):
ValueError: in case the layer is missing shape information
for its `build` call.
"""
+ if isinstance(inputs, list):
+ inputs = inputs[:]
with K.name_scope(self.name):
# Handle laying building (weight creating, input spec locking).
if not self.built:
@@ -1417,7 +1433,7 @@ class Container(Layer):
get_weights
set_weights
get_config
- get_output_shape_for
+ compute_output_shape
# Class Methods
from_config
@@ -2029,7 +2045,7 @@ class Container(Layer):
for i in range(len(input_shapes)):
layer = self.input_layers[i]
input_shape = input_shapes[i]
- # It's an input layer: get_output_shape_for is identity,
+ # It's an input layer: compute_output_shape is identity,
# and there is only one node and one tensor output.
shape_key = layer.name + '_0_0'
layers_to_output_shapes[shape_key] = input_shape
diff --git a/tensorflow/contrib/keras/python/keras/engine/training.py b/tensorflow/contrib/keras/python/keras/engine/training.py
index efd437f6f6..0097c4a1c2 100644
--- a/tensorflow/contrib/keras/python/keras/engine/training.py
+++ b/tensorflow/contrib/keras/python/keras/engine/training.py
@@ -733,11 +733,12 @@ class Model(Container):
loss_functions = []
for name in self.output_names:
if name not in loss:
- warnings.warn('Output "' + name + '" missing from loss dictionary. '
- 'We assume this was done on purpose, '
- 'and we will not be expecting '
- 'any data to be passed to "' + name +
- '" during training.')
+ warnings.warn(
+ 'Output "' + name + '" missing from loss dictionary. '
+ 'We assume this was done on purpose, '
+ 'and we will not be expecting '
+ 'any data to be passed to "' + name + '" during training.',
+ stacklevel=2)
loss_functions.append(losses.get(loss.get(name)))
elif isinstance(loss, list):
if len(loss) != len(self.outputs):
@@ -1202,7 +1203,7 @@ class Model(Container):
if batch_index == 0:
for batch_out in batch_outs:
shape = (samples,) + batch_out.shape[1:]
- outs.append(np.zeros(shape, dtype=K.floatx()))
+ outs.append(np.zeros(shape, dtype=batch_out.dtype))
for i, batch_out in enumerate(batch_outs):
outs[i][batch_start:batch_end] = batch_out
@@ -1718,7 +1719,7 @@ class Model(Container):
- a tuple (inputs, targets, sample_weights).
All arrays should contain the same number of samples.
The generator is expected to loop over its data
- indefinitely. An epoch finishes when `samples_per_epoch`
+ indefinitely. An epoch finishes when `steps_per_epoch`
samples have been seen by the model.
steps_per_epoch: Total number of steps (batches of samples)
to yield from `generator` before declaring one epoch
@@ -1767,7 +1768,7 @@ class Model(Container):
f.close()
model.fit_generator(generate_arrays_from_file('/my_file.txt'),
- samples_per_epoch=10000, epochs=10)
+ steps_per_epoch=10000, epochs=10)
```
Raises:
@@ -2028,7 +2029,8 @@ class Model(Container):
steps,
max_q_size=10,
workers=1,
- pickle_safe=False):
+ pickle_safe=False,
+ verbose=0):
"""Generates predictions for the input samples from a data generator.
The generator should return the same kind of data as accepted by
@@ -2048,6 +2050,7 @@ class Model(Container):
non picklable arguments to the generator
as they can't be passed
easily to children processes.
+ verbose: verbosity mode, 0 or 1.
Returns:
Numpy array(s) of predictions.
@@ -2067,6 +2070,9 @@ class Model(Container):
enqueuer = GeneratorEnqueuer(generator, pickle_safe=pickle_safe)
enqueuer.start(workers=workers, max_q_size=max_q_size)
+ if verbose == 1:
+ progbar = Progbar(target=steps)
+
while steps_done < steps:
generator_output = None
while enqueuer.is_running():
@@ -2103,6 +2109,8 @@ class Model(Container):
for i, out in enumerate(outs):
all_outs[i].append(out)
steps_done += 1
+ if verbose == 1:
+ progbar.update(steps_done)
finally:
if enqueuer is not None:
diff --git a/tensorflow/contrib/keras/python/keras/initializers.py b/tensorflow/contrib/keras/python/keras/initializers.py
index 621069f424..f9cb35e171 100644
--- a/tensorflow/contrib/keras/python/keras/initializers.py
+++ b/tensorflow/contrib/keras/python/keras/initializers.py
@@ -45,14 +45,16 @@ class Initializer(object):
class Zeros(Initializer):
- """Initializer that generates tensors initialized to 0."""
+ """Initializer that generates tensors initialized to 0.
+ """
def __call__(self, shape, dtype=None):
return K.constant(0, shape=shape, dtype=dtype)
class Ones(Initializer):
- """Initializer that generates tensors initialized to 1."""
+ """Initializer that generates tensors initialized to 1.
+ """
def __call__(self, shape, dtype=None):
return K.constant(1, shape=shape, dtype=dtype)
@@ -130,7 +132,7 @@ class RandomUniform(Initializer):
class TruncatedNormal(Initializer):
"""Initializer that generates a truncated normal distribution.
- These values are similar to values from a `random_normal_initializer`
+ These values are similar to values from a `RandomNormal`
except that values more than two standard deviations from the mean
are discarded and re-drawn. This is the recommended initializer for
neural network weights and filters.
@@ -161,6 +163,7 @@ class VarianceScaling(Initializer):
With `distribution="normal"`, samples are drawn from a truncated normal
distribution centered on zero, with `stddev = sqrt(scale / n)` where n is:
+
- number of input units in the weight tensor, if mode = "fan_in"
- number of output units, if mode = "fan_out"
- average of the numbers of input and output units, if mode = "fan_avg"
diff --git a/tensorflow/contrib/keras/python/keras/layers/convolutional.py b/tensorflow/contrib/keras/python/keras/layers/convolutional.py
index 1a28399a28..3b68022115 100644
--- a/tensorflow/contrib/keras/python/keras/layers/convolutional.py
+++ b/tensorflow/contrib/keras/python/keras/layers/convolutional.py
@@ -244,7 +244,7 @@ class _Conv(Layer):
'kernel_initializer':
initializers.serialize(self.kernel_initializer),
'bias_initializer':
- initializers.serialize(self.kernel_initializer),
+ initializers.serialize(self.bias_initializer),
'kernel_regularizer':
regularizers.serialize(self.kernel_regularizer),
'bias_regularizer':
@@ -289,7 +289,7 @@ class Conv1D(_Conv):
any `dilation_rate` value != 1.
padding: One of `"valid"`, `"causal"` or `"same"` (case-insensitive).
`"causal"` results in causal (dilated) convolutions, e.g. output[t]
- depends solely on input[:t-1]. Useful when modeling temporal data
+ does not depend on input[t+1:]. Useful when modeling temporal data
where the model should not violate the temporal order.
See [WaveNet: A Generative Model for Raw Audio, section
2.1](https://arxiv.org/abs/1609.03499).
@@ -395,9 +395,9 @@ class Conv2D(_Conv):
one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
- `(batch, width, height, channels)` while `channels_first`
+ `(batch, height, width, channels)` while `channels_first`
corresponds to inputs with shape
- `(batch, channels, width, height)`.
+ `(batch, channels, height, width)`.
It defaults to the `image_data_format` value found in your
Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be "channels_last".
@@ -621,7 +621,7 @@ class Conv2DTranspose(Conv2D):
Arguments:
filters: Integer, the dimensionality of the output space
- (i.e. the number output of filters in the convolution).
+ (i.e. the number of output filters in the convolution).
kernel_size: An integer or tuple/list of 2 integers, specifying the
width and height of the 2D convolution window.
Can be a single integer to specify the same value for
@@ -637,9 +637,9 @@ class Conv2DTranspose(Conv2D):
one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
- `(batch, width, height, channels)` while `channels_first`
+ `(batch, height, width, channels)` while `channels_first`
corresponds to inputs with shape
- `(batch, channels, width, height)`.
+ `(batch, channels, height, width)`.
It defaults to the `image_data_format` value found in your
Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be "channels_last".
@@ -688,7 +688,7 @@ class Conv2DTranspose(Conv2D):
kernel_size,
strides=(1, 1),
padding='valid',
- data_format='channels_last',
+ data_format=None,
activation=None,
use_bias=True,
kernel_initializer='glorot_uniform',
@@ -845,9 +845,9 @@ class SeparableConv2D(Conv2D):
one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
- `(batch, width, height, channels)` while `channels_first`
+ `(batch, height, width, channels)` while `channels_first`
corresponds to inputs with shape
- `(batch, channels, width, height)`.
+ `(batch, channels, height, width)`.
It defaults to the `image_data_format` value found in your
Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be "channels_last".
@@ -1079,9 +1079,9 @@ class UpSampling2D(Layer):
one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
- `(batch, width, height, channels)` while `channels_first`
+ `(batch, height, width, channels)` while `channels_first`
corresponds to inputs with shape
- `(batch, channels, width, height)`.
+ `(batch, channels, height, width)`.
It defaults to the `image_data_format` value found in your
Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be "channels_last".
@@ -1257,7 +1257,7 @@ class ZeroPadding2D(Layer):
- If tuple of 2 ints:
interpreted as two different
symmetric padding values for height and width:
- `(symmetric_height_pad, symmetrc_width_pad)`.
+ `(symmetric_height_pad, symmetric_width_pad)`.
- If tuple of 2 tuples of 2 ints:
interpreted as
`((top_pad, bottom_pad), (left_pad, right_pad))`
@@ -1265,9 +1265,9 @@ class ZeroPadding2D(Layer):
one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
- `(batch, width, height, channels)` while `channels_first`
+ `(batch, height, width, channels)` while `channels_first`
corresponds to inputs with shape
- `(batch, channels, width, height)`.
+ `(batch, channels, height, width)`.
It defaults to the `image_data_format` value found in your
Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be "channels_last".
@@ -1498,7 +1498,7 @@ class Cropping2D(Layer):
- If tuple of 2 ints:
interpreted as two different
symmetric cropping values for height and width:
- `(symmetric_height_crop, symmetrc_width_crop)`.
+ `(symmetric_height_crop, symmetric_width_crop)`.
- If tuple of 2 tuples of 2 ints:
interpreted as
`((top_crop, bottom_crop), (left_crop, right_crop))`
@@ -1506,9 +1506,9 @@ class Cropping2D(Layer):
one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
- `(batch, width, height, channels)` while `channels_first`
+ `(batch, height, width, channels)` while `channels_first`
corresponds to inputs with shape
- `(batch, channels, width, height)`.
+ `(batch, channels, height, width)`.
It defaults to the `image_data_format` value found in your
Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be "channels_last".
diff --git a/tensorflow/contrib/keras/python/keras/layers/convolutional_recurrent.py b/tensorflow/contrib/keras/python/keras/layers/convolutional_recurrent.py
index 4ed5046dc3..4d8ef44da7 100644
--- a/tensorflow/contrib/keras/python/keras/layers/convolutional_recurrent.py
+++ b/tensorflow/contrib/keras/python/keras/layers/convolutional_recurrent.py
@@ -357,7 +357,7 @@ class ConvLSTM2D(ConvRecurrent2D):
self.states = [None, None]
if self.data_format == 'channels_first':
- channel_axis = 1
+ channel_axis = 2
else:
channel_axis = -1
if input_shape[channel_axis] is None:
diff --git a/tensorflow/contrib/keras/python/keras/layers/core.py b/tensorflow/contrib/keras/python/keras/layers/core.py
index 1207cc119f..8dd55aaa2e 100644
--- a/tensorflow/contrib/keras/python/keras/layers/core.py
+++ b/tensorflow/contrib/keras/python/keras/layers/core.py
@@ -88,7 +88,7 @@ class Dropout(Layer):
"""Applies Dropout to the input.
Dropout consists in randomly setting
- a fraction `p` of input units to 0 at each update during training time,
+ a fraction `rate` of input units to 0 at each update during training time,
which helps prevent overfitting.
Arguments:
@@ -140,7 +140,7 @@ class SpatialDropout1D(Dropout):
between feature maps and should be used instead.
Arguments:
- p: float between 0 and 1. Fraction of the input units to drop.
+ rate: float between 0 and 1. Fraction of the input units to drop.
Input shape:
3D tensor with shape:
@@ -775,7 +775,7 @@ class Dense(Layer):
'kernel_initializer':
initializers.serialize(self.kernel_initializer),
'bias_initializer':
- initializers.serialize(self.kernel_initializer),
+ initializers.serialize(self.bias_initializer),
'kernel_regularizer':
regularizers.serialize(self.kernel_regularizer),
'bias_regularizer':
diff --git a/tensorflow/contrib/keras/python/keras/layers/local.py b/tensorflow/contrib/keras/python/keras/layers/local.py
index 3bf5ee4f0f..895d6e3727 100644
--- a/tensorflow/contrib/keras/python/keras/layers/local.py
+++ b/tensorflow/contrib/keras/python/keras/layers/local.py
@@ -59,7 +59,8 @@ class LocallyConnected1D(Layer):
specifying the stride length of the convolution.
Specifying any stride value != 1 is incompatible with specifying
any `dilation_rate` value != 1.
- padding: One of `"valid"` or `"same"` (case-insensitive).
+ padding: Currently only supports `"valid"` (case-insensitive).
+ `"same"` may be supported in the future.
activation: Activation function to use.
If you don't specify anything, no activation is applied
(ie. "linear" activation: `a(x) = x`).
@@ -188,7 +189,7 @@ class LocallyConnected1D(Layer):
'kernel_initializer':
initializers.serialize(self.kernel_initializer),
'bias_initializer':
- initializers.serialize(self.kernel_initializer),
+ initializers.serialize(self.bias_initializer),
'kernel_regularizer':
regularizers.serialize(self.kernel_regularizer),
'bias_regularizer':
@@ -239,16 +240,15 @@ class LocallyConnected2D(Layer):
specifying the strides of the convolution along the width and height.
Can be a single integer to specify the same value for
all spatial dimensions.
- Specifying any stride value != 1 is incompatible with specifying
- any `dilation_rate` value != 1.
- padding: one of `"valid"` or `"same"` (case-insensitive).
+ padding: Currently only support `"valid"` (case-insensitive).
+ `"same"` will be supported in future.
data_format: A string,
one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
- `(batch, width, height, channels)` while `channels_first`
+ `(batch, height, width, channels)` while `channels_first`
corresponds to inputs with shape
- `(batch, channels, width, height)`.
+ `(batch, channels, height, width)`.
It defaults to the `image_data_format` value found in your
Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be "channels_last".
@@ -460,7 +460,7 @@ class LocallyConnected2D(Layer):
'kernel_initializer':
initializers.serialize(self.kernel_initializer),
'bias_initializer':
- initializers.serialize(self.kernel_initializer),
+ initializers.serialize(self.bias_initializer),
'kernel_regularizer':
regularizers.serialize(self.kernel_regularizer),
'bias_regularizer':
diff --git a/tensorflow/contrib/keras/python/keras/layers/merge.py b/tensorflow/contrib/keras/python/keras/layers/merge.py
index eea4313d31..d52bd2bbb3 100644
--- a/tensorflow/contrib/keras/python/keras/layers/merge.py
+++ b/tensorflow/contrib/keras/python/keras/layers/merge.py
@@ -41,6 +41,44 @@ class _Merge(Layer):
def _merge_function(self, inputs):
raise NotImplementedError
+ def _compute_elemwise_op_output_shape(self, shape1, shape2):
+ """Computes the shape of the resultant of an elementwise operation.
+
+ Arguments:
+ shape1: tuple or None. Shape of the first tensor
+ shape2: tuple or None. Shape of the second tensor
+
+ Returns:
+ expected output shape when an element-wise operation is
+ carried out on 2 tensors with shapes shape1 and shape2.
+ tuple or None.
+
+ Raises:
+ ValueError: if shape1 and shape2 are not compatible for
+ element-wise operations.
+ """
+ if None in [shape1, shape2]:
+ return None
+ elif len(shape1) < len(shape2):
+ return self._compute_elemwise_op_output_shape(shape2, shape1)
+ elif not shape2:
+ return shape1
+ output_shape = list(shape1[:-len(shape2)])
+ for i, j in zip(shape1[-len(shape2):], shape2):
+ if i is None or j is None:
+ output_shape.append(None)
+ elif i == 1:
+ output_shape.append(j)
+ elif j == 1:
+ output_shape.append(i)
+ else:
+ if i != j:
+ raise ValueError('Operands could not be broadcast '
+ 'together with shapes ' + str(shape1) + ' ' + str(
+ shape2))
+ output_shape.append(i)
+ return tuple(output_shape)
+
def build(self, input_shape):
# Used purely for shape validation.
if not isinstance(input_shape, list):
@@ -49,23 +87,107 @@ class _Merge(Layer):
raise ValueError('A merge layer should be called '
'on a list of at least 2 inputs. '
'Got ' + str(len(input_shape)) + ' inputs.')
- if all([shape is None for shape in input_shape]):
- return
- input_shapes = [
- tuple(tensor_shape.TensorShape(shape).as_list())
- for shape in input_shape
- ]
- # TODO(fchollet): handle shapes with None entries.
- input_shapes_set = set(input_shapes)
- if None in input_shapes_set:
- input_shapes_set.remove(None)
- if len(input_shapes_set) > 1:
- raise ValueError('Only tensors of same shape can '
- 'be merged by layer' + self.name +
- ' Got input shapes: %s' % input_shapes)
+ batch_sizes = [s[0] for s in input_shape if s is not None]
+ batch_sizes = set(batch_sizes)
+ batch_sizes -= set([None])
+ if len(batch_sizes) > 1:
+ raise ValueError('Can not merge tensors with different '
+ 'batch sizes. Got tensors with shapes : ' + str(
+ input_shape))
+ if input_shape[0] is None:
+ output_shape = None
+ else:
+ output_shape = input_shape[0][1:]
+ for i in range(1, len(input_shape)):
+ if input_shape[i] is None:
+ shape = None
+ else:
+ shape = input_shape[i][1:]
+ output_shape = self._compute_elemwise_op_output_shape(output_shape, shape)
+ # If the inputs have different ranks, we have to reshape them
+ # to make them broadcastable.
+ if None not in input_shape and len(set(map(len, input_shape))) == 1:
+ self._reshape_required = False
+ else:
+ self._reshape_required = True
def call(self, inputs):
- return self._merge_function(inputs)
+ if self._reshape_required:
+ reshaped_inputs = []
+ input_ndims = list(map(K.ndim, inputs))
+ if None not in input_ndims:
+ # If ranks of all inputs are available,
+ # we simply expand each of them at axis=1
+ # until all of them have the same rank.
+ max_ndim = max(input_ndims)
+ for x in inputs:
+ x_ndim = K.ndim(x)
+ for _ in range(max_ndim - x_ndim):
+ x = K.expand_dims(x, 1)
+ reshaped_inputs.append(x)
+ return self._merge_function(reshaped_inputs)
+ else:
+ # Transpose all inputs so that batch size is the last dimension.
+ # (batch_size, dim1, dim2, ... ) -> (dim1, dim2, ... , batch_size)
+ transposed = False
+ for x in inputs:
+ x_ndim = K.ndim(x)
+ if x_ndim is None:
+ x_shape = K.shape(x)
+ batch_size = x_shape[0]
+ new_shape = K.concatenate([x_shape[1:], K.expand_dims(batch_size)])
+ x_transposed = K.reshape(x,
+ K.stack([batch_size, K.prod(x_shape[1:])]))
+ x_transposed = K.permute_dimensions(x_transposed, (1, 0))
+ x_transposed = K.reshape(x_transposed, new_shape)
+ reshaped_inputs.append(x_transposed)
+ transposed = True
+ elif x_ndim > 1:
+ dims = list(range(1, x_ndim)) + [0]
+ reshaped_inputs.append(K.permute_dimensions(x, dims))
+ transposed = True
+ else:
+ # We don't transpose inputs if they are 1D vectors or scalars.
+ reshaped_inputs.append(x)
+ y = self._merge_function(reshaped_inputs)
+ y_ndim = K.ndim(y)
+ if transposed:
+ # If inputs have been transposed, we have to transpose the output too.
+ if y_ndim is None:
+ y_shape = K.shape(y)
+ y_ndim = K.shape(y_shape)[0]
+ batch_size = y_shape[y_ndim - 1]
+ new_shape = K.concatenate(
+ [K.expand_dims(batch_size), y_shape[:y_ndim - 1]])
+ y = K.reshape(y, (-1, batch_size))
+ y = K.permute_dimensions(y, (1, 0))
+ y = K.reshape(y, new_shape)
+ elif y_ndim > 1:
+ dims = [y_ndim - 1] + list(range(y_ndim - 1))
+ y = K.permute_dimensions(y, dims)
+ return y
+ else:
+ return self._merge_function(inputs)
+
+ def compute_output_shape(self, input_shape):
+ if input_shape[0] is None:
+ output_shape = None
+ else:
+ output_shape = input_shape[0][1:]
+ for i in range(1, len(input_shape)):
+ if input_shape[i] is None:
+ shape = None
+ else:
+ shape = input_shape[i][1:]
+ output_shape = self._compute_elemwise_op_output_shape(output_shape, shape)
+ batch_sizes = [s[0] for s in input_shape if s is not None]
+ batch_sizes = set(batch_sizes)
+ batch_sizes -= set([None])
+ if len(batch_sizes) == 1:
+ output_shape = (list(batch_sizes)[0],) + output_shape
+ else:
+ output_shape = (None,) + output_shape
+ return output_shape
def compute_mask(self, inputs, mask=None):
if mask is None:
@@ -219,8 +341,8 @@ class Concatenate(_Merge):
for input_i, mask_i in zip(inputs, mask):
if mask_i is None:
# Input is unmasked. Append all 1s to masks,
- # but cast it to uint8 first
- masks.append(K.cast(K.ones_like(input_i), 'uint8'))
+ # but cast it to bool first
+ masks.append(K.cast(K.ones_like(input_i), 'bool'))
elif K.ndim(mask_i) < K.ndim(input_i):
# Mask is smaller than the input, expand it
masks.append(K.expand_dims(mask_i))
diff --git a/tensorflow/contrib/keras/python/keras/layers/normalization.py b/tensorflow/contrib/keras/python/keras/layers/normalization.py
index 41c618cc79..d429cd6d9b 100644
--- a/tensorflow/contrib/keras/python/keras/layers/normalization.py
+++ b/tensorflow/contrib/keras/python/keras/layers/normalization.py
@@ -154,7 +154,7 @@ class BatchNormalization(Layer):
broadcast_shape[self.axis] = input_shape[self.axis]
# Determines whether broadcasting is needed.
- needs_broadcasting = (sorted(reduction_axes) != range(ndim)[:-1])
+ needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1])
normed, mean, variance = K.normalize_batch_in_training(
inputs, self.gamma, self.beta, reduction_axes, epsilon=self.epsilon)
diff --git a/tensorflow/contrib/keras/python/keras/layers/pooling.py b/tensorflow/contrib/keras/python/keras/layers/pooling.py
index e31caed3ec..47c88bf4d0 100644
--- a/tensorflow/contrib/keras/python/keras/layers/pooling.py
+++ b/tensorflow/contrib/keras/python/keras/layers/pooling.py
@@ -199,9 +199,9 @@ class MaxPooling2D(_Pooling2D):
one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
- `(batch, width, height, channels)` while `channels_first`
+ `(batch, height, width, channels)` while `channels_first`
corresponds to inputs with shape
- `(batch, channels, width, height)`.
+ `(batch, channels, height, width)`.
It defaults to the `image_data_format` value found in your
Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be "channels_last".
@@ -255,9 +255,9 @@ class AveragePooling2D(_Pooling2D):
one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
- `(batch, width, height, channels)` while `channels_first`
+ `(batch, height, width, channels)` while `channels_first`
corresponds to inputs with shape
- `(batch, channels, width, height)`.
+ `(batch, channels, height, width)`.
It defaults to the `image_data_format` value found in your
Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be "channels_last".
@@ -542,9 +542,9 @@ class GlobalAveragePooling2D(_GlobalPooling2D):
one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
- `(batch, width, height, channels)` while `channels_first`
+ `(batch, height, width, channels)` while `channels_first`
corresponds to inputs with shape
- `(batch, channels, width, height)`.
+ `(batch, channels, height, width)`.
It defaults to the `image_data_format` value found in your
Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be "channels_last".
@@ -577,9 +577,9 @@ class GlobalMaxPooling2D(_GlobalPooling2D):
one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
- `(batch, width, height, channels)` while `channels_first`
+ `(batch, height, width, channels)` while `channels_first`
corresponds to inputs with shape
- `(batch, channels, width, height)`.
+ `(batch, channels, height, width)`.
It defaults to the `image_data_format` value found in your
Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be "channels_last".
diff --git a/tensorflow/contrib/keras/python/keras/layers/recurrent.py b/tensorflow/contrib/keras/python/keras/layers/recurrent.py
index 06986d3eaa..6301132f4d 100644
--- a/tensorflow/contrib/keras/python/keras/layers/recurrent.py
+++ b/tensorflow/contrib/keras/python/keras/layers/recurrent.py
@@ -105,8 +105,16 @@ class Recurrent(Layer):
# now model.output_shape == (None, 32)
# note: `None` is the batch dimension.
- # for subsequent layers, not need to specify the input size:
+ # for subsequent layers, no need to specify the input size:
model.add(LSTM(16))
+
+ # to stack recurrent layers, you must use return_sequences=True
+ # on any recurrent layer that feeds into another recurrent layer.
+ # note that you only need to specify the input size on the first layer.
+ model = Sequential()
+ model.add(LSTM(64, input_dim=64, input_length=10, return_sequences=True))
+ model.add(LSTM(32, return_sequences=True))
+ model.add(LSTM(10))
```
Arguments:
@@ -116,7 +124,8 @@ class Recurrent(Layer):
return_sequences: Boolean. Whether to return the last output
in the output sequence, or the full sequence.
go_backwards: Boolean (default False).
- If True, process the input sequence backwards.
+ If True, process the input sequence backwards and return the
+ reversed sequence.
stateful: Boolean (default False). If True, the last state
for each sample at index i in a batch will be used as initial
state for the sample of index i in the following batch.
@@ -398,6 +407,7 @@ class SimpleRNN(Recurrent):
units: Positive integer, dimensionality of the output space.
activation: Activation function to use.
If you don't specify anything, no activation is applied
+ If you pass None, no activation is applied
(ie. "linear" activation: `a(x) = x`).
use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix,
@@ -547,7 +557,7 @@ class SimpleRNN(Recurrent):
def get_constants(self, inputs, training=None):
constants = []
- if self.implementation == 0 and 0 < self.dropout < 1:
+ if self.implementation != 0 and 0 < self.dropout < 1:
input_shape = K.int_shape(inputs)
input_dim = input_shape[-1]
ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
@@ -619,7 +629,7 @@ class GRU(Recurrent):
Arguments:
units: Positive integer, dimensionality of the output space.
activation: Activation function to use.
- If you don't specify anything, no activation is applied
+ If you pass None, no activation is applied
(ie. "linear" activation: `a(x) = x`).
recurrent_activation: Activation function to use
for the recurrent step.
@@ -792,7 +802,7 @@ class GRU(Recurrent):
def get_constants(self, inputs, training=None):
constants = []
- if self.implementation == 0 and 0 < self.dropout < 1:
+ if self.implementation != 0 and 0 < self.dropout < 1:
input_shape = K.int_shape(inputs)
input_dim = input_shape[-1]
ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
@@ -861,7 +871,7 @@ class GRU(Recurrent):
if self.use_bias:
x_z = K.bias_add(x_z, self.bias_z)
x_r = K.bias_add(x_r, self.bias_r)
- x_h = K.bias_add(x_r, self.bias_h)
+ x_h = K.bias_add(x_h, self.bias_h)
else:
raise ValueError('Unknown `implementation` mode.')
z = self.recurrent_activation(x_z + K.dot(h_tm1 * rec_dp_mask[0],
@@ -924,7 +934,7 @@ class LSTM(Recurrent):
Arguments:
units: Positive integer, dimensionality of the output space.
activation: Activation function to use.
- If you don't specify anything, no activation is applied
+ If you pass None, no activation is applied
(ie. "linear" activation: `a(x) = x`).
recurrent_activation: Activation function to use
for the recurrent step.
@@ -1127,7 +1137,7 @@ class LSTM(Recurrent):
def get_constants(self, inputs, training=None):
constants = []
- if self.implementation == 0 and 0 < self.dropout < 1:
+ if self.implementation != 0 and 0 < self.dropout < 1:
input_shape = K.int_shape(inputs)
input_dim = input_shape[-1]
ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
diff --git a/tensorflow/contrib/keras/python/keras/layers/wrappers.py b/tensorflow/contrib/keras/python/keras/layers/wrappers.py
index 75b4810e40..eeb67493ee 100644
--- a/tensorflow/contrib/keras/python/keras/layers/wrappers.py
+++ b/tensorflow/contrib/keras/python/keras/layers/wrappers.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
+# pylint: disable=protected-access
"""Wrapper layers: layers that augment the functionality of another layer.
"""
from __future__ import absolute_import
@@ -19,6 +20,7 @@ from __future__ import division
from __future__ import print_function
import copy
+import inspect
from tensorflow.contrib.keras.python.keras import backend as K
from tensorflow.contrib.keras.python.keras.engine import InputSpec
@@ -70,9 +72,10 @@ class Wrapper(Layer):
return dict(list(base_config.items()) + list(config.items()))
@classmethod
- def from_config(cls, config):
+ def from_config(cls, config, custom_objects=None):
from tensorflow.contrib.keras.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
- layer = deserialize_layer(config.pop('layer'))
+ layer = deserialize_layer(
+ config.pop('layer'), custom_objects=custom_objects)
return cls(layer, **config)
@@ -188,12 +191,15 @@ class Bidirectional(Wrapper):
If None, the outputs will not be combined,
they will be returned as a list.
+ Raises:
+ ValueError: In case of invalid `merge_mode` argument.
+
Examples:
```python
model = Sequential()
model.add(Bidirectional(LSTM(10, return_sequences=True), input_shape=(5,
- 10)))
+ 10)))
model.add(Bidirectional(LSTM(10)))
model.add(Dense(5))
model.add(Activation('softmax'))
@@ -242,29 +248,47 @@ class Bidirectional(Wrapper):
shape = self.forward_layer._compute_output_shape(input_shape) # pylint: disable=protected-access
return [shape, copy.copy(shape)]
- def call(self, inputs, mask=None):
- y = self.forward_layer.call(inputs, mask)
- y_rev = self.backward_layer.call(inputs, mask)
+ def call(self, inputs, training=None, mask=None):
+ kwargs = {}
+ func_args = inspect.getargspec(self.layer.call).args
+ if 'training' in func_args:
+ kwargs['training'] = training
+ if 'mask' in func_args:
+ kwargs['mask'] = mask
+
+ y = self.forward_layer.call(inputs, **kwargs)
+ y_rev = self.backward_layer.call(inputs, **kwargs)
if self.return_sequences:
y_rev = K.reverse(y_rev, 1)
if self.merge_mode == 'concat':
- return K.concatenate([y, y_rev])
+ output = K.concatenate([y, y_rev])
elif self.merge_mode == 'sum':
- return y + y_rev
+ output = y + y_rev
elif self.merge_mode == 'ave':
- return (y + y_rev) / 2
+ output = (y + y_rev) / 2
elif self.merge_mode == 'mul':
- return y * y_rev
+ output = y * y_rev
elif self.merge_mode is None:
- return [y, y_rev]
+ output = [y, y_rev]
+
+ # Properly set learning phase
+ if 0 < self.layer.dropout + self.layer.recurrent_dropout:
+ if self.merge_mode is None:
+ for out in output:
+ out._uses_learning_phase = True
+ else:
+ output._uses_learning_phase = True
+ return output
def reset_states(self):
self.forward_layer.reset_states()
self.backward_layer.reset_states()
def build(self, input_shape):
- self.forward_layer.build(input_shape)
- self.backward_layer.build(input_shape)
+ with K.name_scope(self.forward_layer.name):
+ self.forward_layer.build(input_shape)
+ with K.name_scope(self.backward_layer.name):
+ self.backward_layer.build(input_shape)
self.built = True
def compute_mask(self, inputs, mask):
diff --git a/tensorflow/contrib/keras/python/keras/metrics.py b/tensorflow/contrib/keras/python/keras/metrics.py
index d7266c94cf..59d380f73b 100644
--- a/tensorflow/contrib/keras/python/keras/metrics.py
+++ b/tensorflow/contrib/keras/python/keras/metrics.py
@@ -43,12 +43,15 @@ def binary_accuracy(y_true, y_pred):
def categorical_accuracy(y_true, y_pred):
- return K.equal(K.argmax(y_true, axis=-1), K.argmax(y_pred, axis=-1))
+ return K.cast(
+ K.equal(K.argmax(y_true, axis=-1), K.argmax(y_pred, axis=-1)), K.floatx())
def sparse_categorical_accuracy(y_true, y_pred):
- return K.equal(
- K.max(y_true, axis=-1), K.cast(K.argmax(y_pred, axis=-1), K.floatx()))
+ return K.cast(
+ K.equal(
+ K.max(y_true, axis=-1), K.cast(K.argmax(y_pred, axis=-1),
+ K.floatx())), K.floatx())
def top_k_categorical_accuracy(y_true, y_pred, k=5):
diff --git a/tensorflow/contrib/keras/python/keras/models.py b/tensorflow/contrib/keras/python/keras/models.py
index 2be4431d03..5289bb732b 100644
--- a/tensorflow/contrib/keras/python/keras/models.py
+++ b/tensorflow/contrib/keras/python/keras/models.py
@@ -207,7 +207,7 @@ def load_model(filepath, custom_objects=None):
ValueError: In case of an invalid savefile.
"""
if h5py is None:
- raise ImportError('`save_model` requires h5py.')
+ raise ImportError('`load_model` requires h5py.')
if not custom_objects:
custom_objects = {}
@@ -1006,7 +1006,7 @@ class Sequential(Model):
steps_per_epoch: Total number of steps (batches of samples)
to yield from `generator` before declaring one epoch
finished and starting the next epoch. It should typically
- be equal to the number of unique samples if your dataset
+ be equal to the number of unique samples of your dataset
divided by the batch size.
epochs: Integer, total number of iterations on the data.
verbose: Verbosity mode, 0, 1, or 2.
@@ -1017,8 +1017,10 @@ class Sequential(Model):
- A tuple (inputs, targets, sample_weights).
validation_steps: Only relevant if `validation_data`
is a generator.
- Number of samples to use from validation generator
- at the end of every epoch.
+ Number of steps to yield from validation generator
+ at the end of every epoch. It should typically
+ be equal to the number of unique samples of your
+ validation dataset divided by the batch size.
class_weight: Dictionary mapping class indices to a weight
for the class.
max_q_size: Maximum size for the generator queue
@@ -1050,7 +1052,7 @@ class Sequential(Model):
# and labels, from each line in the file
x, y = process_line(line)
yield (x, y)
- f.close()
+ f.close()
model.fit_generator(generate_arrays_from_file('/my_file.txt'),
samples_per_epoch=10000, epochs=10)
@@ -1119,7 +1121,8 @@ class Sequential(Model):
steps,
max_q_size=10,
workers=1,
- pickle_safe=False):
+ pickle_safe=False,
+ verbose=0):
"""Generates predictions for the input samples from a data generator.
The generator should return the same kind of data as accepted by
@@ -1136,6 +1139,7 @@ class Sequential(Model):
relies on multiprocessing, you should not pass
non picklable arguments to the generator
as they can't be passed easily to children processes.
+ verbose: verbosity mode, 0 or 1.
Returns:
A Numpy array of predictions.
@@ -1147,7 +1151,8 @@ class Sequential(Model):
steps,
max_q_size=max_q_size,
workers=workers,
- pickle_safe=pickle_safe)
+ pickle_safe=pickle_safe,
+ verbose=verbose)
def get_config(self):
config = []
@@ -1159,9 +1164,9 @@ class Sequential(Model):
return copy.deepcopy(config)
@classmethod
- def from_config(cls, config):
+ def from_config(cls, config, custom_objects=None):
model = cls()
for conf in config:
- layer = layer_module.deserialize(conf)
+ layer = layer_module.deserialize(conf, custom_objects=custom_objects)
model.add(layer)
return model
diff --git a/tensorflow/contrib/keras/python/keras/preprocessing/image.py b/tensorflow/contrib/keras/python/keras/preprocessing/image.py
index 86c7650a07..de0749ae02 100644
--- a/tensorflow/contrib/keras/python/keras/preprocessing/image.py
+++ b/tensorflow/contrib/keras/python/keras/preprocessing/image.py
@@ -785,7 +785,7 @@ class Iterator(object):
index_array = np.random.permutation(n)
current_index = (self.batch_index * batch_size) % n
- if n >= current_index + batch_size:
+ if n > current_index + batch_size:
current_batch_size = batch_size
self.batch_index += 1
else:
diff --git a/tensorflow/contrib/keras/python/keras/utils/generic_utils.py b/tensorflow/contrib/keras/python/keras/utils/generic_utils.py
index c1e0296835..6e83fde2c9 100644
--- a/tensorflow/contrib/keras/python/keras/utils/generic_utils.py
+++ b/tensorflow/contrib/keras/python/keras/utils/generic_utils.py
@@ -172,7 +172,8 @@ def deserialize_keras_object(identifier,
else:
fn = module_objects.get(function_name)
if fn is None:
- raise ValueError('Unknown ' + printable_module_name, ':' + class_name)
+ raise ValueError('Unknown ' + printable_module_name,
+ ':' + function_name)
return fn
else:
raise ValueError('Could not interpret serialized ' + printable_module_name +
@@ -215,6 +216,8 @@ def func_load(code, defaults=None, closure=None, globs=None):
"""
if isinstance(code, (tuple, list)): # unpack previous dump
code, defaults, closure = code
+ if isinstance(defaults, list):
+ defaults = tuple(defaults)
code = marshal.loads(code.encode('raw_unicode_escape'))
if globs is None:
globs = globals()
diff --git a/tensorflow/contrib/keras/python/keras/utils/layer_utils.py b/tensorflow/contrib/keras/python/keras/utils/layer_utils.py
index 32e0de7d3d..26878fdd57 100644
--- a/tensorflow/contrib/keras/python/keras/utils/layer_utils.py
+++ b/tensorflow/contrib/keras/python/keras/utils/layer_utils.py
@@ -171,7 +171,7 @@ def count_total_params(layers, layer_set=None):
[K.count_params(p) for p in layer.trainable_weights])
non_trainable_count += np.sum(
[K.count_params(p) for p in layer.non_trainable_weights])
- return trainable_count, non_trainable_count
+ return int(trainable_count), int(non_trainable_count)
def convert_all_kernels_in_model(model):
diff --git a/tensorflow/contrib/keras/python/keras/wrappers/scikit_learn.py b/tensorflow/contrib/keras/python/keras/wrappers/scikit_learn.py
index ecda890fec..323c31aee8 100644
--- a/tensorflow/contrib/keras/python/keras/wrappers/scikit_learn.py
+++ b/tensorflow/contrib/keras/python/keras/wrappers/scikit_learn.py
@@ -194,6 +194,36 @@ class KerasClassifier(BaseWrapper):
"""Implementation of the scikit-learn classifier API for Keras.
"""
+ def fit(self, x, y, **kwargs):
+ """Constructs a new model with `build_fn` & fit the model to `(x, y)`.
+
+ Arguments:
+ x : array-like, shape `(n_samples, n_features)`
+ Training samples where n_samples in the number of samples
+ and n_features is the number of features.
+ y : array-like, shape `(n_samples,)` or `(n_samples, n_outputs)`
+ True labels for X.
+ **kwargs: dictionary arguments
+ Legal arguments are the arguments of `Sequential.fit`
+
+ Returns:
+ history : object
+ details about the training history at each epoch.
+
+ Raises:
+ ValueError: In case of invalid shape for `y` argument.
+ """
+ y = np.array(y)
+ if len(y.shape) == 2 and y.shape[1] > 1:
+ self.classes_ = np.arange(y.shape[1])
+ elif (len(y.shape) == 2 and y.shape[1] == 1) or len(y.shape) == 1:
+ self.classes_ = np.unique(y)
+ y = np.searchsorted(self.classes_, y)
+ else:
+ raise ValueError('Invalid shape for y: ' + str(y.shape))
+ self.n_classes_ = len(self.classes_)
+ return super(KerasClassifier, self).fit(x, y, **kwargs)
+
def predict(self, x, **kwargs):
"""Returns the class predictions for the given test data.
@@ -210,7 +240,8 @@ class KerasClassifier(BaseWrapper):
Class predictions.
"""
kwargs = self.filter_sk_params(Sequential.predict_classes, kwargs)
- return self.model.predict_classes(x, **kwargs)
+ classes = self.model.predict_classes(x, **kwargs)
+ return self.classes_[classes]
def predict_proba(self, x, **kwargs):
"""Returns class probability estimates for the given test data.
@@ -261,6 +292,7 @@ class KerasClassifier(BaseWrapper):
compute accuracy. You should pass `metrics=["accuracy"]` to
the `.compile()` method of the model.
"""
+ y = np.searchsorted(self.classes_, y)
kwargs = self.filter_sk_params(Sequential.evaluate, kwargs)
loss_name = self.model.loss