aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/cmake/python_sanity_test.py18
-rw-r--r--tensorflow/contrib/layers/python/layers/layers.py661
-rw-r--r--tensorflow/contrib/opt/python/training/model_average_optimizer.py119
-rw-r--r--tensorflow/contrib/opt/python/training/model_average_optimizer_test.py103
-rw-r--r--tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py39
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py852
-rw-r--r--tensorflow/contrib/rnn/python/ops/rnn_cell.py597
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py79
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py99
-rw-r--r--tensorflow/contrib/verbs/rdma.cc109
-rw-r--r--tensorflow/contrib/verbs/rdma.h11
-rw-r--r--tensorflow/contrib/verbs/rdma_mgr.cc29
-rw-r--r--tensorflow/core/kernels/mkl_aggregate_ops.cc188
-rw-r--r--tensorflow/core/kernels/mkl_softmax_op.cc33
-rw-r--r--tensorflow/core/kernels/spectrogram_test_utils.cc4
-rw-r--r--tensorflow/core/kernels/transpose_functor_cpu.cc16
-rw-r--r--tensorflow/examples/tutorials/word2vec/word2vec_basic.py81
-rw-r--r--tensorflow/python/data/kernel_tests/batch_dataset_op_test.py103
-rw-r--r--tensorflow/python/ops/histogram_ops.py13
-rw-r--r--tensorflow/python/ops/histogram_ops_test.py10
-rw-r--r--tensorflow/python/ops/image_ops_impl.py235
-rw-r--r--tensorflow/python/ops/metrics_impl.py530
-rw-r--r--tensorflow/python/ops/nn_impl.py26
-rw-r--r--tensorflow/python/ops/nn_test.py19
-rw-r--r--tensorflow/python/util/compat.py5
-rw-r--r--tensorflow/tools/pip_package/pip_smoke_test.py29
26 files changed, 2119 insertions, 1889 deletions
diff --git a/tensorflow/contrib/cmake/python_sanity_test.py b/tensorflow/contrib/cmake/python_sanity_test.py
index 3be5bd1b23..e0056823a8 100644
--- a/tensorflow/contrib/cmake/python_sanity_test.py
+++ b/tensorflow/contrib/cmake/python_sanity_test.py
@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""
-Complain about invalid or missing entries in python_*.txt files.
+"""Complain about invalid or missing entries in python_*.txt files.
+
Problematic entries can be commented for temporary whitelisting.
"""
@@ -35,6 +35,7 @@ def abs_path(path):
path = os.path.abspath(path)
return path
+
def read_entries(test):
with open(abs_path(test.entries_file), "r") as f:
lines = f.readlines()
@@ -47,25 +48,28 @@ def read_entries(test):
for line in lines:
# line is comment
- if line.startswith('#'):
+ if line.startswith("#"):
line = line[1:].strip()
# whitelist entry
- if line.startswith('tensorflow/'):
+ if line.startswith("tensorflow/"):
test.whitelist.append(line)
# line has comment -> strip comment
- elif line.find('#') != -1:
- line = line[:line.find('#')].strip()
+ elif line.find("#") != -1:
+ line = line[:line.find("#")].strip()
test.entries.append(line)
else:
test.entries.append(line)
+
def test_invalid_directories(test):
for entry in test.entries:
if not os.path.isdir(abs_path(entry)):
problem = "'" + test.entries_file + "' contains invalid '" + entry + "'"
- solution = "Please remove the invalid entry (or add the missing directory)."
+ solution = ("Please remove the invalid entry (or add the missing "
+ "directory).")
raise AssertionError(problem + "\n" + solution)
+
def test_missing_directory(test, path):
if path in test.whitelist:
return
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py
index ef2b673074..7c52da7b49 100644
--- a/tensorflow/contrib/layers/python/layers/layers.py
+++ b/tensorflow/contrib/layers/python/layers/layers.py
@@ -54,47 +54,17 @@ from tensorflow.python.layers.maxout import maxout
# TODO(b/28426988): Replace legacy_* fns migrated from slim.
# TODO(b/28426988): Remove legacy_* when all uses have migrated to new API.
-__all__ = ['avg_pool2d',
- 'avg_pool3d',
- 'batch_norm',
- 'bias_add',
- 'conv2d',
- 'conv3d',
- 'conv2d_in_plane',
- 'conv2d_transpose',
- 'conv3d_transpose',
- 'convolution',
- 'convolution2d',
- 'convolution2d_in_plane',
- 'convolution2d_transpose',
- 'convolution3d',
- 'convolution3d_transpose',
- 'dropout',
- 'elu',
- 'flatten',
- 'fully_connected',
- 'GDN',
- 'gdn',
- 'layer_norm',
- 'linear',
- 'pool',
- 'max_pool2d',
- 'max_pool3d',
- 'one_hot_encoding',
- 'relu',
- 'relu6',
- 'repeat',
- 'scale_gradient',
- 'separable_conv2d',
- 'separable_convolution2d',
- 'softmax',
- 'spatial_softmax',
- 'stack',
- 'unit_norm',
- 'legacy_fully_connected',
- 'legacy_linear',
- 'legacy_relu',
- 'maxout']
+__all__ = [
+ 'avg_pool2d', 'avg_pool3d', 'batch_norm', 'bias_add', 'conv2d', 'conv3d',
+ 'conv2d_in_plane', 'conv2d_transpose', 'conv3d_transpose', 'convolution',
+ 'convolution2d', 'convolution2d_in_plane', 'convolution2d_transpose',
+ 'convolution3d', 'convolution3d_transpose', 'dropout', 'elu', 'flatten',
+ 'fully_connected', 'GDN', 'gdn', 'layer_norm', 'linear', 'pool',
+ 'max_pool2d', 'max_pool3d', 'one_hot_encoding', 'relu', 'relu6', 'repeat',
+ 'scale_gradient', 'separable_conv2d', 'separable_convolution2d', 'softmax',
+ 'spatial_softmax', 'stack', 'unit_norm', 'legacy_fully_connected',
+ 'legacy_linear', 'legacy_relu', 'maxout'
+]
DATA_FORMAT_NCHW = 'NCHW'
DATA_FORMAT_NHWC = 'NHWC'
@@ -139,13 +109,14 @@ def avg_pool2d(inputs,
raise ValueError('data_format has to be either NCHW or NHWC.')
with ops.name_scope(scope, 'AvgPool2D', [inputs]) as sc:
inputs = ops.convert_to_tensor(inputs)
- df = ('channels_first' if data_format and data_format.startswith('NC')
- else 'channels_last')
- layer = pooling_layers.AveragePooling2D(pool_size=kernel_size,
- strides=stride,
- padding=padding,
- data_format=df,
- _scope=sc)
+ df = ('channels_first'
+ if data_format and data_format.startswith('NC') else 'channels_last')
+ layer = pooling_layers.AveragePooling2D(
+ pool_size=kernel_size,
+ strides=stride,
+ padding=padding,
+ data_format=df,
+ _scope=sc)
outputs = layer.apply(inputs)
return utils.collect_named_outputs(outputs_collections, sc, outputs)
@@ -187,13 +158,14 @@ def avg_pool3d(inputs,
raise ValueError('data_format has to be either NCDHW or NDHWC.')
with ops.name_scope(scope, 'AvgPool3D', [inputs]) as sc:
inputs = ops.convert_to_tensor(inputs)
- df = ('channels_first' if data_format and data_format.startswith('NC')
- else 'channels_last')
- layer = pooling_layers.AveragePooling3D(pool_size=kernel_size,
- strides=stride,
- padding=padding,
- data_format=df,
- _scope=sc)
+ df = ('channels_first'
+ if data_format and data_format.startswith('NC') else 'channels_last')
+ layer = pooling_layers.AveragePooling3D(
+ pool_size=kernel_size,
+ strides=stride,
+ padding=padding,
+ data_format=df,
+ _scope=sc)
outputs = layer.apply(inputs)
return utils.collect_named_outputs(outputs_collections, sc, outputs)
@@ -298,8 +270,8 @@ def _fused_batch_norm(inputs,
raise ValueError('Inputs %s has undefined rank' % inputs.name)
elif original_rank not in [2, 4]:
raise ValueError('Inputs %s has unsupported rank.'
- ' Expected 2 or 4 but got %d' % (
- inputs.name, original_rank))
+ ' Expected 2 or 4 but got %d' % (inputs.name,
+ original_rank))
if original_rank == 2:
channels = inputs.get_shape()[-1].value
if channels is None:
@@ -393,6 +365,7 @@ def _fused_batch_norm(inputs,
def _fused_batch_norm_training():
return nn.fused_batch_norm(
inputs, gamma, beta, epsilon=epsilon, data_format=data_format)
+
def _fused_batch_norm_inference():
return nn.fused_batch_norm(
inputs,
@@ -403,9 +376,9 @@ def _fused_batch_norm(inputs,
epsilon=epsilon,
is_training=False,
data_format=data_format)
- outputs, mean, variance = utils.smart_cond(is_training,
- _fused_batch_norm_training,
- _fused_batch_norm_inference)
+
+ outputs, mean, variance = utils.smart_cond(
+ is_training, _fused_batch_norm_training, _fused_batch_norm_inference)
# If `is_training` doesn't have a constant value, because it is a `Tensor`,
# a `Variable` or `Placeholder` then is_training_value will be None and
@@ -415,6 +388,7 @@ def _fused_batch_norm(inputs,
if need_updates:
if updates_collections is None:
no_updates = lambda: outputs
+
def _force_updates():
"""Internal function forces updates moving_vars if is_training."""
update_moving_mean = moving_averages.assign_moving_average(
@@ -424,9 +398,11 @@ def _fused_batch_norm(inputs,
with ops.control_dependencies(
[update_moving_mean, update_moving_variance]):
return array_ops.identity(outputs)
+
outputs = utils.smart_cond(is_training, _force_updates, no_updates)
else:
moving_vars_fn = lambda: (moving_mean, moving_variance)
+
def _delay_updates():
"""Internal function that delay updates moving_vars if is_training."""
update_moving_mean = moving_averages.assign_moving_average(
@@ -434,9 +410,9 @@ def _fused_batch_norm(inputs,
update_moving_variance = moving_averages.assign_moving_average(
moving_variance, variance, decay, zero_debias=False)
return update_moving_mean, update_moving_variance
- update_mean, update_variance = utils.smart_cond(is_training,
- _delay_updates,
- moving_vars_fn)
+
+ update_mean, update_variance = utils.smart_cond(
+ is_training, _delay_updates, moving_vars_fn)
ops.add_to_collections(updates_collections, update_mean)
ops.add_to_collections(updates_collections, update_variance)
@@ -482,9 +458,10 @@ def batch_norm(inputs,
Can be used as a normalizer function for conv2d and fully_connected. The
normalization is over all but the last dimension if `data_format` is `NHWC`
and all but the second dimension if `data_format` is `NCHW`. In case of a 2D
- tensor this corresponds to the batch dimension, while in case of a 4D tensor this
+ tensor this corresponds to the batch dimension, while in case of a 4D tensor
+ this
corresponds to the batch and space dimensions.
-
+
Note: when training, the moving_mean and moving_variance need to be updated.
By default the update ops are placed in `tf.GraphKeys.UPDATE_OPS`, so they
need to be added as a dependency to the `train_op`. For example:
@@ -592,10 +569,9 @@ def batch_norm(inputs,
# implementation in normalization_layers.BatchNormalization.
inputs = ops.convert_to_tensor(inputs)
rank = inputs.get_shape().ndims
- possible_to_fuse = (batch_weights is None and
- not renorm and
- rank in [2, 4] and
- adjustment is None)
+ possible_to_fuse = (
+ batch_weights is None and not renorm and rank in [2, 4] and
+ adjustment is None)
if fused and possible_to_fuse and (
zero_debias_moving_mean or rank == 2 or
updates_collections is not ops.GraphKeys.UPDATE_OPS):
@@ -623,7 +599,9 @@ def batch_norm(inputs,
layer_variable_getter = _build_variable_getter()
with variable_scope.variable_scope(
- scope, 'BatchNorm', [inputs], reuse=reuse,
+ scope,
+ 'BatchNorm', [inputs],
+ reuse=reuse,
custom_getter=layer_variable_getter) as sc:
inputs = ops.convert_to_tensor(inputs)
@@ -671,15 +649,15 @@ def batch_norm(inputs,
outputs = layer.apply(inputs, training=is_training)
# Add variables to collections.
- _add_variable_to_collections(
- layer.moving_mean, variables_collections, 'moving_mean')
- _add_variable_to_collections(
- layer.moving_variance, variables_collections, 'moving_variance')
+ _add_variable_to_collections(layer.moving_mean, variables_collections,
+ 'moving_mean')
+ _add_variable_to_collections(layer.moving_variance, variables_collections,
+ 'moving_variance')
if layer.beta is not None:
_add_variable_to_collections(layer.beta, variables_collections, 'beta')
if layer.gamma is not None:
- _add_variable_to_collections(
- layer.gamma, variables_collections, 'gamma')
+ _add_variable_to_collections(layer.gamma, variables_collections,
+ 'gamma')
if activation_fn is not None:
outputs = activation_fn(outputs)
@@ -719,8 +697,8 @@ def batch_norm(inputs,
params_shape = inputs_shape[-1:]
params_shape_broadcast = None
if not params_shape.is_fully_defined():
- raise ValueError('Inputs %s has undefined channels dimension %s.' % (
- inputs.name, params_shape))
+ raise ValueError('Inputs %s has undefined channels dimension %s.' %
+ (inputs.name, params_shape))
# Allocate parameters for the beta and gamma of the normalization.
beta, gamma = None, None
@@ -731,23 +709,25 @@ def batch_norm(inputs,
'beta')
beta_initializer = param_initializers.get('beta',
init_ops.zeros_initializer())
- beta = variables.model_variable('beta',
- shape=params_shape,
- dtype=dtype,
- initializer=beta_initializer,
- collections=beta_collections,
- trainable=trainable)
+ beta = variables.model_variable(
+ 'beta',
+ shape=params_shape,
+ dtype=dtype,
+ initializer=beta_initializer,
+ collections=beta_collections,
+ trainable=trainable)
if scale:
- gamma_collections = utils.get_variable_collections(variables_collections,
- 'gamma')
+ gamma_collections = utils.get_variable_collections(
+ variables_collections, 'gamma')
gamma_initializer = param_initializers.get('gamma',
init_ops.ones_initializer())
- gamma = variables.model_variable('gamma',
- shape=params_shape,
- dtype=dtype,
- initializer=gamma_initializer,
- collections=gamma_collections,
- trainable=trainable)
+ gamma = variables.model_variable(
+ 'gamma',
+ shape=params_shape,
+ dtype=dtype,
+ initializer=gamma_initializer,
+ collections=gamma_collections,
+ trainable=trainable)
# Create moving_mean and moving_variance variables and add them to the
# appropriate collections. We disable variable partitioning while creating
@@ -796,8 +776,8 @@ def batch_norm(inputs,
mean, variance = nn.moments(inputs, moments_axes)
else:
if data_format == DATA_FORMAT_NCHW:
- mean, variance = nn.weighted_moments(inputs, moments_axes,
- batch_weights, keep_dims=True)
+ mean, variance = nn.weighted_moments(
+ inputs, moments_axes, batch_weights, keep_dims=True)
mean = array_ops.reshape(mean, [-1])
variance = array_ops.reshape(variance, [-1])
else:
@@ -806,19 +786,21 @@ def batch_norm(inputs,
moving_vars_fn = lambda: (moving_mean, moving_variance)
if updates_collections is None:
+
def _force_updates():
"""Internal function forces updates moving_vars if is_training."""
update_moving_mean = moving_averages.assign_moving_average(
moving_mean, mean, decay, zero_debias=zero_debias_moving_mean)
update_moving_variance = moving_averages.assign_moving_average(
moving_variance, variance, decay, zero_debias=False)
- with ops.control_dependencies([update_moving_mean,
- update_moving_variance]):
+ with ops.control_dependencies(
+ [update_moving_mean, update_moving_variance]):
return array_ops.identity(mean), array_ops.identity(variance)
- mean, variance = utils.smart_cond(is_training,
- _force_updates,
+
+ mean, variance = utils.smart_cond(is_training, _force_updates,
moving_vars_fn)
else:
+
def _delay_updates():
"""Internal function that delay updates moving_vars if is_training."""
update_moving_mean = moving_averages.assign_moving_average(
@@ -827,9 +809,8 @@ def batch_norm(inputs,
moving_variance, variance, decay, zero_debias=False)
return update_moving_mean, update_moving_variance
- update_mean, update_variance = utils.smart_cond(is_training,
- _delay_updates,
- moving_vars_fn)
+ update_mean, update_variance = utils.smart_cond(
+ is_training, _delay_updates, moving_vars_fn)
ops.add_to_collections(updates_collections, update_mean)
ops.add_to_collections(updates_collections, update_variance)
# Use computed moments during training and moving_vars otherwise.
@@ -897,8 +878,8 @@ def bias_add(inputs,
"""
if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC):
raise ValueError('data_format has to be either NCHW or NHWC.')
- with variable_scope.variable_scope(scope, 'BiasAdd', [inputs],
- reuse=reuse) as sc:
+ with variable_scope.variable_scope(
+ scope, 'BiasAdd', [inputs], reuse=reuse) as sc:
inputs = ops.convert_to_tensor(inputs)
dtype = inputs.dtype.base_dtype
inputs_shape = inputs.get_shape()
@@ -913,13 +894,16 @@ def bias_add(inputs,
raise ValueError('`C` dimension must be known but is None')
biases_collections = utils.get_variable_collections(variables_collections,
'biases')
- biases = variables.model_variable('biases',
- shape=[num_features,],
- dtype=dtype,
- initializer=initializer,
- regularizer=regularizer,
- collections=biases_collections,
- trainable=trainable)
+ biases = variables.model_variable(
+ 'biases',
+ shape=[
+ num_features,
+ ],
+ dtype=dtype,
+ initializer=initializer,
+ regularizer=regularizer,
+ collections=biases_collections,
+ trainable=trainable)
outputs = nn.bias_add(inputs, biases, data_format=data_format)
if activation_fn is not None:
outputs = activation_fn(outputs)
@@ -1019,8 +1003,10 @@ def convolution(inputs,
if data_format not in [None, 'NWC', 'NCW', 'NHWC', 'NCHW', 'NDHWC', 'NCDHW']:
raise ValueError('Invalid data_format: %r' % (data_format,))
- layer_variable_getter = _build_variable_getter(
- {'bias': 'biases', 'kernel': 'weights'})
+ layer_variable_getter = _build_variable_getter({
+ 'bias': 'biases',
+ 'kernel': 'weights'
+ })
with variable_scope.variable_scope(
scope, 'Conv', [inputs], reuse=reuse,
@@ -1038,26 +1024,27 @@ def convolution(inputs,
raise ValueError('Convolution not supported for input with rank',
input_rank)
- df = ('channels_first' if data_format and data_format.startswith('NC')
- else 'channels_last')
- layer = layer_class(filters=num_outputs,
- kernel_size=kernel_size,
- strides=stride,
- padding=padding,
- data_format=df,
- dilation_rate=rate,
- activation=None,
- use_bias=not normalizer_fn and biases_initializer,
- kernel_initializer=weights_initializer,
- bias_initializer=biases_initializer,
- kernel_regularizer=weights_regularizer,
- bias_regularizer=biases_regularizer,
- activity_regularizer=None,
- trainable=trainable,
- name=sc.name,
- dtype=inputs.dtype.base_dtype,
- _scope=sc,
- _reuse=reuse)
+ df = ('channels_first'
+ if data_format and data_format.startswith('NC') else 'channels_last')
+ layer = layer_class(
+ filters=num_outputs,
+ kernel_size=kernel_size,
+ strides=stride,
+ padding=padding,
+ data_format=df,
+ dilation_rate=rate,
+ activation=None,
+ use_bias=not normalizer_fn and biases_initializer,
+ kernel_initializer=weights_initializer,
+ bias_initializer=biases_initializer,
+ kernel_regularizer=weights_regularizer,
+ bias_regularizer=biases_regularizer,
+ activity_regularizer=None,
+ trainable=trainable,
+ name=sc.name,
+ dtype=inputs.dtype.base_dtype,
+ _scope=sc,
+ _reuse=reuse)
outputs = layer.apply(inputs)
# Add variables to collections.
@@ -1073,6 +1060,7 @@ def convolution(inputs,
outputs = activation_fn(outputs)
return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
+
convolution2d = convolution
convolution3d = convolution
@@ -1148,13 +1136,14 @@ def convolution2d_in_plane(
weights_shape = [kernel_h, kernel_w, 1, 1]
weights_collections = utils.get_variable_collections(
variables_collections, 'weights')
- weights = variables.model_variable('weights',
- shape=weights_shape,
- dtype=dtype,
- initializer=weights_initializer,
- regularizer=weights_regularizer,
- collections=weights_collections,
- trainable=trainable)
+ weights = variables.model_variable(
+ 'weights',
+ shape=weights_shape,
+ dtype=dtype,
+ initializer=weights_initializer,
+ regularizer=weights_regularizer,
+ collections=weights_collections,
+ trainable=trainable)
depthwise_weights = array_ops.tile(weights, [1, 1, num_filters_in, 1])
outputs = nn.depthwise_conv2d(inputs, depthwise_weights,
[1, stride_h, stride_w, 1], padding)
@@ -1165,13 +1154,16 @@ def convolution2d_in_plane(
if biases_initializer is not None:
biases_collections = utils.get_variable_collections(
variables_collections, 'biases')
- biases = variables.model_variable('biases',
- shape=[num_filters_in,],
- dtype=dtype,
- initializer=biases_initializer,
- regularizer=biases_regularizer,
- collections=biases_collections,
- trainable=trainable)
+ biases = variables.model_variable(
+ 'biases',
+ shape=[
+ num_filters_in,
+ ],
+ dtype=dtype,
+ initializer=biases_initializer,
+ regularizer=biases_regularizer,
+ collections=biases_collections,
+ trainable=trainable)
outputs = nn.bias_add(outputs, biases)
if activation_fn is not None:
@@ -1244,19 +1236,23 @@ def convolution2d_transpose(
ValueError: If `data_format` is neither `NHWC` nor `NCHW`.
ValueError: If `C` dimension of `inputs` is None.
"""
- layer_variable_getter = _build_variable_getter(
- {'bias': 'biases', 'kernel': 'weights'})
+ layer_variable_getter = _build_variable_getter({
+ 'bias': 'biases',
+ 'kernel': 'weights'
+ })
with variable_scope.variable_scope(
- scope, 'Conv2d_transpose', [inputs], reuse=reuse,
+ scope,
+ 'Conv2d_transpose', [inputs],
+ reuse=reuse,
custom_getter=layer_variable_getter) as sc:
if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC):
raise ValueError('data_format has to be either NCHW or NHWC.')
inputs = ops.convert_to_tensor(inputs)
- df = ('channels_first' if data_format and data_format.startswith('NC')
- else 'channels_last')
+ df = ('channels_first'
+ if data_format and data_format.startswith('NC') else 'channels_last')
layer = convolutional_layers.Convolution2DTranspose(
filters=num_outputs,
kernel_size=kernel_size,
@@ -1353,19 +1349,23 @@ def convolution3d_transpose(
ValueError: If `data_format` is neither `NDHWC` nor `NCDHW`.
ValueError: If `C` dimension of `inputs` is None.
"""
- layer_variable_getter = _build_variable_getter(
- {'bias': 'biases', 'kernel': 'weights'})
+ layer_variable_getter = _build_variable_getter({
+ 'bias': 'biases',
+ 'kernel': 'weights'
+ })
with variable_scope.variable_scope(
- scope, 'Conv3d_transpose', [inputs], reuse=reuse,
+ scope,
+ 'Conv3d_transpose', [inputs],
+ reuse=reuse,
custom_getter=layer_variable_getter) as sc:
if data_format not in (DATA_FORMAT_NCDHW, DATA_FORMAT_NDHWC):
raise ValueError('data_format has to be either NCDHW or NDHWC.')
inputs = ops.convert_to_tensor(inputs)
- df = ('channels_first' if data_format and data_format.startswith('NC')
- else 'channels_last')
+ df = ('channels_first'
+ if data_format and data_format.startswith('NC') else 'channels_last')
layer = convolutional_layers.Convolution3DTranspose(
filters=num_outputs,
kernel_size=kernel_size,
@@ -1434,19 +1434,18 @@ def dropout(inputs,
with variable_scope.variable_scope(
scope, 'Dropout', [inputs], custom_getter=_model_variable_getter) as sc:
inputs = ops.convert_to_tensor(inputs)
- layer = core_layers.Dropout(rate=1 - keep_prob,
- noise_shape=noise_shape,
- seed=seed,
- name=sc.name,
- _scope=sc)
+ layer = core_layers.Dropout(
+ rate=1 - keep_prob,
+ noise_shape=noise_shape,
+ seed=seed,
+ name=sc.name,
+ _scope=sc)
outputs = layer.apply(inputs, training=is_training)
return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
@add_arg_scope
-def flatten(inputs,
- outputs_collections=None,
- scope=None):
+def flatten(inputs, outputs_collections=None, scope=None):
"""Flattens the input while maintaining the batch_size.
Assumes that the first dimension represents the batch.
@@ -1478,8 +1477,8 @@ def _sparse_inner_flatten(inputs, new_rank):
outer_dimensions = inputs.dense_shape[:new_rank - 1]
inner_dimensions = inputs.dense_shape[new_rank - 1:]
- new_shape = array_ops.concat((outer_dimensions,
- [math_ops.reduce_prod(inner_dimensions)]), 0)
+ new_shape = array_ops.concat(
+ (outer_dimensions, [math_ops.reduce_prod(inner_dimensions)]), 0)
flattened = sparse_ops.sparse_reshape(inputs, new_shape)
return flattened
@@ -1545,10 +1544,18 @@ def _inner_flatten(inputs, new_rank, output_collections=None, scope=None):
return utils.collect_named_outputs(output_collections, sc, flattened)
-def _model_variable_getter(getter, name, shape=None, dtype=None,
- initializer=None, regularizer=None, trainable=True,
- collections=None, caching_device=None,
- partitioner=None, rename=None, use_resource=None,
+def _model_variable_getter(getter,
+ name,
+ shape=None,
+ dtype=None,
+ initializer=None,
+ regularizer=None,
+ trainable=True,
+ collections=None,
+ caching_device=None,
+ partitioner=None,
+ rename=None,
+ use_resource=None,
**_):
"""Getter that uses model_variable for compatibility with core layers."""
short_name = name.split('/')[-1]
@@ -1557,25 +1564,34 @@ def _model_variable_getter(getter, name, shape=None, dtype=None,
name_components[-1] = rename[short_name]
name = '/'.join(name_components)
return variables.model_variable(
- name, shape=shape, dtype=dtype, initializer=initializer,
- regularizer=regularizer, collections=collections, trainable=trainable,
- caching_device=caching_device, partitioner=partitioner,
- custom_getter=getter, use_resource=use_resource)
+ name,
+ shape=shape,
+ dtype=dtype,
+ initializer=initializer,
+ regularizer=regularizer,
+ collections=collections,
+ trainable=trainable,
+ caching_device=caching_device,
+ partitioner=partitioner,
+ custom_getter=getter,
+ use_resource=use_resource)
def _build_variable_getter(rename=None):
"""Build a model variable getter that respects scope getter and renames."""
+
# VariableScope will nest the getters
def layer_variable_getter(getter, *args, **kwargs):
kwargs['rename'] = rename
return _model_variable_getter(getter, *args, **kwargs)
+
return layer_variable_getter
def _add_variable_to_collections(variable, collections_set, collections_name):
"""Adds variable (or all its parts) to all collections with that name."""
- collections = utils.get_variable_collections(
- collections_set, collections_name) or []
+ collections = utils.get_variable_collections(collections_set,
+ collections_name) or []
variables_list = [variable]
if isinstance(variable, tf_variables.PartitionedVariable):
variables_list = [v for v in variable]
@@ -1644,15 +1660,19 @@ def fully_connected(inputs,
ValueError: If x has rank less than 2 or if its last dimension is not set.
"""
if not isinstance(num_outputs, six.integer_types):
- raise ValueError(
- 'num_outputs should be int or long, got %s.' % (num_outputs,))
+ raise ValueError('num_outputs should be int or long, got %s.' %
+ (num_outputs,))
- layer_variable_getter = _build_variable_getter({'bias': 'biases',
- 'kernel': 'weights'})
+ layer_variable_getter = _build_variable_getter({
+ 'bias': 'biases',
+ 'kernel': 'weights'
+ })
with variable_scope.variable_scope(
- scope, 'fully_connected', [inputs],
- reuse=reuse, custom_getter=layer_variable_getter) as sc:
+ scope,
+ 'fully_connected', [inputs],
+ reuse=reuse,
+ custom_getter=layer_variable_getter) as sc:
inputs = ops.convert_to_tensor(inputs)
layer = core_layers.Dense(
units=num_outputs,
@@ -1758,15 +1778,17 @@ class GDN(base.Layer):
inverse=False,
beta_min=1e-6,
gamma_init=.1,
- reparam_offset=2 ** -18,
+ reparam_offset=2**-18,
data_format='channels_last',
activity_regularizer=None,
trainable=True,
name=None,
**kwargs):
- super(GDN, self).__init__(trainable=trainable, name=name,
- activity_regularizer=activity_regularizer,
- **kwargs)
+ super(GDN, self).__init__(
+ trainable=trainable,
+ name=name,
+ activity_regularizer=activity_regularizer,
+ **kwargs)
self.inverse = inverse
self._beta_min = beta_min
self._gamma_init = gamma_init
@@ -1801,8 +1823,9 @@ class GDN(base.Layer):
with ops.name_scope(name, 'GDNLowerBound', [inputs, bound]) as scope:
inputs = ops.convert_to_tensor(inputs, name='inputs')
bound = ops.convert_to_tensor(bound, name='bound')
- with ops.get_default_graph().gradient_override_map(
- {'Maximum': 'GDNLowerBound'}):
+ with ops.get_default_graph().gradient_override_map({
+ 'Maximum': 'GDNLowerBound'
+ }):
return math_ops.maximum(inputs, bound, name=scope)
@staticmethod
@@ -1829,12 +1852,14 @@ class GDN(base.Layer):
raise ValueError('The channel dimension of the inputs to `GDN` '
'must be defined.')
self._input_rank = input_shape.ndims
- self.input_spec = base.InputSpec(ndim=input_shape.ndims,
- axes={channel_axis: num_channels})
+ self.input_spec = base.InputSpec(
+ ndim=input_shape.ndims, axes={
+ channel_axis: num_channels
+ })
- pedestal = array_ops.constant(self._reparam_offset ** 2, dtype=self.dtype)
+ pedestal = array_ops.constant(self._reparam_offset**2, dtype=self.dtype)
beta_bound = array_ops.constant(
- (self._beta_min + self._reparam_offset ** 2) ** .5, dtype=self.dtype)
+ (self._beta_min + self._reparam_offset**2)**.5, dtype=self.dtype)
gamma_bound = array_ops.constant(self._reparam_offset, dtype=self.dtype)
def beta_initializer(shape, dtype=None, partition_info=None):
@@ -1848,19 +1873,21 @@ class GDN(base.Layer):
eye = linalg_ops.eye(shape[0], dtype=dtype)
return math_ops.sqrt(self._gamma_init * eye + pedestal)
- beta = self.add_variable('reparam_beta',
- shape=[num_channels],
- initializer=beta_initializer,
- dtype=self.dtype,
- trainable=True)
+ beta = self.add_variable(
+ 'reparam_beta',
+ shape=[num_channels],
+ initializer=beta_initializer,
+ dtype=self.dtype,
+ trainable=True)
beta = self._lower_bound(beta, beta_bound)
self.beta = math_ops.square(beta) - pedestal
- gamma = self.add_variable('reparam_gamma',
- shape=[num_channels, num_channels],
- initializer=gamma_initializer,
- dtype=self.dtype,
- trainable=True)
+ gamma = self.add_variable(
+ 'reparam_gamma',
+ shape=[num_channels, num_channels],
+ initializer=gamma_initializer,
+ dtype=self.dtype,
+ trainable=True)
gamma = self._lower_bound(gamma, gamma_bound)
self.gamma = math_ops.square(gamma) - pedestal
@@ -1875,8 +1902,11 @@ class GDN(base.Layer):
# Compute normalization pool.
if self.data_format == 'channels_first':
- norm_pool = nn.convolution(math_ops.square(inputs), gamma, 'VALID',
- data_format='NC' + 'DHW'[-(ndim - 2):])
+ norm_pool = nn.convolution(
+ math_ops.square(inputs),
+ gamma,
+ 'VALID',
+ data_format='NC' + 'DHW' [-(ndim - 2):])
if ndim == 3:
norm_pool = array_ops.expand_dims(norm_pool, 2)
norm_pool = nn.bias_add(norm_pool, self.beta, data_format='NCHW')
@@ -1918,7 +1948,7 @@ def gdn(inputs,
inverse=False,
beta_min=1e-6,
gamma_init=.1,
- reparam_offset=2 ** -18,
+ reparam_offset=2**-18,
data_format='channels_last',
activity_regularizer=None,
trainable=True,
@@ -1984,17 +2014,18 @@ def gdn(inputs,
Returns:
Output tensor.
"""
- layer = GDN(inverse=inverse,
- beta_min=beta_min,
- gamma_init=gamma_init,
- reparam_offset=reparam_offset,
- data_format=data_format,
- activity_regularizer=activity_regularizer,
- trainable=trainable,
- name=name,
- dtype=inputs.dtype.base_dtype,
- _scope=name,
- _reuse=reuse)
+ layer = GDN(
+ inverse=inverse,
+ beta_min=beta_min,
+ gamma_init=gamma_init,
+ reparam_offset=reparam_offset,
+ data_format=data_format,
+ activity_regularizer=activity_regularizer,
+ trainable=trainable,
+ name=name,
+ dtype=inputs.dtype.base_dtype,
+ _scope=name,
+ _reuse=reuse)
return layer.apply(inputs)
@@ -2070,8 +2101,8 @@ def layer_norm(inputs,
or if `inputs.shape[begin_params_axis:]` is not fully defined at
graph build time.
"""
- with variable_scope.variable_scope(scope, 'LayerNorm', [inputs],
- reuse=reuse) as sc:
+ with variable_scope.variable_scope(
+ scope, 'LayerNorm', [inputs], reuse=reuse) as sc:
inputs = ops.convert_to_tensor(inputs)
inputs_shape = inputs.shape
inputs_rank = inputs_shape.ndims
@@ -2081,15 +2112,14 @@ def layer_norm(inputs,
if begin_norm_axis < 0:
begin_norm_axis = inputs_rank + begin_norm_axis
if begin_params_axis >= inputs_rank or begin_norm_axis >= inputs_rank:
- raise ValueError(
- 'begin_params_axis (%d) and begin_norm_axis (%d) '
- 'must be < rank(inputs) (%d)'
- % (begin_params_axis, begin_norm_axis, inputs_rank))
+ raise ValueError('begin_params_axis (%d) and begin_norm_axis (%d) '
+ 'must be < rank(inputs) (%d)' %
+ (begin_params_axis, begin_norm_axis, inputs_rank))
params_shape = inputs_shape[begin_params_axis:]
if not params_shape.is_fully_defined():
raise ValueError(
- 'Inputs %s: shape(inputs)[%s:] is not fully defined: %s' % (
- inputs.name, begin_params_axis, inputs_shape))
+ 'Inputs %s: shape(inputs)[%s:] is not fully defined: %s' %
+ (inputs.name, begin_params_axis, inputs_shape))
# Allocate parameters for the beta and gamma of the normalization.
beta, gamma = None, None
if center:
@@ -2103,8 +2133,8 @@ def layer_norm(inputs,
collections=beta_collections,
trainable=trainable)
if scale:
- gamma_collections = utils.get_variable_collections(variables_collections,
- 'gamma')
+ gamma_collections = utils.get_variable_collections(
+ variables_collections, 'gamma')
gamma = variables.model_variable(
'gamma',
shape=params_shape,
@@ -2118,7 +2148,11 @@ def layer_norm(inputs,
# Compute layer normalization using the batch_normalization function.
variance_epsilon = 1e-12
outputs = nn.batch_normalization(
- inputs, mean, variance, offset=beta, scale=gamma,
+ inputs,
+ mean,
+ variance,
+ offset=beta,
+ scale=gamma,
variance_epsilon=variance_epsilon)
outputs.set_shape(inputs_shape)
if activation_fn is not None:
@@ -2164,13 +2198,14 @@ def max_pool2d(inputs,
raise ValueError('data_format has to be either NCHW or NHWC.')
with ops.name_scope(scope, 'MaxPool2D', [inputs]) as sc:
inputs = ops.convert_to_tensor(inputs)
- df = ('channels_first' if data_format and data_format.startswith('NC')
- else 'channels_last')
- layer = pooling_layers.MaxPooling2D(pool_size=kernel_size,
- strides=stride,
- padding=padding,
- data_format=df,
- _scope=sc)
+ df = ('channels_first'
+ if data_format and data_format.startswith('NC') else 'channels_last')
+ layer = pooling_layers.MaxPooling2D(
+ pool_size=kernel_size,
+ strides=stride,
+ padding=padding,
+ data_format=df,
+ _scope=sc)
outputs = layer.apply(inputs)
return utils.collect_named_outputs(outputs_collections, sc, outputs)
@@ -2213,13 +2248,14 @@ def max_pool3d(inputs,
raise ValueError('data_format has to be either NCDHW or NDHWC.')
with ops.name_scope(scope, 'MaxPool3D', [inputs]) as sc:
inputs = ops.convert_to_tensor(inputs)
- df = ('channels_first' if data_format and data_format.startswith('NC')
- else 'channels_last')
- layer = pooling_layers.MaxPooling3D(pool_size=kernel_size,
- strides=stride,
- padding=padding,
- data_format=df,
- _scope=sc)
+ df = ('channels_first'
+ if data_format and data_format.startswith('NC') else 'channels_last')
+ layer = pooling_layers.MaxPooling3D(
+ pool_size=kernel_size,
+ strides=stride,
+ padding=padding,
+ data_format=df,
+ _scope=sc)
outputs = layer.apply(inputs)
return utils.collect_named_outputs(outputs_collections, sc, outputs)
@@ -2272,8 +2308,8 @@ def pool(inputs,
"""
# pylint: enable=line-too-long
- with ops.name_scope(scope, '%s_pool' %
- (pooling_type.lower()), [inputs]) as sc:
+ with ops.name_scope(scope, '%s_pool' % (pooling_type.lower()),
+ [inputs]) as sc:
inputs = ops.convert_to_tensor(inputs)
input_rank = inputs.get_shape().ndims
if input_rank is None:
@@ -2318,18 +2354,16 @@ def one_hot_encoding(labels,
labels = ops.convert_to_tensor(labels)
if labels.dtype == dtypes.int32:
labels = standard_ops.to_int64(labels)
- outputs = standard_ops.one_hot(labels,
- num_classes,
- on_value=on_value,
- off_value=off_value)
+ outputs = standard_ops.one_hot(
+ labels, num_classes, on_value=on_value, off_value=off_value)
return utils.collect_named_outputs(outputs_collections, sc, outputs)
def _apply_activation(y, activation_fn, output_collections):
if activation_fn is not None:
y = activation_fn(y)
- ops.add_to_collections(list(output_collections or []) +
- [ops.GraphKeys.ACTIVATIONS], y)
+ ops.add_to_collections(
+ list(output_collections or []) + [ops.GraphKeys.ACTIVATIONS], y)
return y
@@ -2374,7 +2408,7 @@ def repeat(inputs, repetitions, layer, *args, **kwargs):
scope = 'repeat'
outputs = inputs
for i in range(repetitions):
- kwargs['scope'] = scope + '_' + str(i+1)
+ kwargs['scope'] = scope + '_' + str(i + 1)
outputs = layer(outputs, *args, **kwargs)
return outputs
@@ -2389,8 +2423,8 @@ def _scale_gradient_grad(op, grad):
return [grad * op.inputs[1], None]
-@function.Defun(python_grad_func=_scale_gradient_grad,
- shape_func=_scale_gradient_shape)
+@function.Defun(
+ python_grad_func=_scale_gradient_grad, shape_func=_scale_gradient_shape)
def scale_gradient(inputs, gradient_multiplier):
"""Identity operation, but with the gradient multiplied by a tensor.
@@ -2495,18 +2529,21 @@ def separable_convolution2d(
"""
if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC):
raise ValueError('data_format has to be either NCHW or NHWC.')
- layer_variable_getter = _build_variable_getter(
- {'bias': 'biases',
- 'depthwise_kernel': 'depthwise_weights',
- 'pointwise_kernel': 'pointwise_weights'})
+ layer_variable_getter = _build_variable_getter({
+ 'bias': 'biases',
+ 'depthwise_kernel': 'depthwise_weights',
+ 'pointwise_kernel': 'pointwise_weights'
+ })
with variable_scope.variable_scope(
- scope, 'SeparableConv2d', [inputs], reuse=reuse,
+ scope,
+ 'SeparableConv2d', [inputs],
+ reuse=reuse,
custom_getter=layer_variable_getter) as sc:
inputs = ops.convert_to_tensor(inputs)
- df = ('channels_first' if data_format and data_format.startswith('NC')
- else 'channels_last')
+ df = ('channels_first'
+ if data_format and data_format.startswith('NC') else 'channels_last')
if num_outputs is not None:
# Apply separable conv using the SeparableConvolution2D layer.
layer = convolutional_layers.SeparableConvolution2D(
@@ -2539,8 +2576,8 @@ def separable_convolution2d(
_add_variable_to_collections(layer.pointwise_kernel,
variables_collections, 'weights')
if layer.bias is not None:
- _add_variable_to_collections(layer.bias,
- variables_collections, 'biases')
+ _add_variable_to_collections(layer.bias, variables_collections,
+ 'biases')
if normalizer_fn is not None:
normalizer_params = normalizer_params or {}
@@ -2555,8 +2592,7 @@ def separable_convolution2d(
weights_collections = utils.get_variable_collections(
variables_collections, 'weights')
- depthwise_shape = [kernel_h, kernel_w,
- num_filters_in, depth_multiplier]
+ depthwise_shape = [kernel_h, kernel_w, num_filters_in, depth_multiplier]
depthwise_weights = variables.model_variable(
'depthwise_weights',
shape=depthwise_shape,
@@ -2570,9 +2606,13 @@ def separable_convolution2d(
1, stride_h, stride_w, 1
]
- outputs = nn.depthwise_conv2d(inputs, depthwise_weights, strides, padding,
- rate=utils.two_element_tuple(rate),
- data_format=data_format)
+ outputs = nn.depthwise_conv2d(
+ inputs,
+ depthwise_weights,
+ strides,
+ padding,
+ rate=utils.two_element_tuple(rate),
+ data_format=data_format)
num_outputs = depth_multiplier * num_filters_in
if normalizer_fn is not None:
@@ -2582,13 +2622,16 @@ def separable_convolution2d(
if biases_initializer is not None:
biases_collections = utils.get_variable_collections(
variables_collections, 'biases')
- biases = variables.model_variable('biases',
- shape=[num_outputs,],
- dtype=dtype,
- initializer=biases_initializer,
- regularizer=biases_regularizer,
- trainable=trainable,
- collections=biases_collections)
+ biases = variables.model_variable(
+ 'biases',
+ shape=[
+ num_outputs,
+ ],
+ dtype=dtype,
+ initializer=biases_initializer,
+ regularizer=biases_regularizer,
+ trainable=trainable,
+ collections=biases_collections)
outputs = nn.bias_add(outputs, biases, data_format=data_format)
if activation_fn is not None:
@@ -2673,23 +2716,24 @@ def spatial_softmax(features,
with ops.name_scope('spatial_softmax_op', 'spatial_softmax_op', [features]):
# Create tensors for x and y coordinate values, scaled to range [-1, 1].
- pos_x, pos_y = array_ops.meshgrid(math_ops.lin_space(-1., 1., num=height),
- math_ops.lin_space(-1., 1., num=width),
- indexing='ij')
+ pos_x, pos_y = array_ops.meshgrid(
+ math_ops.lin_space(-1., 1., num=height),
+ math_ops.lin_space(-1., 1., num=width),
+ indexing='ij')
pos_x = array_ops.reshape(pos_x, [height * width])
pos_y = array_ops.reshape(pos_y, [height * width])
-
+
if temperature is None:
temp_initializer = init_ops.ones_initializer()
else:
temp_initializer = init_ops.constant_initializer(temperature)
-
+
if not trainable:
temp_collections = None
else:
temp_collections = utils.get_variable_collections(
- variables_collections, 'temperature')
-
+ variables_collections, 'temperature')
+
temperature = variables.model_variable(
'temperature',
shape=(),
@@ -2703,14 +2747,14 @@ def spatial_softmax(features,
features = array_ops.reshape(
array_ops.transpose(features, [0, 3, 1, 2]), [-1, height * width])
- softmax_attention = nn.softmax(features/temperature)
+ softmax_attention = nn.softmax(features / temperature)
expected_x = math_ops.reduce_sum(
pos_x * softmax_attention, [1], keep_dims=True)
expected_y = math_ops.reduce_sum(
pos_y * softmax_attention, [1], keep_dims=True)
expected_xy = array_ops.concat([expected_x, expected_y], 1)
- feature_keypoints = array_ops.reshape(
- expected_xy, [-1, num_channels.value * 2])
+ feature_keypoints = array_ops.reshape(expected_xy,
+ [-1, num_channels.value * 2])
feature_keypoints.set_shape([None, num_channels.value * 2])
return feature_keypoints
@@ -2762,7 +2806,7 @@ def stack(inputs, layer, stack_args, **kwargs):
scope = 'stack'
outputs = inputs
for i in range(len(stack_args)):
- kwargs['scope'] = scope + '_' + str(i+1)
+ kwargs['scope'] = scope + '_' + str(i + 1)
layer_args = stack_args[i]
if not isinstance(layer_args, (list, tuple)):
layer_args = [layer_args]
@@ -2793,11 +2837,10 @@ def unit_norm(inputs, dim, epsilon=1e-7, scope=None):
raise ValueError('The input rank must be known.')
input_rank = len(inputs.get_shape().as_list())
if dim < 0 or dim >= input_rank:
- raise ValueError(
- 'dim must be positive but smaller than the input rank.')
+ raise ValueError('dim must be positive but smaller than the input rank.')
- lengths = math_ops.sqrt(epsilon + math_ops.reduce_sum(
- math_ops.square(inputs), dim, True))
+ lengths = math_ops.sqrt(
+ epsilon + math_ops.reduce_sum(math_ops.square(inputs), dim, True))
multiples = []
if dim > 0:
multiples.append(array_ops.ones([dim], dtypes.int32))
@@ -2938,29 +2981,31 @@ def legacy_fully_connected(x,
raise ValueError('last dimension of x must be known but is None')
dtype = x.dtype.base_dtype
- weight_collections = set(list(weight_collections or []) +
- [ops.GraphKeys.GLOBAL_VARIABLES])
- w = variable_scope.get_variable('weights',
- shape=[num_input_units, num_output_units],
- dtype=dtype,
- initializer=weight_init,
- collections=weight_collections,
- regularizer=weight_regularizer,
- trainable=trainable)
- x_2_dim = x if len(dims) <= 2 else array_ops.reshape(x,
- [-1, num_input_units])
+ weight_collections = set(
+ list(weight_collections or []) + [ops.GraphKeys.GLOBAL_VARIABLES])
+ w = variable_scope.get_variable(
+ 'weights',
+ shape=[num_input_units, num_output_units],
+ dtype=dtype,
+ initializer=weight_init,
+ collections=weight_collections,
+ regularizer=weight_regularizer,
+ trainable=trainable)
+ x_2_dim = x if len(dims) <= 2 else array_ops.reshape(
+ x, [-1, num_input_units])
y = standard_ops.matmul(x_2_dim, w)
if bias_init is not None:
- bias_collections = set(list(bias_collections or []) +
- [ops.GraphKeys.GLOBAL_VARIABLES])
- b = variable_scope.get_variable('bias',
- shape=[num_output_units],
- dtype=dtype,
- initializer=bias_init,
- collections=bias_collections,
- regularizer=bias_regularizer,
- trainable=trainable)
+ bias_collections = set(
+ list(bias_collections or []) + [ops.GraphKeys.GLOBAL_VARIABLES])
+ b = variable_scope.get_variable(
+ 'bias',
+ shape=[num_output_units],
+ dtype=dtype,
+ initializer=bias_init,
+ collections=bias_collections,
+ regularizer=bias_regularizer,
+ trainable=trainable)
y = nn.bias_add(y, b)
diff --git a/tensorflow/contrib/opt/python/training/model_average_optimizer.py b/tensorflow/contrib/opt/python/training/model_average_optimizer.py
index 47509ecca6..a7c97a1da2 100644
--- a/tensorflow/contrib/opt/python/training/model_average_optimizer.py
+++ b/tensorflow/contrib/opt/python/training/model_average_optimizer.py
@@ -12,30 +12,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
-"""Wrapper optimizer for Model Average """
+"""Wrapper optimizer for Model Average."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import dtypes
from tensorflow.python.framework import constant_op
-from tensorflow.python.training import optimizer
-from tensorflow.python.training import session_run_hook
-from tensorflow.python.ops import math_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
-from tensorflow.python.ops import state_ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.training import optimizer
+from tensorflow.python.training import session_run_hook
-GLOBAL_VARIABLE_NAME = 'global_center_variable'
+GLOBAL_VARIABLE_NAME = "global_center_variable"
class ModelAverageCustomGetter(object):
- """Custom_getter class is used to do:
+ """Custom_getter class is used to do.
+
1. Change trainable variables to local collection and place them at worker
device
2. Generate global variables
@@ -73,15 +73,18 @@ class ModelAverageCustomGetter(object):
def __call__(self, getter, name, trainable, collections, *args, **kwargs):
if trainable:
with ops.device(self._worker_device):
- local_var = getter(name, trainable=True,
- collections=[ops.GraphKeys.LOCAL_VARIABLES],
- *args, **kwargs)
+ local_var = getter(
+ name,
+ trainable=True,
+ collections=[ops.GraphKeys.LOCAL_VARIABLES],
+ *args,
+ **kwargs)
global_variable = variable_scope.variable(
- name='%s/%s' % (GLOBAL_VARIABLE_NAME, name),
- initial_value=local_var.initialized_value(),
- trainable=False,
- collections=[ops.GraphKeys.GLOBAL_VARIABLES])
+ name="%s/%s" % (GLOBAL_VARIABLE_NAME, name),
+ initial_value=local_var.initialized_value(),
+ trainable=False,
+ collections=[ops.GraphKeys.GLOBAL_VARIABLES])
self._local_2_global[local_var] = global_variable
return local_var
@@ -91,6 +94,7 @@ class ModelAverageCustomGetter(object):
class ModelAverageOptimizer(optimizer.Optimizer):
"""Wrapper optimizer that implements the Model Average algorithm.
+
This is a sync optimizer. During the training, each worker will update
the local variables and maintains its own local_step, which starts from 0
and is incremented by 1 after each update of local variables. Whenever the
@@ -99,15 +103,14 @@ class ModelAverageOptimizer(optimizer.Optimizer):
local variables will be assigned by global center variables.
"""
- def __init__(
- self,
- opt,
- num_worker,
- is_chief,
- ma_custom_getter,
- interval_steps=100,
- use_locking=True,
- name="ModelAverageOptimizer"):
+ def __init__(self,
+ opt,
+ num_worker,
+ is_chief,
+ ma_custom_getter,
+ interval_steps=100,
+ use_locking=True,
+ name="ModelAverageOptimizer"):
"""Construct a new model average optimizer.
Args:
@@ -124,18 +127,18 @@ class ModelAverageOptimizer(optimizer.Optimizer):
self._opt = opt
self._num_worker = num_worker
self._is_chief = is_chief
- self._local_2_global = ma_custom_getter._local_2_global
+ self._local_2_global = ma_custom_getter._local_2_global # pylint:disable=protected-access
self._interval_steps = interval_steps
self._accumulator_list = []
self._chief_init_op = None
self._local_step = variable_scope.get_variable(
- initializer=0,
- trainable=False,
- collections=[ops.GraphKeys.LOCAL_VARIABLES],
- name="local_step")
+ initializer=0,
+ trainable=False,
+ collections=[ops.GraphKeys.LOCAL_VARIABLES],
+ name="local_step")
- self._opt._prepare()
+ self._opt._prepare() # pylint:disable=protected-access
def compute_gradients(self, *args, **kwargs):
"""Compute gradients of "loss" for the variables in "var_list".
@@ -159,10 +162,12 @@ class ModelAverageOptimizer(optimizer.Optimizer):
Returns:
An update op
+
+ Raises:
+ ValueError: if var_list is empty.
"""
if not var_list:
- raise ValueError(
- 'The list of local_variables should not be empty')
+ raise ValueError("The list of local_variables should not be empty")
update_ops = []
global_center_vars = [self._local_2_global[var] for var in var_list]
for lvar, gvar in zip(var_list, global_center_vars):
@@ -204,28 +209,29 @@ class ModelAverageOptimizer(optimizer.Optimizer):
apply_updates = self._opt.apply_gradients(grads_and_vars)
with ops.control_dependencies([apply_updates]):
local_update = state_ops.assign_add(
- self._local_step, 1, name='local_step_update').op
+ self._local_step, 1, name="local_step_update").op
# update global variables.
- def _Update_global_variables():
+ def _update_global_variables(): # pylint: disable=missing-docstring
local_vars = [v for g, v in grads_and_vars if g is not None]
global_vars = [self._local_2_global[v] for v in local_vars]
# sync queue
with ops.colocate_with(global_step):
- sync_queue = data_flow_ops.FIFOQueue(-1, [dtypes.bool], shapes=[[]],
- shared_name='sync_queue')
+ sync_queue = data_flow_ops.FIFOQueue(
+ -1, [dtypes.bool], shapes=[[]], shared_name="sync_queue")
train_ops = []
aggregated_vars = []
- with ops.name_scope(None, self._name + '/global'):
+ with ops.name_scope(None, self._name + "/global"):
for var, gvar in zip(local_vars, global_vars):
+ # pylint: disable=protected-access
with ops.device(gvar.device):
if isinstance(var._ref(), ops.Tensor):
var_accum = data_flow_ops.ConditionalAccumulator(
- var.dtype,
- shape=var.get_shape(),
- shared_name=gvar.name + "/var_accum")
+ var.dtype,
+ shape=var.get_shape(),
+ shared_name=gvar.name + "/var_accum")
train_ops.append(
- var_accum.apply_grad(var._ref(), local_step=global_step))
+ var_accum.apply_grad(var._ref(), local_step=global_step))
aggregated_vars.append(var_accum.take_grad(self._num_worker))
else:
raise ValueError("Unknown local variable type!")
@@ -254,24 +260,26 @@ class ModelAverageOptimizer(optimizer.Optimizer):
return local_update_op
with ops.control_dependencies([local_update]):
- condition = math_ops.equal(math_ops.mod(
- self._local_step, self._interval_steps), 0)
+ condition = math_ops.equal(
+ math_ops.mod(self._local_step, self._interval_steps), 0)
conditional_update = control_flow_ops.cond(
- condition, _Update_global_variables, control_flow_ops.no_op)
+ condition, _update_global_variables, control_flow_ops.no_op)
chief_init_ops = []
for accum, dev in self._accumulator_list:
with ops.device(dev):
chief_init_ops.append(
- accum.set_global_step(
- global_step, name="SetGlobalStep"))
+ accum.set_global_step(global_step, name="SetGlobalStep"))
self._chief_init_op = control_flow_ops.group(*(chief_init_ops))
return conditional_update
def get_init_op(self):
- """Returns the op to let all the local variables equal to the global
- variables before the training begins"""
+ """Returns the op.
+
+ This method lets all the local variables equal to the global
+ variables before the training begins.
+ """
return self._local_vars_update(variables.trainable_variables())
def make_session_run_hook(self):
@@ -279,12 +287,13 @@ class ModelAverageOptimizer(optimizer.Optimizer):
return _ModelAverageOptimizerHook(self, self._is_chief)
-class _ModelAverageOptimizerHook(session_run_hook.SessionRunHook):
+class _ModelAverageOptimizerHook(session_run_hook.SessionRunHook): # pylint: disable=missing-docstring
+
def __init__(self, ma_optimizer, is_chief):
"""Creates hook to handle ModelAverageOptimizer initialization ops.
Args:
- ea_optimizer: `ModelAverageOptimizer` which this hook will initialize.
+ ma_optimizer: `ModelAverageOptimizer` which this hook will initialize.
is_chief: `Bool`, whether is this a chief replica or not.
"""
self._ma_optimizer = ma_optimizer
@@ -295,5 +304,5 @@ class _ModelAverageOptimizerHook(session_run_hook.SessionRunHook):
self._global_init_op = None
if self._is_chief:
self._global_init_op = variables.global_variables_initializer()
- self._chief_init_op = self._ma_optimizer._chief_init_op
+ self._chief_init_op = self._ma_optimizer._chief_init_op # pylint: disable=protected-access
self._variable_init_op = self._ma_optimizer.get_init_op()
diff --git a/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py b/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py
index a73aa772bb..29ecd22839 100644
--- a/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py
@@ -18,18 +18,18 @@ from __future__ import division
from __future__ import print_function
import portpicker
+
+from tensorflow.contrib.opt.python.training import model_average_optimizer
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
+from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
+from tensorflow.python.training import device_setter
from tensorflow.python.training import gradient_descent
from tensorflow.python.training import server_lib
from tensorflow.python.training import training
from tensorflow.python.training import training_util
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.training import device_setter
-from tensorflow.contrib.opt.python.training.model_average_optimizer import \
- ModelAverageOptimizer, ModelAverageCustomGetter, GLOBAL_VARIABLE_NAME
def create_local_cluster(num_workers, num_ps, protocol="grpc"):
@@ -37,20 +37,20 @@ def create_local_cluster(num_workers, num_ps, protocol="grpc"):
worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
cluster_dict = {
- "worker": ["localhost:%s" % port for port in worker_ports],
- "ps": ["localhost:%s" % port for port in ps_ports]
+ "worker": ["localhost:%s" % port for port in worker_ports],
+ "ps": ["localhost:%s" % port for port in ps_ports]
}
cs = server_lib.ClusterSpec(cluster_dict)
workers = [
- server_lib.Server(
- cs, job_name="worker", protocol=protocol, task_index=ix, start=True)
- for ix in range(num_workers)
+ server_lib.Server(
+ cs, job_name="worker", protocol=protocol, task_index=ix, start=True)
+ for ix in range(num_workers)
]
ps_servers = [
- server_lib.Server(
- cs, job_name="ps", protocol=protocol, task_index=ix, start=True)
- for ix in range(num_ps)
+ server_lib.Server(
+ cs, job_name="ps", protocol=protocol, task_index=ix, start=True)
+ for ix in range(num_ps)
]
return cluster_dict, workers, ps_servers
@@ -67,16 +67,16 @@ def _get_workers(num_workers, steps, workers):
is_chief = (worker_id == 0)
with graph.as_default():
worker_device = "/job:worker/task:%d/cpu:0" % (worker_id)
- ma_coustom = ModelAverageCustomGetter(
- worker_device=worker_device)
- with variable_scope.variable_scope('',
- custom_getter=ma_coustom), ops.device(
- device_setter.replica_device_setter(worker_device=worker_device,
- ps_device="/job:ps/task:0/cpu:0",
- ps_tasks=1)):
-
- global_step = variables.Variable(0, name='global_step',
- trainable=False)
+ ma_coustom = model_average_optimizer.ModelAverageCustomGetter(
+ worker_device=worker_device)
+ with variable_scope.variable_scope(
+ "", custom_getter=ma_coustom), ops.device(
+ device_setter.replica_device_setter(
+ worker_device=worker_device,
+ ps_device="/job:ps/task:0/cpu:0",
+ ps_tasks=1)):
+
+ global_step = variables.Variable(0, name="global_step", trainable=False)
var_0 = variable_scope.get_variable(initializer=0.0, name="v0")
var_1 = variable_scope.get_variable(initializer=1.0, name="v1")
@@ -88,22 +88,20 @@ def _get_workers(num_workers, steps, workers):
grads_0 = constant_op.constant(-2.0)
grads_1 = constant_op.constant(-2.0)
sgd_opt = gradient_descent.GradientDescentOptimizer(1.0)
- opt = ModelAverageOptimizer(
- opt=sgd_opt,
- num_worker=num_workers,
- ma_custom_getter=ma_coustom,
- is_chief=is_chief,
- interval_steps=steps
- )
+ opt = model_average_optimizer.ModelAverageOptimizer(
+ opt=sgd_opt,
+ num_worker=num_workers,
+ ma_custom_getter=ma_coustom,
+ is_chief=is_chief,
+ interval_steps=steps)
train_op = [
- opt.apply_gradients(
- [[grads_0, var_0],
- [grads_1, var_1]], global_step)
+ opt.apply_gradients([[grads_0, var_0], [grads_1, var_1]],
+ global_step)
]
easgd_hook = opt.make_session_run_hook()
# Creates MonitoredSession
- sess = training.MonitoredTrainingSession(workers[worker_id].target,
- hooks=[easgd_hook])
+ sess = training.MonitoredTrainingSession(
+ workers[worker_id].target, hooks=[easgd_hook])
sessions.append(sess)
graphs.append(graph)
@@ -112,6 +110,7 @@ def _get_workers(num_workers, steps, workers):
class ModelAverageOptimizerTest(test.TestCase):
+
def _run(self, train_op, sess):
sess.run(train_op)
@@ -119,18 +118,18 @@ class ModelAverageOptimizerTest(test.TestCase):
num_workers = 2
steps = 2
num_ps = 1
- cluster, workers, _ = create_local_cluster(num_workers=num_workers,
- num_ps=num_ps)
+ _, workers, _ = create_local_cluster(
+ num_workers=num_workers, num_ps=num_ps)
- sessions, graphs, train_ops = _get_workers(num_workers,
- steps,
- workers)
+ sessions, graphs, train_ops = _get_workers(num_workers, steps, workers)
- var_0 = graphs[0].get_tensor_by_name('v0:0')
- var_1 = graphs[0].get_tensor_by_name('v1:0')
+ var_0 = graphs[0].get_tensor_by_name("v0:0")
+ var_1 = graphs[0].get_tensor_by_name("v1:0")
global_step = training_util.get_global_step(graphs[0])
- global_var_0 = graphs[0].get_tensor_by_name(GLOBAL_VARIABLE_NAME + "/v0:0")
- global_var_1 = graphs[0].get_tensor_by_name(GLOBAL_VARIABLE_NAME + "/v1:0")
+ global_var_0 = graphs[0].get_tensor_by_name(
+ model_average_optimizer.GLOBAL_VARIABLE_NAME + "/v0:0")
+ global_var_1 = graphs[0].get_tensor_by_name(
+ model_average_optimizer.GLOBAL_VARIABLE_NAME + "/v1:0")
# Verify the initialized value.
self.assertAllEqual(0.0, sessions[0].run(var_0))
@@ -150,9 +149,9 @@ class ModelAverageOptimizerTest(test.TestCase):
# iteration 2, global varibale update
thread_0 = self.checkedThread(
- target=self._run, args=(train_ops[0], sessions[0]))
+ target=self._run, args=(train_ops[0], sessions[0]))
thread_1 = self.checkedThread(
- target=self._run, args=(train_ops[1], sessions[1]))
+ target=self._run, args=(train_ops[1], sessions[1]))
thread_0.start()
thread_1.start()
thread_0.join()
@@ -175,20 +174,20 @@ class ModelAverageOptimizerTest(test.TestCase):
def testPS2TasksWithClusterSpecClass(self):
cluster_spec = server_lib.ClusterSpec({
- "ps": ["ps0:2222", "ps1:2222"],
- "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
+ "ps": ["ps0:2222", "ps1:2222"],
+ "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
})
worker_device = "/job:worker/task:0"
- ma_coustom = ModelAverageCustomGetter(
- worker_device=worker_device)
+ ma_coustom = model_average_optimizer.ModelAverageCustomGetter(
+ worker_device=worker_device)
from tensorflow.python.training import device_setter
with ops.device(
device_setter.replica_device_setter(cluster=cluster_spec,
worker_device=worker_device,
ps_device="/job:ps")), \
- variable_scope.variable_scope('', custom_getter=ma_coustom):
+ variable_scope.variable_scope("", custom_getter=ma_coustom):
v = variable_scope.get_variable(initializer=[1, 2], name="v")
- w = variable_scope.get_variable(initializer=[2, 1], name='w')
+ w = variable_scope.get_variable(initializer=[2, 1], name="w")
v_g, w_g = ma_coustom._local_2_global[v], ma_coustom._local_2_global[w]
self.assertDeviceEqual("/job:worker/task:0", v.device)
self.assertDeviceEqual("job:ps/task:0", v_g.device)
@@ -196,5 +195,5 @@ class ModelAverageOptimizerTest(test.TestCase):
self.assertDeviceEqual("job:ps/task:1", w_g.device)
-if __name__ == '__main__':
+if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py b/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py
index 30a2077570..a25de55e18 100644
--- a/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py
+++ b/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py
@@ -53,12 +53,11 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase):
def testPeriodicResampleBasic3D(self):
- input_tensor = numpy.arange(2*2*4).reshape((2, 2, 4))
+ input_tensor = numpy.arange(2 * 2 * 4).reshape((2, 2, 4))
desired_shape = numpy.array([4, 4, None])
- output_tensor = numpy.array([[[0], [2], [4], [6]],
- [[1], [3], [5], [7]],
- [[8], [10], [12], [14]],
- [[9], [11], [13], [15]]])
+ output_tensor = numpy.array([[[0], [2], [4], [6]], [[1], [3], [5], [7]],
+ [[8], [10], [12], [14]], [[9], [11], [13],
+ [15]]])
# NOTE: output_tensor != input_tensor.reshape((4, 4, -1))
with self.test_session():
@@ -72,24 +71,18 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase):
def testPeriodicResampleBasic4D(self):
- input_tensor = numpy.arange(2*2*2*8).reshape((2, 2, 2, 8))
+ input_tensor = numpy.arange(2 * 2 * 2 * 8).reshape((2, 2, 2, 8))
desired_shape = numpy.array([4, 4, 4, None])
- output_tensor = numpy.array([[[[0], [4], [8], [12]],
- [[2], [6], [10], [14]],
- [[16], [20], [24], [28]],
- [[18], [22], [26], [30]]],
- [[[1], [5], [9], [13]],
- [[3], [7], [11], [15]],
- [[17], [21], [25], [29]],
- [[19], [23], [27], [31]]],
- [[[32], [36], [40], [44]],
- [[34], [38], [42], [46]],
- [[48], [52], [56], [60]],
- [[50], [54], [58], [62]]],
- [[[33], [37], [41], [45]],
- [[35], [39], [43], [47]],
- [[49], [53], [57], [61]],
- [[51], [55], [59], [63]]]])
+ output_tensor = numpy.array(
+ [[[[0], [4], [8], [12]], [[2], [6], [10], [14]],
+ [[16], [20], [24], [28]], [[18], [22], [26], [30]]],
+ [[[1], [5], [9], [13]], [[3], [7], [11], [15]], [[17], [21], [25],
+ [29]],
+ [[19], [23], [27],
+ [31]]], [[[32], [36], [40], [44]], [[34], [38], [42], [46]],
+ [[48], [52], [56], [60]], [[50], [54], [58], [62]]],
+ [[[33], [37], [41], [45]], [[35], [39], [43], [47]],
+ [[49], [53], [57], [61]], [[51], [55], [59], [63]]]])
# NOTE: output_tensor != input_tensor.reshape((4, 4, 4, -1))
with self.test_session():
@@ -111,5 +104,5 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase):
periodic_resample(input_tensor, [None, 4, 4]).eval()
-if __name__ == "__main__":
+if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
index 70aaba1728..c780e85d72 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
@@ -53,14 +53,12 @@ class RNNCellTest(test.TestCase):
batch_size = 3
input_size = 4
expected_output = np.array(
- [[0.121753, 0.121753],
- [0.103349, 0.103349],
- [0.100178, 0.100178]],
+ [[0.121753, 0.121753], [0.103349, 0.103349], [0.100178, 0.100178]],
dtype=np.float32)
expected_state = np.array(
- [[0.137523, 0.137523, 0.121753, 0.121753],
- [0.105450, 0.105450, 0.103349, 0.103349],
- [0.100742, 0.100742, 0.100178, 0.100178]],
+ [[0.137523, 0.137523, 0.121753, 0.121753], [
+ 0.105450, 0.105450, 0.103349, 0.103349
+ ], [0.100742, 0.100742, 0.100178, 0.100178]],
dtype=np.float32)
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
@@ -69,14 +67,14 @@ class RNNCellTest(test.TestCase):
output, state = contrib_rnn_cell.CoupledInputForgetGateLSTMCell(
num_units=num_units, forget_bias=1.0, state_is_tuple=False)(x, m)
sess.run([variables.global_variables_initializer()])
- res = sess.run([output, state], {
- x.name:
- np.array([[1., 1., 1., 1.],
- [2., 2., 2., 2.],
- [3., 3., 3., 3.]]),
- m.name:
- 0.1 * np.ones((batch_size, state_size))
- })
+ res = sess.run(
+ [output, state], {
+ x.name:
+ np.array([[1., 1., 1., 1.], [2., 2., 2., 2.],
+ [3., 3., 3., 3.]]),
+ m.name:
+ 0.1 * np.ones((batch_size, state_size))
+ })
# This is a smoke test: Only making sure expected values didn't change.
self.assertEqual(len(res), 2)
self.assertAllClose(res[0], expected_output)
@@ -101,14 +99,14 @@ class RNNCellTest(test.TestCase):
frequency_skip=frequency_skip,
forget_bias=1.0)(x, m)
sess.run([variables.global_variables_initializer()])
- res = sess.run([output, state], {
- x.name:
- np.array([[1., 1., 1., 1.],
- [2., 2., 2., 2.],
- [3., 3., 3., 3.]]),
- m.name:
- 0.1 * np.ones((batch_size, int(state_size * (num_shifts))))
- })
+ res = sess.run(
+ [output, state], {
+ x.name:
+ np.array([[1., 1., 1., 1.], [2., 2., 2., 2.],
+ [3., 3., 3., 3.]]),
+ m.name:
+ 0.1 * np.ones((batch_size, int(state_size * (num_shifts))))
+ })
self.assertEqual(len(res), 2)
# The numbers in results were not calculated, this is mostly just a
# smoke test.
@@ -141,17 +139,14 @@ class RNNCellTest(test.TestCase):
state_is_tuple=True)
inputs = constant_op.constant(
np.array(
- [[1., 1., 1., 1.],
- [2., 2., 2., 2.],
- [3., 3., 3., 3.]],
+ [[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]],
dtype=np.float32),
dtype=dtypes.float32)
state_value = constant_op.constant(
- 0.1 * np.ones(
- (batch_size, num_units), dtype=np.float32),
+ 0.1 * np.ones((batch_size, num_units), dtype=np.float32),
dtype=dtypes.float32)
- init_state = cell.state_tuple_type(
- *([state_value, state_value] * num_shifts))
+ init_state = cell.state_tuple_type(*(
+ [state_value, state_value] * num_shifts))
output, state = cell(inputs, init_state)
sess.run([variables.global_variables_initializer()])
res = sess.run([output, state])
@@ -198,11 +193,10 @@ class RNNCellTest(test.TestCase):
dtype=np.float32),
dtype=dtypes.float32)
state_value = constant_op.constant(
- 0.1 * np.ones(
- (batch_size, num_units), dtype=np.float32),
+ 0.1 * np.ones((batch_size, num_units), dtype=np.float32),
dtype=dtypes.float32)
- init_state = cell.state_tuple_type(
- *([state_value, state_value] * total_blocks))
+ init_state = cell.state_tuple_type(*(
+ [state_value, state_value] * total_blocks))
output, state = cell(inputs, init_state)
sess.run([variables.global_variables_initializer()])
res = sess.run([output, state])
@@ -230,20 +224,28 @@ class RNNCellTest(test.TestCase):
frequency_skip = 1
num_shifts = int((input_size - feature_size) / frequency_skip + 1)
expected_output = np.array(
- [[0.416383, 0.416383, 0.403238, 0.403238, 0.524020, 0.524020,
- 0.565425, 0.565425, 0.557865, 0.557865, 0.609699, 0.609699],
- [0.627331, 0.627331, 0.622393, 0.622393, 0.688342, 0.688342,
- 0.708078, 0.708078, 0.694245, 0.694245, 0.715171, 0.715171],
- [0.711050, 0.711050, 0.709197, 0.709197, 0.736533, 0.736533,
- 0.744264, 0.744264, 0.737390, 0.737390, 0.745250, 0.745250]],
+ [[
+ 0.416383, 0.416383, 0.403238, 0.403238, 0.524020, 0.524020,
+ 0.565425, 0.565425, 0.557865, 0.557865, 0.609699, 0.609699
+ ], [
+ 0.627331, 0.627331, 0.622393, 0.622393, 0.688342, 0.688342,
+ 0.708078, 0.708078, 0.694245, 0.694245, 0.715171, 0.715171
+ ], [
+ 0.711050, 0.711050, 0.709197, 0.709197, 0.736533, 0.736533,
+ 0.744264, 0.744264, 0.737390, 0.737390, 0.745250, 0.745250
+ ]],
dtype=np.float32)
expected_state = np.array(
- [[0.625556, 0.625556, 0.416383, 0.416383, 0.759134, 0.759134,
- 0.524020, 0.524020, 0.798795, 0.798795, 0.557865, 0.557865],
- [0.875488, 0.875488, 0.627331, 0.627331, 0.936432, 0.936432,
- 0.688342, 0.688342, 0.941961, 0.941961, 0.694245, 0.694245],
- [0.957327, 0.957327, 0.711050, 0.711050, 0.979522, 0.979522,
- 0.736533, 0.736533, 0.980245, 0.980245, 0.737390, 0.737390]],
+ [[
+ 0.625556, 0.625556, 0.416383, 0.416383, 0.759134, 0.759134,
+ 0.524020, 0.524020, 0.798795, 0.798795, 0.557865, 0.557865
+ ], [
+ 0.875488, 0.875488, 0.627331, 0.627331, 0.936432, 0.936432,
+ 0.688342, 0.688342, 0.941961, 0.941961, 0.694245, 0.694245
+ ], [
+ 0.957327, 0.957327, 0.711050, 0.711050, 0.979522, 0.979522,
+ 0.736533, 0.736533, 0.980245, 0.980245, 0.737390, 0.737390
+ ]],
dtype=np.float32)
for state_is_tuple in [False, True]:
with self.test_session() as sess:
@@ -259,18 +261,16 @@ class RNNCellTest(test.TestCase):
couple_input_forget_gates=True,
state_is_tuple=state_is_tuple)
inputs = constant_op.constant(
- np.array([[1., 1., 1., 1.],
- [2., 2., 2., 2.],
- [3., 3., 3., 3.]],
- dtype=np.float32),
+ np.array(
+ [[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]],
+ dtype=np.float32),
dtype=dtypes.float32)
if state_is_tuple:
state_value = constant_op.constant(
- 0.1 * np.ones(
- (batch_size, num_units), dtype=np.float32),
+ 0.1 * np.ones((batch_size, num_units), dtype=np.float32),
dtype=dtypes.float32)
- init_state = cell.state_tuple_type(
- *([state_value, state_value] * num_shifts))
+ init_state = cell.state_tuple_type(*(
+ [state_value, state_value] * num_shifts))
else:
init_state = constant_op.constant(
0.1 * np.ones(
@@ -302,32 +302,40 @@ class RNNCellTest(test.TestCase):
frequency_skip = 1
num_shifts = int((input_size - feature_size) / frequency_skip + 1)
expected_output = np.array(
- [[0.464130, 0.464130, 0.419165, 0.419165, 0.593283, 0.593283,
- 0.738350, 0.738350, 0.661638, 0.661638, 0.866774, 0.866774,
- 0.520789, 0.520789, 0.476968, 0.476968, 0.604341, 0.604341,
- 0.760207, 0.760207, 0.635773, 0.635773, 0.850218, 0.850218],
- [0.669636, 0.669636, 0.628966, 0.628966, 0.736057, 0.736057,
- 0.895927, 0.895927, 0.755559, 0.755559, 0.954359, 0.954359,
- 0.692621, 0.692621, 0.652363, 0.652363, 0.737517, 0.737517,
- 0.899558, 0.899558, 0.745984, 0.745984, 0.946840, 0.946840],
- [0.751109, 0.751109, 0.711716, 0.711716, 0.778357, 0.778357,
- 0.940779, 0.940779, 0.784530, 0.784530, 0.980604, 0.980604,
- 0.759940, 0.759940, 0.720652, 0.720652, 0.778552, 0.778552,
- 0.941606, 0.941606, 0.781035, 0.781035, 0.977731, 0.977731]],
+ [[
+ 0.464130, 0.464130, 0.419165, 0.419165, 0.593283, 0.593283,
+ 0.738350, 0.738350, 0.661638, 0.661638, 0.866774, 0.866774,
+ 0.520789, 0.520789, 0.476968, 0.476968, 0.604341, 0.604341,
+ 0.760207, 0.760207, 0.635773, 0.635773, 0.850218, 0.850218
+ ], [
+ 0.669636, 0.669636, 0.628966, 0.628966, 0.736057, 0.736057,
+ 0.895927, 0.895927, 0.755559, 0.755559, 0.954359, 0.954359,
+ 0.692621, 0.692621, 0.652363, 0.652363, 0.737517, 0.737517,
+ 0.899558, 0.899558, 0.745984, 0.745984, 0.946840, 0.946840
+ ], [
+ 0.751109, 0.751109, 0.711716, 0.711716, 0.778357, 0.778357,
+ 0.940779, 0.940779, 0.784530, 0.784530, 0.980604, 0.980604,
+ 0.759940, 0.759940, 0.720652, 0.720652, 0.778552, 0.778552,
+ 0.941606, 0.941606, 0.781035, 0.781035, 0.977731, 0.977731
+ ]],
dtype=np.float32)
expected_state = np.array(
- [[0.710660, 0.710660, 0.464130, 0.464130, 0.877293, 0.877293,
- 0.593283, 0.593283, 0.958505, 0.958505, 0.661638, 0.661638,
- 0.785405, 0.785405, 0.520789, 0.520789, 0.890836, 0.890836,
- 0.604341, 0.604341, 0.928512, 0.928512, 0.635773, 0.635773],
- [0.967579, 0.967579, 0.669636, 0.669636, 1.038811, 1.038811,
- 0.736057, 0.736057, 1.058201, 1.058201, 0.755559, 0.755559,
- 0.993088, 0.993088, 0.692621, 0.692621, 1.040288, 1.040288,
- 0.737517, 0.737517, 1.048773, 1.048773, 0.745984, 0.745984],
- [1.053842, 1.053842, 0.751109, 0.751109, 1.079919, 1.079919,
- 0.778357, 0.778357, 1.085620, 1.085620, 0.784530, 0.784530,
- 1.062455, 1.062455, 0.759940, 0.759940, 1.080101, 1.080101,
- 0.778552, 0.778552, 1.082402, 1.082402, 0.781035, 0.781035]],
+ [[
+ 0.710660, 0.710660, 0.464130, 0.464130, 0.877293, 0.877293,
+ 0.593283, 0.593283, 0.958505, 0.958505, 0.661638, 0.661638,
+ 0.785405, 0.785405, 0.520789, 0.520789, 0.890836, 0.890836,
+ 0.604341, 0.604341, 0.928512, 0.928512, 0.635773, 0.635773
+ ], [
+ 0.967579, 0.967579, 0.669636, 0.669636, 1.038811, 1.038811,
+ 0.736057, 0.736057, 1.058201, 1.058201, 0.755559, 0.755559,
+ 0.993088, 0.993088, 0.692621, 0.692621, 1.040288, 1.040288,
+ 0.737517, 0.737517, 1.048773, 1.048773, 0.745984, 0.745984
+ ], [
+ 1.053842, 1.053842, 0.751109, 0.751109, 1.079919, 1.079919,
+ 0.778357, 0.778357, 1.085620, 1.085620, 0.784530, 0.784530,
+ 1.062455, 1.062455, 0.759940, 0.759940, 1.080101, 1.080101,
+ 0.778552, 0.778552, 1.082402, 1.082402, 0.781035, 0.781035
+ ]],
dtype=np.float32)
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
@@ -339,17 +347,16 @@ class RNNCellTest(test.TestCase):
forget_bias=1.0,
num_frequency_blocks=[num_shifts])
inputs = constant_op.constant(
- np.array([[1.0, 1.1, 1.2, 1.3],
- [2.0, 2.1, 2.2, 2.3],
- [3.0, 3.1, 3.2, 3.3]],
- dtype=np.float32),
+ np.array(
+ [[1.0, 1.1, 1.2, 1.3], [2.0, 2.1, 2.2, 2.3],
+ [3.0, 3.1, 3.2, 3.3]],
+ dtype=np.float32),
dtype=dtypes.float32)
state_value = constant_op.constant(
- 0.1 * np.ones(
- (batch_size, num_units), dtype=np.float32),
+ 0.1 * np.ones((batch_size, num_units), dtype=np.float32),
dtype=dtypes.float32)
- init_state = cell.state_tuple_type(
- *([state_value, state_value] * num_shifts * 2))
+ init_state = cell.state_tuple_type(*(
+ [state_value, state_value] * num_shifts * 2))
output, state = cell(inputs, init_state)
sess.run([variables.global_variables_initializer()])
res = sess.run([output, state])
@@ -375,32 +382,40 @@ class RNNCellTest(test.TestCase):
frequency_skip = 1
num_shifts = int((input_size - feature_size) / frequency_skip + 1)
expected_output = np.array(
- [[0.464130, 0.464130, 0.419165, 0.419165, 0.593283, 0.593283,
- 0.738350, 0.738350, 0.661638, 0.661638, 0.866774, 0.866774,
- 0.322645, 0.322645, 0.276068, 0.276068, 0.584654, 0.584654,
- 0.690292, 0.690292, 0.640446, 0.640446, 0.840071, 0.840071],
- [0.669636, 0.669636, 0.628966, 0.628966, 0.736057, 0.736057,
- 0.895927, 0.895927, 0.755559, 0.755559, 0.954359, 0.954359,
- 0.493625, 0.493625, 0.449236, 0.449236, 0.730828, 0.730828,
- 0.865996, 0.865996, 0.749429, 0.749429, 0.944958, 0.944958],
- [0.751109, 0.751109, 0.711716, 0.711716, 0.778357, 0.778357,
- 0.940779, 0.940779, 0.784530, 0.784530, 0.980604, 0.980604,
- 0.608587, 0.608587, 0.566683, 0.566683, 0.777345, 0.777345,
- 0.925820, 0.925820, 0.782597, 0.782597, 0.976858, 0.976858]],
+ [[
+ 0.464130, 0.464130, 0.419165, 0.419165, 0.593283, 0.593283,
+ 0.738350, 0.738350, 0.661638, 0.661638, 0.866774, 0.866774,
+ 0.322645, 0.322645, 0.276068, 0.276068, 0.584654, 0.584654,
+ 0.690292, 0.690292, 0.640446, 0.640446, 0.840071, 0.840071
+ ], [
+ 0.669636, 0.669636, 0.628966, 0.628966, 0.736057, 0.736057,
+ 0.895927, 0.895927, 0.755559, 0.755559, 0.954359, 0.954359,
+ 0.493625, 0.493625, 0.449236, 0.449236, 0.730828, 0.730828,
+ 0.865996, 0.865996, 0.749429, 0.749429, 0.944958, 0.944958
+ ], [
+ 0.751109, 0.751109, 0.711716, 0.711716, 0.778357, 0.778357,
+ 0.940779, 0.940779, 0.784530, 0.784530, 0.980604, 0.980604,
+ 0.608587, 0.608587, 0.566683, 0.566683, 0.777345, 0.777345,
+ 0.925820, 0.925820, 0.782597, 0.782597, 0.976858, 0.976858
+ ]],
dtype=np.float32)
expected_state = np.array(
- [[0.710660, 0.710660, 0.464130, 0.464130, 0.877293, 0.877293,
- 0.593283, 0.593283, 0.958505, 0.958505, 0.661638, 0.661638,
- 0.516575, 0.516575, 0.322645, 0.322645, 0.866628, 0.866628,
- 0.584654, 0.584654, 0.934002, 0.934002, 0.640446, 0.640446],
- [0.967579, 0.967579, 0.669636, 0.669636, 1.038811, 1.038811,
- 0.736057, 0.736057, 1.058201, 1.058201, 0.755559, 0.755559,
- 0.749836, 0.749836, 0.493625, 0.493625, 1.033488, 1.033488,
- 0.730828, 0.730828, 1.052186, 1.052186, 0.749429, 0.749429],
- [1.053842, 1.053842, 0.751109, 0.751109, 1.079919, 1.079919,
- 0.778357, 0.778357, 1.085620, 1.085620, 0.784530, 0.784530,
- 0.895999, 0.895999, 0.608587, 0.608587, 1.078978, 1.078978,
- 0.777345, 0.777345, 1.083843, 1.083843, 0.782597, 0.782597]],
+ [[
+ 0.710660, 0.710660, 0.464130, 0.464130, 0.877293, 0.877293,
+ 0.593283, 0.593283, 0.958505, 0.958505, 0.661638, 0.661638,
+ 0.516575, 0.516575, 0.322645, 0.322645, 0.866628, 0.866628,
+ 0.584654, 0.584654, 0.934002, 0.934002, 0.640446, 0.640446
+ ], [
+ 0.967579, 0.967579, 0.669636, 0.669636, 1.038811, 1.038811,
+ 0.736057, 0.736057, 1.058201, 1.058201, 0.755559, 0.755559,
+ 0.749836, 0.749836, 0.493625, 0.493625, 1.033488, 1.033488,
+ 0.730828, 0.730828, 1.052186, 1.052186, 0.749429, 0.749429
+ ], [
+ 1.053842, 1.053842, 0.751109, 0.751109, 1.079919, 1.079919,
+ 0.778357, 0.778357, 1.085620, 1.085620, 0.784530, 0.784530,
+ 0.895999, 0.895999, 0.608587, 0.608587, 1.078978, 1.078978,
+ 0.777345, 0.777345, 1.083843, 1.083843, 0.782597, 0.782597
+ ]],
dtype=np.float32)
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
@@ -413,17 +428,16 @@ class RNNCellTest(test.TestCase):
num_frequency_blocks=[num_shifts],
backward_slice_offset=1)
inputs = constant_op.constant(
- np.array([[1.0, 1.1, 1.2, 1.3],
- [2.0, 2.1, 2.2, 2.3],
- [3.0, 3.1, 3.2, 3.3]],
- dtype=np.float32),
+ np.array(
+ [[1.0, 1.1, 1.2, 1.3], [2.0, 2.1, 2.2, 2.3],
+ [3.0, 3.1, 3.2, 3.3]],
+ dtype=np.float32),
dtype=dtypes.float32)
state_value = constant_op.constant(
- 0.1 * np.ones(
- (batch_size, num_units), dtype=np.float32),
+ 0.1 * np.ones((batch_size, num_units), dtype=np.float32),
dtype=dtypes.float32)
- init_state = cell.state_tuple_type(
- *([state_value, state_value] * num_shifts * 2))
+ init_state = cell.state_tuple_type(*(
+ [state_value, state_value] * num_shifts * 2))
output, state = cell(inputs, init_state)
sess.run([variables.global_variables_initializer()])
res = sess.run([output, state])
@@ -474,8 +488,8 @@ class RNNCellTest(test.TestCase):
for state_is_tuple in [False, True]:
with ops.Graph().as_default():
with self.test_session() as sess:
- with variable_scope.variable_scope("state_is_tuple_" + str(
- state_is_tuple)):
+ with variable_scope.variable_scope(
+ "state_is_tuple_" + str(state_is_tuple)):
lstm_cell = rnn_cell.BasicLSTMCell(
num_units, state_is_tuple=state_is_tuple)
cell = contrib_rnn_cell.AttentionCellWrapper(
@@ -525,16 +539,15 @@ class RNNCellTest(test.TestCase):
for state_is_tuple in [False, True]:
with ops.Graph().as_default():
with self.test_session() as sess:
- with variable_scope.variable_scope("state_is_tuple_" + str(
- state_is_tuple)):
+ with variable_scope.variable_scope(
+ "state_is_tuple_" + str(state_is_tuple)):
lstm_cell = rnn_cell.BasicLSTMCell(
num_units, state_is_tuple=state_is_tuple)
cell = contrib_rnn_cell.AttentionCellWrapper(
lstm_cell, attn_length, state_is_tuple=state_is_tuple)
if state_is_tuple:
zeros = constant_op.constant(
- 0.1 * np.ones(
- [batch_size, num_units], dtype=np.float32),
+ 0.1 * np.ones([batch_size, num_units], dtype=np.float32),
dtype=dtypes.float32)
attn_state_zeros = constant_op.constant(
0.1 * np.ones(
@@ -579,22 +592,25 @@ class RNNCellTest(test.TestCase):
[1.018088, 0.378983, -0.572179, 0.268591]],
dtype=np.float32)
expected_state = np.array(
- [[0.74946702, 0.34681597, 0.26474735, 1.06485605, 0.38465962,
- 0.11420801, 0.10272158, 0.30925757, 0.63899988, 0.7181077,
- 0.47534478, 0.33715725, 0.58086717, 0.49446869, 0.7641536,
- 0.12814975, 0.92231739, 0.89857256, 0.21889746, 0.38442063,
- 0.53481543, 0.8876909, 0.45823169, 0.5905602, 0.78038228,
- 0.56501579, 0.03971386, 0.09870267, 0.8074435, 0.66821432,
- 0.99211812, 0.12295902, 1.14606023, 0.34370938, -0.79251152,
- 0.51843399],
- [0.5179342, 0.48682183, -0.25426468, 0.96810579, 0.28809637,
- 0.13607743, -0.11446252, 0.26792109, 0.78047138, 0.63460857,
- 0.49122369, 0.52007174, 0.73000264, 0.66986895, 0.73576689,
- 0.86301267, 0.87887371, 0.35185754, 0.93417215, 0.64732957,
- 0.63173044, 0.66627824, 0.53644657, 0.20477486, 0.98458421,
- 0.38277245, 0.03746676, 0.92510188, 0.57714164, 0.84932971,
- 0.36127412, 0.12125921, 1.1362772, 0.34361625, -0.78150457,
- 0.70582712]],
+ [[
+ 0.74946702, 0.34681597, 0.26474735, 1.06485605, 0.38465962,
+ 0.11420801, 0.10272158, 0.30925757, 0.63899988, 0.7181077,
+ 0.47534478, 0.33715725, 0.58086717, 0.49446869, 0.7641536,
+ 0.12814975, 0.92231739, 0.89857256, 0.21889746, 0.38442063,
+ 0.53481543, 0.8876909, 0.45823169, 0.5905602, 0.78038228,
+ 0.56501579, 0.03971386, 0.09870267, 0.8074435, 0.66821432,
+ 0.99211812, 0.12295902, 1.14606023, 0.34370938, -0.79251152,
+ 0.51843399
+ ], [
+ 0.5179342, 0.48682183, -0.25426468, 0.96810579, 0.28809637,
+ 0.13607743, -0.11446252, 0.26792109, 0.78047138, 0.63460857,
+ 0.49122369, 0.52007174, 0.73000264, 0.66986895, 0.73576689,
+ 0.86301267, 0.87887371, 0.35185754, 0.93417215, 0.64732957,
+ 0.63173044, 0.66627824, 0.53644657, 0.20477486, 0.98458421,
+ 0.38277245, 0.03746676, 0.92510188, 0.57714164, 0.84932971,
+ 0.36127412, 0.12125921, 1.1362772, 0.34361625, -0.78150457,
+ 0.70582712
+ ]],
dtype=np.float32)
seed = 12345
random_seed.set_random_seed(seed)
@@ -602,7 +618,8 @@ class RNNCellTest(test.TestCase):
for state_is_tuple in [False, True]:
with session.Session() as sess:
with variable_scope.variable_scope(
- "state_is_tuple", reuse=state_is_tuple,
+ "state_is_tuple",
+ reuse=state_is_tuple,
initializer=init_ops.glorot_uniform_initializer()):
lstm_cell = rnn_cell.BasicLSTMCell(
num_units, state_is_tuple=state_is_tuple)
@@ -646,36 +663,31 @@ class RNNCellTest(test.TestCase):
def testNASCell(self):
num_units = 6
batch_size = 3
- expected_output = np.array([[0.576751, 0.576751, 0.576751, 0.576751,
- 0.576751, 0.576751],
- [0.618936, 0.618936, 0.618936, 0.618936,
- 0.618936, 0.618936],
- [0.627393, 0.627393, 0.627393, 0.627393,
- 0.627393, 0.627393]])
- expected_state = np.array([[0.71579772, 0.71579772, 0.71579772, 0.71579772,
- 0.71579772, 0.71579772, 0.57675087, 0.57675087,
- 0.57675087, 0.57675087, 0.57675087, 0.57675087],
- [0.78041625, 0.78041625, 0.78041625, 0.78041625,
- 0.78041625, 0.78041625, 0.6189357, 0.6189357,
- 0.61893570, 0.6189357, 0.6189357, 0.6189357],
- [0.79457647, 0.79457647, 0.79457647, 0.79457647,
- 0.79457653, 0.79457653, 0.62739348, 0.62739348,
- 0.62739348, 0.62739348, 0.62739348, 0.62739348]
- ])
+ expected_output = np.array(
+ [[0.576751, 0.576751, 0.576751, 0.576751, 0.576751, 0.576751],
+ [0.618936, 0.618936, 0.618936, 0.618936, 0.618936, 0.618936],
+ [0.627393, 0.627393, 0.627393, 0.627393, 0.627393, 0.627393]])
+ expected_state = np.array([[
+ 0.71579772, 0.71579772, 0.71579772, 0.71579772, 0.71579772, 0.71579772,
+ 0.57675087, 0.57675087, 0.57675087, 0.57675087, 0.57675087, 0.57675087
+ ], [
+ 0.78041625, 0.78041625, 0.78041625, 0.78041625, 0.78041625, 0.78041625,
+ 0.6189357, 0.6189357, 0.61893570, 0.6189357, 0.6189357, 0.6189357
+ ], [
+ 0.79457647, 0.79457647, 0.79457647, 0.79457647, 0.79457653, 0.79457653,
+ 0.62739348, 0.62739348, 0.62739348, 0.62739348, 0.62739348, 0.62739348
+ ]])
with self.test_session() as sess:
with variable_scope.variable_scope(
- "nas_test",
- initializer=init_ops.constant_initializer(0.5)):
+ "nas_test", initializer=init_ops.constant_initializer(0.5)):
cell = contrib_rnn_cell.NASCell(num_units=num_units)
inputs = constant_op.constant(
- np.array([[1., 1., 1., 1.],
- [2., 2., 2., 2.],
- [3., 3., 3., 3.]],
- dtype=np.float32),
+ np.array(
+ [[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]],
+ dtype=np.float32),
dtype=dtypes.float32)
state_value = constant_op.constant(
- 0.1 * np.ones(
- (batch_size, num_units), dtype=np.float32),
+ 0.1 * np.ones((batch_size, num_units), dtype=np.float32),
dtype=dtypes.float32)
init_state = rnn_cell.LSTMStateTuple(state_value, state_value)
output, state = cell(inputs, init_state)
@@ -699,39 +711,34 @@ class RNNCellTest(test.TestCase):
num_units = 6
batch_size = 3
num_proj = 5
- expected_output = np.array([[1.697418, 1.697418, 1.697418, 1.697418,
- 1.697418],
- [1.840037, 1.840037, 1.840037, 1.840037,
- 1.840037],
- [1.873985, 1.873985, 1.873985, 1.873985,
- 1.873985]])
- expected_state = np.array([[0.69855207, 0.69855207, 0.69855207, 0.69855207,
- 0.69855207, 0.69855207, 1.69741797, 1.69741797,
- 1.69741797, 1.69741797, 1.69741797],
- [0.77073824, 0.77073824, 0.77073824, 0.77073824,
- 0.77073824, 0.77073824, 1.84003687, 1.84003687,
- 1.84003687, 1.84003687, 1.84003687],
- [0.78973997, 0.78973997, 0.78973997, 0.78973997,
- 0.78973997, 0.78973997, 1.87398517, 1.87398517,
- 1.87398517, 1.87398517, 1.87398517]])
+ expected_output = np.array(
+ [[1.697418, 1.697418, 1.697418, 1.697418,
+ 1.697418], [1.840037, 1.840037, 1.840037, 1.840037, 1.840037],
+ [1.873985, 1.873985, 1.873985, 1.873985, 1.873985]])
+ expected_state = np.array([[
+ 0.69855207, 0.69855207, 0.69855207, 0.69855207, 0.69855207, 0.69855207,
+ 1.69741797, 1.69741797, 1.69741797, 1.69741797, 1.69741797
+ ], [
+ 0.77073824, 0.77073824, 0.77073824, 0.77073824, 0.77073824, 0.77073824,
+ 1.84003687, 1.84003687, 1.84003687, 1.84003687, 1.84003687
+ ], [
+ 0.78973997, 0.78973997, 0.78973997, 0.78973997, 0.78973997, 0.78973997,
+ 1.87398517, 1.87398517, 1.87398517, 1.87398517, 1.87398517
+ ]])
with self.test_session() as sess:
with variable_scope.variable_scope(
- "nas_proj_test",
- initializer=init_ops.constant_initializer(0.5)):
+ "nas_proj_test", initializer=init_ops.constant_initializer(0.5)):
cell = contrib_rnn_cell.NASCell(num_units=num_units, num_proj=num_proj)
inputs = constant_op.constant(
- np.array([[1., 1., 1., 1.],
- [2., 2., 2., 2.],
- [3., 3., 3., 3.]],
- dtype=np.float32),
+ np.array(
+ [[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]],
+ dtype=np.float32),
dtype=dtypes.float32)
state_value_c = constant_op.constant(
- 0.1 * np.ones(
- (batch_size, num_units), dtype=np.float32),
+ 0.1 * np.ones((batch_size, num_units), dtype=np.float32),
dtype=dtypes.float32)
state_value_h = constant_op.constant(
- 0.1 * np.ones(
- (batch_size, num_proj), dtype=np.float32),
+ 0.1 * np.ones((batch_size, num_proj), dtype=np.float32),
dtype=dtypes.float32)
init_state = rnn_cell.LSTMStateTuple(state_value_c, state_value_h)
output, state = cell(inputs, init_state)
@@ -755,24 +762,20 @@ class RNNCellTest(test.TestCase):
num_units = 2
batch_size = 3
expected_state_and_output = np.array(
- [[0.13752282, 0.13752282],
- [0.10545051, 0.10545051],
+ [[0.13752282, 0.13752282], [0.10545051, 0.10545051],
[0.10074195, 0.10074195]],
dtype=np.float32)
with self.test_session() as sess:
with variable_scope.variable_scope(
- "ugrnn_cell_test",
- initializer=init_ops.constant_initializer(0.5)):
+ "ugrnn_cell_test", initializer=init_ops.constant_initializer(0.5)):
cell = contrib_rnn_cell.UGRNNCell(num_units=num_units)
inputs = constant_op.constant(
- np.array([[1., 1., 1., 1.],
- [2., 2., 2., 2.],
- [3., 3., 3., 3.]],
- dtype=np.float32),
+ np.array(
+ [[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]],
+ dtype=np.float32),
dtype=dtypes.float32)
init_state = constant_op.constant(
- 0.1 * np.ones(
- (batch_size, num_units), dtype=np.float32),
+ 0.1 * np.ones((batch_size, num_units), dtype=np.float32),
dtype=dtypes.float32)
output, state = cell(inputs, init_state)
sess.run([variables.global_variables_initializer()])
@@ -786,13 +789,11 @@ class RNNCellTest(test.TestCase):
num_units = 2
batch_size = 3
expected_state = np.array(
- [[0.13752282, 0.13752282],
- [0.10545051, 0.10545051],
+ [[0.13752282, 0.13752282], [0.10545051, 0.10545051],
[0.10074195, 0.10074195]],
dtype=np.float32)
expected_output = np.array(
- [[2.00431061, 2.00431061],
- [4.00060606, 4.00060606],
+ [[2.00431061, 2.00431061], [4.00060606, 4.00060606],
[6.00008249, 6.00008249]],
dtype=np.float32)
with self.test_session() as sess:
@@ -802,14 +803,12 @@ class RNNCellTest(test.TestCase):
cell = contrib_rnn_cell.IntersectionRNNCell(
num_units=num_units, num_in_proj=num_units)
inputs = constant_op.constant(
- np.array([[1., 1., 1., 1.],
- [2., 2., 2., 2.],
- [3., 3., 3., 3.]],
- dtype=np.float32),
+ np.array(
+ [[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]],
+ dtype=np.float32),
dtype=dtypes.float32)
init_state = constant_op.constant(
- 0.1 * np.ones(
- (batch_size, num_units), dtype=np.float32),
+ 0.1 * np.ones((batch_size, num_units), dtype=np.float32),
dtype=dtypes.float32)
output, state = cell(inputs, init_state)
sess.run([variables.global_variables_initializer()])
@@ -824,19 +823,17 @@ class RNNCellTest(test.TestCase):
batch_size = 3
cell = contrib_rnn_cell.IntersectionRNNCell(num_units=num_units)
inputs = constant_op.constant(
- np.array([[1., 1., 1., 1.],
- [2., 2., 2., 2.],
- [3., 3., 3., 3.]],
- dtype=np.float32),
+ np.array(
+ [[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]],
+ dtype=np.float32),
dtype=dtypes.float32)
init_state = constant_op.constant(
- 0.1 * np.ones(
- (batch_size, num_units), dtype=np.float32),
+ 0.1 * np.ones((batch_size, num_units), dtype=np.float32),
dtype=dtypes.float32)
- with self.assertRaisesRegexp(
- ValueError, "Must have input size == output size for "
- "Intersection RNN. To fix, num_in_proj should "
- "be set to num_units at cell init."):
+ with self.assertRaisesRegexp(ValueError,
+ "Must have input size == output size for "
+ "Intersection RNN. To fix, num_in_proj should "
+ "be set to num_units at cell init."):
cell(inputs, init_state)
def testPhasedLSTMCell(self):
@@ -845,13 +842,11 @@ class RNNCellTest(test.TestCase):
batch_size = 3
input_size = 4
expected_state_c = np.array(
- [[6.450831e-04, 4.697885e-04],
- [9.862894e-05, 7.212213e-04],
+ [[6.450831e-04, 4.697885e-04], [9.862894e-05, 7.212213e-04],
[4.401947e-04, 9.143004e-04]],
dtype=np.float32)
expected_state_h = np.array(
- [[4.621217e-04, 3.365449e-04],
- [7.438179e-05, 5.439147e-04],
+ [[4.621217e-04, 3.365449e-04], [7.438179e-05, 5.439147e-04],
[3.347936e-04, 6.953785e-04]],
dtype=np.float32)
with variable_scope.variable_scope(
@@ -864,14 +859,14 @@ class RNNCellTest(test.TestCase):
output, state = contrib_rnn_cell.PhasedLSTMCell(num_units=num_units)(
(t, x), state0)
sess.run([variables.global_variables_initializer()])
- res = sess.run([output, state], {
- t.name:
- np.array([[1.], [2.], [3.]]),
- x.name:
- np.array([[1., 1., 1., 1.],
- [2., 2., 2., 2.],
- [3., 3., 3., 3.]]),
- })
+ res = sess.run(
+ [output, state], {
+ t.name:
+ np.array([[1.], [2.], [3.]]),
+ x.name:
+ np.array([[1., 1., 1., 1.], [2., 2., 2., 2.],
+ [3., 3., 3., 3.]]),
+ })
# This is a smoke test, making sure expected values are unchanged.
self.assertEqual(len(res), 2)
self.assertAllClose(res[0], res[1].h)
@@ -880,36 +875,32 @@ class RNNCellTest(test.TestCase):
def testConv1DLSTMCell(self):
with self.test_session() as sess:
- shape = [2,1]
+ shape = [2, 1]
filter_size = [3]
num_features = 1
batch_size = 2
expected_state_c = np.array(
- [[[1.4375670191], [1.4375670191]],
- [[2.7542609292], [2.7542609292]]],
+ [[[1.4375670191], [1.4375670191]], [[2.7542609292], [2.7542609292]]],
dtype=np.float32)
expected_state_h = np.array(
- [[[0.6529865603], [0.6529865603]],
- [[0.8736877431], [0.8736877431]]],
+ [[[0.6529865603], [0.6529865603]], [[0.8736877431], [0.8736877431]]],
dtype=np.float32)
with variable_scope.variable_scope(
- "root", initializer=init_ops.constant_initializer(1.0/2.0)):
+ "root", initializer=init_ops.constant_initializer(1.0 / 2.0)):
x = array_ops.placeholder(dtypes.float32, [None, None, 1])
- cell = contrib_rnn_cell.Conv1DLSTMCell(input_shape=shape,
- kernel_shape=filter_size,
- output_channels=num_features)
+ cell = contrib_rnn_cell.Conv1DLSTMCell(
+ input_shape=shape,
+ kernel_shape=filter_size,
+ output_channels=num_features)
hidden = cell.zero_state(array_ops.shape(x)[0], dtypes.float32)
output, state = cell(x, hidden)
sess.run([variables.global_variables_initializer()])
- res = sess.run([output, state], {
- hidden[0].name:
- np.array([[[1.],[1.]],
- [[2.],[2.]]]),
- x.name:
- np.array([[[1.],[1.]],
- [[2.],[2.]]]),
- })
+ res = sess.run(
+ [output, state], {
+ hidden[0].name: np.array([[[1.], [1.]], [[2.], [2.]]]),
+ x.name: np.array([[[1.], [1.]], [[2.], [2.]]]),
+ })
# This is a smoke test, making sure expected values are unchanged.
self.assertEqual(len(res), 2)
self.assertAllClose(res[0], res[1].h)
@@ -918,44 +909,40 @@ class RNNCellTest(test.TestCase):
def testConv2DLSTMCell(self):
with self.test_session() as sess:
- shape = [2,2,1]
- filter_size = [3,3]
+ shape = [2, 2, 1]
+ filter_size = [3, 3]
num_features = 1
batch_size = 2
expected_state_c = np.array(
- [[[[1.4375670191], [1.4375670191]],
- [[1.4375670191], [1.4375670191]]],
- [[[2.7542609292], [2.7542609292]],
- [[2.7542609292], [2.7542609292]]]],
+ [[[[1.4375670191], [1.4375670191]], [[1.4375670191], [1.4375670191]]],
+ [[[2.7542609292], [2.7542609292]], [[2.7542609292], [2.7542609292]]
+ ]],
dtype=np.float32)
expected_state_h = np.array(
- [[[[0.6529865603], [0.6529865603]],
- [[0.6529865603], [0.6529865603]]],
- [[[0.8736877431], [0.8736877431]],
- [[0.8736877431], [0.8736877431]]]],
+ [[[[0.6529865603], [0.6529865603]], [[0.6529865603], [0.6529865603]]],
+ [[[0.8736877431], [0.8736877431]], [[0.8736877431], [0.8736877431]]
+ ]],
dtype=np.float32)
with variable_scope.variable_scope(
- "root", initializer=init_ops.constant_initializer(1.0/4.0)):
+ "root", initializer=init_ops.constant_initializer(1.0 / 4.0)):
x = array_ops.placeholder(dtypes.float32, [None, None, None, 1])
- cell = contrib_rnn_cell.Conv2DLSTMCell(input_shape=shape,
- kernel_shape=filter_size,
- output_channels=num_features)
+ cell = contrib_rnn_cell.Conv2DLSTMCell(
+ input_shape=shape,
+ kernel_shape=filter_size,
+ output_channels=num_features)
hidden = cell.zero_state(array_ops.shape(x)[0], dtypes.float32)
output, state = cell(x, hidden)
sess.run([variables.global_variables_initializer()])
- res = sess.run([output, state], {
- hidden[0].name:
- np.array([[[[1.],[1.]],
- [[1.],[1.]]],
- [[[2.],[2.]],
- [[2.],[2.]]]]),
- x.name:
- np.array([[[[1.],[1.]],
- [[1.],[1.]]],
- [[[2.],[2.]],
- [[2.],[2.]]]]),
- })
+ res = sess.run(
+ [output, state], {
+ hidden[0].name:
+ np.array([[[[1.], [1.]], [[1.], [1.]]], [[[2.], [2.]],
+ [[2.], [2.]]]]),
+ x.name:
+ np.array([[[[1.], [1.]], [[1.], [1.]]], [[[2.], [2.]],
+ [[2.], [2.]]]]),
+ })
# This is a smoke test, making sure expected values are unchanged.
self.assertEqual(len(res), 2)
self.assertAllClose(res[0], res[1].h)
@@ -964,36 +951,33 @@ class RNNCellTest(test.TestCase):
def testConv3DLSTMCell(self):
with self.test_session() as sess:
- shape = [2,2,2,1]
- filter_size = [3,3,3]
+ shape = [2, 2, 2, 1]
+ filter_size = [3, 3, 3]
num_features = 1
batch_size = 2
expected_state_c = np.array(
- [[[[[1.4375670191], [1.4375670191]],
- [[1.4375670191], [1.4375670191]]],
- [[[1.4375670191], [1.4375670191]],
- [[1.4375670191], [1.4375670191]]]],
- [[[[2.7542609292], [2.7542609292]],
- [[2.7542609292], [2.7542609292]]],
- [[[2.7542609292], [2.7542609292]],
- [[2.7542609292], [2.7542609292]]]]],
+ [[[[[1.4375670191], [1.4375670191]], [[1.4375670191], [1.4375670191]]
+ ], [[[1.4375670191], [1.4375670191]], [[1.4375670191],
+ [1.4375670191]]]],
+ [[[[2.7542609292], [2.7542609292]], [[2.7542609292], [2.7542609292]]
+ ], [[[2.7542609292], [2.7542609292]], [[2.7542609292],
+ [2.7542609292]]]]],
dtype=np.float32)
expected_state_h = np.array(
- [[[[[0.6529865603], [0.6529865603]],
- [[0.6529865603], [0.6529865603]]],
- [[[0.6529865603], [0.6529865603]],
- [[0.6529865603], [0.6529865603]]]],
- [[[[0.8736877431], [0.8736877431]],
- [[0.8736877431], [0.8736877431]]],
- [[[0.8736877431], [0.8736877431]],
- [[0.8736877431], [0.8736877431]]]]],
+ [[[[[0.6529865603], [0.6529865603]], [[0.6529865603], [0.6529865603]]
+ ], [[[0.6529865603], [0.6529865603]], [[0.6529865603],
+ [0.6529865603]]]],
+ [[[[0.8736877431], [0.8736877431]], [[0.8736877431], [0.8736877431]]
+ ], [[[0.8736877431], [0.8736877431]], [[0.8736877431],
+ [0.8736877431]]]]],
dtype=np.float32)
with variable_scope.variable_scope(
- "root", initializer=init_ops.constant_initializer(1.0/8.0)):
+ "root", initializer=init_ops.constant_initializer(1.0 / 8.0)):
x = array_ops.placeholder(dtypes.float32, [None, None, None, None, 1])
- cell = contrib_rnn_cell.Conv3DLSTMCell(input_shape=shape,
- kernel_shape=filter_size,
- output_channels=num_features)
+ cell = contrib_rnn_cell.Conv3DLSTMCell(
+ input_shape=shape,
+ kernel_shape=filter_size,
+ output_channels=num_features)
hidden = cell.zero_state(array_ops.shape(x)[0], dtypes.float32)
output, state = cell(x, hidden)
@@ -1056,8 +1040,8 @@ class RNNCellTest(test.TestCase):
num_units=num_units, number_of_groups=number_of_groups)
cell = rnn_cell.LSTMCell(num_units=num_units)
self.assertTrue(isinstance(gcell.state_size, tuple))
- zero_state = gcell.zero_state(batch_size=batch_size,
- dtype=dtypes.float32)
+ zero_state = gcell.zero_state(
+ batch_size=batch_size, dtype=dtypes.float32)
gh, gs = gcell(x, zero_state)
h, g = cell(x, zero_state)
@@ -1080,16 +1064,16 @@ class RNNCellTest(test.TestCase):
glstm_input = array_ops.ones([batch_size, num_units])
gcell = contrib_rnn_cell.GLSTMCell(
num_units=num_units, number_of_groups=number_of_groups)
- gcell_zero_state = gcell.zero_state(batch_size=batch_size,
- dtype=dtypes.float32)
+ gcell_zero_state = gcell.zero_state(
+ batch_size=batch_size, dtype=dtypes.float32)
gh, gs = gcell(glstm_input, gcell_zero_state)
# input for LSTM cell simulating single G-LSTM group
lstm_input = array_ops.ones([batch_size, num_units / number_of_groups])
# note division by number_of_groups. This cell one simulates G-LSTM group
cell = rnn_cell.LSTMCell(num_units=int(num_units / number_of_groups))
- cell_zero_state = cell.zero_state(batch_size=batch_size,
- dtype=dtypes.float32)
+ cell_zero_state = cell.zero_state(
+ batch_size=batch_size, dtype=dtypes.float32)
h, g = cell(lstm_input, cell_zero_state)
sess.run([variables.global_variables_initializer()])
@@ -1099,6 +1083,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(gh_res[:, int(num_units / number_of_groups):],
h_res, 1e-5)
+
class LayerNormBasicLSTMCellTest(test.TestCase):
# NOTE: all the values in the current test case have been calculated.
@@ -1119,13 +1104,14 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
cell = rnn_cell.MultiRNNCell([single_cell() for _ in range(2)])
g, out_m = cell(x, state)
sess.run([variables.global_variables_initializer()])
- res = sess.run([g, out_m], {
- x.name: np.array([[1., 1.]]),
- c0.name: 0.1 * np.asarray([[0, 1]]),
- h0.name: 0.1 * np.asarray([[2, 3]]),
- c1.name: 0.1 * np.asarray([[4, 5]]),
- h1.name: 0.1 * np.asarray([[6, 7]]),
- })
+ res = sess.run(
+ [g, out_m], {
+ x.name: np.array([[1., 1.]]),
+ c0.name: 0.1 * np.asarray([[0, 1]]),
+ h0.name: 0.1 * np.asarray([[2, 3]]),
+ c1.name: 0.1 * np.asarray([[4, 5]]),
+ h1.name: 0.1 * np.asarray([[6, 7]]),
+ })
expected_h = np.array([[-0.38079708, 0.38079708]])
expected_state0_c = np.array([[-1.0, 1.0]])
@@ -1155,11 +1141,12 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
cell = contrib_rnn_cell.LayerNormBasicLSTMCell(2)
g, out_m = cell(x, state)
sess.run([variables.global_variables_initializer()])
- res = sess.run([g, out_m], {
- x.name: np.array([[1., 1., 1.]]),
- c.name: 0.1 * np.asarray([[0, 1]]),
- h.name: 0.1 * np.asarray([[2, 3]]),
- })
+ res = sess.run(
+ [g, out_m], {
+ x.name: np.array([[1., 1., 1.]]),
+ c.name: 0.1 * np.asarray([[0, 1]]),
+ h.name: 0.1 * np.asarray([[2, 3]]),
+ })
expected_h = np.array([[-0.38079708, 0.38079708]])
expected_c = np.array([[-1.0, 1.0]])
@@ -1168,7 +1155,6 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
self.assertAllClose(res[1].c, expected_c, 1e-5)
self.assertAllClose(res[1].h, expected_h, 1e-5)
-
def testBasicLSTMCellWithoutNorm(self):
"""Tests that BasicLSTMCell with layer_norm=False."""
with self.test_session() as sess:
@@ -1186,19 +1172,20 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
cell = rnn_cell.MultiRNNCell([single_cell() for _ in range(2)])
g, out_m = cell(x, state)
sess.run([variables.global_variables_initializer()])
- res = sess.run([g, out_m], {
- x.name: np.array([[1., 1.]]),
- c0.name: 0.1 * np.asarray([[0, 1]]),
- h0.name: 0.1 * np.asarray([[2, 3]]),
- c1.name: 0.1 * np.asarray([[4, 5]]),
- h1.name: 0.1 * np.asarray([[6, 7]]),
- })
-
- expected_h = np.array([[ 0.70230919, 0.72581059]])
- expected_state0_c = np.array([[ 0.8020075, 0.89599884]])
- expected_state0_h = np.array([[ 0.56668288, 0.60858738]])
- expected_state1_c = np.array([[ 1.17500675, 1.26892781]])
- expected_state1_h = np.array([[ 0.70230919, 0.72581059]])
+ res = sess.run(
+ [g, out_m], {
+ x.name: np.array([[1., 1.]]),
+ c0.name: 0.1 * np.asarray([[0, 1]]),
+ h0.name: 0.1 * np.asarray([[2, 3]]),
+ c1.name: 0.1 * np.asarray([[4, 5]]),
+ h1.name: 0.1 * np.asarray([[6, 7]]),
+ })
+
+ expected_h = np.array([[0.70230919, 0.72581059]])
+ expected_state0_c = np.array([[0.8020075, 0.89599884]])
+ expected_state0_h = np.array([[0.56668288, 0.60858738]])
+ expected_state1_c = np.array([[1.17500675, 1.26892781]])
+ expected_state1_h = np.array([[0.70230919, 0.72581059]])
actual_h = res[0]
actual_state0_c = res[1][0].c
@@ -1215,21 +1202,22 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
with variable_scope.variable_scope(
"other", initializer=init_ops.constant_initializer(0.5)) as vs:
x = array_ops.zeros(
- [1, 3]) # Test BasicLSTMCell with input_size != num_units.
+ [1, 3]) # Test BasicLSTMCell with input_size != num_units.
c = array_ops.zeros([1, 2])
h = array_ops.zeros([1, 2])
state = rnn_cell.LSTMStateTuple(c, h)
cell = contrib_rnn_cell.LayerNormBasicLSTMCell(2, layer_norm=False)
g, out_m = cell(x, state)
sess.run([variables.global_variables_initializer()])
- res = sess.run([g, out_m], {
- x.name: np.array([[1., 1., 1.]]),
- c.name: 0.1 * np.asarray([[0, 1]]),
- h.name: 0.1 * np.asarray([[2, 3]]),
- })
-
- expected_h = np.array([[ 0.64121795, 0.68166804]])
- expected_c = np.array([[ 0.88477188, 0.98103917]])
+ res = sess.run(
+ [g, out_m], {
+ x.name: np.array([[1., 1., 1.]]),
+ c.name: 0.1 * np.asarray([[0, 1]]),
+ h.name: 0.1 * np.asarray([[2, 3]]),
+ })
+
+ expected_h = np.array([[0.64121795, 0.68166804]])
+ expected_c = np.array([[0.88477188, 0.98103917]])
self.assertEqual(len(res), 2)
self.assertAllClose(res[0], expected_h, 1e-5)
self.assertAllClose(res[1].c, expected_c, 1e-5)
@@ -1250,13 +1238,14 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
[contrib_rnn_cell.LayerNormBasicLSTMCell(2) for _ in range(2)])
h, (s0, s1) = cell(x, (state0, state1))
sess.run([variables.global_variables_initializer()])
- res = sess.run([h, s0, s1], {
- x.name: np.array([[1., 1.]]),
- c0.name: 0.1 * np.asarray([[0, 1]]),
- h0.name: 0.1 * np.asarray([[2, 3]]),
- c1.name: 0.1 * np.asarray([[4, 5]]),
- h1.name: 0.1 * np.asarray([[6, 7]]),
- })
+ res = sess.run(
+ [h, s0, s1], {
+ x.name: np.array([[1., 1.]]),
+ c0.name: 0.1 * np.asarray([[0, 1]]),
+ h0.name: 0.1 * np.asarray([[2, 3]]),
+ c1.name: 0.1 * np.asarray([[4, 5]]),
+ h1.name: 0.1 * np.asarray([[6, 7]]),
+ })
expected_h = np.array([[-0.38079708, 0.38079708]])
expected_h0 = np.array([[-0.38079708, 0.38079708]])
@@ -1344,11 +1333,12 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
g, s = cell(x, state)
sess.run([variables.global_variables_initializer()])
- res = sess.run([g, s], {
- x.name: np.ones([1, 5]),
- c.name: np.ones([1, 5]),
- h.name: np.ones([1, 5]),
- })
+ res = sess.run(
+ [g, s], {
+ x.name: np.ones([1, 5]),
+ c.name: np.ones([1, 5]),
+ h.name: np.ones([1, 5]),
+ })
# Since the returned tensors are of size [1,n]
# get the first component right now.
@@ -1374,35 +1364,35 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
self.assertIn(dropped_count, allowed_low)
-def _create_multi_lstm_cell_ops(batch_size, num_units, input_depth,
- num_layers, max_time, compiled):
+def _create_multi_lstm_cell_ops(batch_size, num_units, input_depth, num_layers,
+ max_time, compiled):
with variable_scope.variable_scope(
"root",
initializer=init_ops.random_uniform_initializer(-0.1, 0.1, seed=2)):
inputs = variable_scope.get_variable(
- "inputs", initializer=random_ops.random_uniform(
+ "inputs",
+ initializer=random_ops.random_uniform(
(max_time, batch_size, input_depth), seed=1))
maybe_xla = lambda c: contrib_rnn_cell.CompiledWrapper(c) if compiled else c
cell = rnn_cell.MultiRNNCell(
[maybe_xla(rnn_cell.LSTMCell(num_units)) for _ in range(num_layers)])
- initial_state = cell.zero_state(
- batch_size=batch_size, dtype=dtypes.float32)
+ initial_state = cell.zero_state(batch_size=batch_size, dtype=dtypes.float32)
outputs, final_state = rnn.dynamic_rnn(
- cell=cell, inputs=inputs, initial_state=initial_state,
- time_major=True)
+ cell=cell, inputs=inputs, initial_state=initial_state, time_major=True)
flat_final_state = nest.flatten(final_state)
trainable_variables = variables.trainable_variables()
outputs_grad = gradients_impl.gradients(
- [outputs],
- trainable_variables + [inputs] + nest.flatten(initial_state))
+ [outputs], trainable_variables + [inputs] + nest.flatten(initial_state))
final_state_grad = gradients_impl.gradients(
flat_final_state,
trainable_variables + [inputs] + nest.flatten(initial_state))
- return {"outputs": outputs,
- "final_state": flat_final_state,
- "outputs_grad": outputs_grad,
- "final_state_grad": final_state_grad}
+ return {
+ "outputs": outputs,
+ "final_state": flat_final_state,
+ "outputs_grad": outputs_grad,
+ "final_state_grad": final_state_grad
+ }
class CompiledWrapperTest(test.TestCase):
@@ -1420,8 +1410,10 @@ class CompiledWrapperTest(test.TestCase):
random_seed.set_random_seed(1234)
with self.test_session(graph=ops.Graph()) as sess:
xla_ops = _create_multi_lstm_cell_ops(
- batch_size=batch_size, num_units=num_units,
- input_depth=input_depth, num_layers=num_layers,
+ batch_size=batch_size,
+ num_units=num_units,
+ input_depth=input_depth,
+ num_layers=num_layers,
max_time=max_time,
compiled=True)
sess.run([variables.global_variables_initializer()])
@@ -1430,8 +1422,10 @@ class CompiledWrapperTest(test.TestCase):
random_seed.set_random_seed(1234)
with self.test_session(graph=ops.Graph()) as sess:
non_xla_ops = _create_multi_lstm_cell_ops(
- batch_size=batch_size, num_units=num_units,
- input_depth=input_depth, num_layers=num_layers,
+ batch_size=batch_size,
+ num_units=num_units,
+ input_depth=input_depth,
+ num_layers=num_layers,
max_time=max_time,
compiled=False)
sess.run([variables.global_variables_initializer()])
@@ -1440,16 +1434,16 @@ class CompiledWrapperTest(test.TestCase):
self.assertAllClose(
non_xla_results["outputs"], xla_results["outputs"], atol=atol)
- for xla_value, non_xla_value in zip(
- xla_results["final_state"], non_xla_results["final_state"]):
+ for xla_value, non_xla_value in zip(xla_results["final_state"],
+ non_xla_results["final_state"]):
self.assertAllClose(xla_value, non_xla_value, atol=atol)
- for xla_g, non_xla_g in zip(
- xla_results["outputs_grad"], non_xla_results["outputs_grad"]):
+ for xla_g, non_xla_g in zip(xla_results["outputs_grad"],
+ non_xla_results["outputs_grad"]):
self.assertAllClose(xla_g, non_xla_g, atol=atol)
- for xla_g, non_xla_g in zip(
- xla_results["final_state_grad"], non_xla_results["final_state_grad"]):
+ for xla_g, non_xla_g in zip(xla_results["final_state_grad"],
+ non_xla_results["final_state_grad"]):
self.assertAllClose(xla_g, non_xla_g, atol=atol)
def testMultiRNNCellWithStateTuple(self):
@@ -1463,19 +1457,20 @@ class CompiledWrapperTest(test.TestCase):
# Test incorrectness of state
with self.assertRaisesRegexp(ValueError, "Expected state .* a tuple"):
rnn_cell.MultiRNNCell(
- [rnn_cell.GRUCell(2)
- for _ in range(2)], state_is_tuple=True)(x, m_bad)
+ [rnn_cell.GRUCell(2) for _ in range(2)],
+ state_is_tuple=True)(x, m_bad)
_, ml = rnn_cell.MultiRNNCell(
- [rnn_cell.GRUCell(2)
- for _ in range(2)], state_is_tuple=True)(x, m_good)
+ [rnn_cell.GRUCell(2) for _ in range(2)],
+ state_is_tuple=True)(x, m_good)
sess.run([variables.global_variables_initializer()])
- res = sess.run(ml, {
- x.name: np.array([[1., 1.]]),
- m_good[0].name: np.array([[0.1, 0.1]]),
- m_good[1].name: np.array([[0.1, 0.1]])
- })
+ res = sess.run(
+ ml, {
+ x.name: np.array([[1., 1.]]),
+ m_good[0].name: np.array([[0.1, 0.1]]),
+ m_good[1].name: np.array([[0.1, 0.1]])
+ })
# The numbers in results were not calculated, this is just a
# smoke test. However, these numbers should match those of
@@ -1490,24 +1485,20 @@ class BenchmarkLSTMCellXLA(test.Benchmark):
num_layers = 3
max_time = 50
print("benchmarkDynamicRNNWithMultiLSTMCell")
- print("\t" +
- "\t".join(["inter_th", "intra_th",
- "batch_size", "num_units", "input_depth", "device",
- "compiled", "wall_time"]))
+ print("\t" + "\t".join([
+ "inter_th", "intra_th", "batch_size", "num_units", "input_depth",
+ "device", "compiled", "wall_time"
+ ]))
warmup_run = True
- for (threads,
- device,
- num_units,
- batch_size,
- input_depth,
- compiled) in itertools.product(
- [{"inter": 0, "intra": 0}, {"inter": 1, "intra": 4}],
- ["cpu", "gpu"],
- [32, 512],
- [1, 32, 256],
- [32, 512],
- [False, True]):
+ for (threads, device, num_units, batch_size, input_depth,
+ compiled) in itertools.product([{
+ "inter": 0,
+ "intra": 0
+ }, {
+ "inter": 1,
+ "intra": 4
+ }], ["cpu", "gpu"], [32, 512], [1, 32, 256], [32, 512], [False, True]):
if threads["inter"] != 0:
# We only care about testing inter/intra op limitations on
# CPU with small batch size, to mimic embedded devices.
@@ -1523,30 +1514,35 @@ class BenchmarkLSTMCellXLA(test.Benchmark):
with session.Session(config=config, graph=ops.Graph()) as sess:
with ops.device("/%s:0" % device):
ops_dict = _create_multi_lstm_cell_ops(
- batch_size=batch_size, num_units=num_units,
- input_depth=input_depth, num_layers=num_layers,
+ batch_size=batch_size,
+ num_units=num_units,
+ input_depth=input_depth,
+ num_layers=num_layers,
max_time=max_time,
compiled=compiled)
sess.run([variables.global_variables_initializer()])
all_ops = nest.flatten(ops_dict.values())
all_ops_group = control_flow_ops.group(*all_ops)
- name_suffix = (
- "inter_th_%d_intra_th_%d_bs_%d_units_%d_inputdepth_%d"
- "_device_%s_xla_%s" % (
- threads["inter"], threads["intra"],
- batch_size, num_units, input_depth, device, compiled))
+ name_suffix = ("inter_th_%d_intra_th_%d_bs_%d_units_%d_inputdepth_%d"
+ "_device_%s_xla_%s" %
+ (threads["inter"], threads["intra"], batch_size,
+ num_units, input_depth, device, compiled))
if warmup_run:
self.run_op_benchmark(
sess, all_ops_group, min_iters=30, name="ignore_warmup")
warmup_run = False
benchmark_results = self.run_op_benchmark(
- sess, all_ops_group, min_iters=50,
+ sess,
+ all_ops_group,
+ min_iters=50,
name="benchmarkDynamicRNNWithMultiLSTMCell_%s" % name_suffix)
- print("\t" +
- "\t".join(["%s" % x for x in [
- threads["inter"], threads["intra"],
- batch_size, num_units, input_depth, device, compiled,
- benchmark_results["wall_time"]]]))
+ print("\t" + "\t".join([
+ "%s" % x
+ for x in [
+ threads["inter"], threads["intra"], batch_size, num_units,
+ input_depth, device, compiled, benchmark_results["wall_time"]
+ ]
+ ]))
class WeightNormLSTMCellTest(test.TestCase):
@@ -1557,8 +1553,7 @@ class WeightNormLSTMCellTest(test.TestCase):
with self.test_session() as sess:
init = init_ops.constant_initializer(0.5)
- with variable_scope.variable_scope("root",
- initializer=init):
+ with variable_scope.variable_scope("root", initializer=init):
x = array_ops.zeros([1, 2])
c0 = array_ops.zeros([1, 2])
h0 = array_ops.zeros([1, 2])
@@ -1568,11 +1563,12 @@ class WeightNormLSTMCellTest(test.TestCase):
xout, sout = cell()(x, state0)
sess.run([variables.global_variables_initializer()])
- res = sess.run([xout, sout], {
- x.name: np.array([[1., 1.]]),
- c0.name: 0.1 * np.asarray([[0, 1]]),
- h0.name: 0.1 * np.asarray([[2, 3]]),
- })
+ res = sess.run(
+ [xout, sout], {
+ x.name: np.array([[1., 1.]]),
+ c0.name: 0.1 * np.asarray([[0, 1]]),
+ h0.name: 0.1 * np.asarray([[2, 3]]),
+ })
actual_state_c = res[1].c
actual_state_h = res[1].h
@@ -1583,9 +1579,8 @@ class WeightNormLSTMCellTest(test.TestCase):
"""Tests cell w/o peepholes and w/o normalisation"""
def cell():
- return contrib_rnn_cell.WeightNormLSTMCell(2,
- norm=False,
- use_peepholes=False)
+ return contrib_rnn_cell.WeightNormLSTMCell(
+ 2, norm=False, use_peepholes=False)
actual_c, actual_h = self._cell_output(cell)
@@ -1599,9 +1594,8 @@ class WeightNormLSTMCellTest(test.TestCase):
"""Tests cell with peepholes and w/o normalisation"""
def cell():
- return contrib_rnn_cell.WeightNormLSTMCell(2,
- norm=False,
- use_peepholes=True)
+ return contrib_rnn_cell.WeightNormLSTMCell(
+ 2, norm=False, use_peepholes=True)
actual_c, actual_h = self._cell_output(cell)
@@ -1611,14 +1605,12 @@ class WeightNormLSTMCellTest(test.TestCase):
self.assertAllClose(expected_c, actual_c, 1e-5)
self.assertAllClose(expected_h, actual_h, 1e-5)
-
def testBasicCellWithNorm(self):
"""Tests cell w/o peepholes and with normalisation"""
def cell():
- return contrib_rnn_cell.WeightNormLSTMCell(2,
- norm=True,
- use_peepholes=False)
+ return contrib_rnn_cell.WeightNormLSTMCell(
+ 2, norm=True, use_peepholes=False)
actual_c, actual_h = self._cell_output(cell)
@@ -1632,9 +1624,8 @@ class WeightNormLSTMCellTest(test.TestCase):
"""Tests cell with peepholes and with normalisation"""
def cell():
- return contrib_rnn_cell.WeightNormLSTMCell(2,
- norm=True,
- use_peepholes=True)
+ return contrib_rnn_cell.WeightNormLSTMCell(
+ 2, norm=True, use_peepholes=True)
actual_c, actual_h = self._cell_output(cell)
@@ -1644,5 +1635,6 @@ class WeightNormLSTMCellTest(test.TestCase):
self.assertAllClose(expected_c, actual_c, 1e-5)
self.assertAllClose(expected_h, actual_h, 1e-5)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
index d7ae6621db..8adf5dce6e 100644
--- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py
+++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Module for constructing RNN Cells."""
from __future__ import absolute_import
from __future__ import division
@@ -56,16 +55,15 @@ def _get_concat_variable(name, shape, dtype, num_shards):
return value
concat_variable = array_ops.concat(sharded_variable, 0, name=concat_name)
- ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES,
- concat_variable)
+ ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES, concat_variable)
return concat_variable
def _get_sharded_variable(name, shape, dtype, num_shards):
"""Get a list of sharded variables with the given dtype."""
if num_shards > shape[0]:
- raise ValueError("Too many shards: shape=%s, num_shards=%d" %
- (shape, num_shards))
+ raise ValueError("Too many shards: shape=%s, num_shards=%d" % (shape,
+ num_shards))
unit_shard_size = int(math.floor(shape[0] / num_shards))
remaining_rows = shape[0] - unit_shard_size * num_shards
@@ -74,8 +72,9 @@ def _get_sharded_variable(name, shape, dtype, num_shards):
current_size = unit_shard_size
if i < remaining_rows:
current_size += 1
- shards.append(vs.get_variable(name + "_%d" % i, [current_size] + shape[1:],
- dtype=dtype))
+ shards.append(
+ vs.get_variable(
+ name + "_%d" % i, [current_size] + shape[1:], dtype=dtype))
return shards
@@ -177,9 +176,8 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell):
"""
super(CoupledInputForgetGateLSTMCell, self).__init__(_reuse=reuse)
if not state_is_tuple:
- logging.warn(
- "%s: Using a concatenated state is slower and will soon be "
- "deprecated. Use state_is_tuple=True.", self)
+ logging.warn("%s: Using a concatenated state is slower and will soon be "
+ "deprecated. Use state_is_tuple=True.", self)
self._num_units = num_units
self._use_peepholes = use_peepholes
self._initializer = initializer
@@ -196,12 +194,14 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell):
self._norm_shift = norm_shift
if num_proj:
- self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
- if state_is_tuple else num_units + num_proj)
+ self._state_size = (
+ rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
+ if state_is_tuple else num_units + num_proj)
self._output_size = num_proj
else:
- self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_units)
- if state_is_tuple else 2 * num_units)
+ self._state_size = (
+ rnn_cell_impl.LSTMStateTuple(num_units, num_units)
+ if state_is_tuple else 2 * num_units)
self._output_size = num_units
@property
@@ -251,8 +251,8 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell):
if input_size.value is None:
raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
concat_w = _get_concat_variable(
- "W", [input_size.value + num_proj, 3 * self._num_units],
- dtype, self._num_unit_shards)
+ "W", [input_size.value + num_proj, 3 * self._num_units], dtype,
+ self._num_unit_shards)
b = vs.get_variable(
"B",
@@ -299,9 +299,9 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell):
m = sigmoid(o) * self._activation(c)
if self._num_proj is not None:
- concat_w_proj = _get_concat_variable(
- "W_P", [self._num_units, self._num_proj],
- dtype, self._num_proj_shards)
+ concat_w_proj = _get_concat_variable("W_P",
+ [self._num_units, self._num_proj],
+ dtype, self._num_proj_shards)
m = math_ops.matmul(m, concat_w_proj)
if self._proj_clip is not None:
@@ -309,8 +309,9 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell):
m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
# pylint: enable=invalid-unary-operand-type
- new_state = (rnn_cell_impl.LSTMStateTuple(c, m)
- if self._state_is_tuple else array_ops.concat([c, m], 1))
+ new_state = (
+ rnn_cell_impl.LSTMStateTuple(c, m)
+ if self._state_is_tuple else array_ops.concat([c, m], 1))
return m, new_state
@@ -326,10 +327,15 @@ class TimeFreqLSTMCell(rnn_cell_impl.RNNCell):
It uses peep-hole connections and optional cell clipping.
"""
- def __init__(self, num_units, use_peepholes=False,
- cell_clip=None, initializer=None,
- num_unit_shards=1, forget_bias=1.0,
- feature_size=None, frequency_skip=1,
+ def __init__(self,
+ num_units,
+ use_peepholes=False,
+ cell_clip=None,
+ initializer=None,
+ num_unit_shards=1,
+ forget_bias=1.0,
+ feature_size=None,
+ frequency_skip=1,
reuse=None):
"""Initialize the parameters for an LSTM cell.
@@ -399,7 +405,7 @@ class TimeFreqLSTMCell(rnn_cell_impl.RNNCell):
actual_input_size = freq_inputs[0].get_shape().as_list()[1]
concat_w = _get_concat_variable(
- "W", [actual_input_size + 2*self._num_units, 4 * self._num_units],
+ "W", [actual_input_size + 2 * self._num_units, 4 * self._num_units],
dtype, self._num_unit_shards)
b = vs.get_variable(
@@ -418,23 +424,23 @@ class TimeFreqLSTMCell(rnn_cell_impl.RNNCell):
"W_O_diag", shape=[self._num_units], dtype=dtype)
# initialize the first freq state to be zero
- m_prev_freq = array_ops.zeros([int(inputs.get_shape()[0]),
- self._num_units], dtype)
+ m_prev_freq = array_ops.zeros([int(inputs.get_shape()[0]), self._num_units],
+ dtype)
for fq in range(len(freq_inputs)):
- c_prev = array_ops.slice(state, [0, 2*fq*self._num_units],
+ c_prev = array_ops.slice(state, [0, 2 * fq * self._num_units],
[-1, self._num_units])
- m_prev = array_ops.slice(state, [0, (2*fq+1)*self._num_units],
+ m_prev = array_ops.slice(state, [0, (2 * fq + 1) * self._num_units],
[-1, self._num_units])
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
- cell_inputs = array_ops.concat([freq_inputs[fq], m_prev, m_prev_freq],
- 1)
+ cell_inputs = array_ops.concat([freq_inputs[fq], m_prev, m_prev_freq], 1)
lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b)
i, j, f, o = array_ops.split(
value=lstm_matrix, num_or_size_splits=4, axis=1)
if self._use_peepholes:
- c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
- sigmoid(i + w_i_diag * c_prev) * tanh(j))
+ c = (
+ sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
+ sigmoid(i + w_i_diag * c_prev) * tanh(j))
else:
c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * tanh(j))
@@ -472,11 +478,11 @@ class TimeFreqLSTMCell(rnn_cell_impl.RNNCell):
input_size = input_feat.get_shape().with_rank(2)[-1].value
if input_size is None:
raise ValueError("Cannot infer input_size from static shape inference.")
- num_feats = int((input_size - self._feature_size) / (
- self._frequency_skip)) + 1
+ num_feats = int(
+ (input_size - self._feature_size) / (self._frequency_skip)) + 1
freq_inputs = []
for f in range(num_feats):
- cur_input = array_ops.slice(input_feat, [0, f*self._frequency_skip],
+ cur_input = array_ops.slice(input_feat, [0, f * self._frequency_skip],
[-1, self._feature_size])
freq_inputs.append(cur_input)
return freq_inputs
@@ -498,11 +504,16 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
The code uses optional peephole connections, shared_weights and cell clipping.
"""
- def __init__(self, num_units, use_peepholes=False,
+ def __init__(self,
+ num_units,
+ use_peepholes=False,
share_time_frequency_weights=False,
- cell_clip=None, initializer=None,
- num_unit_shards=1, forget_bias=1.0,
- feature_size=None, frequency_skip=None,
+ cell_clip=None,
+ initializer=None,
+ num_unit_shards=1,
+ forget_bias=1.0,
+ feature_size=None,
+ frequency_skip=None,
num_frequency_blocks=None,
start_freqindex_list=None,
end_freqindex_list=None,
@@ -580,10 +591,10 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
for freq_index in range(self._num_frequency_blocks[block_index]):
name_prefix = "state_f%02d_b%02d" % (freq_index, block_index)
state_names += ("%s_c, %s_m," % (name_prefix, name_prefix))
- self._state_tuple_type = collections.namedtuple(
- "GridLSTMStateTuple", state_names.strip(","))
- self._state_size = self._state_tuple_type(
- *([num_units, num_units] * self._total_blocks))
+ self._state_tuple_type = collections.namedtuple("GridLSTMStateTuple",
+ state_names.strip(","))
+ self._state_size = self._state_tuple_type(*(
+ [num_units, num_units] * self._total_blocks))
else:
self._state_tuple_type = None
self._state_size = num_units * self._total_blocks * 2
@@ -626,7 +637,10 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
state_out_lst = []
for block in range(len(freq_inputs)):
m_out_lst_current, state_out_lst_current = self._compute(
- freq_inputs[block], block, state, batch_size,
+ freq_inputs[block],
+ block,
+ state,
+ batch_size,
state_is_tuple=self._state_is_tuple)
m_out_lst.extend(m_out_lst_current)
state_out_lst.extend(state_out_lst_current)
@@ -637,7 +651,11 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
m_out = array_ops.concat(m_out_lst, 1)
return m_out, state_out
- def _compute(self, freq_inputs, block, state, batch_size,
+ def _compute(self,
+ freq_inputs,
+ block,
+ state,
+ batch_size,
state_prefix="state",
state_is_tuple=True):
"""Run the actual computation of one step LSTM.
@@ -666,8 +684,8 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
actual_input_size = freq_inputs[0].get_shape().as_list()[1]
concat_w_f = _get_concat_variable(
- "W_f_%d" % block, [actual_input_size + 2 * self._num_units,
- num_gates * self._num_units],
+ "W_f_%d" % block,
+ [actual_input_size + 2 * self._num_units, num_gates * self._num_units],
dtype, self._num_unit_shards)
b_f = vs.get_variable(
"B_f_%d" % block,
@@ -675,10 +693,9 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
initializer=init_ops.zeros_initializer(),
dtype=dtype)
if not self._share_time_frequency_weights:
- concat_w_t = _get_concat_variable(
- "W_t_%d" % block, [actual_input_size + 2 * self._num_units,
- num_gates * self._num_units],
- dtype, self._num_unit_shards)
+ concat_w_t = _get_concat_variable("W_t_%d" % block, [
+ actual_input_size + 2 * self._num_units, num_gates * self._num_units
+ ], dtype, self._num_unit_shards)
b_t = vs.get_variable(
"B_t_%d" % block,
shape=[num_gates * self._num_units],
@@ -691,7 +708,7 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
w_f_diag_freqf = vs.get_variable(
"W_F_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype)
w_f_diag_freqt = vs.get_variable(
- "W_F_diag_freqt_%d"% block, shape=[self._num_units], dtype=dtype)
+ "W_F_diag_freqt_%d" % block, shape=[self._num_units], dtype=dtype)
w_i_diag_freqf = vs.get_variable(
"W_I_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype)
w_i_diag_freqt = vs.get_variable(
@@ -725,8 +742,7 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
m_prev_time = getattr(state, name_prefix + "_m")
else:
c_prev_time = array_ops.slice(
- state, [0, 2 * freq_index * self._num_units],
- [-1, self._num_units])
+ state, [0, 2 * freq_index * self._num_units], [-1, self._num_units])
m_prev_time = array_ops.slice(
state, [0, (2 * freq_index + 1) * self._num_units],
[-1, self._num_units])
@@ -736,8 +752,8 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
[freq_inputs[freq_index], m_prev_time, m_prev_freq], 1)
# F-LSTM
- lstm_matrix_freq = nn_ops.bias_add(math_ops.matmul(cell_inputs,
- concat_w_f), b_f)
+ lstm_matrix_freq = nn_ops.bias_add(
+ math_ops.matmul(cell_inputs, concat_w_f), b_f)
if self._couple_input_forget_gates:
i_freq, j_freq, o_freq = array_ops.split(
value=lstm_matrix_freq, num_or_size_splits=num_gates, axis=1)
@@ -752,8 +768,8 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
f_time = f_freq
o_time = o_freq
else:
- lstm_matrix_time = nn_ops.bias_add(math_ops.matmul(cell_inputs,
- concat_w_t), b_t)
+ lstm_matrix_time = nn_ops.bias_add(
+ math_ops.matmul(cell_inputs, concat_w_t), b_t)
if self._couple_input_forget_gates:
i_time, j_time, o_time = array_ops.split(
value=lstm_matrix_time, num_or_size_splits=num_gates, axis=1)
@@ -765,8 +781,7 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
# F-LSTM c_freq
# input gate activations
if self._use_peepholes:
- i_freq_g = sigmoid(i_freq +
- w_i_diag_freqf * c_prev_freq +
+ i_freq_g = sigmoid(i_freq + w_i_diag_freqf * c_prev_freq +
w_i_diag_freqt * c_prev_time)
else:
i_freq_g = sigmoid(i_freq)
@@ -775,9 +790,8 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
f_freq_g = 1.0 - i_freq_g
else:
if self._use_peepholes:
- f_freq_g = sigmoid(f_freq + self._forget_bias +
- w_f_diag_freqf * c_prev_freq +
- w_f_diag_freqt * c_prev_time)
+ f_freq_g = sigmoid(f_freq + self._forget_bias + w_f_diag_freqf *
+ c_prev_freq + w_f_diag_freqt * c_prev_time)
else:
f_freq_g = sigmoid(f_freq + self._forget_bias)
# cell state
@@ -792,12 +806,10 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
# input gate activations
if self._use_peepholes:
if self._share_time_frequency_weights:
- i_time_g = sigmoid(i_time +
- w_i_diag_freqf * c_prev_freq +
+ i_time_g = sigmoid(i_time + w_i_diag_freqf * c_prev_freq +
w_i_diag_freqt * c_prev_time)
else:
- i_time_g = sigmoid(i_time +
- w_i_diag_timef * c_prev_freq +
+ i_time_g = sigmoid(i_time + w_i_diag_timef * c_prev_freq +
w_i_diag_timet * c_prev_time)
else:
i_time_g = sigmoid(i_time)
@@ -807,13 +819,11 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
else:
if self._use_peepholes:
if self._share_time_frequency_weights:
- f_time_g = sigmoid(f_time + self._forget_bias +
- w_f_diag_freqf * c_prev_freq +
- w_f_diag_freqt * c_prev_time)
+ f_time_g = sigmoid(f_time + self._forget_bias + w_f_diag_freqf *
+ c_prev_freq + w_f_diag_freqt * c_prev_time)
else:
- f_time_g = sigmoid(f_time + self._forget_bias +
- w_f_diag_timef * c_prev_freq +
- w_f_diag_timet * c_prev_time)
+ f_time_g = sigmoid(f_time + self._forget_bias + w_f_diag_timef *
+ c_prev_freq + w_f_diag_timet * c_prev_time)
else:
f_time_g = sigmoid(f_time + self._forget_bias)
# cell state
@@ -826,8 +836,7 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
# F-LSTM m_freq
if self._use_peepholes:
- m_freq = sigmoid(o_freq +
- w_o_diag_freqf * c_freq +
+ m_freq = sigmoid(o_freq + w_o_diag_freqf * c_freq +
w_o_diag_freqt * c_time) * tanh(c_freq)
else:
m_freq = sigmoid(o_freq) * tanh(c_freq)
@@ -835,12 +844,10 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
# T-LSTM m_time
if self._use_peepholes:
if self._share_time_frequency_weights:
- m_time = sigmoid(o_time +
- w_o_diag_freqf * c_freq +
+ m_time = sigmoid(o_time + w_o_diag_freqf * c_freq +
w_o_diag_freqt * c_time) * tanh(c_time)
else:
- m_time = sigmoid(o_time +
- w_o_diag_timef * c_freq +
+ m_time = sigmoid(o_time + w_o_diag_timef * c_freq +
w_o_diag_timet * c_time) * tanh(c_time)
else:
m_time = sigmoid(o_time) * tanh(c_time)
@@ -879,16 +886,18 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
raise ValueError("Cannot infer input_size from static shape inference.")
if slice_offset > 0:
# Padding to the end
- inputs = array_ops.pad(
- input_feat, array_ops.constant([0, 0, 0, slice_offset], shape=[2, 2],
- dtype=dtypes.int32),
- "CONSTANT")
+ inputs = array_ops.pad(input_feat,
+ array_ops.constant(
+ [0, 0, 0, slice_offset],
+ shape=[2, 2],
+ dtype=dtypes.int32), "CONSTANT")
elif slice_offset < 0:
# Padding to the front
- inputs = array_ops.pad(
- input_feat, array_ops.constant([0, 0, -slice_offset, 0], shape=[2, 2],
- dtype=dtypes.int32),
- "CONSTANT")
+ inputs = array_ops.pad(input_feat,
+ array_ops.constant(
+ [0, 0, -slice_offset, 0],
+ shape=[2, 2],
+ dtype=dtypes.int32), "CONSTANT")
slice_offset = 0
else:
inputs = input_feat
@@ -898,13 +907,13 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
raise ValueError("Length of num_frequency_blocks"
" is not 1, but instead is %d",
len(self._num_frequency_blocks))
- num_feats = int((input_size - self._feature_size) / (
- self._frequency_skip)) + 1
+ num_feats = int(
+ (input_size - self._feature_size) / (self._frequency_skip)) + 1
if num_feats != self._num_frequency_blocks[0]:
raise ValueError(
"Invalid num_frequency_blocks, requires %d but gets %d, please"
- " check the input size and filter config are correct." % (
- self._num_frequency_blocks[0], num_feats))
+ " check the input size and filter config are correct." %
+ (self._num_frequency_blocks[0], num_feats))
block_inputs = []
for f in range(num_feats):
cur_input = array_ops.slice(
@@ -927,18 +936,18 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
start_index = self._start_freqindex_list[b]
end_index = self._end_freqindex_list[b]
cur_size = end_index - start_index
- block_feats = int((cur_size - self._feature_size) / (
- self._frequency_skip)) + 1
+ block_feats = int(
+ (cur_size - self._feature_size) / (self._frequency_skip)) + 1
if block_feats != self._num_frequency_blocks[b]:
raise ValueError(
"Invalid num_frequency_blocks, requires %d but gets %d, please"
- " check the input size and filter config are correct." % (
- self._num_frequency_blocks[b], block_feats))
+ " check the input size and filter config are correct." %
+ (self._num_frequency_blocks[b], block_feats))
block_inputs = []
for f in range(block_feats):
cur_input = array_ops.slice(
- inputs, [0, start_index + slice_offset + f *
- self._frequency_skip],
+ inputs,
+ [0, start_index + slice_offset + f * self._frequency_skip],
[-1, self._feature_size])
block_inputs.append(cur_input)
freq_inputs.append(block_inputs)
@@ -954,11 +963,16 @@ class BidirectionalGridLSTMCell(GridLSTMCell):
The current implementation uses different weights for the two directions.
"""
- def __init__(self, num_units, use_peepholes=False,
+ def __init__(self,
+ num_units,
+ use_peepholes=False,
share_time_frequency_weights=False,
- cell_clip=None, initializer=None,
- num_unit_shards=1, forget_bias=1.0,
- feature_size=None, frequency_skip=None,
+ cell_clip=None,
+ initializer=None,
+ num_unit_shards=1,
+ forget_bias=1.0,
+ feature_size=None,
+ frequency_skip=None,
num_frequency_blocks=None,
start_freqindex_list=None,
end_freqindex_list=None,
@@ -1017,8 +1031,8 @@ class BidirectionalGridLSTMCell(GridLSTMCell):
state_names += ("%s_c, %s_m," % (name_prefix, name_prefix))
self._state_tuple_type = collections.namedtuple(
"BidirectionalGridLSTMStateTuple", state_names.strip(","))
- self._state_size = self._state_tuple_type(
- *([num_units, num_units] * self._total_blocks * 2))
+ self._state_size = self._state_tuple_type(*(
+ [num_units, num_units] * self._total_blocks * 2))
self._output_size = 2 * num_units * self._total_blocks * 2
def call(self, inputs, state):
@@ -1052,8 +1066,12 @@ class BidirectionalGridLSTMCell(GridLSTMCell):
fwd_state_out_lst = []
for block in range(len(fwd_inputs)):
fwd_m_out_lst_current, fwd_state_out_lst_current = self._compute(
- fwd_inputs[block], block, state, batch_size,
- state_prefix="fwd_state", state_is_tuple=True)
+ fwd_inputs[block],
+ block,
+ state,
+ batch_size,
+ state_prefix="fwd_state",
+ state_is_tuple=True)
fwd_m_out_lst.extend(fwd_m_out_lst_current)
fwd_state_out_lst.extend(fwd_state_out_lst_current)
# Backward processing
@@ -1064,8 +1082,12 @@ class BidirectionalGridLSTMCell(GridLSTMCell):
# Reverse the blocks
bwd_inputs_reverse = bwd_inputs[block][::-1]
bwd_m_out_lst_current, bwd_state_out_lst_current = self._compute(
- bwd_inputs_reverse, block, state, batch_size,
- state_prefix="bwd_state", state_is_tuple=True)
+ bwd_inputs_reverse,
+ block,
+ state,
+ batch_size,
+ state_prefix="bwd_state",
+ state_is_tuple=True)
bwd_m_out_lst.extend(bwd_m_out_lst_current)
bwd_state_out_lst.extend(bwd_state_out_lst_current)
state_out = self._state_tuple_type(*(fwd_state_out_lst + bwd_state_out_lst))
@@ -1076,6 +1098,7 @@ class BidirectionalGridLSTMCell(GridLSTMCell):
# pylint: disable=protected-access
_Linear = core_rnn_cell._Linear # pylint: disable=invalid-name
+
# pylint: enable=protected-access
@@ -1085,8 +1108,14 @@ class AttentionCellWrapper(rnn_cell_impl.RNNCell):
Implementation based on https://arxiv.org/abs/1409.0473.
"""
- def __init__(self, cell, attn_length, attn_size=None, attn_vec_size=None,
- input_size=None, state_is_tuple=True, reuse=None):
+ def __init__(self,
+ cell,
+ attn_length,
+ attn_size=None,
+ attn_vec_size=None,
+ input_size=None,
+ state_is_tuple=True,
+ reuse=None):
"""Create a cell with attention.
Args:
@@ -1116,16 +1145,15 @@ class AttentionCellWrapper(rnn_cell_impl.RNNCell):
if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access
raise TypeError("The parameter cell is not RNNCell.")
if nest.is_sequence(cell.state_size) and not state_is_tuple:
- raise ValueError("Cell returns tuple of states, but the flag "
- "state_is_tuple is not set. State size is: %s"
- % str(cell.state_size))
+ raise ValueError(
+ "Cell returns tuple of states, but the flag "
+ "state_is_tuple is not set. State size is: %s" % str(cell.state_size))
if attn_length <= 0:
- raise ValueError("attn_length should be greater than zero, got %s"
- % str(attn_length))
+ raise ValueError(
+ "attn_length should be greater than zero, got %s" % str(attn_length))
if not state_is_tuple:
- logging.warn(
- "%s: Using a concatenated state is slower and will soon be "
- "deprecated. Use state_is_tuple=True.", self)
+ logging.warn("%s: Using a concatenated state is slower and will soon be "
+ "deprecated. Use state_is_tuple=True.", self)
if attn_size is None:
attn_size = cell.output_size
if attn_vec_size is None:
@@ -1161,8 +1189,8 @@ class AttentionCellWrapper(rnn_cell_impl.RNNCell):
else:
states = state
state = array_ops.slice(states, [0, 0], [-1, self._cell.state_size])
- attns = array_ops.slice(
- states, [0, self._cell.state_size], [-1, self._attn_size])
+ attns = array_ops.slice(states, [0, self._cell.state_size],
+ [-1, self._attn_size])
attn_states = array_ops.slice(
states, [0, self._cell.state_size + self._attn_size],
[-1, self._attn_size * self._attn_length])
@@ -1200,8 +1228,8 @@ class AttentionCellWrapper(rnn_cell_impl.RNNCell):
tanh = math_ops.tanh
with vs.variable_scope("attention"):
- k = vs.get_variable(
- "attn_w", [1, 1, self._attn_size, self._attn_vec_size])
+ k = vs.get_variable("attn_w",
+ [1, 1, self._attn_size, self._attn_vec_size])
v = vs.get_variable("attn_v", [self._attn_vec_size])
hidden = array_ops.reshape(attn_states,
[-1, self._attn_length, 1, self._attn_size])
@@ -1228,7 +1256,8 @@ class HighwayWrapper(rnn_cell_impl.RNNCell):
https://arxiv.org/abs/1505.00387
"""
- def __init__(self, cell,
+ def __init__(self,
+ cell,
couple_carry_transform_gates=True,
carry_bias_init=1.0):
"""Constructs a `HighwayWrapper` for `cell`.
@@ -1260,8 +1289,7 @@ class HighwayWrapper(rnn_cell_impl.RNNCell):
carry_weight = vs.get_variable("carry_w", [input_size, input_size])
carry_bias = vs.get_variable(
"carry_b", [input_size],
- initializer=init_ops.constant_initializer(
- self._carry_bias_init))
+ initializer=init_ops.constant_initializer(self._carry_bias_init))
carry = math_ops.sigmoid(nn_ops.xw_plus_b(inp, carry_weight, carry_bias))
if self._couple_carry_transform_gates:
transform = 1 - carry
@@ -1270,11 +1298,9 @@ class HighwayWrapper(rnn_cell_impl.RNNCell):
[input_size, input_size])
transform_bias = vs.get_variable(
"transform_b", [input_size],
- initializer=init_ops.constant_initializer(
- -self._carry_bias_init))
- transform = math_ops.sigmoid(nn_ops.xw_plus_b(inp,
- transform_weight,
- transform_bias))
+ initializer=init_ops.constant_initializer(-self._carry_bias_init))
+ transform = math_ops.sigmoid(
+ nn_ops.xw_plus_b(inp, transform_weight, transform_bias))
return inp * carry + out * transform
def __call__(self, inputs, state, scope=None):
@@ -1294,9 +1320,11 @@ class HighwayWrapper(rnn_cell_impl.RNNCell):
"""
outputs, new_state = self._cell(inputs, state, scope=scope)
nest.assert_same_structure(inputs, outputs)
+
# Ensure shapes match
def assert_shape_match(inp, out):
inp.get_shape().assert_is_compatible_with(out.get_shape())
+
nest.map_structure(assert_shape_match, inputs, outputs)
res_outputs = nest.map_structure(self._highway, inputs, outputs)
return (res_outputs, new_state)
@@ -1322,10 +1350,16 @@ class LayerNormBasicLSTMCell(rnn_cell_impl.RNNCell):
Stanislau Semeniuta, Aliaksei Severyn, Erhardt Barth.
"""
- def __init__(self, num_units, forget_bias=1.0,
- input_size=None, activation=math_ops.tanh,
- layer_norm=True, norm_gain=1.0, norm_shift=0.0,
- dropout_keep_prob=1.0, dropout_prob_seed=None,
+ def __init__(self,
+ num_units,
+ forget_bias=1.0,
+ input_size=None,
+ activation=math_ops.tanh,
+ layer_norm=True,
+ norm_gain=1.0,
+ norm_shift=0.0,
+ dropout_keep_prob=1.0,
+ dropout_prob_seed=None,
reuse=None):
"""Initializes the basic LSTM cell.
@@ -1410,8 +1444,8 @@ class LayerNormBasicLSTMCell(rnn_cell_impl.RNNCell):
if (not isinstance(self._keep_prob, float)) or self._keep_prob < 1:
g = nn_ops.dropout(g, self._keep_prob, seed=self._seed)
- new_c = (c * math_ops.sigmoid(f + self._forget_bias)
- + math_ops.sigmoid(i) * g)
+ new_c = (
+ c * math_ops.sigmoid(f + self._forget_bias) + math_ops.sigmoid(i) * g)
if self._layer_norm:
new_c = self._norm(new_c, "state", dtype=dtype)
new_h = self._activation(new_c) * math_ops.sigmoid(o)
@@ -1433,8 +1467,7 @@ class NASCell(rnn_cell_impl.RNNCell):
The class uses an optional projection layer.
"""
- def __init__(self, num_units, num_proj=None,
- use_biases=False, reuse=None):
+ def __init__(self, num_units, num_proj=None, use_biases=False, reuse=None):
"""Initialize the parameters for a NAS cell.
Args:
@@ -1504,12 +1537,10 @@ class NASCell(rnn_cell_impl.RNNCell):
raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
# Variables for the NAS cell. W_m is all matrices multiplying the
# hiddenstate and W_inputs is all matrices multiplying the inputs.
- concat_w_m = vs.get_variable(
- "recurrent_kernel", [num_proj, 8 * self._num_units],
- dtype)
+ concat_w_m = vs.get_variable("recurrent_kernel",
+ [num_proj, 8 * self._num_units], dtype)
concat_w_inputs = vs.get_variable(
- "kernel", [input_size.value, 8 * self._num_units],
- dtype)
+ "kernel", [input_size.value, 8 * self._num_units], dtype)
m_matrix = math_ops.matmul(m_prev, concat_w_m)
inputs_matrix = math_ops.matmul(inputs, concat_w_inputs)
@@ -1524,10 +1555,10 @@ class NASCell(rnn_cell_impl.RNNCell):
# The NAS cell branches into 8 different splits for both the hiddenstate
# and the input
- m_matrix_splits = array_ops.split(axis=1, num_or_size_splits=8,
- value=m_matrix)
- inputs_matrix_splits = array_ops.split(axis=1, num_or_size_splits=8,
- value=inputs_matrix)
+ m_matrix_splits = array_ops.split(
+ axis=1, num_or_size_splits=8, value=m_matrix)
+ inputs_matrix_splits = array_ops.split(
+ axis=1, num_or_size_splits=8, value=inputs_matrix)
# First layer
layer1_0 = sigmoid(inputs_matrix_splits[0] + m_matrix_splits[0])
@@ -1559,9 +1590,8 @@ class NASCell(rnn_cell_impl.RNNCell):
# Projection layer if specified
if self._num_proj is not None:
- concat_w_proj = vs.get_variable(
- "projection_weights", [self._num_units, self._num_proj],
- dtype)
+ concat_w_proj = vs.get_variable("projection_weights",
+ [self._num_units, self._num_proj], dtype)
new_m = math_ops.matmul(new_m, concat_w_proj)
new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_m)
@@ -1584,8 +1614,12 @@ class UGRNNCell(rnn_cell_impl.RNNCell):
"Capacity and Trainability in Recurrent Neural Networks" Proc. ICLR 2017.
"""
- def __init__(self, num_units, initializer=None, forget_bias=1.0,
- activation=math_ops.tanh, reuse=None):
+ def __init__(self,
+ num_units,
+ initializer=None,
+ forget_bias=1.0,
+ activation=math_ops.tanh,
+ reuse=None):
"""Initialize the parameters for an UGRNN cell.
Args:
@@ -1640,8 +1674,8 @@ class UGRNNCell(rnn_cell_impl.RNNCell):
if input_size.value is None:
raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
- with vs.variable_scope(vs.get_variable_scope(),
- initializer=self._initializer):
+ with vs.variable_scope(
+ vs.get_variable_scope(), initializer=self._initializer):
cell_inputs = array_ops.concat([inputs, state], 1)
if self._linear is None:
self._linear = _Linear(cell_inputs, 2 * self._num_units, True)
@@ -1681,9 +1715,13 @@ class IntersectionRNNCell(rnn_cell_impl.RNNCell):
RNNs so it may not achieve best performance with depth 1.
"""
- def __init__(self, num_units, num_in_proj=None,
- initializer=None, forget_bias=1.0,
- y_activation=nn_ops.relu, reuse=None):
+ def __init__(self,
+ num_units,
+ num_in_proj=None,
+ initializer=None,
+ forget_bias=1.0,
+ y_activation=nn_ops.relu,
+ reuse=None):
"""Initialize the parameters for an +RNN cell.
Args:
@@ -1747,8 +1785,8 @@ class IntersectionRNNCell(rnn_cell_impl.RNNCell):
if input_size.value is None:
raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
- with vs.variable_scope(vs.get_variable_scope(),
- initializer=self._initializer):
+ with vs.variable_scope(
+ vs.get_variable_scope(), initializer=self._initializer):
# read-in projections (should be used for first layer in deep +RNN
# to transform size of inputs from I --> N)
if input_size.value != self._num_units:
@@ -1765,13 +1803,13 @@ class IntersectionRNNCell(rnn_cell_impl.RNNCell):
n_dim = i_dim = self._num_units
cell_inputs = array_ops.concat([inputs, state], 1)
if self._linear2 is None:
- self._linear2 = _Linear(cell_inputs, 2*n_dim + 2*i_dim, True)
+ self._linear2 = _Linear(cell_inputs, 2 * n_dim + 2 * i_dim, True)
rnn_matrix = self._linear2(cell_inputs)
- gh_act = rnn_matrix[:, :n_dim] # b x n
- h_act = rnn_matrix[:, n_dim:2*n_dim] # b x n
- gy_act = rnn_matrix[:, 2*n_dim:2*n_dim+i_dim] # b x i
- y_act = rnn_matrix[:, 2*n_dim+i_dim:2*n_dim+2*i_dim] # b x i
+ gh_act = rnn_matrix[:, :n_dim] # b x n
+ h_act = rnn_matrix[:, n_dim:2 * n_dim] # b x n
+ gy_act = rnn_matrix[:, 2 * n_dim:2 * n_dim + i_dim] # b x i
+ y_act = rnn_matrix[:, 2 * n_dim + i_dim:2 * n_dim + 2 * i_dim] # b x i
h = tanh(h_act)
y = self._y_activation(y_act)
@@ -1817,6 +1855,7 @@ class CompiledWrapper(rnn_cell_impl.RNNCell):
if self._compile_stateful:
compile_ops = True
else:
+
def compile_ops(node_def):
global _REGISTERED_OPS
if _REGISTERED_OPS is None:
@@ -1827,10 +1866,7 @@ class CompiledWrapper(rnn_cell_impl.RNNCell):
return self._cell(inputs, state, scope=scope)
-def _random_exp_initializer(minval,
- maxval,
- seed=None,
- dtype=dtypes.float32):
+def _random_exp_initializer(minval, maxval, seed=None, dtype=dtypes.float32):
"""Returns an exponential distribution initializer.
Args:
@@ -1849,10 +1885,7 @@ def _random_exp_initializer(minval,
del partition_info # Unused.
return math_ops.exp(
random_ops.random_uniform(
- shape,
- math_ops.log(minval),
- math_ops.log(maxval),
- dtype,
+ shape, math_ops.log(minval), math_ops.log(maxval), dtype,
seed=seed))
return _initializer
@@ -1956,8 +1989,7 @@ class PhasedLSTMCell(rnn_cell_impl.RNNCell):
if self._linear1 is None:
self._linear1 = _Linear(in_mask_gates, 2 * self._num_units, True)
- mask_gates = math_ops.sigmoid(
- self._linear1(in_mask_gates))
+ mask_gates = math_ops.sigmoid(self._linear1(in_mask_gates))
[input_gate, forget_gate] = array_ops.split(
axis=1, num_or_size_splits=2, value=mask_gates)
@@ -1981,12 +2013,12 @@ class PhasedLSTMCell(rnn_cell_impl.RNNCell):
period = vs.get_variable(
"period", [self._num_units],
- initializer=_random_exp_initializer(
- self._period_init_min, self._period_init_max))
+ initializer=_random_exp_initializer(self._period_init_min,
+ self._period_init_max))
phase = vs.get_variable(
"phase", [self._num_units],
- initializer=init_ops.random_uniform_initializer(
- 0., period.initial_value))
+ initializer=init_ops.random_uniform_initializer(0.,
+ period.initial_value))
ratio_on = vs.get_variable(
"ratio_on", [self._num_units],
initializer=init_ops.constant_initializer(self._ratio_on),
@@ -2008,6 +2040,7 @@ class PhasedLSTMCell(rnn_cell_impl.RNNCell):
return new_h, new_state
+
class ConvLSTMCell(rnn_cell_impl.RNNCell):
"""Convolutional LSTM recurrent network cell.
@@ -2041,7 +2074,7 @@ class ConvLSTMCell(rnn_cell_impl.RNNCell):
"""
super(ConvLSTMCell, self).__init__(name=name)
- if conv_ndims != len(input_shape)-1:
+ if conv_ndims != len(input_shape) - 1:
raise ValueError("Invalid input_shape {} for conv_ndims={}.".format(
input_shape, conv_ndims))
@@ -2060,8 +2093,8 @@ class ConvLSTMCell(rnn_cell_impl.RNNCell):
state_size = tensor_shape.TensorShape(
self._input_shape[:-1] + [self._output_channels])
self._state_size = rnn_cell_impl.LSTMStateTuple(state_size, state_size)
- self._output_size = tensor_shape.TensorShape(self._input_shape[:-1]
- + [self._total_output_channels])
+ self._output_size = tensor_shape.TensorShape(
+ self._input_shape[:-1] + [self._total_output_channels])
@property
def output_size(self):
@@ -2073,13 +2106,10 @@ class ConvLSTMCell(rnn_cell_impl.RNNCell):
def call(self, inputs, state, scope=None):
cell, hidden = state
- new_hidden = _conv([inputs, hidden],
- self._kernel_shape,
- 4*self._output_channels,
- self._use_bias)
- gates = array_ops.split(value=new_hidden,
- num_or_size_splits=4,
- axis=self._conv_ndims+1)
+ new_hidden = _conv([inputs, hidden], self._kernel_shape,
+ 4 * self._output_channels, self._use_bias)
+ gates = array_ops.split(
+ value=new_hidden, num_or_size_splits=4, axis=self._conv_ndims + 1)
input_gate, new_input, forget_gate, output_gate = gates
new_cell = math_ops.sigmoid(forget_gate + self._forget_bias) * cell
@@ -2091,29 +2121,35 @@ class ConvLSTMCell(rnn_cell_impl.RNNCell):
new_state = rnn_cell_impl.LSTMStateTuple(new_cell, output)
return output, new_state
+
class Conv1DLSTMCell(ConvLSTMCell):
"""1D Convolutional LSTM recurrent network cell.
https://arxiv.org/pdf/1506.04214v1.pdf
"""
+
def __init__(self, name="conv_1d_lstm_cell", **kwargs):
"""Construct Conv1DLSTM. See `ConvLSTMCell` for more details."""
super(Conv1DLSTMCell, self).__init__(conv_ndims=1, **kwargs)
+
class Conv2DLSTMCell(ConvLSTMCell):
"""2D Convolutional LSTM recurrent network cell.
https://arxiv.org/pdf/1506.04214v1.pdf
"""
+
def __init__(self, name="conv_2d_lstm_cell", **kwargs):
"""Construct Conv2DLSTM. See `ConvLSTMCell` for more details."""
super(Conv2DLSTMCell, self).__init__(conv_ndims=2, **kwargs)
+
class Conv3DLSTMCell(ConvLSTMCell):
"""3D Convolutional LSTM recurrent network cell.
https://arxiv.org/pdf/1506.04214v1.pdf
"""
+
def __init__(self, name="conv_3d_lstm_cell", **kwargs):
"""Construct Conv3DLSTM. See `ConvLSTMCell` for more details."""
super(Conv3DLSTMCell, self).__init__(conv_ndims=3, **kwargs)
@@ -2138,7 +2174,7 @@ def _conv(args, filter_size, num_features, bias, bias_start=0.0):
shapes = [a.get_shape().as_list() for a in args]
shape_length = len(shapes[0])
for shape in shapes:
- if len(shape) not in [3,4,5]:
+ if len(shape) not in [3, 4, 5]:
raise ValueError("Conv Linear expects 3D, 4D "
"or 5D arguments: %s" % str(shapes))
if len(shape) != len(shapes[0]):
@@ -2149,40 +2185,36 @@ def _conv(args, filter_size, num_features, bias, bias_start=0.0):
dtype = [a.dtype for a in args][0]
# determine correct conv operation
- if shape_length == 3:
+ if shape_length == 3:
conv_op = nn_ops.conv1d
strides = 1
elif shape_length == 4:
conv_op = nn_ops.conv2d
- strides = shape_length*[1]
+ strides = shape_length * [1]
elif shape_length == 5:
conv_op = nn_ops.conv3d
- strides = shape_length*[1]
+ strides = shape_length * [1]
# Now the computation.
kernel = vs.get_variable(
- "kernel",
- filter_size + [total_arg_size_depth, num_features],
- dtype=dtype)
+ "kernel", filter_size + [total_arg_size_depth, num_features], dtype=dtype)
if len(args) == 1:
- res = conv_op(args[0],
- kernel,
- strides,
- padding='SAME')
+ res = conv_op(args[0], kernel, strides, padding="SAME")
else:
- res = conv_op(array_ops.concat(axis=shape_length-1, values=args),
- kernel,
- strides,
- padding='SAME')
+ res = conv_op(
+ array_ops.concat(axis=shape_length - 1, values=args),
+ kernel,
+ strides,
+ padding="SAME")
if not bias:
return res
bias_term = vs.get_variable(
"biases", [num_features],
dtype=dtype,
- initializer=init_ops.constant_initializer(
- bias_start, dtype=dtype))
+ initializer=init_ops.constant_initializer(bias_start, dtype=dtype))
return res + bias_term
+
class GLSTMCell(rnn_cell_impl.RNNCell):
"""Group LSTM cell (G-LSTM).
@@ -2194,8 +2226,13 @@ class GLSTMCell(rnn_cell_impl.RNNCell):
"Factorization Tricks for LSTM Networks", ICLR 2017 workshop.
"""
- def __init__(self, num_units, initializer=None, num_proj=None,
- number_of_groups=1, forget_bias=1.0, activation=math_ops.tanh,
+ def __init__(self,
+ num_units,
+ initializer=None,
+ num_proj=None,
+ number_of_groups=1,
+ forget_bias=1.0,
+ activation=math_ops.tanh,
reuse=None):
"""Initialize the parameters of G-LSTM cell.
@@ -2232,11 +2269,15 @@ class GLSTMCell(rnn_cell_impl.RNNCell):
if self._num_proj:
if self._num_proj % self._number_of_groups != 0:
raise ValueError("num_proj must be divisible by number_of_groups")
- self._group_shape = [int(self._num_proj / self._number_of_groups),
- int(self._num_units / self._number_of_groups)]
+ self._group_shape = [
+ int(self._num_proj / self._number_of_groups),
+ int(self._num_units / self._number_of_groups)
+ ]
else:
- self._group_shape = [int(self._num_units / self._number_of_groups),
- int(self._num_units / self._number_of_groups)]
+ self._group_shape = [
+ int(self._num_units / self._number_of_groups),
+ int(self._num_units / self._number_of_groups)
+ ]
if num_proj:
self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
@@ -2268,10 +2309,11 @@ class GLSTMCell(rnn_cell_impl.RNNCell):
subset of inputs corresponding to group "group_id",
a Tensor, 2D, [batch x num_units/number_of_groups]
"""
- return array_ops.slice(input_=inputs,
- begin=[0, group_id * group_size],
- size=[self._batch_size, group_size],
- name=("GLSTM_group%d_input_generation" % group_id))
+ return array_ops.slice(
+ input_=inputs,
+ begin=[0, group_id * group_size],
+ size=[self._batch_size, group_size],
+ name=("GLSTM_group%d_input_generation" % group_id))
def call(self, inputs, state):
"""Run one step of G-LSTM.
@@ -2310,10 +2352,13 @@ class GLSTMCell(rnn_cell_impl.RNNCell):
for group_id in range(self._number_of_groups):
with vs.variable_scope("group%d" % group_id):
x_g_id = array_ops.concat(
- [self._get_input_for_group(inputs, group_id,
- self._group_shape[0]),
- self._get_input_for_group(m_prev, group_id,
- self._group_shape[0])], axis=1)
+ [
+ self._get_input_for_group(inputs, group_id,
+ self._group_shape[0]),
+ self._get_input_for_group(m_prev, group_id,
+ self._group_shape[0])
+ ],
+ axis=1)
if self._linear1 is None:
self._linear1 = _Linear(x_g_id, 4 * self._group_shape[1], False)
R_k = self._linear1(x_g_id) # pylint: disable=invalid-name
@@ -2324,34 +2369,35 @@ class GLSTMCell(rnn_cell_impl.RNNCell):
f_parts.append(f_k)
o_parts.append(o_k)
- bi = vs.get_variable(name="bias_i",
- shape=[self._num_units],
- dtype=dtype,
- initializer=
- init_ops.constant_initializer(0.0, dtype=dtype))
- bj = vs.get_variable(name="bias_j",
- shape=[self._num_units],
- dtype=dtype,
- initializer=
- init_ops.constant_initializer(0.0, dtype=dtype))
- bf = vs.get_variable(name="bias_f",
- shape=[self._num_units],
- dtype=dtype,
- initializer=
- init_ops.constant_initializer(0.0, dtype=dtype))
- bo = vs.get_variable(name="bias_o",
- shape=[self._num_units],
- dtype=dtype,
- initializer=
- init_ops.constant_initializer(0.0, dtype=dtype))
+ bi = vs.get_variable(
+ name="bias_i",
+ shape=[self._num_units],
+ dtype=dtype,
+ initializer=init_ops.constant_initializer(0.0, dtype=dtype))
+ bj = vs.get_variable(
+ name="bias_j",
+ shape=[self._num_units],
+ dtype=dtype,
+ initializer=init_ops.constant_initializer(0.0, dtype=dtype))
+ bf = vs.get_variable(
+ name="bias_f",
+ shape=[self._num_units],
+ dtype=dtype,
+ initializer=init_ops.constant_initializer(0.0, dtype=dtype))
+ bo = vs.get_variable(
+ name="bias_o",
+ shape=[self._num_units],
+ dtype=dtype,
+ initializer=init_ops.constant_initializer(0.0, dtype=dtype))
i = nn_ops.bias_add(array_ops.concat(i_parts, axis=1), bi)
j = nn_ops.bias_add(array_ops.concat(j_parts, axis=1), bj)
f = nn_ops.bias_add(array_ops.concat(f_parts, axis=1), bf)
o = nn_ops.bias_add(array_ops.concat(o_parts, axis=1), bo)
- c = (math_ops.sigmoid(f + self._forget_bias) * c_prev +
- math_ops.sigmoid(i) * math_ops.tanh(j))
+ c = (
+ math_ops.sigmoid(f + self._forget_bias) * c_prev +
+ math_ops.sigmoid(i) * math_ops.tanh(j))
m = math_ops.sigmoid(o) * self._activation(c)
if self._num_proj is not None:
@@ -2636,10 +2682,12 @@ class LayerNormLSTMCell(rnn_cell_impl.RNNCell):
class SRUCell(rnn_cell_impl._LayerRNNCell):
"""SRU, Simple Recurrent Unit
+
Implementation based on
Training RNNs as Fast as CNNs (cf. https://arxiv.org/abs/1709.02755).
- This variation of RNN cell is characterized by the simplified data dependence
+ This variation of RNN cell is characterized by the simplified data
+ dependence
between hidden states of two consecutive time steps. Traditionally, hidden
states from a cell at time step t-1 needs to be multiplied with a matrix
W_hh before being fed into the ensuing cell at time step t.
@@ -2657,8 +2705,8 @@ class SRUCell(rnn_cell_impl._LayerRNNCell):
will share weights, but to avoid mistakes we require reuse=True in such
cases.
"""
- def __init__(self, num_units,
- activation=None, reuse=None, name=None):
+
+ def __init__(self, num_units, activation=None, reuse=None, name=None):
super(SRUCell, self).__init__(_reuse=reuse, name=name)
self._num_units = num_units
self._activation = activation or math_ops.tanh
@@ -2676,8 +2724,8 @@ class SRUCell(rnn_cell_impl._LayerRNNCell):
def build(self, inputs_shape):
if inputs_shape[1].value is None:
- raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
- % inputs_shape)
+ raise ValueError(
+ "Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape)
input_depth = inputs_shape[1].value
@@ -2712,12 +2760,12 @@ class SRUCell(rnn_cell_impl._LayerRNNCell):
"""Simple recurrent unit (SRU) with num_units cells."""
U = math_ops.matmul(inputs, self._kernel)
- x_bar, f_intermediate, r_intermediate = array_ops.split(value=U,
- num_or_size_splits=3,
- axis=1)
+ x_bar, f_intermediate, r_intermediate = array_ops.split(
+ value=U, num_or_size_splits=3, axis=1)
- f_r = math_ops.sigmoid(nn_ops.bias_add(array_ops.concat(
- [f_intermediate, r_intermediate], 1), self._bias))
+ f_r = math_ops.sigmoid(
+ nn_ops.bias_add(
+ array_ops.concat([f_intermediate, r_intermediate], 1), self._bias))
f, r = array_ops.split(value=f_r, num_or_size_splits=2, axis=1)
c = f * state + (1.0 - f) * x_bar
@@ -2750,9 +2798,16 @@ class WeightNormLSTMCell(rnn_cell_impl.RNNCell):
large scale acoustic modeling." INTERSPEECH, 2014.
"""
- def __init__(self, num_units, norm=True, use_peepholes=False,
- cell_clip=None, initializer=None, num_proj=None,
- proj_clip=None, forget_bias=1, activation=None,
+ def __init__(self,
+ num_units,
+ norm=True,
+ use_peepholes=False,
+ cell_clip=None,
+ initializer=None,
+ num_proj=None,
+ proj_clip=None,
+ forget_bias=1,
+ activation=None,
reuse=None):
"""Initialize the parameters of a weight-normalized LSTM cell.
@@ -2779,7 +2834,7 @@ class WeightNormLSTMCell(rnn_cell_impl.RNNCell):
"""
super(WeightNormLSTMCell, self).__init__(_reuse=reuse)
- self._scope = 'wn_lstm_cell'
+ self._scope = "wn_lstm_cell"
self._num_units = num_units
self._norm = norm
self._initializer = initializer
@@ -2822,7 +2877,8 @@ class WeightNormLSTMCell(rnn_cell_impl.RNNCell):
g = vs.get_variable(name, [output_size], dtype=weight.dtype)
return nn_impl.l2_normalize(weight, dim=0) * g
- def _linear(self, args,
+ def _linear(self,
+ args,
output_size,
norm,
bias,
@@ -2877,8 +2933,8 @@ class WeightNormLSTMCell(rnn_cell_impl.RNNCell):
with ops.control_dependencies(None):
for i in range(len(args)):
en = st + shapes[i][1].value
- wn.append(self._normalize(weights[st:en, :],
- name='norm_{}'.format(i)))
+ wn.append(
+ self._normalize(weights[st:en, :], name="norm_{}".format(i)))
st = en
weights = array_ops.concat(wn, axis=0)
@@ -2936,8 +2992,8 @@ class WeightNormLSTMCell(rnn_cell_impl.RNNCell):
with vs.variable_scope(self._scope, initializer=self._initializer):
- concat = self._linear([inputs, h], 4 * num_units,
- norm=self._norm, bias=True)
+ concat = self._linear(
+ [inputs, h], 4 * num_units, norm=self._norm, bias=True)
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
@@ -2947,11 +3003,13 @@ class WeightNormLSTMCell(rnn_cell_impl.RNNCell):
w_i_diag = vs.get_variable("w_i_diag", shape=[num_units], dtype=dtype)
w_o_diag = vs.get_variable("w_o_diag", shape=[num_units], dtype=dtype)
- new_c = (c * sigmoid(f + self._forget_bias + w_f_diag * c)
- + sigmoid(i + w_i_diag * c) * self._activation(j))
+ new_c = (
+ c * sigmoid(f + self._forget_bias + w_f_diag * c) +
+ sigmoid(i + w_i_diag * c) * self._activation(j))
else:
- new_c = (c * sigmoid(f + self._forget_bias)
- + sigmoid(i) * self._activation(j))
+ new_c = (
+ c * sigmoid(f + self._forget_bias) +
+ sigmoid(i) * self._activation(j))
if self._cell_clip is not None:
# pylint: disable=invalid-unary-operand-type
@@ -2964,15 +3022,12 @@ class WeightNormLSTMCell(rnn_cell_impl.RNNCell):
if self._num_proj is not None:
with vs.variable_scope("projection"):
- new_h = self._linear(new_h,
- self._num_proj,
- norm=self._norm,
- bias=False)
+ new_h = self._linear(
+ new_h, self._num_proj, norm=self._norm, bias=False)
if self._proj_clip is not None:
# pylint: disable=invalid-unary-operand-type
- new_h = clip_ops.clip_by_value(new_h,
- -self._proj_clip,
+ new_h = clip_ops.clip_by_value(new_h, -self._proj_clip,
self._proj_clip)
# pylint: enable=invalid-unary-operand-type
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
index f498b2bb57..9265540317 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
@@ -46,20 +46,18 @@ class TestGatherTree(test.TestCase):
# create (batch_size, max_time, beam_width) matrix and transpose it
predicted_ids = np.array(
- [[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
- [[2, 3, 4], [5, 6, 7], [8, 9, 10]]],
+ [[[1, 2, 3], [4, 5, 6], [7, 8, 9]], [[2, 3, 4], [5, 6, 7], [8, 9, 10]]],
dtype=np.int32).transpose([1, 0, 2])
parent_ids = np.array(
- [[[0, 0, 0], [0, 1, 1], [2, 1, 2]],
- [[0, 0, 0], [1, 2, 0], [2, 1, 1]]],
+ [[[0, 0, 0], [0, 1, 1], [2, 1, 2]], [[0, 0, 0], [1, 2, 0], [2, 1, 1]]],
dtype=np.int32).transpose([1, 0, 2])
# sequence_lengths is shaped (batch_size = 3)
max_sequence_lengths = [3, 3]
- expected_result = np.array(
- [[[2, 2, 2], [6, 5, 6], [7, 8, 9]],
- [[2, 4, 4], [7, 6, 6], [8, 9, 10]]]).transpose([1, 0, 2])
+ expected_result = np.array([[[2, 2, 2], [6, 5, 6], [7, 8, 9]],
+ [[2, 4, 4], [7, 6, 6],
+ [8, 9, 10]]]).transpose([1, 0, 2])
res = beam_search_ops.gather_tree(
predicted_ids,
@@ -157,8 +155,8 @@ class TestBeamStep(test.TestCase):
self.assertAllEqual(outputs_.predicted_ids, [[3, 3, 2], [2, 2, 1]])
self.assertAllEqual(outputs_.parent_ids, [[1, 0, 0], [2, 1, 0]])
self.assertAllEqual(next_state_.lengths, [[3, 3, 3], [3, 3, 3]])
- self.assertAllEqual(next_state_.finished, [[False, False, False],
- [False, False, False]])
+ self.assertAllEqual(next_state_.finished,
+ [[False, False, False], [False, False, False]])
expected_log_probs = []
expected_log_probs.append(state_.log_probs[0][[1, 0, 0]])
@@ -212,8 +210,8 @@ class TestBeamStep(test.TestCase):
self.assertAllEqual(outputs_.parent_ids, [[1, 0, 0], [1, 2, 0]])
self.assertAllEqual(outputs_.predicted_ids, [[0, 3, 2], [2, 0, 1]])
self.assertAllEqual(next_state_.lengths, [[1, 3, 3], [3, 1, 3]])
- self.assertAllEqual(next_state_.finished, [[True, False, False],
- [False, True, False]])
+ self.assertAllEqual(next_state_.finished,
+ [[True, False, False], [False, True, False]])
expected_log_probs = []
expected_log_probs.append(state_.log_probs[0][[1, 0, 0]])
@@ -226,9 +224,10 @@ class TestBeamStep(test.TestCase):
class TestLargeBeamStep(test.TestCase):
- """
- Tests a single step of beam search in such
- case that beam size is larger than vocabulary size.
+ """Tests large beam step.
+
+ Tests a single step of beam search in such case that beam size is larger than
+ vocabulary size.
"""
def setUp(self):
@@ -239,19 +238,21 @@ class TestLargeBeamStep(test.TestCase):
self.end_token = 0
self.length_penalty_weight = 0.6
-
def test_step(self):
- def get_probs():
- """this simulates the initialize method in BeamSearchDecoder"""
- log_prob_mask = array_ops.one_hot(array_ops.zeros([self.batch_size],
- dtype=dtypes.int32),
- depth=self.beam_width, on_value=True,
- off_value=False, dtype=dtypes.bool)
- log_prob_zeros = array_ops.zeros([self.batch_size, self.beam_width],
- dtype=dtypes.float32)
- log_prob_neg_inf = array_ops.ones([self.batch_size, self.beam_width],
- dtype=dtypes.float32) * -np.Inf
+ def get_probs():
+ """this simulates the initialize method in BeamSearchDecoder."""
+ log_prob_mask = array_ops.one_hot(
+ array_ops.zeros([self.batch_size], dtype=dtypes.int32),
+ depth=self.beam_width,
+ on_value=True,
+ off_value=False,
+ dtype=dtypes.bool)
+
+ log_prob_zeros = array_ops.zeros(
+ [self.batch_size, self.beam_width], dtype=dtypes.float32)
+ log_prob_neg_inf = array_ops.ones(
+ [self.batch_size, self.beam_width], dtype=dtypes.float32) * -np.Inf
log_probs = array_ops.where(log_prob_mask, log_prob_zeros,
log_prob_neg_inf)
@@ -260,12 +261,15 @@ class TestLargeBeamStep(test.TestCase):
log_probs = get_probs()
dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width])
+ # pylint: disable=invalid-name
_finished = array_ops.one_hot(
array_ops.zeros([self.batch_size], dtype=dtypes.int32),
- depth=self.beam_width, on_value=False,
- off_value=True, dtype=dtypes.bool)
+ depth=self.beam_width,
+ on_value=False,
+ off_value=True,
+ dtype=dtypes.bool)
_lengths = np.zeros([self.batch_size, self.beam_width], dtype=np.int64)
- _lengths[:, 0]=2
+ _lengths[:, 0] = 2
_lengths = constant_op.constant(_lengths, dtype=dtypes.int64)
beam_state = beam_search_decoder.BeamSearchDecoderState(
@@ -298,20 +302,20 @@ class TestLargeBeamStep(test.TestCase):
length_penalty_weight=self.length_penalty_weight)
with self.test_session() as sess:
- outputs_, next_state_, state_, log_probs_ = sess.run(
+ outputs_, next_state_, _, _ = sess.run(
[outputs, next_beam_state, beam_state, log_probs])
self.assertEqual(outputs_.predicted_ids[0, 0], 3)
self.assertEqual(outputs_.predicted_ids[0, 1], 2)
self.assertEqual(outputs_.predicted_ids[1, 0], 1)
neg_inf = -np.Inf
- self.assertAllEqual(next_state_.log_probs[:, -3:],
- [[neg_inf, neg_inf, neg_inf],
- [neg_inf, neg_inf, neg_inf]])
+ self.assertAllEqual(
+ next_state_.log_probs[:, -3:],
+ [[neg_inf, neg_inf, neg_inf], [neg_inf, neg_inf, neg_inf]])
self.assertEqual((next_state_.log_probs[:, :-3] > neg_inf).all(), True)
self.assertEqual((next_state_.lengths[:, :-3] > 0).all(), True)
- self.assertAllEqual(next_state_.lengths[:, -3:], [[0, 0, 0],
- [0, 0, 0]])
+ self.assertAllEqual(next_state_.lengths[:, -3:], [[0, 0, 0], [0, 0, 0]])
+
class BeamSearchDecoderTest(test.TestCase):
@@ -338,8 +342,8 @@ class BeamSearchDecoderTest(test.TestCase):
initial_state = cell.zero_state(batch_size, dtypes.float32)
if has_attention:
inputs = array_ops.placeholder_with_default(
- np.random.randn(batch_size, decoder_max_time,
- input_depth).astype(np.float32),
+ np.random.randn(batch_size, decoder_max_time, input_depth).astype(
+ np.float32),
shape=(None, None, input_depth))
tiled_inputs = beam_search_decoder.tile_batch(
inputs, multiplier=beam_width)
@@ -359,8 +363,7 @@ class BeamSearchDecoderTest(test.TestCase):
cell_state = cell.zero_state(
dtype=dtypes.float32, batch_size=batch_size_tensor * beam_width)
if has_attention:
- cell_state = cell_state.clone(
- cell_state=initial_state)
+ cell_state = cell_state.clone(cell_state=initial_state)
bsd = beam_search_decoder.BeamSearchDecoder(
cell=cell,
embedding=embedding,
diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
index a5f7169c31..d6184d6109 100644
--- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
+++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
@@ -37,7 +37,6 @@ from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.util import nest
-
__all__ = [
"BeamSearchDecoderOutput",
"BeamSearchDecoderState",
@@ -48,8 +47,8 @@ __all__ = [
class BeamSearchDecoderState(
- collections.namedtuple("BeamSearchDecoderState", ("cell_state", "log_probs",
- "finished", "lengths"))):
+ collections.namedtuple("BeamSearchDecoderState",
+ ("cell_state", "log_probs", "finished", "lengths"))):
pass
@@ -85,11 +84,12 @@ def _tile_batch(t, multiplier):
tiled_static_batch_size = (
t.shape[0].value * multiplier if t.shape[0].value is not None else None)
tiled = array_ops.tile(array_ops.expand_dims(t, 1), tiling)
- tiled = array_ops.reshape(
- tiled, array_ops.concat(([shape_t[0] * multiplier], shape_t[1:]), 0))
+ tiled = array_ops.reshape(tiled,
+ array_ops.concat(
+ ([shape_t[0] * multiplier], shape_t[1:]), 0))
tiled.set_shape(
- tensor_shape.TensorShape(
- [tiled_static_batch_size]).concatenate(t.shape[1:]))
+ tensor_shape.TensorShape([tiled_static_batch_size]).concatenate(
+ t.shape[1:]))
return tiled
@@ -197,8 +197,8 @@ class BeamSearchDecoder(decoder.Decoder):
"""
if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access
raise TypeError("cell must be an RNNCell, received: %s" % type(cell))
- if (output_layer is not None
- and not isinstance(output_layer, layers_base.Layer)):
+ if (output_layer is not None and
+ not isinstance(output_layer, layers_base.Layer)):
raise TypeError(
"output_layer must be a Layer, received: %s" % type(output_layer))
self._cell = cell
@@ -223,16 +223,17 @@ class BeamSearchDecoder(decoder.Decoder):
self._beam_width = beam_width
self._length_penalty_weight = length_penalty_weight
self._initial_cell_state = nest.map_structure(
- self._maybe_split_batch_beams,
- initial_state, self._cell.state_size)
+ self._maybe_split_batch_beams, initial_state, self._cell.state_size)
self._start_tokens = array_ops.tile(
array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width])
self._start_inputs = self._embedding_fn(self._start_tokens)
-
+
self._finished = array_ops.one_hot(
array_ops.zeros([self._batch_size], dtype=dtypes.int32),
- depth=self._beam_width, on_value=False,
- off_value=True, dtype=dtypes.bool)
+ depth=self._beam_width,
+ on_value=False,
+ off_value=True,
+ dtype=dtypes.bool)
@property
def batch_size(self):
@@ -250,8 +251,7 @@ class BeamSearchDecoder(decoder.Decoder):
# dimensions to get the output size of the rnn with the layer
# applied to the top.
output_shape_with_unknown_batch = nest.map_structure(
- lambda s: tensor_shape.TensorShape([None]).concatenate(s),
- size)
+ lambda s: tensor_shape.TensorShape([None]).concatenate(s), size)
layer_output_shape = self._output_layer.compute_output_shape(
output_shape_with_unknown_batch)
return nest.map_structure(lambda s: s[1:], layer_output_shape)
@@ -302,10 +302,11 @@ class BeamSearchDecoder(decoder.Decoder):
log_probs = array_ops.one_hot( # shape(batch_sz, beam_sz)
array_ops.zeros([self._batch_size], dtype=dtypes.int32),
- depth=self._beam_width, on_value=0.0, off_value=-np.Inf,
+ depth=self._beam_width,
+ on_value=0.0,
+ off_value=-np.Inf,
dtype=nest.flatten(self._initial_cell_state)[0].dtype)
-
initial_state = BeamSearchDecoderState(
cell_state=self._initial_cell_state,
log_probs=log_probs,
@@ -365,11 +366,12 @@ class BeamSearchDecoder(decoder.Decoder):
t_shape = array_ops.shape(t)
static_batch_size = tensor_util.constant_value(self._batch_size)
batch_size_beam_width = (
- None if static_batch_size is None
- else static_batch_size * self._beam_width)
+ None
+ if static_batch_size is None else static_batch_size * self._beam_width)
reshaped_t = array_ops.reshape(
- t, array_ops.concat(
- ([self._batch_size * self._beam_width], t_shape[2:]), 0))
+ t,
+ array_ops.concat(([self._batch_size * self._beam_width], t_shape[2:]),
+ 0))
reshaped_t.set_shape(
(tensor_shape.TensorShape([batch_size_beam_width]).concatenate(s)))
return reshaped_t
@@ -398,8 +400,9 @@ class BeamSearchDecoder(decoder.Decoder):
s = tensor_shape.TensorShape(s)
t_shape = array_ops.shape(t)
reshaped_t = array_ops.reshape(
- t, array_ops.concat(
- ([self._batch_size, self._beam_width], t_shape[1:]), 0))
+ t,
+ array_ops.concat(([self._batch_size, self._beam_width], t_shape[1:]),
+ 0))
static_batch_size = tensor_util.constant_value(self._batch_size)
expected_reshaped_shape = tensor_shape.TensorShape(
[static_batch_size, self._beam_width]).concatenate(s)
@@ -409,8 +412,8 @@ class BeamSearchDecoder(decoder.Decoder):
"We expected it to have shape "
"(batch_size, beam_width, depth) == %s. Perhaps you "
"forgot to create a zero_state with "
- "batch_size=encoder_batch_size * beam_width?"
- % (reshaped_t.shape, expected_reshaped_shape))
+ "batch_size=encoder_batch_size * beam_width?" %
+ (reshaped_t.shape, expected_reshaped_shape))
reshaped_t.set_shape(expected_reshaped_shape)
return reshaped_t
@@ -482,15 +485,13 @@ class BeamSearchDecoder(decoder.Decoder):
cell_state = state.cell_state
inputs = nest.map_structure(
lambda inp: self._merge_batch_beams(inp, s=inp.shape[2:]), inputs)
- cell_state = nest.map_structure(
- self._maybe_merge_batch_beams,
- cell_state, self._cell.state_size)
+ cell_state = nest.map_structure(self._maybe_merge_batch_beams, cell_state,
+ self._cell.state_size)
cell_outputs, next_cell_state = self._cell(inputs, cell_state)
cell_outputs = nest.map_structure(
lambda out: self._split_batch_beams(out, out.shape[1:]), cell_outputs)
next_cell_state = nest.map_structure(
- self._maybe_split_batch_beams,
- next_cell_state, self._cell.state_size)
+ self._maybe_split_batch_beams, next_cell_state, self._cell.state_size)
if self._output_layer is not None:
cell_outputs = self._output_layer(cell_outputs)
@@ -553,7 +554,8 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
lengths_to_add = array_ops.one_hot(
indices=array_ops.fill([batch_size, beam_width], end_token),
depth=vocab_size,
- on_value=np.int64(0), off_value=np.int64(1),
+ on_value=np.int64(0),
+ off_value=np.int64(1),
dtype=dtypes.int64)
add_mask = math_ops.to_int64(math_ops.logical_not(previously_finished))
lengths_to_add *= array_ops.expand_dims(add_mask, 2)
@@ -572,8 +574,8 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
scores_flat = array_ops.reshape(scores, [batch_size, -1])
# Pick the next beams according to the specified successors function
- next_beam_size = ops.convert_to_tensor(beam_width, dtype=dtypes.int32,
- name="beam_width")
+ next_beam_size = ops.convert_to_tensor(
+ beam_width, dtype=dtypes.int32, name="beam_width")
next_beam_scores, word_indices = nn_ops.top_k(scores_flat, k=next_beam_size)
next_beam_scores.set_shape([static_batch_size, beam_width])
@@ -592,11 +594,11 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
# name="next_beam_word_ids")
# would be a lot cleaner but for reasons unclear, that hides the results of
# the op which prevents capturing it with tfdbg debug ops.
- raw_next_word_ids = math_ops.mod(word_indices, vocab_size,
- name="next_beam_word_ids")
+ raw_next_word_ids = math_ops.mod(
+ word_indices, vocab_size, name="next_beam_word_ids")
next_word_ids = math_ops.to_int32(raw_next_word_ids)
- next_beam_ids = math_ops.to_int32(word_indices / vocab_size,
- name="next_beam_parent_ids")
+ next_beam_ids = math_ops.to_int32(
+ word_indices / vocab_size, name="next_beam_parent_ids")
# Append new ids to current predictions
previously_finished = _tensor_gather_helper(
@@ -605,9 +607,10 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
batch_size=batch_size,
range_size=beam_width,
gather_shape=[-1])
- next_finished = math_ops.logical_or(previously_finished,
- math_ops.equal(next_word_ids, end_token),
- name="next_beam_finished")
+ next_finished = math_ops.logical_or(
+ previously_finished,
+ math_ops.equal(next_word_ids, end_token),
+ name="next_beam_finished")
# Calculate the length of the next predictions.
# 1. Finished beams remain unchanged.
@@ -768,8 +771,12 @@ def _maybe_tensor_gather_helper(gather_indices, gather_from, batch_size,
return gather_from
-def _tensor_gather_helper(gather_indices, gather_from, batch_size,
- range_size, gather_shape, name=None):
+def _tensor_gather_helper(gather_indices,
+ gather_from,
+ batch_size,
+ range_size,
+ gather_shape,
+ name=None):
"""Helper for gathering the right indices from the tensor.
This works by reshaping gather_from to gather_shape (e.g. [-1]) and then
@@ -800,9 +807,9 @@ def _tensor_gather_helper(gather_indices, gather_from, batch_size,
array_ops.reshape(gather_from, gather_shape), gather_indices)
final_shape = array_ops.shape(gather_from)[:1 + len(gather_shape)]
static_batch_size = tensor_util.constant_value(batch_size)
- final_static_shape = (tensor_shape.TensorShape([static_batch_size])
- .concatenate(
- gather_from.shape[1:1 + len(gather_shape)]))
+ final_static_shape = (
+ tensor_shape.TensorShape([static_batch_size]).concatenate(
+ gather_from.shape[1:1 + len(gather_shape)]))
output = array_ops.reshape(output, final_shape, name="output")
output.set_shape(final_static_shape)
return output
diff --git a/tensorflow/contrib/verbs/rdma.cc b/tensorflow/contrib/verbs/rdma.cc
index ec5271abe0..7d95b6522c 100644
--- a/tensorflow/contrib/verbs/rdma.cc
+++ b/tensorflow/contrib/verbs/rdma.cc
@@ -15,10 +15,11 @@ limitations under the License.
#ifdef TENSORFLOW_USE_VERBS
+#include <fcntl.h>
+#include <cstdlib>
+
#include "tensorflow/contrib/verbs/rdma.h"
#include "tensorflow/contrib/verbs/verbs_service.pb.h"
-#include <cstdlib>
-#include <fcntl.h>
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/common_runtime/process_util.h"
@@ -27,15 +28,15 @@ limitations under the License.
#include "tensorflow/core/common_runtime/gpu/process_state.h"
#endif
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
-#include "tensorflow/core/distributed_runtime/session_mgr.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
+#include "tensorflow/core/distributed_runtime/session_mgr.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/random/random.h"
-#include "tensorflow/core/lib/core/threadpool.h"
namespace tensorflow {
@@ -447,9 +448,9 @@ void RdmaAdapter::Process_CQ() {
CHECK_GE(ne, 0);
for (int i = 0; i < ne; ++i) {
CHECK(wc_[i].status == IBV_WC_SUCCESS)
- << "Failed status \n" << ibv_wc_status_str(wc_[i].status) << " "
- << wc_[i].status << " " << static_cast<int>(wc_[i].wr_id) << " "
- << wc_[i].vendor_err;
+ << "Failed status \n"
+ << ibv_wc_status_str(wc_[i].status) << " " << wc_[i].status << " "
+ << static_cast<int>(wc_[i].wr_id) << " " << wc_[i].vendor_err;
if (wc_[i].opcode == IBV_WC_RECV_RDMA_WITH_IMM) {
RdmaChannel* rc = reinterpret_cast<RdmaChannel*>(wc_[i].wr_id);
// put back a recv wr.
@@ -538,7 +539,7 @@ int RdmaChannel::PingPostRecv() {
int RdmaChannel::PingPostSend() {
struct ibv_send_wr wr, *bad_wr;
memset(&wr, 0, sizeof(wr));
- wr.wr_id = (uint64_t) this;
+ wr.wr_id = (uint64_t)this;
wr.sg_list = &ping_sge_list_;
wr.num_sge = 1;
wr.opcode = IBV_WR_SEND;
@@ -658,7 +659,7 @@ void RdmaChannel::SetRemoteAddress(const RdmaAddress& ra, bool override) {
void RdmaChannel::Recv() {
struct ibv_recv_wr wr;
memset(&wr, 0, sizeof(wr));
- wr.wr_id = (uint64_t) this;
+ wr.wr_id = (uint64_t)this;
struct ibv_recv_wr* bad_wr;
CHECK(!ibv_post_recv(qp_, &wr, &bad_wr)) << "Failed to post recv";
}
@@ -729,11 +730,11 @@ void RdmaChannel::Connect(const RdmaAddress& remoteAddr) {
attr.ah_attr.grh.traffic_class = adapter_->params_.traffic_class;
int r;
- CHECK(!(r = ibv_modify_qp(qp_, &attr, IBV_QP_STATE | IBV_QP_AV |
- IBV_QP_PATH_MTU |
- IBV_QP_DEST_QPN | IBV_QP_RQ_PSN |
- IBV_QP_MAX_DEST_RD_ATOMIC |
- IBV_QP_MIN_RNR_TIMER)))
+ CHECK(!(r = ibv_modify_qp(qp_, &attr,
+ IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU |
+ IBV_QP_DEST_QPN | IBV_QP_RQ_PSN |
+ IBV_QP_MAX_DEST_RD_ATOMIC |
+ IBV_QP_MIN_RNR_TIMER)))
<< "QP to Ready to Receive " << r;
memset(&attr, 0, sizeof(ibv_qp_attr));
@@ -744,10 +745,10 @@ void RdmaChannel::Connect(const RdmaAddress& remoteAddr) {
attr.rnr_retry = 7; /* infinite */
attr.max_rd_atomic = 1;
- CHECK(!(r = ibv_modify_qp(qp_, &attr, IBV_QP_STATE | IBV_QP_TIMEOUT |
- IBV_QP_RETRY_CNT |
- IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN |
- IBV_QP_MAX_QP_RD_ATOMIC)))
+ CHECK(!(r = ibv_modify_qp(qp_, &attr,
+ IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT |
+ IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN |
+ IBV_QP_MAX_QP_RD_ATOMIC)))
<< "QP to Ready to Send " << r;
connected_ = true;
@@ -897,16 +898,16 @@ static void CountCopies(const std::string& key, void* src_addr, void* dst_addr,
}
if ((++numTotalCopies % 0x400) == 0) {
RDMA_LOG(0) << "Tensor copies:"
- << " GPU to CPU: " << numGPUToCPUCopies
- << " (" << numGPUToCPUCopiedBytes << " Bytes)"
- << " CPU to GPU: " << numCPUToGPUCopies
- << " (" << numCPUToGPUCopiedBytes << " Bytes)";
+ << " GPU to CPU: " << numGPUToCPUCopies << " ("
+ << numGPUToCPUCopiedBytes << " Bytes)"
+ << " CPU to GPU: " << numCPUToGPUCopies << " ("
+ << numCPUToGPUCopiedBytes << " Bytes)";
}
- RDMA_LOG(2) << "Copying tensor " << key
- << " From: " << src_addr << " To: " << dst_addr;
-#endif // RDMA_COUNT_COPIES
+ RDMA_LOG(2) << "Copying tensor " << key << " From: " << src_addr
+ << " To: " << dst_addr;
+#endif // RDMA_COUNT_COPIES
}
-#endif // GOOGLE_CUDA
+#endif // GOOGLE_CUDA
#ifdef RDMA_DATA_VALIDATION
static uint64_t Checksum(Device* device, const DeviceContext* device_context,
@@ -920,7 +921,7 @@ static uint64_t Checksum(Device* device, const DeviceContext* device_context,
checksum = (device_context != nullptr)
? GPUUtil::Checksum(device, device_context, in)
: GPUUtil::Checksum(in);
-#endif // GOOGLE_CUDA
+#endif // GOOGLE_CUDA
} else {
string s = in.SummarizeValue(999999);
checksum = Hash64(s.c_str(), s.size(), 0);
@@ -955,17 +956,16 @@ static void ValidateChecksum(uint64_t expected, uint64_t actual,
}
}
}
-#endif // RDMA_DATA_VALIDATION
+#endif // RDMA_DATA_VALIDATION
#if GOOGLE_CUDA
// Sync the 'done' operation on the GPU stream, but without all the data
// copying.
-static void StreamGPUOp(Device* gpu_device,
- const DeviceContext* device_context,
+static void StreamGPUOp(Device* gpu_device, const DeviceContext* device_context,
StatusCallback done) {
Tensor dummy1, dummy2;
- GPUUtil::CopyGPUTensorToCPU(
- gpu_device, device_context, &dummy1, &dummy2, done);
+ GPUUtil::CopyGPUTensorToCPU(gpu_device, device_context, &dummy1, &dummy2,
+ done);
}
#endif // GOOGLE_CUDA
@@ -1072,7 +1072,7 @@ void RdmaTensorResponse::RecvHandler(Rendezvous::ParsedKey parsed,
// skip the copy here as well.
if ((in.TotalBytes() > 0) && !meta_data_changed_ &&
(RdmaMemoryMgr::Singleton().FindMemoryRegion(
- (void*)DMAHelper::base(&in), in.TotalBytes()) != nullptr)) {
+ (void*)DMAHelper::base(&in), in.TotalBytes()) != nullptr)) {
StreamGPUOp(src_dev_, send_dev_context,
[this, in, proto, is_dead](const Status& s) {
Send(in, proto, is_dead, s);
@@ -1118,8 +1118,8 @@ void RdmaTensorResponse::Send(const Tensor& in, const TensorProto& proto,
return;
}
bool can_memcpy = DataTypeCanUseMemcpy(in.dtype());
- bool proto_size_changed = (!can_memcpy) &&
- (proto.ByteSize() != rm_.tensor_bytes_);
+ bool proto_size_changed =
+ (!can_memcpy) && (proto.ByteSize() != rm_.tensor_bytes_);
if (meta_data_changed_ || proto_size_changed) {
Clone(in, proto, is_dead);
SendMetaData(in, proto, is_dead);
@@ -1238,9 +1238,8 @@ void RdmaTensorResponse::SendErrorStatus(const Status& status) {
rm.request_index_ = rm_.request_index_;
rm.status_ = status;
LOG(ERROR) << "Step 0x" << std::hex << rm.step_id_ << std::dec
- << ": Sending RDMA_MESSAGE_ERROR_STATUS #"
- << rm.request_index_ << ": " << rm.name_
- << ". Status: " << status.ToString();
+ << ": Sending RDMA_MESSAGE_ERROR_STATUS #" << rm.request_index_
+ << ": " << rm.name_ << ". Status: " << status.ToString();
string message = RdmaMessage::CreateMessage(rm);
channel_->tx_message_buffer_->EnqueueItem(message);
@@ -1336,14 +1335,13 @@ string RdmaMessage::CreateMessage(const RdmaMessage& rm) {
uint32_t gsProtoSize = gsProto.ByteSize();
if (gsProtoSize + 4 > kErrorStatusMaxSize) {
LOG(ERROR) << "Error status (" << gsProtoSize + 4 << " bytes) "
- << "is too big to fit in RDMA message ("
- << kErrorStatusMaxSize << " bytes). Truncated.";
+ << "is too big to fit in RDMA message (" << kErrorStatusMaxSize
+ << " bytes). Truncated.";
gsProtoSize = kErrorStatusMaxSize - 4;
}
uint32_t* proto_size = (uint32_t*)&message[kErrorStatusStartIndex];
*proto_size = gsProtoSize;
- gsProto.SerializeToArray(&message[kErrorStatusStartIndex + 4],
- gsProtoSize);
+ gsProto.SerializeToArray(&message[kErrorStatusStartIndex + 4], gsProtoSize);
message_size += gsProtoSize + 4;
}
return string(message, message_size);
@@ -1393,8 +1391,8 @@ void RdmaMessage::ParseMessage(RdmaMessage& rm, void* buffer) {
if (rm.type_ == RDMA_MESSAGE_ERROR_STATUS) {
ErrorStatusProto gsProto;
uint32_t gsProtoSize = *(uint32_t*)&message[kErrorStatusStartIndex];
- CHECK(ParseProtoUnlimited(
- &gsProto, &message[kErrorStatusStartIndex + 4], gsProtoSize))
+ CHECK(ParseProtoUnlimited(&gsProto, &message[kErrorStatusStartIndex + 4],
+ gsProtoSize))
<< "Failed to parse error status proto from message. Aborting.";
::grpc::Status gs((::grpc::StatusCode)gsProto.error_code(),
gsProto.error_message(), gsProto.error_details());
@@ -1566,8 +1564,8 @@ void RdmaTensorRequest::AllocateTensorsAsync(StatusCallback done) {
if (dst_dev_->tensorflow_gpu_device_info() && !on_host &&
(proxy_tensor_ == nullptr)) {
#if GOOGLE_CUDA
- // We need to sync the memory allocation on the GPU:
- StreamGPUOp(dst_dev_, recv_args_.device_context, done);
+ // We need to sync the memory allocation on the GPU:
+ StreamGPUOp(dst_dev_, recv_args_.device_context, done);
#endif
} else {
done(Status::OK());
@@ -1594,9 +1592,8 @@ void RdmaTensorRequest::Send(RdmaMessageType message_type) {
rm.rkey_ = (mr_ == nullptr) ? 0 : mr_->rkey;
RDMA_LOG(1) << "Step 0x" << std::hex << rm.step_id_ << std::dec
- << ": Sending " << MessageTypeToString(message_type)
- << " #" << index_ << ": "
- << rm.name_ << " on " << rdma_addr_
+ << ": Sending " << MessageTypeToString(message_type) << " #"
+ << index_ << ": " << rm.name_ << " on " << rdma_addr_
<< " (rkey: 0x" << std::hex << rm.rkey_ << ")";
string message = RdmaMessage::CreateMessage(rm);
@@ -1610,9 +1607,8 @@ void RdmaTensorRequest::RecvTensorMetaData(DataType dtype, TensorShape shape,
key_, dtype, shape, is_dead, proto_size);
DeallocateTensors();
- AllocateTensorsAsync([this](const Status& s) {
- Send(RDMA_MESSAGE_TENSOR_RE_REQUEST);
- });
+ AllocateTensorsAsync(
+ [this](const Status& s) { Send(RDMA_MESSAGE_TENSOR_RE_REQUEST); });
}
void RdmaTensorRequest::RecvTensorContent() {
@@ -1620,8 +1616,8 @@ void RdmaTensorRequest::RecvTensorContent() {
size_t message_size =
can_memcpy ? result_tensor_->TotalBytes() : meta_data_->proto_size_;
RDMA_LOG(1) << "Step 0x" << std::hex << step_id_ << std::dec
- << ": Received tensor content #" << index_ << ": "
- << key_ << " (Size: 0x" << std::hex << message_size << ")";
+ << ": Received tensor content #" << index_ << ": " << key_
+ << " (Size: 0x" << std::hex << message_size << ")";
Tensor val;
@@ -1667,9 +1663,8 @@ void RdmaTensorRequest::RecvErrorStatus(const Status& status) {
void RdmaTensorRequest::Start() {
meta_data_ = RdmaMemoryMgr::Singleton().GetTensorMetaData(key_);
if (meta_data_ != nullptr) {
- AllocateTensorsAsync([this](const Status& s) {
- Send(RDMA_MESSAGE_TENSOR_REQUEST);
- });
+ AllocateTensorsAsync(
+ [this](const Status& s) { Send(RDMA_MESSAGE_TENSOR_REQUEST); });
} else {
Send(RDMA_MESSAGE_TENSOR_REQUEST);
}
diff --git a/tensorflow/contrib/verbs/rdma.h b/tensorflow/contrib/verbs/rdma.h
index 68b3d59f56..b6c41de6ee 100644
--- a/tensorflow/contrib/verbs/rdma.h
+++ b/tensorflow/contrib/verbs/rdma.h
@@ -73,15 +73,8 @@ struct RemoteMR {
uint64_t remote_addr;
uint32_t rkey;
};
-enum BufferStatus {
- none,
- idle,
- busy
-};
-enum Location {
- local,
- remote
-};
+enum BufferStatus { none, idle, busy };
+enum Location { local, remote };
enum RdmaMessageType {
RDMA_MESSAGE_META_DATA_UPDATE,
diff --git a/tensorflow/contrib/verbs/rdma_mgr.cc b/tensorflow/contrib/verbs/rdma_mgr.cc
index f3644af0b4..369bd986df 100644
--- a/tensorflow/contrib/verbs/rdma_mgr.cc
+++ b/tensorflow/contrib/verbs/rdma_mgr.cc
@@ -116,9 +116,9 @@ void RdmaMgr::SetupChannels() {
}
CHECK(i == RdmaChannel::kNumMessageBuffers);
} else {
- LOG(ERROR) << "Connecting to " << worker_name
- << ": Got " << s.error_message() << ". Retrying ("
- << (attempts + 1) << "/" << max_num_attempts << ")..." ;
+ LOG(ERROR) << "Connecting to " << worker_name << ": Got "
+ << s.error_message() << ". Retrying (" << (attempts + 1)
+ << "/" << max_num_attempts << ")...";
if (++attempts == max_num_attempts) {
break;
}
@@ -159,19 +159,17 @@ bool RdmaMgr::ConnectivityCheck() {
ibv_wc_status s = rdma_adapter_->wc_[i].status;
// recv complete
if ((int)rdma_adapter_->wc_[i].wr_id == RdmaChannel::kPingRecvWrid) {
- CHECK(s == IBV_WC_SUCCESS) << ": " << ibv_wc_status_str(
- rdma_adapter_->wc_[i].status)
- << "(" << rdma_adapter_->wc_[i].status
- << ") for PING_RECV_WRID";
+ CHECK(s == IBV_WC_SUCCESS)
+ << ": " << ibv_wc_status_str(rdma_adapter_->wc_[i].status) << "("
+ << rdma_adapter_->wc_[i].status << ") for PING_RECV_WRID";
++rcnt;
// send complete
} else {
RdmaChannel* rc =
reinterpret_cast<RdmaChannel*>(rdma_adapter_->wc_[i].wr_id);
- CHECK(s == IBV_WC_SUCCESS) << ": " << ibv_wc_status_str(
- rdma_adapter_->wc_[i].status)
- << "(" << rdma_adapter_->wc_[i].status
- << ") to " << rc->remote_name_;
+ CHECK(s == IBV_WC_SUCCESS)
+ << ": " << ibv_wc_status_str(rdma_adapter_->wc_[i].status) << "("
+ << rdma_adapter_->wc_[i].status << ") to " << rc->remote_name_;
++scnt;
}
} // for
@@ -238,8 +236,9 @@ int TryToReadNumaNode(ibv_device* device) {
if (strings::safe_strto32(content, &value)) {
if (value < 0) {
LOG(INFO) << "Successful NUMA node read from SysFS had negative value ("
- << value << "), but there must be at least one NUMA node"
- ", so returning NUMA node zero";
+ << value
+ << "), but there must be at least one NUMA node"
+ ", so returning NUMA node zero";
return 0;
}
LOG(INFO) << "NUMA node for device: " << device->name << " is " << value;
@@ -302,8 +301,8 @@ void RdmaMgr::InitAllocators() {
&RdmaMemoryMgr::EvictMemoryRegion, &RdmaMemoryMgr::Singleton(), _1, _2);
auto* visitable_allocator = dynamic_cast<VisitableAllocator*>(allocator);
- CHECK(visitable_allocator) << "is not visitable for instrumentation"
- << allocator->Name();
+ CHECK(visitable_allocator)
+ << "is not visitable for instrumentation" << allocator->Name();
// Make sure we don't instrument the same allocator twice
if (instrumented_.find(allocator) == std::end(instrumented_)) {
visitable_allocator->AddAllocVisitor(alloc_visitor);
diff --git a/tensorflow/core/kernels/mkl_aggregate_ops.cc b/tensorflow/core/kernels/mkl_aggregate_ops.cc
index bb5eceab27..89d37d2f87 100644
--- a/tensorflow/core/kernels/mkl_aggregate_ops.cc
+++ b/tensorflow/core/kernels/mkl_aggregate_ops.cc
@@ -65,13 +65,11 @@ class MklAddNOp : public OpKernel {
TensorShape src1_shape, src2_shape;
src1_shape = input0.shape();
src2_shape = input1.shape();
- if (!src1_shape.IsSameSize(src2_shape) ){
- ctx->SetStatus(
- errors::InvalidArgument(
- "Inputs to operation ", this->name(), " of type ", this->type_string(),
- " must have the same size and shape. Input 0: ",
- src1_shape.DebugString(), " != input 1: ",
- src2_shape.DebugString()));
+ if (!src1_shape.IsSameSize(src2_shape)) {
+ ctx->SetStatus(errors::InvalidArgument(
+ "Inputs to operation ", this->name(), " of type ",
+ this->type_string(), " must have the same size and shape. Input 0: ",
+ src1_shape.DebugString(), " != input 1: ", src2_shape.DebugString()));
}
// handle the case of a scalar
if (!input1_in_mkl_format && input0.dims() == 0) {
@@ -82,17 +80,16 @@ class MklAddNOp : public OpKernel {
mkl_context.output_shape);
float user_i1 = (input0.scalar<T>()());
float user_i2 = (input1.scalar<T>()());
- out_tensor->scalar<T>()() =
- std::plus<float>{}(user_i1, user_i2);
+ out_tensor->scalar<T>()() = std::plus<float>{}(user_i1, user_i2);
return;
}
mkl_context.in_dims = input1_in_mkl_format
- ? mkl_context.input1_shape.GetDimension()
- : input0.dims();
+ ? mkl_context.input1_shape.GetDimension()
+ : input0.dims();
mkl_context.in_dims = input2_in_mkl_format
- ? mkl_context.input2_shape.GetDimension()
- : input1.dims();
+ ? mkl_context.input2_shape.GetDimension()
+ : input1.dims();
// If there is nothing to compute, return.
if (!input1_in_mkl_format && !input2_in_mkl_format) {
@@ -101,7 +98,7 @@ class MklAddNOp : public OpKernel {
Tensor* out_tensor = nullptr;
mkl_context.output_shape.SetMklTensor(false);
AllocateOutputSetMklShape(ctx, src1_idx, &out_tensor, o_shape,
- mkl_context.output_shape);
+ mkl_context.output_shape);
return;
}
}
@@ -110,9 +107,9 @@ class MklAddNOp : public OpKernel {
mkl_context.in_strides = new size_t[mkl_context.in_dims];
// Generate size, stride for input if input is in MKL format.
if (input1_in_mkl_format || input2_in_mkl_format) {
- const MklShape* tmp_mkl_shape =
- (input1_in_mkl_format) ? &mkl_context.input1_shape :
- &mkl_context.input2_shape;
+ const MklShape* tmp_mkl_shape = (input1_in_mkl_format)
+ ? &mkl_context.input1_shape
+ : &mkl_context.input2_shape;
for (int i = 0; i < mkl_context.in_dims; i++) {
mkl_context.in_sizes[i] = tmp_mkl_shape->GetSizes()[i];
mkl_context.in_strides[i] = tmp_mkl_shape->GetStrides()[i];
@@ -136,32 +133,33 @@ class MklAddNOp : public OpKernel {
Tensor mkl_tmp_input1_buf_tensor, mkl_tmp_input2_buf_tensor;
mkl_context.MklPrepareAddNInputs(ctx, &mkl_tmp_input1_buf_tensor,
- &mkl_tmp_input2_buf_tensor);
+ &mkl_tmp_input2_buf_tensor);
Tensor* output = nullptr;
if (input1_in_mkl_format || input2_in_mkl_format) {
- TensorShape tf_shape;
- mkl_context.output_shape.SetMklTensor(true);
- mkl_context.output_shape.SetMklLayout(mkl_context.Eltwise, dnnResourceDst);
-
- mkl_context.output_shape.SetTfLayout(
- mkl_context.in_dims, mkl_context.in_sizes, mkl_context.in_strides);
- if (input1_in_mkl_format == true) {
- mkl_context.output_shape.SetTfDimOrder(mkl_context.in_dims,
- mkl_context.input1_shape.GetTfToMklDimMap());
- } else {
- mkl_context.output_shape.SetTfDimOrder(mkl_context.in_dims,
- mkl_context.input2_shape.GetTfToMklDimMap());
- }
- tf_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
- mkl_context.output_shape.GetMklLayout())) /
- sizeof(T));
-
- AllocateOutputSetMklShape(ctx, src1_idx, &output, tf_shape,
- mkl_context.output_shape);
+ TensorShape tf_shape;
+ mkl_context.output_shape.SetMklTensor(true);
+ mkl_context.output_shape.SetMklLayout(mkl_context.Eltwise,
+ dnnResourceDst);
+
+ mkl_context.output_shape.SetTfLayout(
+ mkl_context.in_dims, mkl_context.in_sizes, mkl_context.in_strides);
+ if (input1_in_mkl_format == true) {
+ mkl_context.output_shape.SetTfDimOrder(
+ mkl_context.in_dims, mkl_context.input1_shape.GetTfToMklDimMap());
+ } else {
+ mkl_context.output_shape.SetTfDimOrder(
+ mkl_context.in_dims, mkl_context.input2_shape.GetTfToMklDimMap());
+ }
+ tf_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
+ mkl_context.output_shape.GetMklLayout())) /
+ sizeof(T));
+
+ AllocateOutputSetMklShape(ctx, src1_idx, &output, tf_shape,
+ mkl_context.output_shape);
} else {
- const TensorShape& o_shape = input1.shape();
- mkl_context.output_shape.SetMklTensor(false);
- AllocateOutputSetMklShape(ctx, src1_idx, &output, o_shape,
+ const TensorShape& o_shape = input1.shape();
+ mkl_context.output_shape.SetMklTensor(false);
+ AllocateOutputSetMklShape(ctx, src1_idx, &output, o_shape,
mkl_context.output_shape);
}
@@ -189,18 +187,16 @@ class MklAddNOp : public OpKernel {
void MklCreateInputLayouts(OpKernelContext* context) {
bool input1_in_mkl_format = input1_shape.IsMklTensor();
if (!input1_in_mkl_format) {
- CHECK_EQ(
- dnnLayoutCreate_F32(&lt_input1, in_dims, in_sizes, in_strides),
- E_SUCCESS);
+ CHECK_EQ(dnnLayoutCreate_F32(&lt_input1, in_dims, in_sizes, in_strides),
+ E_SUCCESS);
} else {
lt_input1 = static_cast<dnnLayout_t>(input1_shape.GetCurLayout());
}
bool input2_in_mkl_format = input2_shape.IsMklTensor();
if (!input2_in_mkl_format) {
- CHECK_EQ(
- dnnLayoutCreate_F32(&lt_input2, in_dims, in_sizes, in_strides),
- E_SUCCESS);
+ CHECK_EQ(dnnLayoutCreate_F32(&lt_input2, in_dims, in_sizes, in_strides),
+ E_SUCCESS);
} else {
lt_input2 = static_cast<dnnLayout_t>(input2_shape.GetCurLayout());
}
@@ -276,14 +272,14 @@ class MklAddNOp : public OpKernel {
bool input2_in_mkl_format = input2_shape.IsMklTensor();
dnnDelete_F32(Eltwise);
if (!input1_in_mkl_format || !input2_in_mkl_format) {
- delete [] in_sizes;
- delete [] in_strides;
+ delete[] in_sizes;
+ delete[] in_strides;
}
if (!input1_in_mkl_format) {
- dnnLayoutDelete_F32(lt_input1);
+ dnnLayoutDelete_F32(lt_input1);
}
if (!input2_in_mkl_format) {
- dnnLayoutDelete_F32(lt_input2);
+ dnnLayoutDelete_F32(lt_input2);
}
}
} MklAddNOpContext;
@@ -315,45 +311,44 @@ class MklAddNOp : public OpKernel {
GetMklShape(ctx, src2_idx, &src2_mkl_shape);
bool input1_in_mkl_format = src1_mkl_shape.IsMklTensor();
bool input2_in_mkl_format = src2_mkl_shape.IsMklTensor();
- int src1_dims_size = input1_in_mkl_format?
- src1_mkl_shape.GetDimension(): src1_tensor.dims();
- int src2_dims_size = input2_in_mkl_format?
- src2_mkl_shape.GetDimension(): src2_tensor.dims();
+ int src1_dims_size = input1_in_mkl_format ? src1_mkl_shape.GetDimension()
+ : src1_tensor.dims();
+ int src2_dims_size = input2_in_mkl_format ? src2_mkl_shape.GetDimension()
+ : src2_tensor.dims();
// if the shapes of two tensors are not same raise op error
TensorShape src1_shape, src2_shape;
src1_shape = src1_tensor.shape();
src2_shape = src2_tensor.shape();
- if (!src1_shape.IsSameSize(src2_shape) ){
- ctx->SetStatus(
- errors::InvalidArgument(
- "Inputs to operation ", this->name(), " of type ", this->type_string(),
+ if (!src1_shape.IsSameSize(src2_shape)) {
+ ctx->SetStatus(errors::InvalidArgument(
+ "Inputs to operation ", this->name(), " of type ",
+ this->type_string(),
" must have the same size and shape. Input 0: ",
- src1_shape.DebugString(), " != input 1: ",
- src2_shape.DebugString()));
+ src1_shape.DebugString(),
+ " != input 1: ", src2_shape.DebugString()));
}
if (!input1_in_mkl_format && src1_dims_size == 0) {
- Tensor* dst_tensor = nullptr;
- MklShape mkl_shape_dst;
- mkl_shape_dst.SetMklTensor(false);
- AllocateOutputSetMklShape(ctx, output_idx, &dst_tensor,
- src1_tensor.shape(), mkl_shape_dst);
- float user_i1 = (src1_tensor.scalar<T>()());
- float user_i2 = (src2_tensor.scalar<T>()());
- dst_tensor->scalar<T>()() =
- std::plus<float>{}(user_i1, user_i2);
- return;
- }
+ Tensor* dst_tensor = nullptr;
+ MklShape mkl_shape_dst;
+ mkl_shape_dst.SetMklTensor(false);
+ AllocateOutputSetMklShape(ctx, output_idx, &dst_tensor,
+ src1_tensor.shape(), mkl_shape_dst);
+ float user_i1 = (src1_tensor.scalar<T>()());
+ float user_i2 = (src2_tensor.scalar<T>()());
+ dst_tensor->scalar<T>()() = std::plus<float>{}(user_i1, user_i2);
+ return;
+ }
// If there is nothing to compute, return.
if (!input1_in_mkl_format && !input2_in_mkl_format) {
if (src1_tensor.shape().num_elements() == 0) {
- Tensor* dst_tensor = nullptr;
- MklShape mkl_shape_dst;
- mkl_shape_dst.SetMklTensor(false);
- AllocateOutputSetMklShape(ctx, output_idx, &dst_tensor,
- src1_tensor.shape(), mkl_shape_dst);
- return;
+ Tensor* dst_tensor = nullptr;
+ MklShape mkl_shape_dst;
+ mkl_shape_dst.SetMklTensor(false);
+ AllocateOutputSetMklShape(ctx, output_idx, &dst_tensor,
+ src1_tensor.shape(), mkl_shape_dst);
+ return;
}
}
@@ -362,7 +357,7 @@ class MklAddNOp : public OpKernel {
MklDnnData<T> src2(&cpu_engine);
MklDnnData<T> dst(&cpu_engine);
- int tmp_size = input1_in_mkl_format ? src2_dims_size: src1_dims_size;
+ int tmp_size = input1_in_mkl_format ? src2_dims_size : src1_dims_size;
memory::dims dims(tmp_size);
memory::dims strides(tmp_size);
memory::desc md1({}, memory::data_undef, memory::format_undef);
@@ -392,21 +387,19 @@ class MklAddNOp : public OpKernel {
md1 = src1_mkl_shape.GetMklLayout();
memory::format src1_mkl_data_format = src1_mkl_shape.GetTfDataFormat();
- auto src1_tf_data_format = MklDnnDataFormatToTFDataFormat(
- src1_mkl_data_format);
- auto src2_dims = TFShapeToMklDnnDimsInNCHW(src2_tensor.shape(),
- src1_tf_data_format);
- md2 = memory::desc(src2_dims, MklDnnType<T>(),
- src1_mkl_data_format);
+ auto src1_tf_data_format =
+ MklDnnDataFormatToTFDataFormat(src1_mkl_data_format);
+ auto src2_dims =
+ TFShapeToMklDnnDimsInNCHW(src2_tensor.shape(), src1_tf_data_format);
+ md2 = memory::desc(src2_dims, MklDnnType<T>(), src1_mkl_data_format);
} else if (input2_in_mkl_format && !input1_in_mkl_format) {
// Same comment as above.
memory::format src2_mkl_data_format = src2_mkl_shape.GetTfDataFormat();
- auto src2_tf_data_format = MklDnnDataFormatToTFDataFormat(
- src2_mkl_data_format);
- auto src1_dims = TFShapeToMklDnnDimsInNCHW(src1_tensor.shape(),
- src2_tf_data_format);
- md1 = memory::desc(src1_dims, MklDnnType<T>(),
- src2_mkl_data_format);
+ auto src2_tf_data_format =
+ MklDnnDataFormatToTFDataFormat(src2_mkl_data_format);
+ auto src1_dims =
+ TFShapeToMklDnnDimsInNCHW(src1_tensor.shape(), src2_tf_data_format);
+ md1 = memory::desc(src1_dims, MklDnnType<T>(), src2_mkl_data_format);
md2 = src2_mkl_shape.GetMklLayout();
} else {
@@ -480,20 +473,19 @@ class MklAddNOp : public OpKernel {
output_mkl_shape.SetMklTensor(false);
output_tf_shape = src1_tensor.shape();
}
- AllocateOutputSetMklShape(ctx, output_idx, &dst_tensor,
- output_tf_shape, output_mkl_shape);
+ AllocateOutputSetMklShape(ctx, output_idx, &dst_tensor, output_tf_shape,
+ output_mkl_shape);
dst.SetUsrMemDataHandle(dst_tensor);
// Create Sum op, and submit net for execution.
net.push_back(sum(sum_pd, inputs, dst.GetOpMem()));
stream(stream::kind::eager).submit(net).wait();
- } catch (mkldnn::error &e) {
+ } catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
- ", message: " + string(e.message) +
- ", in file " + string(__FILE__) + ":" +
- std::to_string(__LINE__);
- OP_REQUIRES_OK(ctx, errors::Aborted("Operation received an exception:",
- error_msg));
+ ", message: " + string(e.message) + ", in file " +
+ string(__FILE__) + ":" + std::to_string(__LINE__);
+ OP_REQUIRES_OK(
+ ctx, errors::Aborted("Operation received an exception:", error_msg));
}
}
};
diff --git a/tensorflow/core/kernels/mkl_softmax_op.cc b/tensorflow/core/kernels/mkl_softmax_op.cc
index 896d562933..c46eabdde1 100644
--- a/tensorflow/core/kernels/mkl_softmax_op.cc
+++ b/tensorflow/core/kernels/mkl_softmax_op.cc
@@ -17,13 +17,13 @@ limitations under the License.
#ifdef INTEL_MKL
#ifdef INTEL_MKL_DNN
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/util/tensor_format.h"
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "mkldnn.h"
#include "mkldnn_types.h"
@@ -31,16 +31,14 @@ limitations under the License.
#include "tensorflow/core/util/mkl_util.h"
#include "mkldnn.hpp"
-using mkldnn::stream;
using mkldnn::prop_kind;
using mkldnn::softmax_forward;
+using mkldnn::stream;
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
-
-
template <typename Device, typename T>
class MklSoftmaxOp : public OpKernel {
public:
@@ -60,11 +58,11 @@ class MklSoftmaxOp : public OpKernel {
MklDnnShape src_mkl_shape;
GetMklShape(context, src_idx, &src_mkl_shape);
-
// src_dims is the dimenstion of src_tensor
// dim of the dst will also be same as src_dims
- auto src_tf_shape = src_mkl_shape.IsMklTensor() ?
- src_mkl_shape.GetTfShape() : src_tensor.shape();
+ auto src_tf_shape = src_mkl_shape.IsMklTensor()
+ ? src_mkl_shape.GetTfShape()
+ : src_tensor.shape();
auto src_dims = TFShapeToMklDnnDims(src_tf_shape);
auto output_dims = src_dims;
@@ -77,10 +75,10 @@ class MklSoftmaxOp : public OpKernel {
// construct input Tf layout. For TF layout, although input shape
// (src_dims) required is in MKL-DNN order, the layout is Tensorflow's
// layout
- auto src_md = src_mkl_shape.IsMklTensor()
- ? src_mkl_shape.GetMklLayout()
- : memory::desc(src_dims, MklDnnType<T>(),
- memory::format::nc);
+ auto src_md =
+ src_mkl_shape.IsMklTensor()
+ ? src_mkl_shape.GetMklLayout()
+ : memory::desc(src_dims, MklDnnType<T>(), memory::format::nc);
// src: setting memory descriptor and op memory descriptor
// Basically following two functions maps the TF "src_tensor" to mkl
@@ -95,8 +93,8 @@ class MklSoftmaxOp : public OpKernel {
int axis = 1; // axis to which softmax will be applied
auto softmax_fwd_desc = softmax_forward::desc(prop_kind::forward_scoring,
src.GetOpMemDesc(), axis);
- auto softmax_fwd_pd = softmax_forward::primitive_desc(softmax_fwd_desc,
- cpu_engine);
+ auto softmax_fwd_pd =
+ softmax_forward::primitive_desc(softmax_fwd_desc, cpu_engine);
// add: output
Tensor* output_tensor = nullptr;
@@ -136,9 +134,9 @@ class MklSoftmaxOp : public OpKernel {
net.push_back(softmax_fwd);
stream(stream::kind::eager).submit(net).wait();
} catch (mkldnn::error& e) {
- string error_msg = "Status: " + std::to_string(e.status) + ", message: " +
- string(e.message) + ", in file " + string(__FILE__) +
- ":" + std::to_string(__LINE__);
+ string error_msg = "Status: " + std::to_string(e.status) +
+ ", message: " + string(e.message) + ", in file " +
+ string(__FILE__) + ":" + std::to_string(__LINE__);
OP_REQUIRES_OK(
context,
errors::Aborted("Operation received an exception:", error_msg));
@@ -148,7 +146,7 @@ class MklSoftmaxOp : public OpKernel {
/* Register DNN kernels for supported operations and supported types - right now
* it is only Softmax and f32 */
-#define REGISTER_SOFTMAX_MKL_SUPPORTED_KERNELS_TYPES(type) \
+#define REGISTER_SOFTMAX_MKL_SUPPORTED_KERNELS_TYPES(type) \
REGISTER_KERNEL_BUILDER(Name("_MklSoftmax") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
@@ -156,7 +154,6 @@ class MklSoftmaxOp : public OpKernel {
MklSoftmaxOp<CPUDevice, type>);
TF_CALL_float(REGISTER_SOFTMAX_MKL_SUPPORTED_KERNELS_TYPES);
-
} // namespace tensorflow
#endif // INTEL_MKL_DNN
diff --git a/tensorflow/core/kernels/spectrogram_test_utils.cc b/tensorflow/core/kernels/spectrogram_test_utils.cc
index bc30330d61..872a6e9d1b 100644
--- a/tensorflow/core/kernels/spectrogram_test_utils.cc
+++ b/tensorflow/core/kernels/spectrogram_test_utils.cc
@@ -72,12 +72,12 @@ bool ReadRawFloatFileToComplexVector(
while (offset < end) {
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
char arr[4];
- for (int i = 0; i < kBytesPerValue; ++i ) {
+ for (int i = 0; i < kBytesPerValue; ++i) {
arr[3 - i] = *(data_string.data() + offset + i);
}
memcpy(&real_out, arr, kBytesPerValue);
offset += kBytesPerValue;
- for (int i = 0; i < kBytesPerValue; ++i ) {
+ for (int i = 0; i < kBytesPerValue; ++i) {
arr[3 - i] = *(data_string.data() + offset + i);
}
memcpy(&imag_out, arr, kBytesPerValue);
diff --git a/tensorflow/core/kernels/transpose_functor_cpu.cc b/tensorflow/core/kernels/transpose_functor_cpu.cc
index 6594f7ee7b..5198df7e16 100644
--- a/tensorflow/core/kernels/transpose_functor_cpu.cc
+++ b/tensorflow/core/kernels/transpose_functor_cpu.cc
@@ -89,17 +89,17 @@ struct Transpose<CPUDevice, T, conjugate> {
out);
break;
case 6:
- internal::TransposeUsingEigen<CPUDevice, T, 6>(d, in, perm, conjugate,
- out);
- break;
+ internal::TransposeUsingEigen<CPUDevice, T, 6>(d, in, perm, conjugate,
+ out);
+ break;
case 7:
- internal::TransposeUsingEigen<CPUDevice, T, 7>(d, in, perm, conjugate,
- out);
- break;
+ internal::TransposeUsingEigen<CPUDevice, T, 7>(d, in, perm, conjugate,
+ out);
+ break;
case 8:
internal::TransposeUsingEigen<CPUDevice, T, 8>(d, in, perm, conjugate,
- out);
- break;
+ out);
+ break;
default:
TransposeSimple<T, conjugate>(d, in, perm, out);
break;
diff --git a/tensorflow/examples/tutorials/word2vec/word2vec_basic.py b/tensorflow/examples/tutorials/word2vec/word2vec_basic.py
index 7d1650f05e..f6906b0f79 100644
--- a/tensorflow/examples/tutorials/word2vec/word2vec_basic.py
+++ b/tensorflow/examples/tutorials/word2vec/word2vec_basic.py
@@ -40,10 +40,10 @@ current_path = os.path.dirname(os.path.realpath(sys.argv[0]))
parser = argparse.ArgumentParser()
parser.add_argument(
- '--log_dir',
- type=str,
- default=os.path.join(current_path, 'log'),
- help='The log directory for TensorBoard summaries.')
+ '--log_dir',
+ type=str,
+ default=os.path.join(current_path, 'log'),
+ help='The log directory for TensorBoard summaries.')
FLAGS, unparsed = parser.parse_known_args()
# Create the directory for TensorBoard variables if there is not.
@@ -81,6 +81,7 @@ def read_data(filename):
data = tf.compat.as_str(f.read(f.namelist()[0])).split()
return data
+
vocabulary = read_data(filename)
print('Data size', len(vocabulary))
@@ -106,20 +107,22 @@ def build_dataset(words, n_words):
reversed_dictionary = dict(zip(dictionary.values(), dictionary.keys()))
return data, count, dictionary, reversed_dictionary
+
# Filling 4 global variables:
# data - list of codes (integers from 0 to vocabulary_size-1).
# This is the original text but words are replaced by their codes
# count - map of words(strings) to count of occurrences
# dictionary - map of words(strings) to their codes(integers)
# reverse_dictionary - maps codes(integers) to words(strings)
-data, count, dictionary, reverse_dictionary = build_dataset(vocabulary,
- vocabulary_size)
+data, count, dictionary, reverse_dictionary = build_dataset(
+ vocabulary, vocabulary_size)
del vocabulary # Hint to reduce memory.
print('Most common words (+UNK)', count[:5])
print('Sample data', data[:10], [reverse_dictionary[i] for i in data[:10]])
data_index = 0
+
# Step 3: Function to generate a training batch for the skip-gram model.
def generate_batch(batch_size, num_skips, skip_window):
global data_index
@@ -149,28 +152,28 @@ def generate_batch(batch_size, num_skips, skip_window):
data_index = (data_index + len(data) - span) % len(data)
return batch, labels
+
batch, labels = generate_batch(batch_size=8, num_skips=2, skip_window=1)
for i in range(8):
- print(batch[i], reverse_dictionary[batch[i]],
- '->', labels[i, 0], reverse_dictionary[labels[i, 0]])
+ print(batch[i], reverse_dictionary[batch[i]], '->', labels[i, 0],
+ reverse_dictionary[labels[i, 0]])
# Step 4: Build and train a skip-gram model.
batch_size = 128
embedding_size = 128 # Dimension of the embedding vector.
-skip_window = 1 # How many words to consider left and right.
-num_skips = 2 # How many times to reuse an input to generate a label.
-num_sampled = 64 # Number of negative examples to sample.
+skip_window = 1 # How many words to consider left and right.
+num_skips = 2 # How many times to reuse an input to generate a label.
+num_sampled = 64 # Number of negative examples to sample.
# We pick a random validation set to sample nearest neighbors. Here we limit the
# validation samples to the words that have a low numeric ID, which by
# construction are also the most frequent. These 3 variables are used only for
# displaying model accuracy, they don't affect calculation.
-valid_size = 16 # Random set of words to evaluate similarity on.
+valid_size = 16 # Random set of words to evaluate similarity on.
valid_window = 100 # Only pick dev samples in the head of the distribution.
valid_examples = np.random.choice(valid_window, valid_size, replace=False)
-
graph = tf.Graph()
with graph.as_default():
@@ -192,8 +195,9 @@ with graph.as_default():
# Construct the variables for the NCE loss
with tf.name_scope('weights'):
nce_weights = tf.Variable(
- tf.truncated_normal([vocabulary_size, embedding_size],
- stddev=1.0 / math.sqrt(embedding_size)))
+ tf.truncated_normal(
+ [vocabulary_size, embedding_size],
+ stddev=1.0 / math.sqrt(embedding_size)))
with tf.name_scope('biases'):
nce_biases = tf.Variable(tf.zeros([vocabulary_size]))
@@ -204,12 +208,13 @@ with graph.as_default():
# http://mccormickml.com/2016/04/19/word2vec-tutorial-the-skip-gram-model/
with tf.name_scope('loss'):
loss = tf.reduce_mean(
- tf.nn.nce_loss(weights=nce_weights,
- biases=nce_biases,
- labels=train_labels,
- inputs=embed,
- num_sampled=num_sampled,
- num_classes=vocabulary_size))
+ tf.nn.nce_loss(
+ weights=nce_weights,
+ biases=nce_biases,
+ labels=train_labels,
+ inputs=embed,
+ num_sampled=num_sampled,
+ num_classes=vocabulary_size))
# Add the loss value as a scalar to summary.
tf.summary.scalar('loss', loss)
@@ -221,8 +226,8 @@ with graph.as_default():
# Compute the cosine similarity between minibatch examples and all embeddings.
norm = tf.sqrt(tf.reduce_sum(tf.square(embeddings), 1, keep_dims=True))
normalized_embeddings = embeddings / norm
- valid_embeddings = tf.nn.embedding_lookup(
- normalized_embeddings, valid_dataset)
+ valid_embeddings = tf.nn.embedding_lookup(normalized_embeddings,
+ valid_dataset)
similarity = tf.matmul(
valid_embeddings, normalized_embeddings, transpose_b=True)
@@ -248,8 +253,8 @@ with tf.Session(graph=graph) as session:
average_loss = 0
for step in xrange(num_steps):
- batch_inputs, batch_labels = generate_batch(
- batch_size, num_skips, skip_window)
+ batch_inputs, batch_labels = generate_batch(batch_size, num_skips,
+ skip_window)
feed_dict = {train_inputs: batch_inputs, train_labels: batch_labels}
# Define metadata variable.
@@ -259,9 +264,12 @@ with tf.Session(graph=graph) as session:
# in the list of returned values for session.run()
# Also, evaluate the merged op to get all summaries from the returned "summary" variable.
# Feed metadata variable to session for visualizing the graph in TensorBoard.
- _, summary, loss_val = session.run([optimizer, merged, loss], feed_dict=feed_dict, run_metadata=run_metadata)
+ _, summary, loss_val = session.run(
+ [optimizer, merged, loss],
+ feed_dict=feed_dict,
+ run_metadata=run_metadata)
average_loss += loss_val
-
+
# Add returned summaries to writer in each step.
writer.add_summary(summary, step)
# Add metadata to visualize the graph for the last run.
@@ -295,7 +303,7 @@ with tf.Session(graph=graph) as session:
f.write(reverse_dictionary[i] + '\n')
# Save the model for checkpoints.
- saver.save(session, os.path.join(FLAGS.log_dir, "model.ckpt"))
+ saver.save(session, os.path.join(FLAGS.log_dir, 'model.ckpt'))
# Create a configuration for visualizing embeddings with the labels in TensorBoard.
config = projector.ProjectorConfig()
@@ -317,21 +325,24 @@ def plot_with_labels(low_dim_embs, labels, filename):
for i, label in enumerate(labels):
x, y = low_dim_embs[i, :]
plt.scatter(x, y)
- plt.annotate(label,
- xy=(x, y),
- xytext=(5, 2),
- textcoords='offset points',
- ha='right',
- va='bottom')
+ plt.annotate(
+ label,
+ xy=(x, y),
+ xytext=(5, 2),
+ textcoords='offset points',
+ ha='right',
+ va='bottom')
plt.savefig(filename)
+
try:
# pylint: disable=g-import-not-at-top
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
- tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=5000, method='exact')
+ tsne = TSNE(
+ perplexity=30, n_components=2, init='pca', n_iter=5000, method='exact')
plot_only = 500
low_dim_embs = tsne.fit_transform(final_embeddings[:plot_only, :])
labels = [reverse_dictionary[i] for i in xrange(plot_only)]
diff --git a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
index eac1c1960d..bd80b9dbf5 100644
--- a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
@@ -51,8 +51,9 @@ class BatchDatasetTest(test.TestCase):
def _map_fn(x, y, z):
return math_ops.square(x), math_ops.square(y), math_ops.square(z)
- iterator = (dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn)
- .repeat(count).batch(batch_size).make_initializable_iterator())
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn)
+ .repeat(count).batch(batch_size).make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
@@ -68,7 +69,7 @@ class BatchDatasetTest(test.TestCase):
result = sess.run(get_next)
for component, result_component in zip(components, result):
for j in range(14):
- self.assertAllEqual(component[(i*14 + j) % 7]**2,
+ self.assertAllEqual(component[(i * 14 + j) % 7]**2,
result_component[j])
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -83,12 +84,12 @@ class BatchDatasetTest(test.TestCase):
result = sess.run(get_next)
for component, result_component in zip(components, result):
for j in range(8):
- self.assertAllEqual(component[(i*8 + j) % 7]**2,
+ self.assertAllEqual(component[(i * 8 + j) % 7]**2,
result_component[j])
result = sess.run(get_next)
for component, result_component in zip(components, result):
for j in range((14 * 7) % 8):
- self.assertAllEqual(component[((num_batches - 1)*8 + j) % 7]**2,
+ self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2,
result_component[j])
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -189,33 +190,34 @@ class BatchDatasetTest(test.TestCase):
sess.run(get_next)
def testBatchShapeError(self):
+
def generator():
yield [1.0, 2.0, 3.0]
yield [4.0, 5.0, 6.0]
yield [7.0, 8.0, 9.0, 10.0]
- iterator = (dataset_ops.Dataset.from_generator(generator, dtypes.float32,
- output_shapes=[None])
- .batch(3)
- .make_initializable_iterator())
+ iterator = (
+ dataset_ops.Dataset.from_generator(
+ generator, dtypes.float32, output_shapes=[None]).batch(3)
+ .make_initializable_iterator())
next_element = iterator.get_next()
with self.test_session() as sess:
sess.run(iterator.initializer)
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
- r"Cannot batch tensors with different shapes in component 0. "
- r"First element had shape \[3\] and element 2 had shape \[4\]."):
+ r'Cannot batch tensors with different shapes in component 0. '
+ r'First element had shape \[3\] and element 2 had shape \[4\].'):
sess.run(next_element)
def testPaddedBatchDataset(self):
seq_lens = array_ops.placeholder(dtypes.int32, shape=[None])
padded_shape = array_ops.placeholder(dtypes.int64, shape=[1])
- iterator = (dataset_ops.Dataset.from_tensor_slices(seq_lens)
- .map(lambda x: array_ops.fill([x], x)).padded_batch(
- 4,
- padded_shapes=padded_shape).make_initializable_iterator())
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(seq_lens)
+ .map(lambda x: array_ops.fill([x], x)).padded_batch(
+ 4, padded_shapes=padded_shape).make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
@@ -223,35 +225,40 @@ class BatchDatasetTest(test.TestCase):
with self.test_session() as sess:
# Test with random sequence lengths, and max padding.
random_seq_lens = np.random.randint(20, size=(32,)).astype(np.int32)
- sess.run(init_op, feed_dict={padded_shape: [-1],
- seq_lens: random_seq_lens})
+ sess.run(
+ init_op, feed_dict={
+ padded_shape: [-1],
+ seq_lens: random_seq_lens
+ })
for i in range(8):
result = sess.run(get_next)
padded_len = np.max(result)
self.assertEqual((4, padded_len), result.shape)
for j in range(4):
- seq_len = random_seq_lens[(i*4)+j]
+ seq_len = random_seq_lens[(i * 4) + j]
self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len)
self.assertAllEqual(result[j, seq_len:], [0] * (padded_len - seq_len))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test with random sequence lengths, and constant padding.
- sess.run(init_op, feed_dict={padded_shape: [25],
- seq_lens: random_seq_lens})
+ sess.run(
+ init_op, feed_dict={
+ padded_shape: [25],
+ seq_lens: random_seq_lens
+ })
for i in range(8):
result = sess.run(get_next)
self.assertEqual((4, 25), result.shape)
for j in range(4):
- seq_len = random_seq_lens[(i*4)+j]
+ seq_len = random_seq_lens[(i * 4) + j]
self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len)
self.assertAllEqual(result[j, seq_len:], [0] * (25 - seq_len))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test correct handling of empty tensors.
- sess.run(init_op, feed_dict={padded_shape: [-1],
- seq_lens: [0, 0, 0, 0]})
+ sess.run(init_op, feed_dict={padded_shape: [-1], seq_lens: [0, 0, 0, 0]})
result = sess.run(get_next)
self.assertAllEqual([[], [], [], []], result)
with self.assertRaises(errors.OutOfRangeError):
@@ -259,8 +266,7 @@ class BatchDatasetTest(test.TestCase):
# Test error handling with constant sequence lengths, and
# too-short padding.
- sess.run(init_op, feed_dict={padded_shape: [5],
- seq_lens: [6, 5, 5, 5]})
+ sess.run(init_op, feed_dict={padded_shape: [5], seq_lens: [6, 5, 5, 5]})
with self.assertRaises(errors.DataLossError):
result = sess.run(get_next)
@@ -271,11 +277,13 @@ class BatchDatasetTest(test.TestCase):
def fill_tuple(x):
filled = array_ops.fill([x], x)
return (filled, string_ops.as_string(filled))
- iterator = (dataset_ops.Dataset.from_tensor_slices(seq_lens).map(fill_tuple)
- .padded_batch(
- 4,
- padded_shapes=(padded_shape, padded_shape),
- padding_values=(-1, "<end>")).make_initializable_iterator())
+
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(seq_lens).map(fill_tuple)
+ .padded_batch(
+ 4,
+ padded_shapes=(padded_shape, padded_shape),
+ padding_values=(-1, '<end>')).make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
@@ -283,46 +291,46 @@ class BatchDatasetTest(test.TestCase):
with self.test_session() as sess:
# Test with random sequence lengths, and max padding.
random_seq_lens = np.random.randint(20, size=(32,)).astype(np.int32)
- sess.run(init_op, feed_dict={padded_shape: [-1],
- seq_lens: random_seq_lens})
+ sess.run(
+ init_op, feed_dict={
+ padded_shape: [-1],
+ seq_lens: random_seq_lens
+ })
for i in range(8):
result = sess.run(get_next)
padded_len = np.max(result[0])
self.assertEqual((4, padded_len), result[0].shape)
self.assertEqual((4, padded_len), result[1].shape)
for j in range(4):
- seq_len = random_seq_lens[(i*4)+j]
+ seq_len = random_seq_lens[(i * 4) + j]
self.assertAllEqual(result[0][j, :seq_len], [seq_len] * seq_len)
self.assertAllEqual(result[0][j, seq_len:],
[-1] * (padded_len - seq_len))
self.assertAllEqual(result[1][j, :seq_len],
[compat.as_bytes(str(seq_len))] * seq_len)
self.assertAllEqual(result[1][j, seq_len:],
- [b"<end>"] * (padded_len - seq_len))
+ [b'<end>'] * (padded_len - seq_len))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testPaddedBatchDatasetUnicode(self):
# See GitHub issue 16149
def generator():
- data = [
- [u'Простой', u'тест', u'юникода'],
- [u'никогда', u'не', u'бывает', u'простым']]
+ data = [[u'Простой', u'тест', u'юникода'],
+ [u'никогда', u'не', u'бывает', u'простым']]
for seq in data:
yield seq, [0, 1, 2, 3]
dataset = dataset_ops.Dataset.from_generator(
- generator,
- (dtypes.string, dtypes.int32),
+ generator, (dtypes.string, dtypes.int32),
(tensor_shape.TensorShape([None]), tensor_shape.TensorShape([None])))
- padded_dataset = dataset.padded_batch(2, padded_shapes=([None], [None]),
- padding_values=('', 0))
+ padded_dataset = dataset.padded_batch(
+ 2, padded_shapes=([None], [None]), padding_values=('', 0))
with self.test_session() as sess:
next_element = padded_dataset.make_one_shot_iterator().get_next()
sess.run(next_element)
-
def testPaddedBatchDatasetShapeSpecifications(self):
int_placeholder = array_ops.placeholder(dtypes.int32)
float_placeholder = array_ops.placeholder(dtypes.float32)
@@ -346,15 +354,16 @@ class BatchDatasetTest(test.TestCase):
constant_op.constant([-1, -1], dtype=dtypes.int64),
constant_op.constant([37], dtype=dtypes.int64)))
- for dataset in [dynamic_padding_from_tensor_shapes,
- dynamic_padding_from_lists,
- dynamic_padding_from_lists_with_minus_one,
- dynamic_padding_from_tensors]:
+ for dataset in [
+ dynamic_padding_from_tensor_shapes, dynamic_padding_from_lists,
+ dynamic_padding_from_lists_with_minus_one, dynamic_padding_from_tensors
+ ]:
self.assertEqual([None, None], dataset.output_shapes[0].as_list())
self.assertEqual([None, None, None], dataset.output_shapes[1].as_list())
self.assertEqual([None, 37], dataset.output_shapes[2].as_list())
def testPaddedBatchSparseError(self):
+
def _map_fn(i):
return sparse_tensor.SparseTensorValue(
indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]), i
@@ -363,5 +372,5 @@ class BatchDatasetTest(test.TestCase):
_ = dataset_ops.Dataset.range(10).map(_map_fn).padded_batch(10)
-if __name__ == "__main__":
+if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/ops/histogram_ops.py b/tensorflow/python/ops/histogram_ops.py
index b2de2e5015..f079e56b10 100644
--- a/tensorflow/python/ops/histogram_ops.py
+++ b/tensorflow/python/ops/histogram_ops.py
@@ -74,7 +74,7 @@ def histogram_fixed_width_bins(values,
```
"""
with ops.name_scope(name, 'histogram_fixed_width_bins',
- [values, value_range, nbins]) as scope:
+ [values, value_range, nbins]):
values = ops.convert_to_tensor(values, name='values')
shape = array_ops.shape(values)
@@ -84,9 +84,10 @@ def histogram_fixed_width_bins(values,
nbins_float = math_ops.cast(nbins, values.dtype)
# Map tensor values that fall within value_range to [0, 1].
- scaled_values = math_ops.truediv(values - value_range[0],
- value_range[1] - value_range[0],
- name='scaled_values')
+ scaled_values = math_ops.truediv(
+ values - value_range[0],
+ value_range[1] - value_range[0],
+ name='scaled_values')
# map tensor values within the open interval value_range to {0,.., nbins-1},
# values outside the open interval will be zero or less, or nbins or more.
@@ -138,5 +139,5 @@ def histogram_fixed_width(values,
"""
with ops.name_scope(name, 'histogram_fixed_width',
[values, value_range, nbins]) as name:
- return gen_math_ops._histogram_fixed_width(values, value_range, nbins,
- dtype=dtype, name=name)
+ return gen_math_ops._histogram_fixed_width( # pylint: disable=protected-access
+ values, value_range, nbins, dtype=dtype, name=name)
diff --git a/tensorflow/python/ops/histogram_ops_test.py b/tensorflow/python/ops/histogram_ops_test.py
index 80ee090575..a226ac81bb 100644
--- a/tensorflow/python/ops/histogram_ops_test.py
+++ b/tensorflow/python/ops/histogram_ops_test.py
@@ -36,7 +36,8 @@ class BinValuesFixedWidth(test.TestCase):
values = []
expected_bins = []
with self.test_session():
- bins = histogram_ops.histogram_fixed_width_bins(values, value_range, nbins=5)
+ bins = histogram_ops.histogram_fixed_width_bins(
+ values, value_range, nbins=5)
self.assertEqual(dtypes.int32, bins.dtype)
self.assertAllClose(expected_bins, bins.eval())
@@ -69,8 +70,7 @@ class BinValuesFixedWidth(test.TestCase):
# (-inf, 1), [1, 2), [2, 3), [3, 4), [4, inf)
value_range = [0.0, 5.0]
values = constant_op.constant(
- [[-1.0, 0.0, 1.5], [2.0, 5.0, 15]],
- shape=(2, 3))
+ [[-1.0, 0.0, 1.5], [2.0, 5.0, 15]], shape=(2, 3))
expected_bins = [[0, 0, 1], [2, 4, 4]]
with self.test_session():
bins = histogram_ops.histogram_fixed_width_bins(
@@ -140,8 +140,8 @@ class HistogramFixedWidthTest(test.TestCase):
self.assertEqual(dtypes.int32, hist.dtype)
self.assertAllClose(expected_bin_counts, hist.eval())
- hist = histogram_ops.histogram_fixed_width(values, value_range,
- nbins=placeholder)
+ hist = histogram_ops.histogram_fixed_width(
+ values, value_range, nbins=placeholder)
self.assertEquals(hist.shape.ndims, 1)
self.assertIs(hist.shape[0].value, None)
self.assertEqual(dtypes.int32, hist.dtype)
diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py
index b713c44717..76da3bed31 100644
--- a/tensorflow/python/ops/image_ops_impl.py
+++ b/tensorflow/python/ops/image_ops_impl.py
@@ -12,15 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Implementation of image ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import os
-
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -28,7 +25,6 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
-from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_image_ops
from tensorflow.python.ops import gen_nn_ops
@@ -38,7 +34,6 @@ from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variables
from tensorflow.python.util.tf_export import tf_export
-
ops.NotDifferentiable('RandomCrop')
# TODO(b/31222613): This op may be differentiable, and there may be
# latent bugs here.
@@ -110,8 +105,9 @@ def _ImageDimensions(image, rank):
else:
static_shape = image.get_shape().with_rank(rank).as_list()
dynamic_shape = array_ops.unstack(array_ops.shape(image), rank)
- return [s if s is not None else d
- for s, d in zip(static_shape, dynamic_shape)]
+ return [
+ s if s is not None else d for s, d in zip(static_shape, dynamic_shape)
+ ]
def _Check3DImage(image, require_static=True):
@@ -132,18 +128,19 @@ def _Check3DImage(image, require_static=True):
try:
image_shape = image.get_shape().with_rank(3)
except ValueError:
- raise ValueError("'image' (shape %s) must be three-dimensional." %
- image.shape)
+ raise ValueError(
+ "'image' (shape %s) must be three-dimensional." % image.shape)
if require_static and not image_shape.is_fully_defined():
- raise ValueError("'image' (shape %s) must be fully defined." %
- image_shape)
+ raise ValueError("'image' (shape %s) must be fully defined." % image_shape)
if any(x == 0 for x in image_shape):
- raise ValueError("all dims of 'image.shape' must be > 0: %s" %
- image_shape)
+ raise ValueError("all dims of 'image.shape' must be > 0: %s" % image_shape)
if not image_shape.is_fully_defined():
- return [check_ops.assert_positive(array_ops.shape(image),
- ["all dims of 'image.shape' "
- "must be > 0."])]
+ return [
+ check_ops.assert_positive(
+ array_ops.shape(image),
+ ["all dims of 'image.shape' "
+ 'must be > 0.'])
+ ]
else:
return []
@@ -167,7 +164,7 @@ def _Assert3DImage(image):
added that asserts the correct dynamic shape.
"""
return control_flow_ops.with_dependencies(
- _Check3DImage(image, require_static=False), image)
+ _Check3DImage(image, require_static=False), image)
def _CheckAtLeast3DImage(image, require_static=True):
@@ -195,12 +192,15 @@ def _CheckAtLeast3DImage(image, require_static=True):
if require_static and not image_shape.is_fully_defined():
raise ValueError('\'image\' must be fully defined.')
if any(x == 0 for x in image_shape):
- raise ValueError('all dims of \'image.shape\' must be > 0: %s' %
- image_shape)
+ raise ValueError(
+ 'all dims of \'image.shape\' must be > 0: %s' % image_shape)
if not image_shape.is_fully_defined():
- return [check_ops.assert_positive(array_ops.shape(image),
- ["all dims of 'image.shape' "
- "must be > 0."])]
+ return [
+ check_ops.assert_positive(
+ array_ops.shape(image),
+ ["all dims of 'image.shape' "
+ 'must be > 0.'])
+ ]
else:
return []
@@ -248,10 +248,11 @@ def random_flip_up_down(image, seed=None):
image = _Assert3DImage(image)
uniform_random = random_ops.random_uniform([], 0, 1.0, seed=seed)
mirror_cond = math_ops.less(uniform_random, .5)
- result = control_flow_ops.cond(mirror_cond,
- lambda: array_ops.reverse(image, [0]),
- lambda: image,
- name=scope)
+ result = control_flow_ops.cond(
+ mirror_cond,
+ lambda: array_ops.reverse(image, [0]),
+ lambda: image,
+ name=scope)
return fix_image_flip_shape(image, result)
@@ -279,10 +280,11 @@ def random_flip_left_right(image, seed=None):
image = _Assert3DImage(image)
uniform_random = random_ops.random_uniform([], 0, 1.0, seed=seed)
mirror_cond = math_ops.less(uniform_random, .5)
- result = control_flow_ops.cond(mirror_cond,
- lambda: array_ops.reverse(image, [1]),
- lambda: image,
- name=scope)
+ result = control_flow_ops.cond(
+ mirror_cond,
+ lambda: array_ops.reverse(image, [1]),
+ lambda: image,
+ name=scope)
return fix_image_flip_shape(image, result)
@@ -307,8 +309,8 @@ def flip_left_right(image):
with ops.name_scope(None, 'flip_left_right', [image]) as scope:
image = ops.convert_to_tensor(image, name='image')
image = _Assert3DImage(image)
- return fix_image_flip_shape(image,
- array_ops.reverse(image, [1], name=scope))
+ return fix_image_flip_shape(image, array_ops.reverse(
+ image, [1], name=scope))
@tf_export('image.flip_up_down')
@@ -332,8 +334,8 @@ def flip_up_down(image):
with ops.name_scope(None, 'flip_up_down', [image]) as scope:
image = ops.convert_to_tensor(image, name='image')
image = _Assert3DImage(image)
- return fix_image_flip_shape(image,
- array_ops.reverse(image, [0], name=scope))
+ return fix_image_flip_shape(image, array_ops.reverse(
+ image, [0], name=scope))
@tf_export('image.rot90')
@@ -356,19 +358,19 @@ def rot90(image, k=1, name=None):
k = math_ops.mod(k, 4)
def _rot90():
- return array_ops.transpose(array_ops.reverse_v2(image, [1]),
- [1, 0, 2])
+ return array_ops.transpose(array_ops.reverse_v2(image, [1]), [1, 0, 2])
+
def _rot180():
return array_ops.reverse_v2(image, [0, 1])
+
def _rot270():
- return array_ops.reverse_v2(array_ops.transpose(image, [1, 0, 2]),
- [1])
- cases = [(math_ops.equal(k, 1), _rot90),
- (math_ops.equal(k, 2), _rot180),
+ return array_ops.reverse_v2(array_ops.transpose(image, [1, 0, 2]), [1])
+
+ cases = [(math_ops.equal(k, 1), _rot90), (math_ops.equal(k, 2), _rot180),
(math_ops.equal(k, 3), _rot270)]
- ret = control_flow_ops.case(cases, default=lambda: image, exclusive=True,
- name=scope)
+ ret = control_flow_ops.case(
+ cases, default=lambda: image, exclusive=True, name=scope)
ret.set_shape([None, None, image.get_shape()[2]])
return ret
@@ -518,8 +520,10 @@ def pad_to_bounding_box(image, offset_height, offset_width, target_height,
]), [4, 2])
padded = array_ops.pad(image, paddings)
- padded_shape = [None if _is_tensor(i) else i
- for i in [batch, target_height, target_width, depth]]
+ padded_shape = [
+ None if _is_tensor(i) else i
+ for i in [batch, target_height, target_width, depth]
+ ]
padded.set_shape(padded_shape)
if not is_batch:
@@ -593,12 +597,13 @@ def crop_to_bounding_box(image, offset_height, offset_width, target_height,
image = control_flow_ops.with_dependencies(assert_ops, image)
cropped = array_ops.slice(
- image,
- array_ops.stack([0, offset_height, offset_width, 0]),
+ image, array_ops.stack([0, offset_height, offset_width, 0]),
array_ops.stack([-1, target_height, target_width, -1]))
- cropped_shape = [None if _is_tensor(i) else i
- for i in [batch, target_height, target_width, depth]]
+ cropped_shape = [
+ None if _is_tensor(i) else i
+ for i in [batch, target_height, target_width, depth]
+ ]
cropped.set_shape(cropped_shape)
if not is_batch:
@@ -663,8 +668,8 @@ def resize_image_with_crop_or_pad(image, target_height, target_width):
target_height = control_flow_ops.with_dependencies(
assert_ops, target_height)
if _is_tensor(target_width):
- target_width = control_flow_ops.with_dependencies(
- assert_ops, target_width)
+ target_width = control_flow_ops.with_dependencies(assert_ops,
+ target_width)
def max_(x, y):
if _is_tensor(x) or _is_tensor(y):
@@ -709,10 +714,12 @@ def resize_image_with_crop_or_pad(image, target_height, target_width):
_, resized_height, resized_width, _ = _ImageDimensions(resized, rank=4)
assert_ops = []
- assert_ops += _assert(equal_(resized_height, target_height), ValueError,
- 'resized height is not correct.')
- assert_ops += _assert(equal_(resized_width, target_width), ValueError,
- 'resized width is not correct.')
+ assert_ops += _assert(
+ equal_(resized_height, target_height), ValueError,
+ 'resized height is not correct.')
+ assert_ops += _assert(
+ equal_(resized_width, target_width), ValueError,
+ 'resized width is not correct.')
resized = control_flow_ops.with_dependencies(assert_ops, resized)
@@ -813,22 +820,17 @@ def resize_images(images,
return images
if method == ResizeMethod.BILINEAR:
- images = gen_image_ops.resize_bilinear(images,
- size,
- align_corners=align_corners)
+ images = gen_image_ops.resize_bilinear(
+ images, size, align_corners=align_corners)
elif method == ResizeMethod.NEAREST_NEIGHBOR:
- images = gen_image_ops.resize_nearest_neighbor(images,
- size,
- align_corners=
- align_corners)
+ images = gen_image_ops.resize_nearest_neighbor(
+ images, size, align_corners=align_corners)
elif method == ResizeMethod.BICUBIC:
- images = gen_image_ops.resize_bicubic(images,
- size,
- align_corners=align_corners)
+ images = gen_image_ops.resize_bicubic(
+ images, size, align_corners=align_corners)
elif method == ResizeMethod.AREA:
- images = gen_image_ops.resize_area(images,
- size,
- align_corners=align_corners)
+ images = gen_image_ops.resize_area(
+ images, size, align_corners=align_corners)
else:
raise ValueError('Resize method is not implemented.')
@@ -869,8 +871,9 @@ def per_image_standardization(image):
image = math_ops.cast(image, dtype=dtypes.float32)
image_mean = math_ops.reduce_mean(image)
- variance = (math_ops.reduce_mean(math_ops.square(image)) -
- math_ops.square(image_mean))
+ variance = (
+ math_ops.reduce_mean(math_ops.square(image)) -
+ math_ops.square(image_mean))
variance = gen_nn_ops.relu(variance)
stddev = math_ops.sqrt(variance)
@@ -971,9 +974,8 @@ def adjust_brightness(image, delta):
orig_dtype = image.dtype
flt_image = convert_image_dtype(image, dtypes.float32)
- adjusted = math_ops.add(flt_image,
- math_ops.cast(delta, dtypes.float32),
- name=name)
+ adjusted = math_ops.add(
+ flt_image, math_ops.cast(delta, dtypes.float32), name=name)
return convert_image_dtype(adjusted, orig_dtype, saturate=True)
@@ -1012,9 +1014,8 @@ def adjust_contrast(images, contrast_factor):
flt_images = convert_image_dtype(images, dtypes.float32)
# pylint: disable=protected-access
- adjusted = gen_image_ops._adjust_contrastv2(flt_images,
- contrast_factor=contrast_factor,
- name=name)
+ adjusted = gen_image_ops._adjust_contrastv2(
+ flt_images, contrast_factor=contrast_factor, name=name)
# pylint: enable=protected-access
return convert_image_dtype(adjusted, orig_dtype, saturate=True)
@@ -1061,10 +1062,10 @@ def adjust_gamma(image, gamma=1, gain=1):
gamma = control_flow_ops.with_dependencies(assert_op, gamma)
# scale = max(dtype) - min(dtype).
- scale = constant_op.constant(image.dtype.limits[1] - image.dtype.limits[0],
- dtype=dtypes.float32)
+ scale = constant_op.constant(
+ image.dtype.limits[1] - image.dtype.limits[0], dtype=dtypes.float32)
# According to the definition of gamma correction.
- adjusted_img = (img / scale) ** gamma * scale * gain
+ adjusted_img = (img / scale)**gamma * scale * gain
return adjusted_img
@@ -1195,9 +1196,8 @@ def grayscale_to_rgb(images, name=None):
with ops.name_scope(name, 'grayscale_to_rgb', [images]) as name:
images = ops.convert_to_tensor(images, name='images')
rank_1 = array_ops.expand_dims(array_ops.rank(images) - 1, 0)
- shape_list = (
- [array_ops.ones(rank_1,
- dtype=dtypes.int32)] + [array_ops.expand_dims(3, 0)])
+ shape_list = ([array_ops.ones(rank_1, dtype=dtypes.int32)] +
+ [array_ops.expand_dims(3, 0)])
multiples = array_ops.concat(shape_list, 0)
rgb = array_ops.tile(images, multiples, name=name)
rgb.set_shape(images.get_shape()[:-1].concatenate([3]))
@@ -1393,8 +1393,7 @@ def decode_image(contents, channels=None, name=None):
gif_channels = 0 if channels is None else channels
good_channels = math_ops.logical_and(
math_ops.not_equal(gif_channels, 1, name='check_gif_channels'),
- math_ops.not_equal(gif_channels, 4, name='check_gif_channels')
- )
+ math_ops.not_equal(gif_channels, 4, name='check_gif_channels'))
channels_msg = 'Channels must be in (None, 0, 3) when decoding GIF images'
assert_channels = control_flow_ops.Assert(good_channels, [channels_msg])
with ops.control_dependencies([assert_channels]):
@@ -1417,8 +1416,8 @@ def decode_image(contents, channels=None, name=None):
def _jpeg():
"""Decodes a jpeg image."""
jpeg_channels = 0 if channels is None else channels
- good_channels = math_ops.not_equal(jpeg_channels, 4,
- name='check_jpeg_channels')
+ good_channels = math_ops.not_equal(
+ jpeg_channels, 4, name='check_jpeg_channels')
channels_msg = ('Channels must be in (None, 0, 1, 3) when decoding JPEG '
'images')
assert_channels = control_flow_ops.Assert(good_channels, [channels_msg])
@@ -1496,16 +1495,21 @@ def total_variation(images, name=None):
# Calculate the total variation by taking the absolute value of the
# pixel-differences and summing over the appropriate axis.
- tot_var = (math_ops.reduce_sum(math_ops.abs(pixel_dif1), axis=sum_axis) +
- math_ops.reduce_sum(math_ops.abs(pixel_dif2), axis=sum_axis))
+ tot_var = (
+ math_ops.reduce_sum(math_ops.abs(pixel_dif1), axis=sum_axis) +
+ math_ops.reduce_sum(math_ops.abs(pixel_dif2), axis=sum_axis))
return tot_var
@tf_export('image.sample_distorted_bounding_box')
-def sample_distorted_bounding_box(image_size, bounding_boxes, seed=None,
- seed2=None, min_object_covered=None,
- aspect_ratio_range=None, area_range=None,
+def sample_distorted_bounding_box(image_size,
+ bounding_boxes,
+ seed=None,
+ seed2=None,
+ min_object_covered=None,
+ aspect_ratio_range=None,
+ area_range=None,
max_attempts=None,
use_image_if_no_bounding_boxes=None,
name=None):
@@ -1521,10 +1525,12 @@ def sample_distorted_bounding_box(image_size, bounding_boxes, seed=None,
The output of this Op is a single bounding box that may be used to crop the
original image. The output is returned as 3 tensors: `begin`, `size` and
`bboxes`. The first 2 tensors can be fed directly into `tf.slice` to crop the
- image. The latter may be supplied to `tf.image.draw_bounding_boxes` to visualize
+ image. The latter may be supplied to `tf.image.draw_bounding_boxes` to
+ visualize
what the bounding box looks like.
- Bounding boxes are supplied and returned as `[y_min, x_min, y_max, x_max]`. The
+ Bounding boxes are supplied and returned as `[y_min, x_min, y_max, x_max]`.
+ The
bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and
height of the underlying image.
@@ -1552,23 +1558,27 @@ def sample_distorted_bounding_box(image_size, bounding_boxes, seed=None,
false and no bounding boxes are supplied, an error is raised.
Args:
- image_size: A `Tensor`. Must be one of the following types: `uint8`, `int8`, `int16`, `int32`, `int64`.
+ image_size: A `Tensor`. Must be one of the following types: `uint8`, `int8`,
+ `int16`, `int32`, `int64`.
1-D, containing `[height, width, channels]`.
bounding_boxes: A `Tensor` of type `float32`.
3-D with shape `[batch, N, 4]` describing the N bounding boxes
associated with the image.
seed: An optional `int`. Defaults to `0`.
If either `seed` or `seed2` are set to non-zero, the random number
- generator is seeded by the given `seed`. Otherwise, it is seeded by a random
+ generator is seeded by the given `seed`. Otherwise, it is seeded by a
+ random
seed.
seed2: An optional `int`. Defaults to `0`.
A second seed to avoid seed collision.
min_object_covered: A Tensor of type `float32`. Defaults to `0.1`.
The cropped area of the image must contain at least this
- fraction of any bounding box supplied. The value of this parameter should be
+ fraction of any bounding box supplied. The value of this parameter should
+ be
non-negative. In the case of 0, the cropped area does not need to overlap
any of the bounding boxes supplied.
- aspect_ratio_range: An optional list of `floats`. Defaults to `[0.75, 1.33]`.
+ aspect_ratio_range: An optional list of `floats`. Defaults to `[0.75,
+ 1.33]`.
The cropped area of the image must have an aspect ratio =
width / height within this range.
area_range: An optional list of `floats`. Defaults to `[0.05, 1]`.
@@ -1576,32 +1586,41 @@ def sample_distorted_bounding_box(image_size, bounding_boxes, seed=None,
supplied image within in this range.
max_attempts: An optional `int`. Defaults to `100`.
Number of attempts at generating a cropped region of the image
- of the specified constraints. After `max_attempts` failures, return the entire
+ of the specified constraints. After `max_attempts` failures, return the
+ entire
image.
use_image_if_no_bounding_boxes: An optional `bool`. Defaults to `False`.
Controls behavior if no bounding boxes supplied.
- If true, assume an implicit bounding box covering the whole input. If false,
+ If true, assume an implicit bounding box covering the whole input. If
+ false,
raise an error.
name: A name for the operation (optional).
Returns:
A tuple of `Tensor` objects (begin, size, bboxes).
- begin: A `Tensor`. Has the same type as `image_size`. 1-D, containing `[offset_height, offset_width, 0]`. Provide as input to
+ begin: A `Tensor`. Has the same type as `image_size`. 1-D, containing
+ `[offset_height, offset_width, 0]`. Provide as input to
`tf.slice`.
- size: A `Tensor`. Has the same type as `image_size`. 1-D, containing `[target_height, target_width, -1]`. Provide as input to
+ size: A `Tensor`. Has the same type as `image_size`. 1-D, containing
+ `[target_height, target_width, -1]`. Provide as input to
`tf.slice`.
- bboxes: A `Tensor` of type `float32`. 3-D with shape `[1, 1, 4]` containing the distorted bounding box.
+ bboxes: A `Tensor` of type `float32`. 3-D with shape `[1, 1, 4]` containing
+ the distorted bounding box.
Provide as input to `tf.image.draw_bounding_boxes`.
"""
with ops.name_scope(name, 'sample_distorted_bounding_box'):
- return gen_image_ops._sample_distorted_bounding_box_v2(image_size,
- bounding_boxes, seed=seed,
- seed2=seed2, min_object_covered=min_object_covered,
- aspect_ratio_range=aspect_ratio_range, area_range=area_range,
- max_attempts=max_attempts,
- use_image_if_no_bounding_boxes=use_image_if_no_bounding_boxes,
- name=name)
+ return gen_image_ops._sample_distorted_bounding_box_v2( # pylint: disable=protected-access
+ image_size,
+ bounding_boxes,
+ seed=seed,
+ seed2=seed2,
+ min_object_covered=min_object_covered,
+ aspect_ratio_range=aspect_ratio_range,
+ area_range=area_range,
+ max_attempts=max_attempts,
+ use_image_if_no_bounding_boxes=use_image_if_no_bounding_boxes,
+ name=name)
@tf_export('image.non_max_suppression')
diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py
index 2d77e26081..7776ff08c4 100644
--- a/tensorflow/python/ops/metrics_impl.py
+++ b/tensorflow/python/ops/metrics_impl.py
@@ -100,27 +100,29 @@ def _remove_squeezable_dimensions(predictions, labels, weights):
# Use dynamic rank.
weights_rank_tensor = array_ops.rank(weights)
rank_diff = weights_rank_tensor - array_ops.rank(predictions)
+
def _maybe_expand_weights():
return control_flow_ops.cond(
math_ops.equal(rank_diff, -1),
- lambda: array_ops.expand_dims(weights, [-1]),
- lambda: weights)
+ lambda: array_ops.expand_dims(weights, [-1]), lambda: weights)
+
# Don't attempt squeeze if it will fail based on static check.
if ((weights_rank is not None) and
(not weights_shape.dims[-1].is_compatible_with(1))):
maybe_squeeze_weights = lambda: weights
else:
maybe_squeeze_weights = lambda: array_ops.squeeze(weights, [-1])
+
def _maybe_adjust_weights():
return control_flow_ops.cond(
- math_ops.equal(rank_diff, 1),
- maybe_squeeze_weights,
+ math_ops.equal(rank_diff, 1), maybe_squeeze_weights,
_maybe_expand_weights)
+
# If weights are scalar, do nothing. Otherwise, try to add or remove a
# dimension to match predictions.
weights = control_flow_ops.cond(
- math_ops.equal(weights_rank_tensor, 0),
- lambda: weights, _maybe_adjust_weights)
+ math_ops.equal(weights_rank_tensor, 0), lambda: weights,
+ _maybe_adjust_weights)
return predictions, labels, weights
@@ -165,14 +167,14 @@ def _maybe_expand_labels(labels, predictions):
if predictions_rank == labels_rank + 1:
return array_ops.expand_dims(labels, -1, name=scope)
raise ValueError(
- 'Unexpected labels shape %s for predictions shape %s.' % (
- labels.get_shape(), predictions.get_shape()))
+ 'Unexpected labels shape %s for predictions shape %s.' %
+ (labels.get_shape(), predictions.get_shape()))
# Otherwise, use dynamic shape.
return control_flow_ops.cond(
- math_ops.equal(array_ops.rank(predictions), array_ops.rank(labels) + 1),
- lambda: array_ops.expand_dims(labels, -1, name=scope),
- lambda: labels)
+ math_ops.equal(array_ops.rank(predictions),
+ array_ops.rank(labels) + 1),
+ lambda: array_ops.expand_dims(labels, -1, name=scope), lambda: labels)
def _safe_div(numerator, denominator, name):
@@ -264,8 +266,11 @@ def _streaming_confusion_matrix(labels, predictions, num_classes, weights=None):
@tf_export('metrics.mean')
-def mean(values, weights=None, metrics_collections=None,
- updates_collections=None, name=None):
+def mean(values,
+ weights=None,
+ metrics_collections=None,
+ updates_collections=None,
+ name=None):
"""Computes the (weighted) mean of the given values.
The `mean` function creates two local variables, `total` and `count`
@@ -340,8 +345,12 @@ def mean(values, weights=None, metrics_collections=None,
@tf_export('metrics.accuracy')
-def accuracy(labels, predictions, weights=None, metrics_collections=None,
- updates_collections=None, name=None):
+def accuracy(labels,
+ predictions,
+ weights=None,
+ metrics_collections=None,
+ updates_collections=None,
+ name=None):
"""Calculates how often `predictions` matches `labels`.
The `accuracy` function creates two local variables, `total` and
@@ -395,12 +404,15 @@ def accuracy(labels, predictions, weights=None, metrics_collections=None,
if labels.dtype != predictions.dtype:
predictions = math_ops.cast(predictions, labels.dtype)
is_correct = math_ops.to_float(math_ops.equal(predictions, labels))
- return mean(is_correct, weights, metrics_collections,
- updates_collections, name or 'accuracy')
+ return mean(is_correct, weights, metrics_collections, updates_collections,
+ name or 'accuracy')
-def _confusion_matrix_at_thresholds(
- labels, predictions, thresholds, weights=None, includes=None):
+def _confusion_matrix_at_thresholds(labels,
+ predictions,
+ thresholds,
+ weights=None,
+ includes=None):
"""Computes true_positives, false_negatives, true_negatives, false_positives.
This function creates up to four local variables, `true_positives`,
@@ -498,8 +510,8 @@ def _confusion_matrix_at_thresholds(
if weights is not None:
weights = weights_broadcast_ops.broadcast_weights(
math_ops.to_float(weights), predictions)
- weights_tiled = array_ops.tile(array_ops.reshape(
- weights, [1, -1]), [num_thresholds, 1])
+ weights_tiled = array_ops.tile(
+ array_ops.reshape(weights, [1, -1]), [num_thresholds, 1])
thresh_tiled.get_shape().assert_is_compatible_with(
weights_tiled.get_shape())
else:
@@ -515,8 +527,9 @@ def _confusion_matrix_at_thresholds(
math_ops.logical_and(label_is_pos, pred_is_pos))
if weights_tiled is not None:
is_true_positive *= weights_tiled
- update_ops['tp'] = state_ops.assign_add(
- true_p, math_ops.reduce_sum(is_true_positive, 1))
+ update_ops['tp'] = state_ops.assign_add(true_p,
+ math_ops.reduce_sum(
+ is_true_positive, 1))
values['tp'] = true_p
if 'fn' in includes:
@@ -526,8 +539,9 @@ def _confusion_matrix_at_thresholds(
math_ops.logical_and(label_is_pos, pred_is_neg))
if weights_tiled is not None:
is_false_negative *= weights_tiled
- update_ops['fn'] = state_ops.assign_add(
- false_n, math_ops.reduce_sum(is_false_negative, 1))
+ update_ops['fn'] = state_ops.assign_add(false_n,
+ math_ops.reduce_sum(
+ is_false_negative, 1))
values['fn'] = false_n
if 'tn' in includes:
@@ -537,8 +551,9 @@ def _confusion_matrix_at_thresholds(
math_ops.logical_and(label_is_neg, pred_is_neg))
if weights_tiled is not None:
is_true_negative *= weights_tiled
- update_ops['tn'] = state_ops.assign_add(
- true_n, math_ops.reduce_sum(is_true_negative, 1))
+ update_ops['tn'] = state_ops.assign_add(true_n,
+ math_ops.reduce_sum(
+ is_true_negative, 1))
values['tn'] = true_n
if 'fp' in includes:
@@ -548,17 +563,24 @@ def _confusion_matrix_at_thresholds(
math_ops.logical_and(label_is_neg, pred_is_pos))
if weights_tiled is not None:
is_false_positive *= weights_tiled
- update_ops['fp'] = state_ops.assign_add(
- false_p, math_ops.reduce_sum(is_false_positive, 1))
+ update_ops['fp'] = state_ops.assign_add(false_p,
+ math_ops.reduce_sum(
+ is_false_positive, 1))
values['fp'] = false_p
return values, update_ops
@tf_export('metrics.auc')
-def auc(labels, predictions, weights=None, num_thresholds=200,
- metrics_collections=None, updates_collections=None,
- curve='ROC', name=None, summation_method='trapezoidal'):
+def auc(labels,
+ predictions,
+ weights=None,
+ num_thresholds=200,
+ metrics_collections=None,
+ updates_collections=None,
+ curve='ROC',
+ name=None,
+ summation_method='trapezoidal'):
"""Computes the approximate AUC via a Riemann sum.
The `auc` function creates four local variables, `true_positives`,
@@ -626,14 +648,14 @@ def auc(labels, predictions, weights=None, num_thresholds=200,
raise RuntimeError('tf.metrics.auc is not supported when eager execution '
'is enabled.')
- with variable_scope.variable_scope(
- name, 'auc', (labels, predictions, weights)):
+ with variable_scope.variable_scope(name, 'auc',
+ (labels, predictions, weights)):
if curve != 'ROC' and curve != 'PR':
- raise ValueError('curve must be either ROC or PR, %s unknown' %
- (curve))
+ raise ValueError('curve must be either ROC or PR, %s unknown' % (curve))
kepsilon = 1e-7 # to account for floating point imprecisions
- thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
- for i in range(num_thresholds-2)]
+ thresholds = [
+ (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)
+ ]
thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
values, update_ops = _confusion_matrix_at_thresholds(
@@ -641,6 +663,7 @@ def auc(labels, predictions, weights=None, num_thresholds=200,
# Add epsilons to avoid dividing by 0.
epsilon = 1.0e-6
+
def compute_auc(tp, fn, tn, fp, name):
"""Computes the roc-auc or pr-auc based on confusion counts."""
rec = math_ops.div(tp + epsilon, tp + fn + epsilon)
@@ -671,11 +694,10 @@ def auc(labels, predictions, weights=None, num_thresholds=200,
raise ValueError('Invalid summation_method: %s' % summation_method)
# sum up the areas of all the trapeziums
- auc_value = compute_auc(
- values['tp'], values['fn'], values['tn'], values['fp'], 'value')
- update_op = compute_auc(
- update_ops['tp'], update_ops['fn'], update_ops['tn'], update_ops['fp'],
- 'update_op')
+ auc_value = compute_auc(values['tp'], values['fn'], values['tn'],
+ values['fp'], 'value')
+ update_op = compute_auc(update_ops['tp'], update_ops['fn'],
+ update_ops['tn'], update_ops['fp'], 'update_op')
if metrics_collections:
ops.add_to_collections(metrics_collections, auc_value)
@@ -687,7 +709,9 @@ def auc(labels, predictions, weights=None, num_thresholds=200,
@tf_export('metrics.mean_absolute_error')
-def mean_absolute_error(labels, predictions, weights=None,
+def mean_absolute_error(labels,
+ predictions,
+ weights=None,
metrics_collections=None,
updates_collections=None,
name=None):
@@ -746,7 +770,10 @@ def mean_absolute_error(labels, predictions, weights=None,
@tf_export('metrics.mean_cosine_distance')
-def mean_cosine_distance(labels, predictions, dim, weights=None,
+def mean_cosine_distance(labels,
+ predictions,
+ dim,
+ weights=None,
metrics_collections=None,
updates_collections=None,
name=None):
@@ -802,10 +829,8 @@ def mean_cosine_distance(labels, predictions, dim, weights=None,
radial_diffs, reduction_indices=[
dim,
], keepdims=True)
- mean_distance, update_op = mean(radial_diffs, weights,
- None,
- None,
- name or 'mean_cosine_distance')
+ mean_distance, update_op = mean(radial_diffs, weights, None, None, name or
+ 'mean_cosine_distance')
mean_distance = math_ops.subtract(1.0, mean_distance)
update_op = math_ops.subtract(1.0, update_op)
@@ -906,8 +931,8 @@ def mean_per_class_accuracy(labels,
per_class_accuracy = _safe_div(count, total, None)
- mean_accuracy_v = math_ops.reduce_mean(per_class_accuracy,
- name='mean_accuracy')
+ mean_accuracy_v = math_ops.reduce_mean(
+ per_class_accuracy, name='mean_accuracy')
update_op = _safe_div(update_count_op, update_total_op, name='update_op')
if metrics_collections:
@@ -975,13 +1000,14 @@ def mean_iou(labels,
raise RuntimeError('tf.metrics.mean_iou is not supported when '
'eager execution is enabled.')
- with variable_scope.variable_scope(
- name, 'mean_iou', (predictions, labels, weights)):
+ with variable_scope.variable_scope(name, 'mean_iou',
+ (predictions, labels, weights)):
# Check if shape is compatible.
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
total_cm, update_op = _streaming_confusion_matrix(labels, predictions,
num_classes, weights)
+
def compute_mean_iou(name):
"""Compute the mean intersection-over-union via the confusion matrix."""
sum_over_row = math_ops.to_float(math_ops.reduce_sum(total_cm, 0))
@@ -992,22 +1018,21 @@ def mean_iou(labels,
# The mean is only computed over classes that appear in the
# label or prediction tensor. If the denominator is 0, we need to
# ignore the class.
- num_valid_entries = math_ops.reduce_sum(math_ops.cast(
- math_ops.not_equal(denominator, 0), dtype=dtypes.float32))
+ num_valid_entries = math_ops.reduce_sum(
+ math_ops.cast(
+ math_ops.not_equal(denominator, 0), dtype=dtypes.float32))
# If the value of the denominator is 0, set it to 1 to avoid
# zero division.
denominator = array_ops.where(
- math_ops.greater(denominator, 0),
- denominator,
+ math_ops.greater(denominator, 0), denominator,
array_ops.ones_like(denominator))
iou = math_ops.div(cm_diag, denominator)
# If the number of valid entries is 0 (no classes) we return 0.
result = array_ops.where(
math_ops.greater(num_valid_entries, 0),
- math_ops.reduce_sum(iou, name=name) / num_valid_entries,
- 0)
+ math_ops.reduce_sum(iou, name=name) / num_valid_entries, 0)
return result
mean_iou_v = compute_mean_iou('mean_iou')
@@ -1022,7 +1047,10 @@ def mean_iou(labels,
@tf_export('metrics.mean_relative_error')
-def mean_relative_error(labels, predictions, normalizer, weights=None,
+def mean_relative_error(labels,
+ predictions,
+ normalizer,
+ weights=None,
metrics_collections=None,
updates_collections=None,
name=None):
@@ -1081,15 +1109,16 @@ def mean_relative_error(labels, predictions, normalizer, weights=None,
predictions, normalizer)
predictions.get_shape().assert_is_compatible_with(normalizer.get_shape())
relative_errors = array_ops.where(
- math_ops.equal(normalizer, 0.0),
- array_ops.zeros_like(labels),
+ math_ops.equal(normalizer, 0.0), array_ops.zeros_like(labels),
math_ops.div(math_ops.abs(labels - predictions), normalizer))
return mean(relative_errors, weights, metrics_collections,
updates_collections, name or 'mean_relative_error')
@tf_export('metrics.mean_squared_error')
-def mean_squared_error(labels, predictions, weights=None,
+def mean_squared_error(labels,
+ predictions,
+ weights=None,
metrics_collections=None,
updates_collections=None,
name=None):
@@ -1143,13 +1172,16 @@ def mean_squared_error(labels, predictions, weights=None,
predictions, labels, weights = _remove_squeezable_dimensions(
predictions=predictions, labels=labels, weights=weights)
squared_error = math_ops.square(labels - predictions)
- return mean(squared_error, weights, metrics_collections,
- updates_collections, name or 'mean_squared_error')
+ return mean(squared_error, weights, metrics_collections, updates_collections,
+ name or 'mean_squared_error')
@tf_export('metrics.mean_tensor')
-def mean_tensor(values, weights=None, metrics_collections=None,
- updates_collections=None, name=None):
+def mean_tensor(values,
+ weights=None,
+ metrics_collections=None,
+ updates_collections=None,
+ name=None):
"""Computes the element-wise (weighted) mean of the given tensors.
In contrast to the `mean` function which returns a scalar with the
@@ -1216,9 +1248,8 @@ def mean_tensor(values, weights=None, metrics_collections=None,
update_count_op = state_ops.assign_add(count, num_values)
def compute_mean(total, count, name):
- non_zero_count = math_ops.maximum(count,
- array_ops.ones_like(count),
- name=name)
+ non_zero_count = math_ops.maximum(
+ count, array_ops.ones_like(count), name=name)
return math_ops.truediv(total, non_zero_count, name=name)
mean_t = compute_mean(total, count, 'value')
@@ -1234,7 +1265,9 @@ def mean_tensor(values, weights=None, metrics_collections=None,
@tf_export('metrics.percentage_below')
-def percentage_below(values, threshold, weights=None,
+def percentage_below(values,
+ threshold,
+ weights=None,
metrics_collections=None,
updates_collections=None,
name=None):
@@ -1281,14 +1314,13 @@ def percentage_below(values, threshold, weights=None,
'eager execution is enabled.')
is_below_threshold = math_ops.to_float(math_ops.less(values, threshold))
- return mean(is_below_threshold,
- weights,
- metrics_collections,
- updates_collections,
- name or 'percentage_below_threshold')
+ return mean(is_below_threshold, weights, metrics_collections,
+ updates_collections, name or 'percentage_below_threshold')
-def _count_condition(values, weights=None, metrics_collections=None,
+def _count_condition(values,
+ weights=None,
+ metrics_collections=None,
updates_collections=None):
"""Sums the weights of cases where the given values are True.
@@ -1318,8 +1350,8 @@ def _count_condition(values, weights=None, metrics_collections=None,
values = math_ops.to_float(values)
if weights is not None:
- with ops.control_dependencies((
- check_ops.assert_rank_in(weights, (0, array_ops.rank(values))),)):
+ with ops.control_dependencies((check_ops.assert_rank_in(
+ weights, (0, array_ops.rank(values))),)):
weights = math_ops.to_float(weights)
values = math_ops.multiply(values, weights)
@@ -1336,7 +1368,9 @@ def _count_condition(values, weights=None, metrics_collections=None,
@tf_export('metrics.false_negatives')
-def false_negatives(labels, predictions, weights=None,
+def false_negatives(labels,
+ predictions,
+ weights=None,
metrics_collections=None,
updates_collections=None,
name=None):
@@ -1372,21 +1406,24 @@ def false_negatives(labels, predictions, weights=None,
raise RuntimeError('tf.metrics.false_negatives is not supported when '
'eager execution is enabled.')
- with variable_scope.variable_scope(
- name, 'false_negatives', (predictions, labels, weights)):
+ with variable_scope.variable_scope(name, 'false_negatives',
+ (predictions, labels, weights)):
predictions, labels, weights = _remove_squeezable_dimensions(
predictions=math_ops.cast(predictions, dtype=dtypes.bool),
labels=math_ops.cast(labels, dtype=dtypes.bool),
weights=weights)
- is_false_negative = math_ops.logical_and(math_ops.equal(labels, True),
- math_ops.equal(predictions, False))
+ is_false_negative = math_ops.logical_and(
+ math_ops.equal(labels, True), math_ops.equal(predictions, False))
return _count_condition(is_false_negative, weights, metrics_collections,
updates_collections)
@tf_export('metrics.false_negatives_at_thresholds')
-def false_negatives_at_thresholds(labels, predictions, thresholds, weights=None,
+def false_negatives_at_thresholds(labels,
+ predictions,
+ thresholds,
+ weights=None,
metrics_collections=None,
updates_collections=None,
name=None):
@@ -1440,7 +1477,9 @@ def false_negatives_at_thresholds(labels, predictions, thresholds, weights=None,
@tf_export('metrics.false_positives')
-def false_positives(labels, predictions, weights=None,
+def false_positives(labels,
+ predictions,
+ weights=None,
metrics_collections=None,
updates_collections=None,
name=None):
@@ -1477,21 +1516,24 @@ def false_positives(labels, predictions, weights=None,
raise RuntimeError('tf.metrics.false_positives is not supported when '
'eager execution is enabled.')
- with variable_scope.variable_scope(
- name, 'false_positives', (predictions, labels, weights)):
+ with variable_scope.variable_scope(name, 'false_positives',
+ (predictions, labels, weights)):
predictions, labels, weights = _remove_squeezable_dimensions(
predictions=math_ops.cast(predictions, dtype=dtypes.bool),
labels=math_ops.cast(labels, dtype=dtypes.bool),
weights=weights)
- is_false_positive = math_ops.logical_and(math_ops.equal(labels, False),
- math_ops.equal(predictions, True))
+ is_false_positive = math_ops.logical_and(
+ math_ops.equal(labels, False), math_ops.equal(predictions, True))
return _count_condition(is_false_positive, weights, metrics_collections,
updates_collections)
@tf_export('metrics.false_positives_at_thresholds')
-def false_positives_at_thresholds(labels, predictions, thresholds, weights=None,
+def false_positives_at_thresholds(labels,
+ predictions,
+ thresholds,
+ weights=None,
metrics_collections=None,
updates_collections=None,
name=None):
@@ -1545,7 +1587,9 @@ def false_positives_at_thresholds(labels, predictions, thresholds, weights=None,
@tf_export('metrics.true_negatives')
-def true_negatives(labels, predictions, weights=None,
+def true_negatives(labels,
+ predictions,
+ weights=None,
metrics_collections=None,
updates_collections=None,
name=None):
@@ -1582,21 +1626,24 @@ def true_negatives(labels, predictions, weights=None,
raise RuntimeError('tf.metrics.true_negatives is not '
'supported when eager execution is enabled.')
- with variable_scope.variable_scope(
- name, 'true_negatives', (predictions, labels, weights)):
+ with variable_scope.variable_scope(name, 'true_negatives',
+ (predictions, labels, weights)):
predictions, labels, weights = _remove_squeezable_dimensions(
predictions=math_ops.cast(predictions, dtype=dtypes.bool),
labels=math_ops.cast(labels, dtype=dtypes.bool),
weights=weights)
- is_true_negative = math_ops.logical_and(math_ops.equal(labels, False),
- math_ops.equal(predictions, False))
+ is_true_negative = math_ops.logical_and(
+ math_ops.equal(labels, False), math_ops.equal(predictions, False))
return _count_condition(is_true_negative, weights, metrics_collections,
updates_collections)
@tf_export('metrics.true_negatives_at_thresholds')
-def true_negatives_at_thresholds(labels, predictions, thresholds, weights=None,
+def true_negatives_at_thresholds(labels,
+ predictions,
+ thresholds,
+ weights=None,
metrics_collections=None,
updates_collections=None,
name=None):
@@ -1650,7 +1697,9 @@ def true_negatives_at_thresholds(labels, predictions, thresholds, weights=None,
@tf_export('metrics.true_positives')
-def true_positives(labels, predictions, weights=None,
+def true_positives(labels,
+ predictions,
+ weights=None,
metrics_collections=None,
updates_collections=None,
name=None):
@@ -1687,21 +1736,24 @@ def true_positives(labels, predictions, weights=None,
raise RuntimeError('tf.metrics.true_positives is not '
'supported when eager execution is enabled.')
- with variable_scope.variable_scope(
- name, 'true_positives', (predictions, labels, weights)):
+ with variable_scope.variable_scope(name, 'true_positives',
+ (predictions, labels, weights)):
predictions, labels, weights = _remove_squeezable_dimensions(
predictions=math_ops.cast(predictions, dtype=dtypes.bool),
labels=math_ops.cast(labels, dtype=dtypes.bool),
weights=weights)
- is_true_positive = math_ops.logical_and(math_ops.equal(labels, True),
- math_ops.equal(predictions, True))
+ is_true_positive = math_ops.logical_and(
+ math_ops.equal(labels, True), math_ops.equal(predictions, True))
return _count_condition(is_true_positive, weights, metrics_collections,
updates_collections)
@tf_export('metrics.true_positives_at_thresholds')
-def true_positives_at_thresholds(labels, predictions, thresholds, weights=None,
+def true_positives_at_thresholds(labels,
+ predictions,
+ thresholds,
+ weights=None,
metrics_collections=None,
updates_collections=None,
name=None):
@@ -1755,8 +1807,11 @@ def true_positives_at_thresholds(labels, predictions, thresholds, weights=None,
@tf_export('metrics.precision')
-def precision(labels, predictions, weights=None,
- metrics_collections=None, updates_collections=None,
+def precision(labels,
+ predictions,
+ weights=None,
+ metrics_collections=None,
+ updates_collections=None,
name=None):
"""Computes the precision of the predictions with respect to the labels.
@@ -1805,8 +1860,8 @@ def precision(labels, predictions, weights=None,
raise RuntimeError('tf.metrics.precision is not '
'supported when eager execution is enabled.')
- with variable_scope.variable_scope(
- name, 'precision', (predictions, labels, weights)):
+ with variable_scope.variable_scope(name, 'precision',
+ (predictions, labels, weights)):
predictions, labels, weights = _remove_squeezable_dimensions(
predictions=math_ops.cast(predictions, dtype=dtypes.bool),
@@ -1814,22 +1869,27 @@ def precision(labels, predictions, weights=None,
weights=weights)
true_p, true_positives_update_op = true_positives(
- labels, predictions, weights, metrics_collections=None,
- updates_collections=None, name=None)
+ labels,
+ predictions,
+ weights,
+ metrics_collections=None,
+ updates_collections=None,
+ name=None)
false_p, false_positives_update_op = false_positives(
- labels, predictions, weights, metrics_collections=None,
- updates_collections=None, name=None)
+ labels,
+ predictions,
+ weights,
+ metrics_collections=None,
+ updates_collections=None,
+ name=None)
def compute_precision(tp, fp, name):
return array_ops.where(
- math_ops.greater(tp + fp, 0),
- math_ops.div(tp, tp + fp),
- 0,
- name)
+ math_ops.greater(tp + fp, 0), math_ops.div(tp, tp + fp), 0, name)
p = compute_precision(true_p, false_p, 'value')
- update_op = compute_precision(
- true_positives_update_op, false_positives_update_op, 'update_op')
+ update_op = compute_precision(true_positives_update_op,
+ false_positives_update_op, 'update_op')
if metrics_collections:
ops.add_to_collections(metrics_collections, p)
@@ -1841,10 +1901,13 @@ def precision(labels, predictions, weights=None,
@tf_export('metrics.precision_at_thresholds')
-def precision_at_thresholds(labels, predictions, thresholds,
+def precision_at_thresholds(labels,
+ predictions,
+ thresholds,
weights=None,
metrics_collections=None,
- updates_collections=None, name=None):
+ updates_collections=None,
+ name=None):
"""Computes precision values for different `thresholds` on `predictions`.
The `precision_at_thresholds` function creates four local variables,
@@ -1900,12 +1963,13 @@ def precision_at_thresholds(labels, predictions, thresholds,
# Avoid division by zero.
epsilon = 1e-7
+
def compute_precision(tp, fp, name):
return math_ops.div(tp, epsilon + tp + fp, name='precision_' + name)
prec = compute_precision(values['tp'], values['fp'], 'value')
- update_op = compute_precision(
- update_ops['tp'], update_ops['fp'], 'update_op')
+ update_op = compute_precision(update_ops['tp'], update_ops['fp'],
+ 'update_op')
if metrics_collections:
ops.add_to_collections(metrics_collections, prec)
@@ -1917,8 +1981,11 @@ def precision_at_thresholds(labels, predictions, thresholds,
@tf_export('metrics.recall')
-def recall(labels, predictions, weights=None,
- metrics_collections=None, updates_collections=None,
+def recall(labels,
+ predictions,
+ weights=None,
+ metrics_collections=None,
+ updates_collections=None,
name=None):
"""Computes the recall of the predictions with respect to the labels.
@@ -1965,30 +2032,36 @@ def recall(labels, predictions, weights=None,
raise RuntimeError('tf.metrics.recall is not supported is not '
'supported when eager execution is enabled.')
- with variable_scope.variable_scope(
- name, 'recall', (predictions, labels, weights)):
+ with variable_scope.variable_scope(name, 'recall',
+ (predictions, labels, weights)):
predictions, labels, weights = _remove_squeezable_dimensions(
predictions=math_ops.cast(predictions, dtype=dtypes.bool),
labels=math_ops.cast(labels, dtype=dtypes.bool),
weights=weights)
true_p, true_positives_update_op = true_positives(
- labels, predictions, weights, metrics_collections=None,
- updates_collections=None, name=None)
+ labels,
+ predictions,
+ weights,
+ metrics_collections=None,
+ updates_collections=None,
+ name=None)
false_n, false_negatives_update_op = false_negatives(
- labels, predictions, weights, metrics_collections=None,
- updates_collections=None, name=None)
+ labels,
+ predictions,
+ weights,
+ metrics_collections=None,
+ updates_collections=None,
+ name=None)
def compute_recall(true_p, false_n, name):
return array_ops.where(
math_ops.greater(true_p + false_n, 0),
- math_ops.div(true_p, true_p + false_n),
- 0,
- name)
+ math_ops.div(true_p, true_p + false_n), 0, name)
rec = compute_recall(true_p, false_n, 'value')
- update_op = compute_recall(
- true_positives_update_op, false_negatives_update_op, 'update_op')
+ update_op = compute_recall(true_positives_update_op,
+ false_negatives_update_op, 'update_op')
if metrics_collections:
ops.add_to_collections(metrics_collections, rec)
@@ -2022,8 +2095,8 @@ def _select_class_id(ids, selected_id):
"""
ids = sparse_tensor.convert_to_tensor_or_sparse_tensor(ids)
if isinstance(ids, sparse_tensor.SparseTensor):
- return sparse_ops.sparse_retain(
- ids, math_ops.equal(ids.values, selected_id))
+ return sparse_ops.sparse_retain(ids, math_ops.equal(ids.values,
+ selected_id))
# TODO(ptucker): Make this more efficient, maybe add a sparse version of
# tf.equal and tf.reduce_any?
@@ -2031,12 +2104,13 @@ def _select_class_id(ids, selected_id):
# Shape of filled IDs is the same as `ids` with the last dim collapsed to 1.
ids_shape = array_ops.shape(ids, out_type=dtypes.int64)
ids_last_dim = array_ops.size(ids_shape) - 1
- filled_selected_id_shape = math_ops.reduced_shape(
- ids_shape, array_ops.reshape(ids_last_dim, [1]))
+ filled_selected_id_shape = math_ops.reduced_shape(ids_shape,
+ array_ops.reshape(
+ ids_last_dim, [1]))
# Intersect `ids` with the selected ID.
- filled_selected_id = array_ops.fill(
- filled_selected_id_shape, math_ops.to_int64(selected_id))
+ filled_selected_id = array_ops.fill(filled_selected_id_shape,
+ math_ops.to_int64(selected_id))
result = sets.set_intersection(filled_selected_id, ids)
return sparse_tensor.SparseTensor(
indices=result.indices, values=result.values, dense_shape=ids_shape)
@@ -2096,15 +2170,15 @@ def _sparse_true_positive_at_k(labels,
Returns:
A [D1, ... DN] `Tensor` of true positive counts.
"""
- with ops.name_scope(
- name, 'true_positives', (predictions_idx, labels, weights)):
- labels, predictions_idx = _maybe_select_class_id(
- labels, predictions_idx, class_id)
+ with ops.name_scope(name, 'true_positives',
+ (predictions_idx, labels, weights)):
+ labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx,
+ class_id)
tp = sets.set_size(sets.set_intersection(predictions_idx, labels))
tp = math_ops.to_double(tp)
if weights is not None:
- with ops.control_dependencies((
- weights_broadcast_ops.assert_broadcastable(weights, tp),)):
+ with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable(
+ weights, tp),)):
weights = math_ops.to_double(weights)
tp = math_ops.multiply(tp, weights)
return tp
@@ -2148,11 +2222,12 @@ def _streaming_sparse_true_positive_at_k(labels,
Raises:
ValueError: If `weights` is not `None` and has an incompatible shape.
"""
- with ops.name_scope(
- name, _at_k_name('true_positive', k, class_id=class_id),
- (predictions_idx, labels, weights)) as scope:
+ with ops.name_scope(name, _at_k_name('true_positive', k, class_id=class_id),
+ (predictions_idx, labels, weights)) as scope:
tp = _sparse_true_positive_at_k(
- predictions_idx=predictions_idx, labels=labels, class_id=class_id,
+ predictions_idx=predictions_idx,
+ labels=labels,
+ class_id=class_id,
weights=weights)
batch_total_tp = math_ops.to_double(math_ops.reduce_sum(tp))
@@ -2189,18 +2264,16 @@ def _sparse_false_negative_at_k(labels,
Returns:
A [D1, ... DN] `Tensor` of false negative counts.
"""
- with ops.name_scope(
- None, 'false_negatives', (predictions_idx, labels, weights)):
- labels, predictions_idx = _maybe_select_class_id(labels,
- predictions_idx,
+ with ops.name_scope(None, 'false_negatives',
+ (predictions_idx, labels, weights)):
+ labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx,
class_id)
- fn = sets.set_size(sets.set_difference(predictions_idx,
- labels,
- aminusb=False))
+ fn = sets.set_size(
+ sets.set_difference(predictions_idx, labels, aminusb=False))
fn = math_ops.to_double(fn)
if weights is not None:
- with ops.control_dependencies((
- weights_broadcast_ops.assert_broadcastable(weights, fn),)):
+ with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable(
+ weights, fn),)):
weights = math_ops.to_double(weights)
fn = math_ops.multiply(fn, weights)
return fn
@@ -2244,11 +2317,12 @@ def _streaming_sparse_false_negative_at_k(labels,
Raises:
ValueError: If `weights` is not `None` and has an incompatible shape.
"""
- with ops.name_scope(
- name, _at_k_name('false_negative', k, class_id=class_id),
- (predictions_idx, labels, weights)) as scope:
+ with ops.name_scope(name, _at_k_name('false_negative', k, class_id=class_id),
+ (predictions_idx, labels, weights)) as scope:
fn = _sparse_false_negative_at_k(
- predictions_idx=predictions_idx, labels=labels, class_id=class_id,
+ predictions_idx=predictions_idx,
+ labels=labels,
+ class_id=class_id,
weights=weights)
batch_total_fn = math_ops.to_double(math_ops.reduce_sum(fn))
@@ -2335,9 +2409,8 @@ def recall_at_k(labels,
raise RuntimeError('tf.metrics.recall_at_k is not '
'supported when eager execution is enabled.')
- with ops.name_scope(
- name, _at_k_name('recall', k, class_id=class_id),
- (predictions, labels, weights)) as scope:
+ with ops.name_scope(name, _at_k_name('recall', k, class_id=class_id),
+ (predictions, labels, weights)) as scope:
_, top_k_idx = nn.top_k(predictions, k)
return recall_at_top_k(
labels=labels,
@@ -2404,16 +2477,21 @@ def recall_at_top_k(labels,
`predictions`, or if either `metrics_collections` or `updates_collections`
are not a list or tuple.
"""
- with ops.name_scope(name,
- _at_k_name('recall', k, class_id=class_id),
+ with ops.name_scope(name, _at_k_name('recall', k, class_id=class_id),
(predictions_idx, labels, weights)) as scope:
labels = _maybe_expand_labels(labels, predictions_idx)
top_k_idx = math_ops.to_int64(predictions_idx)
tp, tp_update = _streaming_sparse_true_positive_at_k(
- predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
+ predictions_idx=top_k_idx,
+ labels=labels,
+ k=k,
+ class_id=class_id,
weights=weights)
fn, fn_update = _streaming_sparse_false_negative_at_k(
- predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
+ predictions_idx=top_k_idx,
+ labels=labels,
+ k=k,
+ class_id=class_id,
weights=weights)
metric = math_ops.div(tp, math_ops.add(tp, fn), name=scope)
@@ -2427,9 +2505,13 @@ def recall_at_top_k(labels,
@tf_export('metrics.recall_at_thresholds')
-def recall_at_thresholds(labels, predictions, thresholds,
- weights=None, metrics_collections=None,
- updates_collections=None, name=None):
+def recall_at_thresholds(labels,
+ predictions,
+ thresholds,
+ weights=None,
+ metrics_collections=None,
+ updates_collections=None,
+ name=None):
"""Computes various recall values for different `thresholds` on `predictions`.
The `recall_at_thresholds` function creates four local variables,
@@ -2483,6 +2565,7 @@ def recall_at_thresholds(labels, predictions, thresholds,
# Avoid division by zero.
epsilon = 1e-7
+
def compute_recall(tp, fn, name):
return math_ops.div(tp, epsilon + tp + fn, name='recall_' + name)
@@ -2499,7 +2582,9 @@ def recall_at_thresholds(labels, predictions, thresholds,
@tf_export('metrics.root_mean_squared_error')
-def root_mean_squared_error(labels, predictions, weights=None,
+def root_mean_squared_error(labels,
+ predictions,
+ weights=None,
metrics_collections=None,
updates_collections=None,
name=None):
@@ -2552,9 +2637,9 @@ def root_mean_squared_error(labels, predictions, weights=None,
predictions, labels, weights = _remove_squeezable_dimensions(
predictions=predictions, labels=labels, weights=weights)
- mse, update_mse_op = mean_squared_error(
- labels, predictions, weights, None, None,
- name or 'root_mean_squared_error')
+ mse, update_mse_op = mean_squared_error(labels, predictions, weights, None,
+ None, name or
+ 'root_mean_squared_error')
rmse = math_ops.sqrt(mse)
update_rmse_op = math_ops.sqrt(update_mse_op)
@@ -2569,9 +2654,14 @@ def root_mean_squared_error(labels, predictions, weights=None,
@tf_export('metrics.sensitivity_at_specificity')
-def sensitivity_at_specificity(
- labels, predictions, specificity, weights=None, num_thresholds=200,
- metrics_collections=None, updates_collections=None, name=None):
+def sensitivity_at_specificity(labels,
+ predictions,
+ specificity,
+ weights=None,
+ num_thresholds=200,
+ metrics_collections=None,
+ updates_collections=None,
+ name=None):
"""Computes the specificity at a given sensitivity.
The `sensitivity_at_specificity` function creates four local
@@ -2632,8 +2722,9 @@ def sensitivity_at_specificity(
with variable_scope.variable_scope(name, 'sensitivity_at_specificity',
(predictions, labels, weights)):
kepsilon = 1e-7 # to account for floating point imprecisions
- thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
- for i in range(num_thresholds-2)]
+ thresholds = [
+ (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)
+ ]
thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
values, update_ops = _confusion_matrix_at_thresholds(
@@ -2645,8 +2736,7 @@ def sensitivity_at_specificity(
tf_index = math_ops.cast(tf_index, dtypes.int32)
# Now, we have the implicit threshold, so compute the sensitivity:
- return math_ops.div(tp[tf_index],
- tp[tf_index] + fn[tf_index] + kepsilon,
+ return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + kepsilon,
name)
sensitivity = compute_sensitivity_at_specificity(
@@ -2685,8 +2775,8 @@ def _expand_and_tile(tensor, multiple, dim=0, name=None):
"""
if multiple < 1:
raise ValueError('Invalid multiple %s, must be > 0.' % multiple)
- with ops.name_scope(
- name, 'expand_and_tile', (tensor, multiple, dim)) as scope:
+ with ops.name_scope(name, 'expand_and_tile',
+ (tensor, multiple, dim)) as scope:
# Sparse.
tensor = sparse_tensor.convert_to_tensor_or_sparse_tensor(tensor)
if isinstance(tensor, sparse_tensor.SparseTensor):
@@ -2786,8 +2876,8 @@ def _sparse_average_precision_at_top_k(labels, predictions_idx):
Raises:
ValueError: if the last dimension of predictions_idx is not set.
"""
- with ops.name_scope(
- None, 'average_precision', (predictions_idx, labels)) as scope:
+ with ops.name_scope(None, 'average_precision',
+ (predictions_idx, labels)) as scope:
predictions_idx = math_ops.to_int64(predictions_idx, name='predictions_idx')
if predictions_idx.get_shape().ndims == 0:
raise ValueError('The rank of predictions_idx must be at least 1.')
@@ -2824,10 +2914,12 @@ def _sparse_average_precision_at_top_k(labels, predictions_idx):
retrieved_per_k = math_ops.cumsum(
array_ops.ones_like(relevant_per_k), axis=-1, name='retrieved_per_k')
precision_per_k = math_ops.div(
- math_ops.to_double(tp_per_k), math_ops.to_double(retrieved_per_k),
+ math_ops.to_double(tp_per_k),
+ math_ops.to_double(retrieved_per_k),
name='precision_per_k')
relevant_precision_per_k = math_ops.multiply(
- precision_per_k, math_ops.to_double(relevant_per_k),
+ precision_per_k,
+ math_ops.to_double(relevant_per_k),
name='relevant_precision_per_k')
# Reduce along k dimension to get the sum, yielding a [D1, ... DN] tensor.
@@ -3017,9 +3109,8 @@ def average_precision_at_k(labels,
if k < 1:
raise ValueError('Invalid k=%s.' % k)
- with ops.name_scope(
- name, _at_k_name('average_precision', k),
- (predictions, labels, weights)) as scope:
+ with ops.name_scope(name, _at_k_name('average_precision', k),
+ (predictions, labels, weights)) as scope:
# Calculate top k indices to produce [D1, ... DN, k] tensor.
_, predictions_idx = nn.top_k(predictions, k)
return _streaming_sparse_average_precision_at_top_k(
@@ -3060,17 +3151,16 @@ def _sparse_false_positive_at_k(labels,
Returns:
A [D1, ... DN] `Tensor` of false positive counts.
"""
- with ops.name_scope(
- None, 'false_positives', (predictions_idx, labels, weights)):
- labels, predictions_idx = _maybe_select_class_id(labels,
- predictions_idx,
+ with ops.name_scope(None, 'false_positives',
+ (predictions_idx, labels, weights)):
+ labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx,
class_id)
- fp = sets.set_size(sets.set_difference(
- predictions_idx, labels, aminusb=True))
+ fp = sets.set_size(
+ sets.set_difference(predictions_idx, labels, aminusb=True))
fp = math_ops.to_double(fp)
if weights is not None:
- with ops.control_dependencies((
- weights_broadcast_ops.assert_broadcastable(weights, fp),)):
+ with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable(
+ weights, fp),)):
weights = math_ops.to_double(weights)
fp = math_ops.multiply(fp, weights)
return fp
@@ -3114,11 +3204,12 @@ def _streaming_sparse_false_positive_at_k(labels,
Raises:
ValueError: If `weights` is not `None` and has an incompatible shape.
"""
- with ops.name_scope(
- name, _at_k_name('false_positive', k, class_id=class_id),
- (predictions_idx, labels, weights)) as scope:
+ with ops.name_scope(name, _at_k_name('false_positive', k, class_id=class_id),
+ (predictions_idx, labels, weights)) as scope:
fp = _sparse_false_positive_at_k(
- predictions_idx=predictions_idx, labels=labels, class_id=class_id,
+ predictions_idx=predictions_idx,
+ labels=labels,
+ class_id=class_id,
weights=weights)
batch_total_fp = math_ops.to_double(math_ops.reduce_sum(fp))
@@ -3190,10 +3281,16 @@ def precision_at_top_k(labels,
labels = _maybe_expand_labels(labels, predictions_idx)
top_k_idx = math_ops.to_int64(predictions_idx)
tp, tp_update = _streaming_sparse_true_positive_at_k(
- predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
+ predictions_idx=top_k_idx,
+ labels=labels,
+ k=k,
+ class_id=class_id,
weights=weights)
fp, fp_update = _streaming_sparse_false_positive_at_k(
- predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
+ predictions_idx=top_k_idx,
+ labels=labels,
+ k=k,
+ class_id=class_id,
weights=weights)
metric = math_ops.div(tp, math_ops.add(tp, fp), name=scope)
@@ -3323,9 +3420,14 @@ def precision_at_k(labels,
@tf_export('metrics.specificity_at_sensitivity')
-def specificity_at_sensitivity(
- labels, predictions, sensitivity, weights=None, num_thresholds=200,
- metrics_collections=None, updates_collections=None, name=None):
+def specificity_at_sensitivity(labels,
+ predictions,
+ sensitivity,
+ weights=None,
+ num_thresholds=200,
+ metrics_collections=None,
+ updates_collections=None,
+ name=None):
"""Computes the specificity at a given sensitivity.
The `specificity_at_sensitivity` function creates four local
@@ -3386,8 +3488,9 @@ def specificity_at_sensitivity(
with variable_scope.variable_scope(name, 'specificity_at_sensitivity',
(predictions, labels, weights)):
kepsilon = 1e-7 # to account for floating point imprecisions
- thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
- for i in range(num_thresholds-2)]
+ thresholds = [
+ (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)
+ ]
thresholds = [0.0 - kepsilon] + thresholds + [1.0 - kepsilon]
values, update_ops = _confusion_matrix_at_thresholds(
@@ -3419,8 +3522,7 @@ def specificity_at_sensitivity(
tf_index = math_ops.cast(tf_index, dtypes.int32)
# Now, we have the implicit threshold, so compute the specificity:
- return math_ops.div(tn[tf_index],
- tn[tf_index] + fp[tf_index] + kepsilon,
+ return math_ops.div(tn[tf_index], tn[tf_index] + fp[tf_index] + kepsilon,
name)
specificity = compute_specificity_at_sensitivity(
diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py
index 837ee02e64..3268fd0e0a 100644
--- a/tensorflow/python/ops/nn_impl.py
+++ b/tensorflow/python/ops/nn_impl.py
@@ -196,9 +196,12 @@ def weighted_cross_entropy_with_logits(targets, logits, pos_weight, name=None):
targets * -log(sigmoid(logits)) +
(1 - targets) * -log(1 - sigmoid(logits))
- A value `pos_weights > 1` decreases the false negative count, hence increasing the recall.
- Conversely setting `pos_weights < 1` decreases the false positive count and increases the precision.
- This can be seen from the fact that `pos_weight` is introduced as a multiplicative coefficient for the positive targets term
+ A value `pos_weights > 1` decreases the false negative count, hence increasing
+ the recall.
+ Conversely setting `pos_weights < 1` decreases the false positive count and
+ increases the precision.
+ This can be seen from the fact that `pos_weight` is introduced as a
+ multiplicative coefficient for the positive targets term
in the loss expression:
targets * -log(sigmoid(logits)) * pos_weight +
@@ -646,9 +649,12 @@ def normalize_moments(counts, mean_ss, variance_ss, shift, name=None):
@tf_export("nn.moments")
-def moments(x, axes,
- shift=None, # pylint: disable=unused-argument
- name=None, keep_dims=False):
+def moments(
+ x,
+ axes,
+ shift=None, # pylint: disable=unused-argument
+ name=None,
+ keep_dims=False):
"""Calculate the mean and variance of `x`.
The mean and variance are calculated by aggregating the contents of `x`
@@ -692,8 +698,8 @@ def moments(x, axes,
mean = array_ops.squeeze(mean, axes)
variance = array_ops.squeeze(variance, axes)
if x.dtype == dtypes.float16:
- return (math_ops.cast(mean, dtypes.float16), math_ops.cast(
- variance, dtypes.float16))
+ return (math_ops.cast(mean, dtypes.float16),
+ math_ops.cast(variance, dtypes.float16))
else:
return (mean, variance)
@@ -824,8 +830,8 @@ def batch_normalization(x,
inv = math_ops.rsqrt(variance + variance_epsilon)
if scale is not None:
inv *= scale
- return x * inv + (offset - mean * inv
- if offset is not None else -mean * inv)
+ return x * inv + (
+ offset - mean * inv if offset is not None else -mean * inv)
@tf_export("nn.fused_batch_norm")
diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py
index 6767564024..5a45bdc1e5 100644
--- a/tensorflow/python/ops/nn_test.py
+++ b/tensorflow/python/ops/nn_test.py
@@ -131,8 +131,7 @@ class LogPoissonLossTest(test_lib.TestCase):
y_np = self._log_poisson_loss(x_np, z_np, compute_full_loss=False)
y_np_stirling = self._log_poisson_loss(x_np, z_np, compute_full_loss=True)
y_tf = nn_impl.log_poisson_loss(z_np, x_np, compute_full_loss=False)
- y_tf_stirling = nn_impl.log_poisson_loss(
- z_np, x_np, compute_full_loss=True)
+ y_tf_stirling = nn_impl.log_poisson_loss(z_np, x_np, compute_full_loss=True)
y_tf_np = self.evaluate(y_tf)
y_tf_np_stirling = self.evaluate(y_tf_stirling)
eps = 1e-3
@@ -773,8 +772,8 @@ class ComputeSampledLogitsTest(test_lib.TestCase):
def _SoftmaxCrossEntropyWithLogits(logits, targets):
# logits, targets: float arrays of the same shape.
assert logits.shape == targets.shape
- stable_exp_logits = np.exp(logits - np.amax(
- logits, axis=1, keepdims=True))
+ stable_exp_logits = np.exp(
+ logits - np.amax(logits, axis=1, keepdims=True))
pred = stable_exp_logits / np.sum(stable_exp_logits, 1, keepdims=True)
return -np.sum(targets * np.log(pred + 1.0e-20), axis=1)
@@ -865,8 +864,8 @@ class LeakyReluTest(test_lib.TestCase):
batch_size = 3
height, width = 4, 4
np.random.seed(1) # Make it reproducible.
- inputs = np.random.uniform(
- size=(batch_size, height, width, 3)).astype(np.float32)
+ inputs = np.random.uniform(size=(batch_size, height, width, 3)).astype(
+ np.float32)
inputs = constant_op.constant(inputs)
outputs = nn_ops.leaky_relu(inputs)
@@ -884,7 +883,8 @@ class LeakyReluTest(test_lib.TestCase):
with self.test_session() as sess:
outputs = sess.run(outputs)
tol = 2e-3 if dtype == np.float16 else 1e-6
- self.assertAllClose(outputs, [-0.4, -0.2, 0.0, 1.0, 2.0], rtol=tol, atol=tol)
+ self.assertAllClose(
+ outputs, [-0.4, -0.2, 0.0, 1.0, 2.0], rtol=tol, atol=tol)
class SwishTest(test_lib.TestCase):
@@ -915,7 +915,10 @@ class SwishTest(test_lib.TestCase):
class MomentsTest(test_lib.TestCase):
- def doOutputTest(self, input_shape, moments_axes, tol=1e-4,
+ def doOutputTest(self,
+ input_shape,
+ moments_axes,
+ tol=1e-4,
check_gradients=False):
for mu in [0.0, 1.0, 1e3]:
for sigma in [1.0, 0.1]:
diff --git a/tensorflow/python/util/compat.py b/tensorflow/python/util/compat.py
index 3ab0bd16fa..270d96a3c7 100644
--- a/tensorflow/python/util/compat.py
+++ b/tensorflow/python/util/compat.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Functions for Python 2 vs. 3 compatibility.
## Conversion routines
@@ -118,7 +117,7 @@ def path_to_str(path):
Returns:
A `str` object.
"""
- if hasattr(path, "__fspath__"):
+ if hasattr(path, '__fspath__'):
path = as_str_any(path.__fspath__())
return path
@@ -129,11 +128,9 @@ integral_types = (_numbers.Integral, _np.integer)
real_types = (_numbers.Real, _np.integer, _np.floating)
complex_types = (_numbers.Complex, _np.number)
-
# Either bytes or text.
bytes_or_text_types = (bytes, _six.text_type)
-
_allowed_symbols = [
'as_str',
'bytes_or_text_types',
diff --git a/tensorflow/tools/pip_package/pip_smoke_test.py b/tensorflow/tools/pip_package/pip_smoke_test.py
index 8eee489e2d..38a9007387 100644
--- a/tensorflow/tools/pip_package/pip_smoke_test.py
+++ b/tensorflow/tools/pip_package/pip_smoke_test.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""This pip smoke test verifies dependency files exist in the pip package.
This script runs bazel queries to see what python files are required by the
@@ -26,13 +25,12 @@ from __future__ import print_function
import os
import subprocess
+os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")))
-os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../..')))
-
-
-PIP_PACKAGE_QUERY_EXPRESSION = \
- 'deps(//tensorflow/tools/pip_package:build_pip_package)'
+PIP_PACKAGE_QUERY_EXPRESSION = (
+ "deps(//tensorflow/tools/pip_package:build_pip_package)")
+# pylint: disable=g-backslash-continuation
PY_TEST_QUERY_EXPRESSION = 'deps(\
filter("^((?!benchmark).)*$",\
kind(py_test,\
@@ -40,6 +38,7 @@ PY_TEST_QUERY_EXPRESSION = 'deps(\
+ //tensorflow/contrib/... \
- //tensorflow/contrib/tensorboard/... \
- attr(tags, "manual|no_pip", //tensorflow/...))), 1)'
+# pylint: enable=g-backslash-continuation
# Hard-coded blacklist of files if not included in pip package
# TODO(amitpatankar): Clean up blacklist.
@@ -90,15 +89,15 @@ def main():
"""
# pip_package_dependencies_list is the list of included files in pip packages
- pip_package_dependencies = subprocess.check_output([
- 'bazel', 'query', PIP_PACKAGE_QUERY_EXPRESSION])
+ pip_package_dependencies = subprocess.check_output(
+ ["bazel", "query", PIP_PACKAGE_QUERY_EXPRESSION])
pip_package_dependencies_list = pip_package_dependencies.strip().split("\n")
print("Pip package superset size: %d" % len(pip_package_dependencies_list))
# tf_py_test_dependencies is the list of dependencies for all python
# tests in tensorflow
- tf_py_test_dependencies = subprocess.check_output([
- 'bazel', 'query', PY_TEST_QUERY_EXPRESSION])
+ tf_py_test_dependencies = subprocess.check_output(
+ ["bazel", "query", PY_TEST_QUERY_EXPRESSION])
tf_py_test_dependencies_list = tf_py_test_dependencies.strip().split("\n")
print("Pytest dependency subset size: %d" % len(tf_py_test_dependencies_list))
@@ -119,8 +118,7 @@ def main():
# Check if the dependency is in the pip package, the blacklist, or
# should be ignored because of its file extension
- if not (ignore or
- dependency in pip_package_dependencies_list or
+ if not (ignore or dependency in pip_package_dependencies_list or
dependency in BLACKLIST):
missing_dependencies.append(dependency)
@@ -131,9 +129,9 @@ def main():
for missing_dependency in missing_dependencies:
print("\nMissing dependency: %s " % missing_dependency)
print("Affected Tests:")
- rdep_query = 'rdeps(kind(py_test, \
- //tensorflow/python/...), %s)' % missing_dependency
- affected_tests = subprocess.check_output(['bazel', 'query', rdep_query])
+ rdep_query = ("rdeps(kind(py_test, //tensorflow/python/...), %s)" %
+ missing_dependency)
+ affected_tests = subprocess.check_output(["bazel", "query", rdep_query])
affected_tests_list = affected_tests.split("\n")[:-2]
print("\n".join(affected_tests_list))
@@ -145,5 +143,6 @@ or add them to //tensorflow/tools/pip_package/BUILD.""")
else:
print("TEST PASSED")
+
if __name__ == "__main__":
main()