aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops
diff options
context:
space:
mode:
authorGravatar wangsiyu <siyu.wsy@gmail.com>2018-09-25 13:01:50 +0800
committerGravatar wangsiyu <siyu.wsy@gmail.com>2018-09-25 13:01:50 +0800
commit351ef3409fb913067cc26eccb3c6de350e84ca52 (patch)
treec97e38fea0f3a62000128312ebe83415f18debea /tensorflow/python/ops
parent6dd7a09211cc74d11ff1554624b527c432020cbc (diff)
parentc1644948d23cae271b140d67101c1a386e5495fd (diff)
Merge branch 'master' of github.com:tensorflow/tensorflow into assign_in_part_vars
Diffstat (limited to 'tensorflow/python/ops')
-rw-r--r--tensorflow/python/ops/cond_v2_impl.py6
-rw-r--r--tensorflow/python/ops/distributions/bijector_impl.py39
-rw-r--r--tensorflow/python/ops/distributions/util.py12
-rw-r--r--tensorflow/python/ops/functional_ops.py40
-rw-r--r--tensorflow/python/ops/image_ops_impl.py14
-rw-r--r--tensorflow/python/ops/image_ops_test.py12
-rw-r--r--tensorflow/python/ops/lookup_ops.py40
-rw-r--r--tensorflow/python/ops/nn_ops.py34
8 files changed, 120 insertions, 77 deletions
diff --git a/tensorflow/python/ops/cond_v2_impl.py b/tensorflow/python/ops/cond_v2_impl.py
index c6a6b2a7fa..f8b1ddb140 100644
--- a/tensorflow/python/ops/cond_v2_impl.py
+++ b/tensorflow/python/ops/cond_v2_impl.py
@@ -119,7 +119,11 @@ def cond_v2(pred, true_fn, false_fn, name="cond"):
attr_value_pb2.AttrValue(b=True))
# pylint: enable=protected-access
- return tuple(tensors[:num_cond_outputs])
+ result = tuple(tensors[:num_cond_outputs])
+ if len(result) == 1:
+ return result[0]
+ else:
+ return result
@ops.RegisterGradient("If")
diff --git a/tensorflow/python/ops/distributions/bijector_impl.py b/tensorflow/python/ops/distributions/bijector_impl.py
index 2e7aa30296..9c63385dd0 100644
--- a/tensorflow/python/ops/distributions/bijector_impl.py
+++ b/tensorflow/python/ops/distributions/bijector_impl.py
@@ -825,10 +825,21 @@ class Bijector(object):
min_event_ndims=self.inverse_min_event_ndims,
event_ndims=event_ndims)):
if not self._is_injective: # No caching for non-injective
- ildjs = self._inverse_log_det_jacobian(y, **kwargs)
- return tuple(self._reduce_jacobian_det_over_event(
- y, ildj, self.inverse_min_event_ndims, event_ndims)
- for ildj in ildjs)
+ try:
+ ildjs = self._inverse_log_det_jacobian(y, **kwargs)
+ return tuple(self._reduce_jacobian_det_over_event(
+ y, ildj, self.inverse_min_event_ndims, event_ndims)
+ for ildj in ildjs)
+ except NotImplementedError as original_exception:
+ try:
+ x = self._inverse(y, **kwargs)
+ fldjs = self._forward_log_det_jacobian(x, **kwargs)
+ return tuple(self._reduce_jacobian_det_over_event(
+ x, -fldj, self.forward_min_event_ndims, event_ndims)
+ for fldj in fldjs)
+ except NotImplementedError:
+ raise original_exception
+
mapping = self._lookup(y=y, kwargs=kwargs)
if mapping.ildj_map is not None and event_ndims in mapping.ildj_map:
return mapping.ildj_map[event_ndims]
@@ -917,11 +928,21 @@ class Bijector(object):
return -1. * self._constant_ildj_map[event_ndims]
x = ops.convert_to_tensor(x, name="x")
self._maybe_assert_dtype(x)
- if not self._is_injective:
- fldjs = self._forward_log_det_jacobian(x, **kwargs) # No caching.
- return tuple(self._reduce_jacobian_det_over_event(
- x, fldj, self.forward_min_event_ndims, event_ndims)
- for fldj in fldjs)
+ if not self._is_injective: # No caching for non-injective
+ try:
+ fldjs = self._forward_log_det_jacobian(x, **kwargs) # No caching.
+ return tuple(self._reduce_jacobian_det_over_event(
+ x, fldj, self.forward_min_event_ndims, event_ndims)
+ for fldj in fldjs)
+ except NotImplementedError as original_exception:
+ try:
+ y = self._forward(x, **kwargs)
+ ildjs = self._inverse_log_det_jacobian(y, **kwargs)
+ return tuple(self._reduce_jacobian_det_over_event(
+ y, -ildj, self.inverse_min_event_ndims, event_ndims)
+ for ildj in ildjs)
+ except NotImplementedError:
+ raise original_exception
mapping = self._lookup(x=x, kwargs=kwargs)
if mapping.ildj_map is not None and event_ndims in mapping.ildj_map:
return -mapping.ildj_map[event_ndims]
diff --git a/tensorflow/python/ops/distributions/util.py b/tensorflow/python/ops/distributions/util.py
index 3e480a79f5..ad848dfee6 100644
--- a/tensorflow/python/ops/distributions/util.py
+++ b/tensorflow/python/ops/distributions/util.py
@@ -155,7 +155,8 @@ def get_logits_and_probs(logits=None,
probs=None,
multidimensional=False,
validate_args=False,
- name="get_logits_and_probs"):
+ name="get_logits_and_probs",
+ dtype=None):
"""Converts logit to probabilities (or vice-versa), and returns both.
Args:
@@ -169,6 +170,7 @@ def get_logits_and_probs(logits=None,
`0 <= probs <= 1` (if not `multidimensional`) or that the last dimension
of `probs` sums to one.
name: A name for this operation (optional).
+ dtype: `tf.DType` to prefer when converting args to `Tensor`s.
Returns:
logits, probs: Tuple of `Tensor`s. If `probs` has an entry that is `0` or
@@ -183,7 +185,7 @@ def get_logits_and_probs(logits=None,
raise ValueError("Must pass probs or logits, but not both.")
if probs is None:
- logits = ops.convert_to_tensor(logits, name="logits")
+ logits = ops.convert_to_tensor(logits, name="logits", dtype=dtype)
if not logits.dtype.is_floating:
raise TypeError("logits must having floating type.")
# We can early return since we constructed probs and therefore know
@@ -194,7 +196,7 @@ def get_logits_and_probs(logits=None,
return logits, nn.softmax(logits, name="probs")
return logits, math_ops.sigmoid(logits, name="probs")
- probs = ops.convert_to_tensor(probs, name="probs")
+ probs = ops.convert_to_tensor(probs, name="probs", dtype=dtype)
if not probs.dtype.is_floating:
raise TypeError("probs must having floating type.")
@@ -524,6 +526,8 @@ def matrix_diag_transform(matrix, transform=None, name=None):
Example of heteroskedastic 2-D linear regression.
```python
+ tfd = tfp.distributions
+
# Get a trainable Cholesky factor.
matrix_values = tf.contrib.layers.fully_connected(activations, 4)
matrix = tf.reshape(matrix_values, (batch_size, 2, 2))
@@ -533,7 +537,7 @@ def matrix_diag_transform(matrix, transform=None, name=None):
mu = tf.contrib.layers.fully_connected(activations, 2)
# This is a fully trainable multivariate normal!
- dist = tf.contrib.distributions.MVNCholesky(mu, chol)
+ dist = tfd.MultivariateNormalTriL(mu, chol)
# Standard log loss. Minimizing this will "train" mu and chol, and then dist
# will be a distribution predicting labels as multivariate Gaussians.
diff --git a/tensorflow/python/ops/functional_ops.py b/tensorflow/python/ops/functional_ops.py
index a4e7c84ae4..119d9522bd 100644
--- a/tensorflow/python/ops/functional_ops.py
+++ b/tensorflow/python/ops/functional_ops.py
@@ -41,6 +41,7 @@ from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops.gen_functional_ops import remote_call
# pylint: enable=unused-import
from tensorflow.python.ops.gen_functional_ops import symbolic_gradient
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
@@ -263,7 +264,7 @@ def foldr(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
@tf_export("map_fn")
-def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
+def map_fn(fn, elems, dtype=None, parallel_iterations=None, back_prop=True,
swap_memory=False, infer_shape=True, name=None):
"""map on the list of tensors unpacked from `elems` on dimension 0.
@@ -305,6 +306,25 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
instead.
+ When executing eagerly, map_fn does not execute in parallel even if
+ `parallel_iterations` is set to a value > 1. You can still get the
+ performance benefits of running a function in parallel by using the
+ `tf.contrib.eager.defun` decorator,
+
+ ```python
+ # Assume the function being used in map_fn is fn.
+ # To ensure map_fn calls fn in parallel, use the defun decorator.
+ @tf.contrib.eager.defun
+ def func(tensor):
+ return tf.map_fn(fn, tensor)
+ ```
+
+ Note that if you use the defun decorator, any non-TensorFlow Python code
+ that you may have written in your function won't get executed. See
+ `tf.contrib.eager.defun` for more details. The recommendation would be to
+ debug without defun but switch to defun to get performance benefits of
+ running map_fn in parallel.
+
Args:
fn: The callable to be performed. It accepts one argument, which will
have the same (possibly nested) structure as `elems`. Its output
@@ -317,7 +337,8 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
of Tensors differing from the structure of `elems`, then `dtype` is not
optional and must have the same structure as the output of `fn`.
parallel_iterations: (optional) The number of iterations allowed to run
- in parallel.
+ in parallel. When graph building, the default value is 10. While executing
+ eagerly, the default value is set to 1.
back_prop: (optional) True enables support for back propagation.
swap_memory: (optional) True enables GPU-CPU memory swapping.
infer_shape: (optional) False disables tests for consistent output shapes.
@@ -363,6 +384,20 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
" SparseTensor(input.indices, map_fn(fn, input.values), "
"input.dense_shape)")
+ in_graph_mode = not context.executing_eagerly()
+ # Set the default number of parallel_iterations depending on graph/eager mode.
+ if in_graph_mode and not parallel_iterations:
+ parallel_iterations = 10
+ elif not in_graph_mode and not parallel_iterations:
+ parallel_iterations = 1
+
+ if not in_graph_mode and parallel_iterations > 1:
+ logging.log_first_n(logging.WARN, "Setting parallel_iterations > 1 has no "
+ "effect when executing eagerly. Consider calling map_fn"
+ " with tf.contrib.eager.defun to execute fn in "
+ "parallel.", 1)
+ parallel_iterations = 1
+
input_is_sequence = nest.is_sequence(elems)
input_flatten = lambda x: nest.flatten(x) if input_is_sequence else [x]
def input_pack(x):
@@ -381,7 +416,6 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
elems_flat = input_flatten(elems)
- in_graph_mode = not context.executing_eagerly()
with ops.name_scope(name, "map", elems_flat):
# TODO(akshayka): Remove the in_graph_mode check once caching devices are
# supported in Eager
diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py
index 325418d5f7..1c75aab578 100644
--- a/tensorflow/python/ops/image_ops_impl.py
+++ b/tensorflow/python/ops/image_ops_impl.py
@@ -329,8 +329,6 @@ def _random_flip(image, flip_index, seed, scope_name):
lambda: image,
name=scope
)
- if isinstance(result, tuple):
- result = result[0] # TODO(b/111124878) remove this logic (CondV2).
return fix_image_flip_shape(image, result)
elif shape.ndims == 4:
batch_size = array_ops.shape(image)[0]
@@ -1029,10 +1027,10 @@ def resize_images(images,
scale_factor_width = (math_ops.to_float(new_width_const) /
math_ops.to_float(current_width))
scale_factor = math_ops.minimum(scale_factor_height, scale_factor_width)
- scaled_height_const = math_ops.to_int32(scale_factor *
- math_ops.to_float(current_height))
- scaled_width_const = math_ops.to_int32(scale_factor *
- math_ops.to_float(current_width))
+ scaled_height_const = math_ops.to_int32(
+ math_ops.round(scale_factor * math_ops.to_float(current_height)))
+ scaled_width_const = math_ops.to_int32(
+ math_ops.round(scale_factor * math_ops.to_float(current_width)))
# NOTE: Reset the size and other constants used later.
size = ops.convert_to_tensor([scaled_height_const, scaled_width_const],
@@ -1176,7 +1174,7 @@ def resize_image_with_pad(image,
@tf_export('image.per_image_standardization')
def per_image_standardization(image):
- """Linearly scales `image` to have zero mean and unit norm.
+ """Linearly scales `image` to have zero mean and unit variance.
This op computes `(x - mean) / adjusted_stddev`, where `mean` is the average
of all values in image, and
@@ -1379,7 +1377,7 @@ def adjust_gamma(image, gamma=1, gain=1):
[1] http://en.wikipedia.org/wiki/Gamma_correction
"""
- with ops.op_scope([image, gamma, gain], None, 'adjust_gamma'):
+ with ops.name_scope(None, 'adjust_gamma', [image, gamma, gain]) as name:
# Convert pixel value to DT_FLOAT for computing adjusted image.
img = ops.convert_to_tensor(image, name='img', dtype=dtypes.float32)
# Keep image dtype for computing the scale of corresponding dtype.
diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py
index 795e6bbc3e..35fdee4fad 100644
--- a/tensorflow/python/ops/image_ops_test.py
+++ b/tensorflow/python/ops/image_ops_test.py
@@ -2687,6 +2687,12 @@ class ResizeImagesTest(test_util.TensorFlowTestCase):
self._assertResizeCheckShape(x, x_shape, [3840, 2160], [3840, 2160, 3])
+ def testPreserveAspectRatioSquare(self):
+ x_shape = [299, 299, 3]
+ x = np.random.uniform(size=x_shape)
+
+ self._assertResizeCheckShape(x, x_shape, [320, 320], [320, 320, 3])
+
class ResizeImageWithPadTest(test_util.TensorFlowTestCase):
@@ -3667,7 +3673,7 @@ class NonMaxSuppressionTest(test_util.TensorFlowTestCase):
# Note: There are multiple versions of non_max_suppression v2, v3, v4.
# gen_image_ops.non_max_suppression_v2:
for dtype in [np.float16, np.float32]:
- with self.test_session():
+ with self.cached_session():
boxes = constant_op.constant(boxes_np, dtype=dtype)
scores = constant_op.constant(scores_np, dtype=dtype)
max_output_size = constant_op.constant(max_output_size_np)
@@ -3677,7 +3683,7 @@ class NonMaxSuppressionTest(test_util.TensorFlowTestCase):
self.assertAllClose(selected_indices, [3, 0, 5])
# image_ops.non_max_suppression = gen_image_ops.non_max_suppression_v3.
for dtype in [np.float16, np.float32]:
- with self.test_session():
+ with self.cached_session():
boxes = constant_op.constant(boxes_np, dtype=dtype)
scores = constant_op.constant(scores_np, dtype=dtype)
max_output_size = constant_op.constant(max_output_size_np)
@@ -3688,7 +3694,7 @@ class NonMaxSuppressionTest(test_util.TensorFlowTestCase):
# gen_image_ops.non_max_suppression_v4.
score_threshold = float('-inf')
for dtype in [np.float16, np.float32]:
- with self.test_session():
+ with self.cached_session():
boxes = constant_op.constant(boxes_np, dtype=dtype)
scores = constant_op.constant(scores_np, dtype=dtype)
max_output_size = constant_op.constant(max_output_size_np)
diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py
index 561a341cf3..5443699ddd 100644
--- a/tensorflow/python/ops/lookup_ops.py
+++ b/tensorflow/python/ops/lookup_ops.py
@@ -422,7 +422,7 @@ class TextFileInitializer(TableInitializerBase):
* `palmer -> 30`
```python
- table = tf.contrib.lookup.HashTable(tf.contrib.lookup.TextFileInitializer(
+ table = tf.lookup.HashTable(tf.lookup.TextFileInitializer(
"test.txt", tf.string, 0, tf.int64, 1, delimiter=" "), -1)
...
table.init.run()
@@ -435,9 +435,9 @@ class TextFileInitializer(TableInitializerBase):
* `palmer 30 -> 2`
```python
- table = tf.contrib.lookup.HashTable(tf.contrib.lookup.TextFileInitializer(
- "test.txt", tf.string, tf.contrib.lookup.TextFileIndex.WHOLE_LINE,
- tf.int64, tf.contrib.lookup.TextFileIndex.LINE_NUMBER, delimiter=" "), -1)
+ table = tf.lookup.HashTable(tf.lookup.TextFileInitializer(
+ "test.txt", tf.string, tf.lookup.TextFileIndex.WHOLE_LINE,
+ tf.int64, tf.lookup.TextFileIndex.LINE_NUMBER, delimiter=" "), -1)
...
table.init.run()
```
@@ -953,7 +953,7 @@ def index_table_from_file(vocabulary_file=None,
```python
features = tf.constant(["emerson", "lake", "and", "palmer"])
- table = tf.contrib.lookup.index_table_from_file(
+ table = tf.lookup.index_table_from_file(
vocabulary_file="test.txt", num_oov_buckets=1)
ids = table.lookup(features)
...
@@ -1054,21 +1054,21 @@ def index_table_from_tensor(vocabulary_list,
Any lookup of an out-of-vocabulary token will return a bucket ID based on its
hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the
- `default_value`.
- The bucket ID range is `[mapping size, mapping size + num_oov_buckets - 1]`.
+ `default_value`. The bucket ID range is
+ `[vocabulary list size, vocabulary list size + num_oov_buckets - 1]`.
The underlying table must be initialized by calling
`tf.tables_initializer.run()` or `table.init.run()` once.
- Elements in `mapping` cannot have duplicates, otherwise when executing the
- table initializer op, it will throw a `FailedPreconditionError`.
+ Elements in `vocabulary_list` cannot have duplicates, otherwise when executing
+ the table initializer op, it will throw a `FailedPreconditionError`.
Sample Usages:
```python
vocabulary_list = tf.constant(["emerson", "lake", "palmer"])
- table = tf.contrib.lookup.index_table_from_tensor(
- mapping=vocabulary_list, num_oov_buckets=1, default_value=-1)
+ table = tf.lookup.index_table_from_tensor(
+ vocabulary_list=vocabulary_list, num_oov_buckets=1, default_value=-1)
features = tf.constant(["emerson", "lake", "and", "palmer"])
ids = table.lookup(features)
...
@@ -1093,7 +1093,7 @@ def index_table_from_tensor(vocabulary_list,
The lookup table to map an input `Tensor` to index `int64` `Tensor`.
Raises:
- ValueError: If `mapping` is invalid.
+ ValueError: If `vocabulary_list` is invalid.
ValueError: If `num_oov_buckets` is negative.
"""
if vocabulary_list is None:
@@ -1185,7 +1185,7 @@ def index_to_string_table_from_file(vocabulary_file,
```python
indices = tf.constant([1, 5], tf.int64)
- table = tf.contrib.lookup.index_to_string_table_from_file(
+ table = tf.lookup.index_to_string_table_from_file(
vocabulary_file="test.txt", default_value="UNKNOWN")
values = table.lookup(indices)
...
@@ -1250,25 +1250,25 @@ def index_to_string_table_from_tensor(vocabulary_list,
"""Returns a lookup table that maps a `Tensor` of indices into strings.
This operation constructs a lookup table to map int64 indices into string
- values. The mapping is initialized from a string `mapping` 1-D `Tensor` where
- each element is a value and the corresponding index within the tensor is the
- key.
+ values. The mapping is initialized from a string `vocabulary_list` 1-D
+ `Tensor` where each element is a value and the corresponding index within the
+ tensor is the key.
- Any input which does not have a corresponding index in 'mapping'
+ Any input which does not have a corresponding index in 'vocabulary_list'
(an out-of-vocabulary entry) is assigned the `default_value`
The underlying table must be initialized by calling
`tf.tables_initializer.run()` or `table.init.run()` once.
- Elements in `mapping` cannot have duplicates, otherwise when executing the
- table initializer op, it will throw a `FailedPreconditionError`.
+ Elements in `vocabulary_list` cannot have duplicates, otherwise when executing
+ the table initializer op, it will throw a `FailedPreconditionError`.
Sample Usages:
```python
vocabulary_list = tf.constant(["emerson", "lake", "palmer"])
indices = tf.constant([1, 5], tf.int64)
- table = tf.contrib.lookup.index_to_string_table_from_tensor(
+ table = tf.lookup.index_to_string_table_from_tensor(
vocabulary_list, default_value="UNKNOWN")
values = table.lookup(indices)
...
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index 2526e6fee2..9ef177e97b 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -22,7 +22,6 @@ import numbers
import numpy as np
-from tensorflow.python.compat import compat
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import graph_util
@@ -1670,47 +1669,24 @@ def _softmax(logits, compute_op, dim=-1, name=None):
shape = logits.get_shape()
is_last_dim = (dim is -1) or (dim == shape.ndims - 1)
- # TODO(phawkins): remove after 2018/8/27 and simplify this code.
- softmax_accepts_r1_or_greater = compat.forward_compatible(2018, 8, 27)
- reshape_required = (not softmax_accepts_r1_or_greater) and shape.ndims != 2
if is_last_dim:
- if reshape_required:
- # If dim is the last dimension, simply reshape the logits to a matrix and
- # apply the internal softmax.
- input_shape = array_ops.shape(logits)
- logits = _flatten_outer_dims(logits)
- output = compute_op(logits)
- output = array_ops.reshape(output, input_shape, name=name)
- return output
return compute_op(logits, name=name)
- # If dim is not the last dimension, we have to do a reshape and transpose so
- # that we can still perform softmax on its last dimension.
+ # If dim is not the last dimension, we have to do a transpose so that we can
+ # still perform softmax on its last dimension.
# Swap logits' dimension of dim and its last dimension.
input_rank = array_ops.rank(logits)
dim_axis = dim % shape.ndims
logits = _swap_axis(logits, dim_axis, math_ops.subtract(input_rank, 1))
- shape_after_swap = array_ops.shape(logits)
- if reshape_required:
- # Reshape logits into a matrix.
- logits = _flatten_outer_dims(logits)
-
- # Do the actual softmax on its last dimension.
- output = compute_op(logits)
-
- # Transform back the output tensor.
- output = array_ops.reshape(output, shape_after_swap)
- else:
- # Do the actual softmax on its last dimension.
- output = compute_op(logits)
+ # Do the actual softmax on its last dimension.
+ output = compute_op(logits)
output = _swap_axis(
output, dim_axis, math_ops.subtract(input_rank, 1), name=name)
- # Make shape inference work since reshape and transpose may erase its static
- # shape.
+ # Make shape inference work since transpose may erase its static shape.
output.set_shape(shape)
return output