aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops
diff options
context:
space:
mode:
authorGravatar Cao Zongyan <zongyan.cao@alibaba-inc.com>2018-09-26 11:54:30 +0800
committerGravatar Cao Zongyan <zongyan.cao@alibaba-inc.com>2018-09-26 11:54:30 +0800
commit35174f46b973c66a2e6894a12b3018d60e8414ec (patch)
tree5bdae0172159bc02ec3a470722bf959b14dd47ba /tensorflow/python/ops
parentf0886f7269de900d226455d4831722f6fc94a71b (diff)
parent6666516f390f125ed70ddbd4e6f89b83d953c408 (diff)
Merge remote-tracking branch 'origin'
Diffstat (limited to 'tensorflow/python/ops')
-rw-r--r--tensorflow/python/ops/array_ops.py61
-rw-r--r--tensorflow/python/ops/boosted_trees_ops.py6
-rw-r--r--tensorflow/python/ops/cond_v2_impl.py6
-rw-r--r--tensorflow/python/ops/control_flow_ops.py11
-rw-r--r--tensorflow/python/ops/ctc_ops.py6
-rw-r--r--tensorflow/python/ops/distributions/beta.py9
-rw-r--r--tensorflow/python/ops/distributions/bijector_impl.py45
-rw-r--r--tensorflow/python/ops/distributions/categorical.py4
-rw-r--r--tensorflow/python/ops/distributions/dirichlet.py9
-rw-r--r--tensorflow/python/ops/distributions/distribution.py113
-rw-r--r--tensorflow/python/ops/distributions/gamma.py9
-rw-r--r--tensorflow/python/ops/distributions/kullback_leibler.py4
-rw-r--r--tensorflow/python/ops/distributions/normal.py9
-rw-r--r--tensorflow/python/ops/distributions/student_t.py14
-rw-r--r--tensorflow/python/ops/distributions/util.py12
-rw-r--r--tensorflow/python/ops/embedding_ops.py8
-rw-r--r--tensorflow/python/ops/functional_ops.py40
-rw-r--r--tensorflow/python/ops/gradients_impl.py58
-rw-r--r--tensorflow/python/ops/gradients_test.py39
-rw-r--r--tensorflow/python/ops/image_ops_impl.py54
-rw-r--r--tensorflow/python/ops/image_ops_test.py12
-rw-r--r--tensorflow/python/ops/linalg/linear_operator_addition.py432
-rw-r--r--tensorflow/python/ops/linalg/linear_operator_circulant.py18
-rw-r--r--tensorflow/python/ops/linalg/linear_operator_test_util.py16
-rw-r--r--tensorflow/python/ops/logging_ops.py260
-rw-r--r--tensorflow/python/ops/lookup_ops.py40
-rw-r--r--tensorflow/python/ops/losses/util_test.py6
-rw-r--r--tensorflow/python/ops/math_ops.py25
-rw-r--r--tensorflow/python/ops/nn_ops.py34
-rw-r--r--tensorflow/python/ops/parallel_for/BUILD2
-rw-r--r--tensorflow/python/ops/parallel_for/control_flow_ops_test.py192
-rw-r--r--tensorflow/python/ops/parallel_for/gradients.py2
-rw-r--r--tensorflow/python/ops/parallel_for/gradients_test.py26
-rw-r--r--tensorflow/python/ops/parallel_for/pfor.py98
-rw-r--r--tensorflow/python/ops/parsing_ops.py10
-rw-r--r--tensorflow/python/ops/resource_variable_ops.py74
-rw-r--r--tensorflow/python/ops/rnn.py4
-rw-r--r--tensorflow/python/ops/rnn_cell_impl.py10
-rw-r--r--tensorflow/python/ops/string_ops.py102
-rw-r--r--tensorflow/python/ops/summary_ops_v2.py1
-rw-r--r--tensorflow/python/ops/while_v2.py580
41 files changed, 2139 insertions, 322 deletions
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index c8b883350d..a7f57e94e3 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -2787,4 +2787,65 @@ def quantize(input, # pylint: disable=redefined-builtin
name=name)
+@tf_export("searchsorted")
+def searchsorted(sorted_sequence,
+ values,
+ side="left",
+ out_type=dtypes.int32,
+ name=None):
+ """Searches input tensor for values on the innermost dimension.
+
+ A 2-D example:
+
+ ```
+ sorted_sequence = [[0, 3, 9, 9, 10],
+ [1, 2, 3, 4, 5]]
+ values = [[2, 4, 9],
+ [0, 2, 6]]
+
+ result = searchsorted(sorted_sequence, values, side="left")
+
+ result == [[1, 2, 2],
+ [0, 1, 5]]
+
+ result = searchsorted(sorted_sequence, values, side="right")
+
+ result == [[1, 2, 4],
+ [0, 2, 5]]
+ ```
+
+ Args:
+ sorted_sequence: N-D `Tensor` containing a sorted sequence.
+ values: N-D `Tensor` containing the search values.
+ side: 'left' or 'right'; 'left' corresponds to lower_bound and 'right' to
+ upper_bound.
+ out_type: The output type (`int32` or `int64`). Default is `tf.int32`.
+ name: Optional name for the operation.
+
+ Returns:
+ An N-D `Tensor` the size of values containing the result of applying either
+ lower_bound or upper_bound (depending on side) to each value. The result
+ is not a global index to the entire `Tensor`, but the index in the last
+ dimension.
+
+ Raises:
+ ValueError: If the last dimension of `sorted_sequence >= 2^31-1` elements.
+ If the total size of values exceeds `2^31 - 1` elements.
+ If the first `N-1` dimensions of the two tensors don't match.
+ """
+ sequence_size = shape_internal(sorted_sequence)[-1]
+ values_size = shape_internal(values)[-1]
+ sorted_sequence_2d = reshape(sorted_sequence, [-1, sequence_size])
+ values_2d = reshape(values, [-1, values_size])
+ if side == "right":
+ output = gen_array_ops.upper_bound(sorted_sequence_2d, values_2d, out_type,
+ name)
+ elif side == "left":
+ output = gen_array_ops.lower_bound(sorted_sequence_2d, values_2d, out_type,
+ name)
+ else:
+ raise ValueError("side must be either 'right' or 'left'. Saw: %s." % side)
+ return reshape(output, shape_internal(values))
+
+
quantize.__doc__ = gen_array_ops.quantize_v2.__doc__
diff --git a/tensorflow/python/ops/boosted_trees_ops.py b/tensorflow/python/ops/boosted_trees_ops.py
index f7cbfe0312..720f9f4d41 100644
--- a/tensorflow/python/ops/boosted_trees_ops.py
+++ b/tensorflow/python/ops/boosted_trees_ops.py
@@ -24,11 +24,17 @@ from tensorflow.python.ops import resources
# Re-exporting ops used by other modules.
# pylint: disable=unused-import
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_bucketize
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_calculate_best_gains_per_feature as calculate_best_gains_per_feature
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_center_bias as center_bias
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_create_quantile_stream_resource as create_quantile_stream_resource
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_example_debug_outputs as example_debug_outputs
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_make_quantile_summaries as make_quantile_summaries
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_make_stats_summary as make_stats_summary
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_predict as predict
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_add_summaries as quantile_add_summaries
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_flush as quantile_flush
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_get_bucket_boundaries as get_bucket_boundaries
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_training_predict as training_predict
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_update_ensemble as update_ensemble
# pylint: enable=unused-import
diff --git a/tensorflow/python/ops/cond_v2_impl.py b/tensorflow/python/ops/cond_v2_impl.py
index c6a6b2a7fa..f8b1ddb140 100644
--- a/tensorflow/python/ops/cond_v2_impl.py
+++ b/tensorflow/python/ops/cond_v2_impl.py
@@ -119,7 +119,11 @@ def cond_v2(pred, true_fn, false_fn, name="cond"):
attr_value_pb2.AttrValue(b=True))
# pylint: enable=protected-access
- return tuple(tensors[:num_cond_outputs])
+ result = tuple(tensors[:num_cond_outputs])
+ if len(result) == 1:
+ return result[0]
+ else:
+ return result
@ops.RegisterGradient("If")
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index e3c1aa3d5a..87f8bd85a5 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -61,7 +61,7 @@ from tensorflow.python.util import tf_should_use
from tensorflow.python.util.tf_export import tf_export
-_ENABLE_COND_V2 = os.getenv("TF_ENABLE_COND_V2", "0") != "0"
+ENABLE_COND_V2 = os.getenv("TF_ENABLE_COND_V2", "0") != "0"
# We override the 'tuple' for a control flow op, so we keep python's
@@ -610,9 +610,10 @@ def _EnforceShapeInvariant(merge_var, next_var):
"less-specific shape." %
(input_t.name, input_t.shape, n_shape))
else:
- if not isinstance(var, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
- raise TypeError("Type %s not supported" % type(var))
- if isinstance(var, ops.IndexedSlices):
+ if not isinstance(merge_var,
+ (ops.IndexedSlices, sparse_tensor.SparseTensor)):
+ raise TypeError("Type %s not supported" % type(merge_var))
+ if isinstance(merge_var, ops.IndexedSlices):
m_values_shape = merge_var.values.get_shape()
m_indices_shape = merge_var.indices.get_shape()
m_shape_shape = tensor_shape.TensorShape(None)
@@ -2026,7 +2027,7 @@ def cond(pred,
```
"""
- if _ENABLE_COND_V2:
+ if ENABLE_COND_V2 and not context.executing_eagerly():
return cond_v2_impl.cond_v2(pred, true_fn, false_fn, name)
# We needed to make true_fn/false_fn keyword arguments for
diff --git a/tensorflow/python/ops/ctc_ops.py b/tensorflow/python/ops/ctc_ops.py
index 908e793902..32d455bdad 100644
--- a/tensorflow/python/ops/ctc_ops.py
+++ b/tensorflow/python/ops/ctc_ops.py
@@ -242,11 +242,11 @@ def ctc_beam_search_decoder(inputs, sequence_length, beam_width=100,
If `merge_repeated` is `True`, merge repeated classes in the output beams.
This means that if consecutive entries in a beam are the same,
- only the first of these is emitted. That is, when the top path
- is `A B B B B`, the return value is:
+ only the first of these is emitted. That is, when the sequence is
+ `A B B * B * B` (where '*' is the blank label), the return value is:
* `A B` if `merge_repeated = True`.
- * `A B B B B` if `merge_repeated = False`.
+ * `A B B B` if `merge_repeated = False`.
Args:
inputs: 3-D `float` `Tensor`, size
diff --git a/tensorflow/python/ops/distributions/beta.py b/tensorflow/python/ops/distributions/beta.py
index 99d30b0bd1..2ba1ea6744 100644
--- a/tensorflow/python/ops/distributions/beta.py
+++ b/tensorflow/python/ops/distributions/beta.py
@@ -98,10 +98,13 @@ class Beta(distribution.Distribution):
#### Examples
```python
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+
# Create a batch of three Beta distributions.
alpha = [1, 2, 3]
beta = [1, 2, 3]
- dist = tf.distributions.Beta(alpha, beta)
+ dist = tfd.Beta(alpha, beta)
dist.sample([4, 5]) # Shape [4, 5, 3]
@@ -117,7 +120,7 @@ class Beta(distribution.Distribution):
# Create batch_shape=[2, 3] via parameter broadcast:
alpha = [[1.], [2]] # Shape [2, 1]
beta = [3., 4, 5] # Shape [3]
- dist = tf.distributions.Beta(alpha, beta)
+ dist = tfd.Beta(alpha, beta)
# alpha broadcast as: [[1., 1, 1,],
# [2, 2, 2]]
@@ -138,7 +141,7 @@ class Beta(distribution.Distribution):
```python
alpha = tf.constant(1.0)
beta = tf.constant(2.0)
- dist = tf.distributions.Beta(alpha, beta)
+ dist = tfd.Beta(alpha, beta)
samples = dist.sample(5) # Shape [5]
loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function
# Unbiased stochastic gradients of the loss function
diff --git a/tensorflow/python/ops/distributions/bijector_impl.py b/tensorflow/python/ops/distributions/bijector_impl.py
index b65e64d401..9c63385dd0 100644
--- a/tensorflow/python/ops/distributions/bijector_impl.py
+++ b/tensorflow/python/ops/distributions/bijector_impl.py
@@ -825,10 +825,21 @@ class Bijector(object):
min_event_ndims=self.inverse_min_event_ndims,
event_ndims=event_ndims)):
if not self._is_injective: # No caching for non-injective
- ildjs = self._inverse_log_det_jacobian(y, **kwargs)
- return tuple(self._reduce_jacobian_det_over_event(
- y, ildj, self.inverse_min_event_ndims, event_ndims)
- for ildj in ildjs)
+ try:
+ ildjs = self._inverse_log_det_jacobian(y, **kwargs)
+ return tuple(self._reduce_jacobian_det_over_event(
+ y, ildj, self.inverse_min_event_ndims, event_ndims)
+ for ildj in ildjs)
+ except NotImplementedError as original_exception:
+ try:
+ x = self._inverse(y, **kwargs)
+ fldjs = self._forward_log_det_jacobian(x, **kwargs)
+ return tuple(self._reduce_jacobian_det_over_event(
+ x, -fldj, self.forward_min_event_ndims, event_ndims)
+ for fldj in fldjs)
+ except NotImplementedError:
+ raise original_exception
+
mapping = self._lookup(y=y, kwargs=kwargs)
if mapping.ildj_map is not None and event_ndims in mapping.ildj_map:
return mapping.ildj_map[event_ndims]
@@ -917,11 +928,21 @@ class Bijector(object):
return -1. * self._constant_ildj_map[event_ndims]
x = ops.convert_to_tensor(x, name="x")
self._maybe_assert_dtype(x)
- if not self._is_injective:
- fldjs = self._forward_log_det_jacobian(x, **kwargs) # No caching.
- return tuple(self._reduce_jacobian_det_over_event(
- x, fldj, self.forward_min_event_ndims, event_ndims)
- for fldj in fldjs)
+ if not self._is_injective: # No caching for non-injective
+ try:
+ fldjs = self._forward_log_det_jacobian(x, **kwargs) # No caching.
+ return tuple(self._reduce_jacobian_det_over_event(
+ x, fldj, self.forward_min_event_ndims, event_ndims)
+ for fldj in fldjs)
+ except NotImplementedError as original_exception:
+ try:
+ y = self._forward(x, **kwargs)
+ ildjs = self._inverse_log_det_jacobian(y, **kwargs)
+ return tuple(self._reduce_jacobian_det_over_event(
+ y, -ildj, self.inverse_min_event_ndims, event_ndims)
+ for ildj in ildjs)
+ except NotImplementedError:
+ raise original_exception
mapping = self._lookup(x=x, kwargs=kwargs)
if mapping.ildj_map is not None and event_ndims in mapping.ildj_map:
return -mapping.ildj_map[event_ndims]
@@ -1011,12 +1032,6 @@ class Bijector(object):
def _reduce_jacobian_det_over_event(
self, y, ildj, min_event_ndims, event_ndims):
"""Reduce jacobian over event_ndims - min_event_ndims."""
-
- if not self.is_constant_jacobian:
- return math_ops.reduce_sum(
- ildj,
- self._get_event_reduce_dims(min_event_ndims, event_ndims))
-
# In this case, we need to tile the Jacobian over the event and reduce.
y_rank = array_ops.rank(y)
y_shape = array_ops.shape(y)[
diff --git a/tensorflow/python/ops/distributions/categorical.py b/tensorflow/python/ops/distributions/categorical.py
index dd25fce2ec..fbbacf2521 100644
--- a/tensorflow/python/ops/distributions/categorical.py
+++ b/tensorflow/python/ops/distributions/categorical.py
@@ -69,7 +69,7 @@ class Categorical(distribution.Distribution):
The Categorical distribution is closely related to the `OneHotCategorical` and
`Multinomial` distributions. The Categorical distribution can be intuited as
generating samples according to `argmax{ OneHotCategorical(probs) }` itself
- being identical to `argmax{ Multinomial(probs, total_count=1) }.
+ being identical to `argmax{ Multinomial(probs, total_count=1) }`.
#### Mathematical Details
@@ -83,7 +83,7 @@ class Categorical(distribution.Distribution):
The number of classes, `K`, must not exceed:
- the largest integer representable by `self.dtype`, i.e.,
- `2**(mantissa_bits+1)` (IEE754),
+ `2**(mantissa_bits+1)` (IEEE 754),
- the maximum `Tensor` index, i.e., `2**31-1`.
In other words,
diff --git a/tensorflow/python/ops/distributions/dirichlet.py b/tensorflow/python/ops/distributions/dirichlet.py
index 9104a1d071..415249a958 100644
--- a/tensorflow/python/ops/distributions/dirichlet.py
+++ b/tensorflow/python/ops/distributions/dirichlet.py
@@ -104,10 +104,13 @@ class Dirichlet(distribution.Distribution):
#### Examples
```python
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+
# Create a single trivariate Dirichlet, with the 3rd class being three times
# more frequent than the first. I.e., batch_shape=[], event_shape=[3].
alpha = [1., 2, 3]
- dist = tf.distributions.Dirichlet(alpha)
+ dist = tfd.Dirichlet(alpha)
dist.sample([4, 5]) # shape: [4, 5, 3]
@@ -129,7 +132,7 @@ class Dirichlet(distribution.Distribution):
# Create batch_shape=[2], event_shape=[3]:
alpha = [[1., 2, 3],
[4, 5, 6]] # shape: [2, 3]
- dist = tf.distributions.Dirichlet(alpha)
+ dist = tfd.Dirichlet(alpha)
dist.sample([4, 5]) # shape: [4, 5, 2, 3]
@@ -144,7 +147,7 @@ class Dirichlet(distribution.Distribution):
```python
alpha = tf.constant([1.0, 2.0, 3.0])
- dist = tf.distributions.Dirichlet(alpha)
+ dist = tfd.Dirichlet(alpha)
samples = dist.sample(5) # Shape [5, 3]
loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function
# Unbiased stochastic gradients of the loss function
diff --git a/tensorflow/python/ops/distributions/distribution.py b/tensorflow/python/ops/distributions/distribution.py
index 578e7b7dd2..76d980679e 100644
--- a/tensorflow/python/ops/distributions/distribution.py
+++ b/tensorflow/python/ops/distributions/distribution.py
@@ -601,7 +601,8 @@ class Distribution(_BaseDistribution):
return type(self)(**parameters)
def _batch_shape_tensor(self):
- raise NotImplementedError("batch_shape_tensor is not implemented")
+ raise NotImplementedError(
+ "batch_shape_tensor is not implemented: {}".format(type(self).__name__))
def batch_shape_tensor(self, name="batch_shape_tensor"):
"""Shape of a single sample from a single event index as a 1-D `Tensor`.
@@ -640,7 +641,8 @@ class Distribution(_BaseDistribution):
return tensor_shape.as_shape(self._batch_shape())
def _event_shape_tensor(self):
- raise NotImplementedError("event_shape_tensor is not implemented")
+ raise NotImplementedError(
+ "event_shape_tensor is not implemented: {}".format(type(self).__name__))
def event_shape_tensor(self, name="event_shape_tensor"):
"""Shape of a single sample from a single batch as a 1-D int32 `Tensor`.
@@ -701,7 +703,8 @@ class Distribution(_BaseDistribution):
name="is_scalar_batch")
def _sample_n(self, n, seed=None):
- raise NotImplementedError("sample_n is not implemented")
+ raise NotImplementedError("sample_n is not implemented: {}".format(
+ type(self).__name__))
def _call_sample_n(self, sample_shape, seed, name, **kwargs):
with self._name_scope(name, values=[sample_shape]):
@@ -733,15 +736,19 @@ class Distribution(_BaseDistribution):
return self._call_sample_n(sample_shape, seed, name)
def _log_prob(self, value):
- raise NotImplementedError("log_prob is not implemented")
+ raise NotImplementedError("log_prob is not implemented: {}".format(
+ type(self).__name__))
def _call_log_prob(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
value = ops.convert_to_tensor(value, name="value")
try:
return self._log_prob(value, **kwargs)
- except NotImplementedError:
- return math_ops.log(self._prob(value, **kwargs))
+ except NotImplementedError as original_exception:
+ try:
+ return math_ops.log(self._prob(value, **kwargs))
+ except NotImplementedError:
+ raise original_exception
def log_prob(self, value, name="log_prob"):
"""Log probability density/mass function.
@@ -757,15 +764,19 @@ class Distribution(_BaseDistribution):
return self._call_log_prob(value, name)
def _prob(self, value):
- raise NotImplementedError("prob is not implemented")
+ raise NotImplementedError("prob is not implemented: {}".format(
+ type(self).__name__))
def _call_prob(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
value = ops.convert_to_tensor(value, name="value")
try:
return self._prob(value, **kwargs)
- except NotImplementedError:
- return math_ops.exp(self._log_prob(value, **kwargs))
+ except NotImplementedError as original_exception:
+ try:
+ return math_ops.exp(self._log_prob(value, **kwargs))
+ except NotImplementedError:
+ raise original_exception
def prob(self, value, name="prob"):
"""Probability density/mass function.
@@ -781,15 +792,19 @@ class Distribution(_BaseDistribution):
return self._call_prob(value, name)
def _log_cdf(self, value):
- raise NotImplementedError("log_cdf is not implemented")
+ raise NotImplementedError("log_cdf is not implemented: {}".format(
+ type(self).__name__))
def _call_log_cdf(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
value = ops.convert_to_tensor(value, name="value")
try:
return self._log_cdf(value, **kwargs)
- except NotImplementedError:
- return math_ops.log(self._cdf(value, **kwargs))
+ except NotImplementedError as original_exception:
+ try:
+ return math_ops.log(self._cdf(value, **kwargs))
+ except NotImplementedError:
+ raise original_exception
def log_cdf(self, value, name="log_cdf"):
"""Log cumulative distribution function.
@@ -815,15 +830,19 @@ class Distribution(_BaseDistribution):
return self._call_log_cdf(value, name)
def _cdf(self, value):
- raise NotImplementedError("cdf is not implemented")
+ raise NotImplementedError("cdf is not implemented: {}".format(
+ type(self).__name__))
def _call_cdf(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
value = ops.convert_to_tensor(value, name="value")
try:
return self._cdf(value, **kwargs)
- except NotImplementedError:
- return math_ops.exp(self._log_cdf(value, **kwargs))
+ except NotImplementedError as original_exception:
+ try:
+ return math_ops.exp(self._log_cdf(value, **kwargs))
+ except NotImplementedError:
+ raise original_exception
def cdf(self, value, name="cdf"):
"""Cumulative distribution function.
@@ -845,15 +864,20 @@ class Distribution(_BaseDistribution):
return self._call_cdf(value, name)
def _log_survival_function(self, value):
- raise NotImplementedError("log_survival_function is not implemented")
+ raise NotImplementedError(
+ "log_survival_function is not implemented: {}".format(
+ type(self).__name__))
def _call_log_survival_function(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
value = ops.convert_to_tensor(value, name="value")
try:
return self._log_survival_function(value, **kwargs)
- except NotImplementedError:
- return math_ops.log1p(-self.cdf(value, **kwargs))
+ except NotImplementedError as original_exception:
+ try:
+ return math_ops.log1p(-self.cdf(value, **kwargs))
+ except NotImplementedError:
+ raise original_exception
def log_survival_function(self, value, name="log_survival_function"):
"""Log survival function.
@@ -880,15 +904,19 @@ class Distribution(_BaseDistribution):
return self._call_log_survival_function(value, name)
def _survival_function(self, value):
- raise NotImplementedError("survival_function is not implemented")
+ raise NotImplementedError("survival_function is not implemented: {}".format(
+ type(self).__name__))
def _call_survival_function(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
value = ops.convert_to_tensor(value, name="value")
try:
return self._survival_function(value, **kwargs)
- except NotImplementedError:
- return 1. - self.cdf(value, **kwargs)
+ except NotImplementedError as original_exception:
+ try:
+ return 1. - self.cdf(value, **kwargs)
+ except NotImplementedError:
+ raise original_exception
def survival_function(self, value, name="survival_function"):
"""Survival function.
@@ -912,7 +940,8 @@ class Distribution(_BaseDistribution):
return self._call_survival_function(value, name)
def _entropy(self):
- raise NotImplementedError("entropy is not implemented")
+ raise NotImplementedError("entropy is not implemented: {}".format(
+ type(self).__name__))
def entropy(self, name="entropy"):
"""Shannon entropy in nats."""
@@ -920,7 +949,8 @@ class Distribution(_BaseDistribution):
return self._entropy()
def _mean(self):
- raise NotImplementedError("mean is not implemented")
+ raise NotImplementedError("mean is not implemented: {}".format(
+ type(self).__name__))
def mean(self, name="mean"):
"""Mean."""
@@ -928,7 +958,8 @@ class Distribution(_BaseDistribution):
return self._mean()
def _quantile(self, value):
- raise NotImplementedError("quantile is not implemented")
+ raise NotImplementedError("quantile is not implemented: {}".format(
+ type(self).__name__))
def _call_quantile(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
@@ -955,7 +986,8 @@ class Distribution(_BaseDistribution):
return self._call_quantile(value, name)
def _variance(self):
- raise NotImplementedError("variance is not implemented")
+ raise NotImplementedError("variance is not implemented: {}".format(
+ type(self).__name__))
def variance(self, name="variance"):
"""Variance.
@@ -979,11 +1011,15 @@ class Distribution(_BaseDistribution):
with self._name_scope(name):
try:
return self._variance()
- except NotImplementedError:
- return math_ops.square(self._stddev())
+ except NotImplementedError as original_exception:
+ try:
+ return math_ops.square(self._stddev())
+ except NotImplementedError:
+ raise original_exception
def _stddev(self):
- raise NotImplementedError("stddev is not implemented")
+ raise NotImplementedError("stddev is not implemented: {}".format(
+ type(self).__name__))
def stddev(self, name="stddev"):
"""Standard deviation.
@@ -1008,11 +1044,15 @@ class Distribution(_BaseDistribution):
with self._name_scope(name):
try:
return self._stddev()
- except NotImplementedError:
- return math_ops.sqrt(self._variance())
+ except NotImplementedError as original_exception:
+ try:
+ return math_ops.sqrt(self._variance())
+ except NotImplementedError:
+ raise original_exception
def _covariance(self):
- raise NotImplementedError("covariance is not implemented")
+ raise NotImplementedError("covariance is not implemented: {}".format(
+ type(self).__name__))
def covariance(self, name="covariance"):
"""Covariance.
@@ -1054,7 +1094,8 @@ class Distribution(_BaseDistribution):
return self._covariance()
def _mode(self):
- raise NotImplementedError("mode is not implemented")
+ raise NotImplementedError("mode is not implemented: {}".format(
+ type(self).__name__))
def mode(self, name="mode"):
"""Mode."""
@@ -1080,7 +1121,7 @@ class Distribution(_BaseDistribution):
where `F` denotes the support of the random variable `X ~ P`.
Args:
- other: `tf.distributions.Distribution` instance.
+ other: `tfp.distributions.Distribution` instance.
name: Python `str` prepended to names of ops created by this function.
Returns:
@@ -1111,7 +1152,7 @@ class Distribution(_BaseDistribution):
denotes (Shanon) cross entropy, and `H[.]` denotes (Shanon) entropy.
Args:
- other: `tf.distributions.Distribution` instance.
+ other: `tfp.distributions.Distribution` instance.
name: Python `str` prepended to names of ops created by this function.
Returns:
@@ -1123,7 +1164,7 @@ class Distribution(_BaseDistribution):
return self._kl_divergence(other)
def __str__(self):
- return ("tf.distributions.{type_name}("
+ return ("tfp.distributions.{type_name}("
"\"{self_name}\""
"{maybe_batch_shape}"
"{maybe_event_shape}"
@@ -1139,7 +1180,7 @@ class Distribution(_BaseDistribution):
dtype=self.dtype.name))
def __repr__(self):
- return ("<tf.distributions.{type_name} "
+ return ("<tfp.distributions.{type_name} "
"'{self_name}'"
" batch_shape={batch_shape}"
" event_shape={event_shape}"
diff --git a/tensorflow/python/ops/distributions/gamma.py b/tensorflow/python/ops/distributions/gamma.py
index b631f0247c..3293cda874 100644
--- a/tensorflow/python/ops/distributions/gamma.py
+++ b/tensorflow/python/ops/distributions/gamma.py
@@ -100,8 +100,11 @@ class Gamma(distribution.Distribution):
#### Examples
```python
- dist = tf.distributions.Gamma(concentration=3.0, rate=2.0)
- dist2 = tf.distributions.Gamma(concentration=[3.0, 4.0], rate=[2.0, 3.0])
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+
+ dist = tfd.Gamma(concentration=3.0, rate=2.0)
+ dist2 = tfd.Gamma(concentration=[3.0, 4.0], rate=[2.0, 3.0])
```
Compute the gradients of samples w.r.t. the parameters:
@@ -109,7 +112,7 @@ class Gamma(distribution.Distribution):
```python
concentration = tf.constant(3.0)
rate = tf.constant(2.0)
- dist = tf.distributions.Gamma(concentration, rate)
+ dist = tfd.Gamma(concentration, rate)
samples = dist.sample(5) # Shape [5]
loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function
# Unbiased stochastic gradients of the loss function
diff --git a/tensorflow/python/ops/distributions/kullback_leibler.py b/tensorflow/python/ops/distributions/kullback_leibler.py
index e3c6f3e789..fdeb97bf64 100644
--- a/tensorflow/python/ops/distributions/kullback_leibler.py
+++ b/tensorflow/python/ops/distributions/kullback_leibler.py
@@ -127,8 +127,8 @@ def cross_entropy(ref, other,
where `F` denotes the support of the random variable `X ~ P`.
Args:
- ref: `tf.distributions.Distribution` instance.
- other: `tf.distributions.Distribution` instance.
+ ref: `tfd.Distribution` instance.
+ other: `tfd.Distribution` instance.
allow_nan_stats: Python `bool`, default `True`. When `True`,
statistics (e.g., mean, mode, variance) use the value "`NaN`" to
indicate the result is undefined. When `False`, an exception is raised
diff --git a/tensorflow/python/ops/distributions/normal.py b/tensorflow/python/ops/distributions/normal.py
index d0a987ba7c..2feaf806c0 100644
--- a/tensorflow/python/ops/distributions/normal.py
+++ b/tensorflow/python/ops/distributions/normal.py
@@ -71,15 +71,18 @@ class Normal(distribution.Distribution):
Examples of initialization of one or a batch of distributions.
```python
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+
# Define a single scalar Normal distribution.
- dist = tf.distributions.Normal(loc=0., scale=3.)
+ dist = tfd.Normal(loc=0., scale=3.)
# Evaluate the cdf at 1, returning a scalar.
dist.cdf(1.)
# Define a batch of two scalar valued Normals.
# The first has mean 1 and standard deviation 11, the second 2 and 22.
- dist = tf.distributions.Normal(loc=[1, 2.], scale=[11, 22.])
+ dist = tfd.Normal(loc=[1, 2.], scale=[11, 22.])
# Evaluate the pdf of the first distribution on 0, and the second on 1.5,
# returning a length two tensor.
@@ -94,7 +97,7 @@ class Normal(distribution.Distribution):
```python
# Define a batch of two scalar valued Normals.
# Both have mean 1, but different standard deviations.
- dist = tf.distributions.Normal(loc=1., scale=[11, 22.])
+ dist = tfd.Normal(loc=1., scale=[11, 22.])
# Evaluate the pdf of both distributions on the same point, 3.0,
# returning a length 2 tensor.
diff --git a/tensorflow/python/ops/distributions/student_t.py b/tensorflow/python/ops/distributions/student_t.py
index e0cf6f86f1..e8d214bbe0 100644
--- a/tensorflow/python/ops/distributions/student_t.py
+++ b/tensorflow/python/ops/distributions/student_t.py
@@ -91,8 +91,11 @@ class StudentT(distribution.Distribution):
Examples of initialization of one or a batch of distributions.
```python
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+
# Define a single scalar Student t distribution.
- single_dist = tf.distributions.StudentT(df=3)
+ single_dist = tfd.StudentT(df=3)
# Evaluate the pdf at 1, returning a scalar Tensor.
single_dist.prob(1.)
@@ -100,9 +103,7 @@ class StudentT(distribution.Distribution):
# Define a batch of two scalar valued Student t's.
# The first has degrees of freedom 2, mean 1, and scale 11.
# The second 3, 2 and 22.
- multi_dist = tf.distributions.StudentT(df=[2, 3],
- loc=[1, 2.],
- scale=[11, 22.])
+ multi_dist = tfd.StudentT(df=[2, 3], loc=[1, 2.], scale=[11, 22.])
# Evaluate the pdf of the first distribution on 0, and the second on 1.5,
# returning a length two tensor.
@@ -117,7 +118,7 @@ class StudentT(distribution.Distribution):
```python
# Define a batch of two Student's t distributions.
# Both have df 2 and mean 1, but different scales.
- dist = tf.distributions.StudentT(df=2, loc=1, scale=[11, 22.])
+ dist = tfd.StudentT(df=2, loc=1, scale=[11, 22.])
# Evaluate the pdf of both distributions on the same point, 3.0,
# returning a length 2 tensor.
@@ -130,7 +131,7 @@ class StudentT(distribution.Distribution):
df = tf.constant(2.0)
loc = tf.constant(2.0)
scale = tf.constant(11.0)
- dist = tf.distributions.StudentT(df=df, loc=loc, scale=scale)
+ dist = tfd.StudentT(df=df, loc=loc, scale=scale)
samples = dist.sample(5) # Shape [5]
loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function
# Unbiased stochastic gradients of the loss function
@@ -138,7 +139,6 @@ class StudentT(distribution.Distribution):
```
"""
- # pylint: enable=line-too-long
def __init__(self,
df,
diff --git a/tensorflow/python/ops/distributions/util.py b/tensorflow/python/ops/distributions/util.py
index 3e480a79f5..ad848dfee6 100644
--- a/tensorflow/python/ops/distributions/util.py
+++ b/tensorflow/python/ops/distributions/util.py
@@ -155,7 +155,8 @@ def get_logits_and_probs(logits=None,
probs=None,
multidimensional=False,
validate_args=False,
- name="get_logits_and_probs"):
+ name="get_logits_and_probs",
+ dtype=None):
"""Converts logit to probabilities (or vice-versa), and returns both.
Args:
@@ -169,6 +170,7 @@ def get_logits_and_probs(logits=None,
`0 <= probs <= 1` (if not `multidimensional`) or that the last dimension
of `probs` sums to one.
name: A name for this operation (optional).
+ dtype: `tf.DType` to prefer when converting args to `Tensor`s.
Returns:
logits, probs: Tuple of `Tensor`s. If `probs` has an entry that is `0` or
@@ -183,7 +185,7 @@ def get_logits_and_probs(logits=None,
raise ValueError("Must pass probs or logits, but not both.")
if probs is None:
- logits = ops.convert_to_tensor(logits, name="logits")
+ logits = ops.convert_to_tensor(logits, name="logits", dtype=dtype)
if not logits.dtype.is_floating:
raise TypeError("logits must having floating type.")
# We can early return since we constructed probs and therefore know
@@ -194,7 +196,7 @@ def get_logits_and_probs(logits=None,
return logits, nn.softmax(logits, name="probs")
return logits, math_ops.sigmoid(logits, name="probs")
- probs = ops.convert_to_tensor(probs, name="probs")
+ probs = ops.convert_to_tensor(probs, name="probs", dtype=dtype)
if not probs.dtype.is_floating:
raise TypeError("probs must having floating type.")
@@ -524,6 +526,8 @@ def matrix_diag_transform(matrix, transform=None, name=None):
Example of heteroskedastic 2-D linear regression.
```python
+ tfd = tfp.distributions
+
# Get a trainable Cholesky factor.
matrix_values = tf.contrib.layers.fully_connected(activations, 4)
matrix = tf.reshape(matrix_values, (batch_size, 2, 2))
@@ -533,7 +537,7 @@ def matrix_diag_transform(matrix, transform=None, name=None):
mu = tf.contrib.layers.fully_connected(activations, 2)
# This is a fully trainable multivariate normal!
- dist = tf.contrib.distributions.MVNCholesky(mu, chol)
+ dist = tfd.MultivariateNormalTriL(mu, chol)
# Standard log loss. Minimizing this will "train" mu and chol, and then dist
# will be a distribution predicting labels as multivariate Gaussians.
diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py
index 6263041b8d..60d73a1693 100644
--- a/tensorflow/python/ops/embedding_ops.py
+++ b/tensorflow/python/ops/embedding_ops.py
@@ -550,9 +550,11 @@ def safe_embedding_lookup_sparse(embedding_weights,
raise ValueError('Missing embedding_weights %s.' % embedding_weights)
dtype = sparse_weights.dtype if sparse_weights is not None else None
- embedding_weights = [
- ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
- ]
+ if not isinstance(embedding_weights[0],
+ resource_variable_ops.ResourceVariable):
+ embedding_weights = [
+ ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
+ ]
with ops.name_scope(name, 'embedding_lookup',
embedding_weights + [sparse_ids,
diff --git a/tensorflow/python/ops/functional_ops.py b/tensorflow/python/ops/functional_ops.py
index a4e7c84ae4..119d9522bd 100644
--- a/tensorflow/python/ops/functional_ops.py
+++ b/tensorflow/python/ops/functional_ops.py
@@ -41,6 +41,7 @@ from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops.gen_functional_ops import remote_call
# pylint: enable=unused-import
from tensorflow.python.ops.gen_functional_ops import symbolic_gradient
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
@@ -263,7 +264,7 @@ def foldr(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
@tf_export("map_fn")
-def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
+def map_fn(fn, elems, dtype=None, parallel_iterations=None, back_prop=True,
swap_memory=False, infer_shape=True, name=None):
"""map on the list of tensors unpacked from `elems` on dimension 0.
@@ -305,6 +306,25 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
instead.
+ When executing eagerly, map_fn does not execute in parallel even if
+ `parallel_iterations` is set to a value > 1. You can still get the
+ performance benefits of running a function in parallel by using the
+ `tf.contrib.eager.defun` decorator,
+
+ ```python
+ # Assume the function being used in map_fn is fn.
+ # To ensure map_fn calls fn in parallel, use the defun decorator.
+ @tf.contrib.eager.defun
+ def func(tensor):
+ return tf.map_fn(fn, tensor)
+ ```
+
+ Note that if you use the defun decorator, any non-TensorFlow Python code
+ that you may have written in your function won't get executed. See
+ `tf.contrib.eager.defun` for more details. The recommendation would be to
+ debug without defun but switch to defun to get performance benefits of
+ running map_fn in parallel.
+
Args:
fn: The callable to be performed. It accepts one argument, which will
have the same (possibly nested) structure as `elems`. Its output
@@ -317,7 +337,8 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
of Tensors differing from the structure of `elems`, then `dtype` is not
optional and must have the same structure as the output of `fn`.
parallel_iterations: (optional) The number of iterations allowed to run
- in parallel.
+ in parallel. When graph building, the default value is 10. While executing
+ eagerly, the default value is set to 1.
back_prop: (optional) True enables support for back propagation.
swap_memory: (optional) True enables GPU-CPU memory swapping.
infer_shape: (optional) False disables tests for consistent output shapes.
@@ -363,6 +384,20 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
" SparseTensor(input.indices, map_fn(fn, input.values), "
"input.dense_shape)")
+ in_graph_mode = not context.executing_eagerly()
+ # Set the default number of parallel_iterations depending on graph/eager mode.
+ if in_graph_mode and not parallel_iterations:
+ parallel_iterations = 10
+ elif not in_graph_mode and not parallel_iterations:
+ parallel_iterations = 1
+
+ if not in_graph_mode and parallel_iterations > 1:
+ logging.log_first_n(logging.WARN, "Setting parallel_iterations > 1 has no "
+ "effect when executing eagerly. Consider calling map_fn"
+ " with tf.contrib.eager.defun to execute fn in "
+ "parallel.", 1)
+ parallel_iterations = 1
+
input_is_sequence = nest.is_sequence(elems)
input_flatten = lambda x: nest.flatten(x) if input_is_sequence else [x]
def input_pack(x):
@@ -381,7 +416,6 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
elems_flat = input_flatten(elems)
- in_graph_mode = not context.executing_eagerly()
with ops.name_scope(name, "map", elems_flat):
# TODO(akshayka): Remove the in_graph_mode check once caching devices are
# supported in Eager
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py
index 3268b38b86..056015d6b6 100644
--- a/tensorflow/python/ops/gradients_impl.py
+++ b/tensorflow/python/ops/gradients_impl.py
@@ -184,7 +184,7 @@ def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs,
between_op_list.append(op)
# Clear the boolean so we won't add the inputs again.
reached_ops.remove(op)
- for inp in _Inputs(op, xs):
+ for inp in _NonEagerInputs(op, xs):
queue.append(inp.op)
# X in between_ops iff X is on a path of zero or more backpropagatable tensors
# between from_ops and to_ops
@@ -196,7 +196,7 @@ def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs,
# Initialize pending count for between ops.
pending_count = collections.defaultdict(int)
for op in between_op_list:
- for x in _Inputs(op, xs):
+ for x in _NonEagerInputs(op, xs):
if x.op in between_ops:
pending_count[x.op] += 1
@@ -260,6 +260,12 @@ def _DefaultGradYs(grad_ys,
"Gradient type %s generated for complex-valued "
"tensor %s with type %s must be real" % (dtypes.as_dtype(
grad_y.dtype).name, y, dtypes.as_dtype(y.dtype).name))
+ elif y.dtype == dtypes.variant:
+ if grad_y.dtype != dtypes.variant:
+ raise TypeError(
+ "Gradient type %s generated for variant "
+ "tensor %s with type %s must be variant" % (dtypes.as_dtype(
+ grad_y.dtype).name, y, dtypes.as_dtype(y.dtype).name))
else:
raise TypeError(
"Tensor %s with type %s must be numeric "
@@ -298,7 +304,7 @@ def _IsBackpropagatable(tensor):
if _IsTrainable(tensor):
return True
dtype = dtypes.as_dtype(tensor.dtype)
- return dtype.base_dtype in (dtypes.bfloat16, dtypes.resource, dtypes.variant)
+ return dtype.base_dtype in (dtypes.bfloat16, dtypes.variant)
def _VerifyGeneratedGradients(grads, op):
@@ -341,7 +347,7 @@ def _StopOps(from_ops, stop_gradient_ops, pending_count, xs):
stop_ops = set()
for op in from_ops:
is_stop_op = True
- for inp in _Inputs(op, xs):
+ for inp in _NonEagerInputs(op, xs):
if pending_count[inp.op] > 0:
is_stop_op = False
break
@@ -365,10 +371,10 @@ def _IsPartitionedCall(op):
return op.type == "PartitionedCall" or op.type == "StatefulPartitionedCall"
-def _SymGrad(op, out_grads, xs):
+def _SymGrad(op, out_grads):
"""Backprop through a function call node op given its outputs' gradients."""
- f_in = [x for x in _Inputs(op, xs)] + out_grads
- f_types = [x.dtype for x in _Inputs(op, xs)]
+ f_in = [x for x in op.inputs] + out_grads
+ f_types = [x.dtype for x in op.inputs]
f = attr_value_pb2.NameAttrList()
if _IsPartitionedCall(op):
f.name = op.get_attr("f").name
@@ -435,7 +441,7 @@ def _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs):
if curr_op in from_ops:
target_op = curr_op
break
- queue.extend(t.op for t in _Inputs(curr_op, xs))
+ queue.extend(t.op for t in _NonEagerInputs(curr_op, xs))
assert target_op
raise ValueError(
"Cannot compute gradient inside while loop with respect to op '%s'. "
@@ -468,7 +474,8 @@ def _MaybeCaptured(t):
A tensor, potentially from a different Graph/_function.FuncGraph.
"""
# pylint: disable=protected-access
- if _IsFunction(t.op.graph) and t.op.type == "Placeholder":
+ if (not isinstance(t, ops.EagerTensor) and
+ _IsFunction(t.op.graph) and t.op.type == "Placeholder"):
for input_t, placeholder_t in _Captures(t.op.graph).items():
if t == placeholder_t:
return _MaybeCaptured(input_t)
@@ -478,9 +485,12 @@ def _MaybeCaptured(t):
# TODO(skyewm): plumbing xs through everywhere is ugly, consider making
# _GradientsHelper a class with xs as a member variable.
-def _Inputs(op, xs):
+def _NonEagerInputs(op, xs):
"""Returns the inputs of op, crossing closure boundaries where necessary.
+ Does not return any captured EagerTensors, i.e., the number of tensors
+ returned may be less than than the actual number of inputs.
+
Args:
op: Operation
xs: list of Tensors we are differentiating w.r.t.
@@ -491,12 +501,19 @@ def _Inputs(op, xs):
captured inputs.
"""
if _IsFunction(op.graph): # pylint: disable=protected-access
- # If we're differentiating w.r.t. `t`, do not attempt to traverse through it
- # to a captured value. The algorithm needs to "see" `t` in this case, even
- # if it's a function input for a captured value, whereas usually we'd like
- # to traverse through these closures as if the captured value was the direct
- # input to op.
- return [t if (t in xs) else _MaybeCaptured(t) for t in op.inputs]
+ inputs = []
+ for t in op.inputs:
+ # If we're differentiating w.r.t. `t`, do not attempt to traverse through
+ # it to a captured value. The algorithm needs to "see" `t` in this case,
+ # even if it's a function input for a captured value, whereas usually we'd
+ # like to traverse through these closures as if the captured value was the
+ # direct input to op.
+ if t not in xs:
+ t = _MaybeCaptured(t)
+ # Skip captured eager inputs.
+ if isinstance(t, ops.EagerTensor): continue
+ inputs.append(t)
+ return inputs
else:
return op.inputs
@@ -799,7 +816,7 @@ def _GradientsHelper(ys,
# For function call ops, we add a 'SymbolicGradient'
# node to the graph to compute gradients.
in_grads = _MaybeCompile(grad_scope, op, func_call,
- lambda: _SymGrad(op, out_grads, xs))
+ lambda: _SymGrad(op, out_grads))
in_grads = _AsList(in_grads)
_VerifyGeneratedGradients(in_grads, op)
if gate_gradients and len([x for x in in_grads
@@ -814,8 +831,9 @@ def _GradientsHelper(ys,
else:
# If no grad_fn is defined or none of out_grads is available,
# just propagate a list of None backwards.
- in_grads = [None] * len(_Inputs(op, xs))
- for i, (t_in, in_grad) in enumerate(zip(_Inputs(op, xs), in_grads)):
+ in_grads = [None] * len(_NonEagerInputs(op, xs))
+ for i, (t_in, in_grad) in enumerate(zip(_NonEagerInputs(op, xs),
+ in_grads)):
if in_grad is not None:
if (isinstance(in_grad, ops.Tensor) and
t_in.dtype != dtypes.resource):
@@ -856,7 +874,7 @@ def _HasAnyNotNoneGrads(grads, op):
def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state,
xs):
"""Update pending count for the inputs of op and enqueue ready ops."""
- for x in _Inputs(op, xs):
+ for x in _NonEagerInputs(op, xs):
pending_count[x.op] -= 1
ready = (pending_count[x.op] == 0)
if loop_state and not ready:
diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py
index 3759d8a543..4f6e5dc473 100644
--- a/tensorflow/python/ops/gradients_test.py
+++ b/tensorflow/python/ops/gradients_test.py
@@ -45,6 +45,7 @@ from tensorflow.python.ops import data_flow_ops # pylint: disable=unused-import
from tensorflow.python.ops import functional_ops # pylint: disable=unused-import
from tensorflow.python.ops import gradients
from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import list_ops
from tensorflow.python.ops import math_grad # pylint: disable=unused-import
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_grad # pylint: disable=unused-import
@@ -530,6 +531,24 @@ class FunctionGradientsTest(test_util.TensorFlowTestCase):
with self.cached_session() as sess:
self.assertEqual(sess.run(z_grad), 3.0)
+ def testCapturedEagerTensors(self):
+ # Test that we can handle captured eager tensors unrelated to the gradient
+ # computation (i.e. we need to ignore them).
+ # TODO(skyewm): make it an error if you try to take the gradient wrt a
+ # captured EagerTensor
+ with context.eager_mode():
+ c = constant_op.constant(2.0, name="c")
+
+ @function.defun
+ def Foo():
+ x = constant_op.constant(10.0, name="x")
+ y = math_ops.multiply(x, c, name="y")
+ z = math_ops.multiply(y, 3.0, name="z")
+ g = gradients_impl.gradients(z, x)
+ return g[0]
+
+ self.assertEqual(Foo().numpy(), 6.0)
+
class StopGradientTest(test_util.TensorFlowTestCase):
@@ -1004,5 +1023,25 @@ class AggregateIndexedSlicesGradientsTest(test_util.TensorFlowTestCase):
self._assert_indexed_slices_equal(total, result)
+class TensorListGradientsTest(test_util.TensorFlowTestCase):
+
+ def testDefaultGradYs(self):
+ with ops.Graph().as_default():
+ tl = list_ops.empty_tensor_list(
+ element_dtype=dtypes.float32,
+ element_shape=ops.convert_to_tensor([], dtype=dtypes.int32))
+ a = constant(1.0)
+ tl = list_ops.tensor_list_push_back(tl, a)
+
+ grad_tl = list_ops.empty_tensor_list(
+ element_dtype=dtypes.float32,
+ element_shape=ops.convert_to_tensor([], dtype=dtypes.int32))
+ grad_tl = list_ops.tensor_list_push_back(tl, constant(5.0))
+
+ grad = gradients.gradients(tl, a, grad_ys=grad_tl)[0]
+ with self.cached_session() as sess:
+ self.assertEquals(sess.run(grad), 5.)
+
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py
index de260f3140..1c75aab578 100644
--- a/tensorflow/python/ops/image_ops_impl.py
+++ b/tensorflow/python/ops/image_ops_impl.py
@@ -29,7 +29,6 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gen_image_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import math_ops
@@ -301,21 +300,21 @@ def random_flip_left_right(image, seed=None):
def _random_flip(image, flip_index, seed, scope_name):
"""Randomly (50% chance) flip an image along axis `flip_index`.
- Args:
- image: 4-D Tensor of shape `[batch, height, width, channels]` or
- 3-D Tensor of shape `[height, width, channels]`.
- flip_index: The dimension along which to flip the image.
- Vertical: 0, Horizontal: 1
- seed: A Python integer. Used to create a random seed. See
- `tf.set_random_seed`
- for behavior.
- scope_name: Name of the scope in which the ops are added.
- Returns:
- A tensor of the same type and shape as `image`.
+ Args:
+ image: 4-D Tensor of shape `[batch, height, width, channels]` or
+ 3-D Tensor of shape `[height, width, channels]`.
+ flip_index: Dimension along which to flip image. Vertical: 0, Horizontal: 1
+ seed: A Python integer. Used to create a random seed. See
+ `tf.set_random_seed`
+ for behavior.
+ scope_name: Name of the scope in which the ops are added.
- Raises:
- ValueError: if the shape of `image` not supported.
+ Returns:
+ A tensor of the same type and shape as `image`.
+
+ Raises:
+ ValueError: if the shape of `image` not supported.
"""
with ops.name_scope(None, scope_name, [image]) as scope:
image = ops.convert_to_tensor(image, name='image')
@@ -330,19 +329,18 @@ def _random_flip(image, flip_index, seed, scope_name):
lambda: image,
name=scope
)
- if isinstance(result, tuple):
- result = result[0] # TODO(b/111124878) remove this logic (CondV2).
return fix_image_flip_shape(image, result)
elif shape.ndims == 4:
+ batch_size = array_ops.shape(image)[0]
uniform_random = random_ops.random_uniform(
- [array_ops.shape(image)[0]], 0, 1.0, seed=seed
+ [batch_size], 0, 1.0, seed=seed
)
- mirror_cond = math_ops.less(uniform_random, .5)
- return array_ops.where(
- mirror_cond,
- image,
- functional_ops.map_fn(lambda x: array_ops.reverse(x, [flip_index]), image, dtype=image.dtype)
+ flips = math_ops.round(
+ array_ops.reshape(uniform_random, [batch_size, 1, 1, 1])
)
+ flips = math_ops.cast(flips, image.dtype)
+ flipped_input = array_ops.reverse(image, [flip_index + 1])
+ return flips * flipped_input + (1 - flips) * image
else:
raise ValueError('\'image\' must have either 3 or 4 dimensions.')
@@ -1029,10 +1027,10 @@ def resize_images(images,
scale_factor_width = (math_ops.to_float(new_width_const) /
math_ops.to_float(current_width))
scale_factor = math_ops.minimum(scale_factor_height, scale_factor_width)
- scaled_height_const = math_ops.to_int32(scale_factor *
- math_ops.to_float(current_height))
- scaled_width_const = math_ops.to_int32(scale_factor *
- math_ops.to_float(current_width))
+ scaled_height_const = math_ops.to_int32(
+ math_ops.round(scale_factor * math_ops.to_float(current_height)))
+ scaled_width_const = math_ops.to_int32(
+ math_ops.round(scale_factor * math_ops.to_float(current_width)))
# NOTE: Reset the size and other constants used later.
size = ops.convert_to_tensor([scaled_height_const, scaled_width_const],
@@ -1176,7 +1174,7 @@ def resize_image_with_pad(image,
@tf_export('image.per_image_standardization')
def per_image_standardization(image):
- """Linearly scales `image` to have zero mean and unit norm.
+ """Linearly scales `image` to have zero mean and unit variance.
This op computes `(x - mean) / adjusted_stddev`, where `mean` is the average
of all values in image, and
@@ -1379,7 +1377,7 @@ def adjust_gamma(image, gamma=1, gain=1):
[1] http://en.wikipedia.org/wiki/Gamma_correction
"""
- with ops.op_scope([image, gamma, gain], None, 'adjust_gamma'):
+ with ops.name_scope(None, 'adjust_gamma', [image, gamma, gain]) as name:
# Convert pixel value to DT_FLOAT for computing adjusted image.
img = ops.convert_to_tensor(image, name='img', dtype=dtypes.float32)
# Keep image dtype for computing the scale of corresponding dtype.
diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py
index 795e6bbc3e..35fdee4fad 100644
--- a/tensorflow/python/ops/image_ops_test.py
+++ b/tensorflow/python/ops/image_ops_test.py
@@ -2687,6 +2687,12 @@ class ResizeImagesTest(test_util.TensorFlowTestCase):
self._assertResizeCheckShape(x, x_shape, [3840, 2160], [3840, 2160, 3])
+ def testPreserveAspectRatioSquare(self):
+ x_shape = [299, 299, 3]
+ x = np.random.uniform(size=x_shape)
+
+ self._assertResizeCheckShape(x, x_shape, [320, 320], [320, 320, 3])
+
class ResizeImageWithPadTest(test_util.TensorFlowTestCase):
@@ -3667,7 +3673,7 @@ class NonMaxSuppressionTest(test_util.TensorFlowTestCase):
# Note: There are multiple versions of non_max_suppression v2, v3, v4.
# gen_image_ops.non_max_suppression_v2:
for dtype in [np.float16, np.float32]:
- with self.test_session():
+ with self.cached_session():
boxes = constant_op.constant(boxes_np, dtype=dtype)
scores = constant_op.constant(scores_np, dtype=dtype)
max_output_size = constant_op.constant(max_output_size_np)
@@ -3677,7 +3683,7 @@ class NonMaxSuppressionTest(test_util.TensorFlowTestCase):
self.assertAllClose(selected_indices, [3, 0, 5])
# image_ops.non_max_suppression = gen_image_ops.non_max_suppression_v3.
for dtype in [np.float16, np.float32]:
- with self.test_session():
+ with self.cached_session():
boxes = constant_op.constant(boxes_np, dtype=dtype)
scores = constant_op.constant(scores_np, dtype=dtype)
max_output_size = constant_op.constant(max_output_size_np)
@@ -3688,7 +3694,7 @@ class NonMaxSuppressionTest(test_util.TensorFlowTestCase):
# gen_image_ops.non_max_suppression_v4.
score_threshold = float('-inf')
for dtype in [np.float16, np.float32]:
- with self.test_session():
+ with self.cached_session():
boxes = constant_op.constant(boxes_np, dtype=dtype)
scores = constant_op.constant(scores_np, dtype=dtype)
max_output_size = constant_op.constant(max_output_size_np)
diff --git a/tensorflow/python/ops/linalg/linear_operator_addition.py b/tensorflow/python/ops/linalg/linear_operator_addition.py
new file mode 100644
index 0000000000..86130a2c07
--- /dev/null
+++ b/tensorflow/python/ops/linalg/linear_operator_addition.py
@@ -0,0 +1,432 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Add one or more `LinearOperators` efficiently."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+
+import six
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops.linalg import linear_operator
+from tensorflow.python.ops.linalg import linear_operator_diag
+from tensorflow.python.ops.linalg import linear_operator_full_matrix
+from tensorflow.python.ops.linalg import linear_operator_identity
+from tensorflow.python.ops.linalg import linear_operator_lower_triangular
+
+__all__ = []
+
+
+def add_operators(operators,
+ operator_name=None,
+ addition_tiers=None,
+ name=None):
+ """Efficiently add one or more linear operators.
+
+ Given operators `[A1, A2,...]`, this `Op` returns a possibly shorter list of
+ operators `[B1, B2,...]` such that
+
+ ```sum_k Ak.matmul(x) = sum_k Bk.matmul(x).```
+
+ The operators `Bk` result by adding some of the `Ak`, as allowed by
+ `addition_tiers`.
+
+ Example of efficient adding of diagonal operators.
+
+ ```python
+ A1 = LinearOperatorDiag(diag=[1., 1.], name="A1")
+ A2 = LinearOperatorDiag(diag=[2., 2.], name="A2")
+
+ # Use two tiers, the first contains an Adder that returns Diag. Since both
+ # A1 and A2 are Diag, they can use this Adder. The second tier will not be
+ # used.
+ addition_tiers = [
+ [_AddAndReturnDiag()],
+ [_AddAndReturnMatrix()]]
+ B_list = add_operators([A1, A2], addition_tiers=addition_tiers)
+
+ len(B_list)
+ ==> 1
+
+ B_list[0].__class__.__name__
+ ==> 'LinearOperatorDiag'
+
+ B_list[0].to_dense()
+ ==> [[3., 0.],
+ [0., 3.]]
+
+ B_list[0].name
+ ==> 'Add/A1__A2/'
+ ```
+
+ Args:
+ operators: Iterable of `LinearOperator` objects with same `dtype`, domain
+ and range dimensions, and broadcastable batch shapes.
+ operator_name: String name for returned `LinearOperator`. Defaults to
+ concatenation of "Add/A__B/" that indicates the order of addition steps.
+ addition_tiers: List tiers, like `[tier_0, tier_1, ...]`, where `tier_i`
+ is a list of `Adder` objects. This function attempts to do all additions
+ in tier `i` before trying tier `i + 1`.
+ name: A name for this `Op`. Defaults to `add_operators`.
+
+ Returns:
+ Subclass of `LinearOperator`. Class and order of addition may change as new
+ (and better) addition strategies emerge.
+
+ Raises:
+ ValueError: If `operators` argument is empty.
+ ValueError: If shapes are incompatible.
+ """
+ # Default setting
+ if addition_tiers is None:
+ addition_tiers = _DEFAULT_ADDITION_TIERS
+
+ # Argument checking.
+ check_ops.assert_proper_iterable(operators)
+ operators = list(reversed(operators))
+ if len(operators) < 1:
+ raise ValueError(
+ "Argument 'operators' must contain at least one operator. "
+ "Found: %s" % operators)
+ if not all(
+ isinstance(op, linear_operator.LinearOperator) for op in operators):
+ raise TypeError(
+ "Argument 'operators' must contain only LinearOperator instances. "
+ "Found: %s" % operators)
+ _static_check_for_same_dimensions(operators)
+ _static_check_for_broadcastable_batch_shape(operators)
+
+ graph_parents = []
+ for operator in operators:
+ graph_parents.extend(operator.graph_parents)
+
+ with ops.name_scope(name or "add_operators", values=graph_parents):
+
+ # Additions done in one of the tiers. Try tier 0, 1,...
+ ops_to_try_at_next_tier = list(operators)
+ for tier in addition_tiers:
+ ops_to_try_at_this_tier = ops_to_try_at_next_tier
+ ops_to_try_at_next_tier = []
+ while ops_to_try_at_this_tier:
+ op1 = ops_to_try_at_this_tier.pop()
+ op2, adder = _pop_a_match_at_tier(op1, ops_to_try_at_this_tier, tier)
+ if op2 is not None:
+ # Will try to add the result of this again at this same tier.
+ new_operator = adder.add(op1, op2, operator_name)
+ ops_to_try_at_this_tier.append(new_operator)
+ else:
+ ops_to_try_at_next_tier.append(op1)
+
+ return ops_to_try_at_next_tier
+
+
+def _pop_a_match_at_tier(op1, operator_list, tier):
+ # Search from the back of list to the front in order to create nice default
+ # order of operations.
+ for i in range(1, len(operator_list) + 1):
+ op2 = operator_list[-i]
+ for adder in tier:
+ if adder.can_add(op1, op2):
+ return operator_list.pop(-i), adder
+ return None, None
+
+
+def _infer_hints_allowing_override(op1, op2, hints):
+ """Infer hints from op1 and op2. hints argument is an override.
+
+ Args:
+ op1: LinearOperator
+ op2: LinearOperator
+ hints: _Hints object holding "is_X" boolean hints to use for returned
+ operator.
+ If some hint is None, try to set using op1 and op2. If the
+ hint is provided, ignore op1 and op2 hints. This allows an override
+ of previous hints, but does not allow forbidden hints (e.g. you still
+ cannot say a real diagonal operator is not self-adjoint.
+
+ Returns:
+ _Hints object.
+ """
+ hints = hints or _Hints()
+ # If A, B are self-adjoint, then so is A + B.
+ if hints.is_self_adjoint is None:
+ is_self_adjoint = op1.is_self_adjoint and op2.is_self_adjoint
+ else:
+ is_self_adjoint = hints.is_self_adjoint
+
+ # If A, B are positive definite, then so is A + B.
+ if hints.is_positive_definite is None:
+ is_positive_definite = op1.is_positive_definite and op2.is_positive_definite
+ else:
+ is_positive_definite = hints.is_positive_definite
+
+ # A positive definite operator is always non-singular.
+ if is_positive_definite and hints.is_positive_definite is None:
+ is_non_singular = True
+ else:
+ is_non_singular = hints.is_non_singular
+
+ return _Hints(
+ is_non_singular=is_non_singular,
+ is_self_adjoint=is_self_adjoint,
+ is_positive_definite=is_positive_definite)
+
+
+def _static_check_for_same_dimensions(operators):
+ """ValueError if operators determined to have different dimensions."""
+ if len(operators) < 2:
+ return
+
+ domain_dimensions = [(op.name, op.domain_dimension.value) for op in operators
+ if op.domain_dimension.value is not None]
+ if len(set(value for name, value in domain_dimensions)) > 1:
+ raise ValueError("Operators must have the same domain dimension. Found: %s"
+ % domain_dimensions)
+
+ range_dimensions = [(op.name, op.range_dimension.value) for op in operators
+ if op.range_dimension.value is not None]
+ if len(set(value for name, value in range_dimensions)) > 1:
+ raise ValueError("Operators must have the same range dimension. Found: %s" %
+ range_dimensions)
+
+
+def _static_check_for_broadcastable_batch_shape(operators):
+ """ValueError if operators determined to have non-broadcastable shapes."""
+ if len(operators) < 2:
+ return
+
+ # This will fail if they cannot be broadcast together.
+ batch_shape = operators[0].batch_shape
+ for op in operators[1:]:
+ batch_shape = array_ops.broadcast_static_shape(batch_shape, op.batch_shape)
+
+
+class _Hints(object):
+ """Holds 'is_X' flags that every LinearOperator is initialized with."""
+
+ def __init__(self,
+ is_non_singular=None,
+ is_positive_definite=None,
+ is_self_adjoint=None):
+ self.is_non_singular = is_non_singular
+ self.is_positive_definite = is_positive_definite
+ self.is_self_adjoint = is_self_adjoint
+
+
+################################################################################
+# Classes to add two linear operators.
+################################################################################
+
+
+@six.add_metaclass(abc.ABCMeta)
+class _Adder(object):
+ """Abstract base class to add two operators.
+
+ Each `Adder` acts independently, adding everything it can, paying no attention
+ as to whether another `Adder` could have done the addition more efficiently.
+ """
+
+ @property
+ def name(self):
+ return self.__class__.__name__
+
+ @abc.abstractmethod
+ def can_add(self, op1, op2):
+ """Returns `True` if this `Adder` can add `op1` and `op2`. Else `False`."""
+ pass
+
+ @abc.abstractmethod
+ def _add(self, op1, op2, operator_name, hints):
+ # Derived classes can assume op1 and op2 have been validated, e.g. they have
+ # the same dtype, and their domain/range dimensions match.
+ pass
+
+ def add(self, op1, op2, operator_name, hints=None):
+ """Return new `LinearOperator` acting like `op1 + op2`.
+
+ Args:
+ op1: `LinearOperator`
+ op2: `LinearOperator`, with `shape` and `dtype` such that adding to
+ `op1` is allowed.
+ operator_name: `String` name to give to returned `LinearOperator`
+ hints: `_Hints` object. Returned `LinearOperator` will be created with
+ these hints.
+
+ Returns:
+ `LinearOperator`
+ """
+ updated_hints = _infer_hints_allowing_override(op1, op2, hints)
+
+ if operator_name is None:
+ operator_name = "Add/" + op1.name + "__" + op2.name + "/"
+
+ values = op1.graph_parents + op2.graph_parents
+ scope_name = self.name
+ if scope_name.startswith("_"):
+ scope_name = scope_name[1:]
+ with ops.name_scope(scope_name, values=values):
+ return self._add(op1, op2, operator_name, updated_hints)
+
+
+class _AddAndReturnScaledIdentity(_Adder):
+ """Handles additions resulting in an Identity family member.
+
+ The Identity (`LinearOperatorScaledIdentity`, `LinearOperatorIdentity`) family
+ is closed under addition. This `Adder` respects that, and returns an Identity
+ """
+
+ def can_add(self, op1, op2):
+ types = {_type(op1), _type(op2)}
+ return not types.difference(_IDENTITY_FAMILY)
+
+ def _add(self, op1, op2, operator_name, hints):
+ # Will build a LinearOperatorScaledIdentity.
+
+ if _type(op1) == _SCALED_IDENTITY:
+ multiplier_1 = op1.multiplier
+ else:
+ multiplier_1 = array_ops.ones(op1.batch_shape_tensor(), dtype=op1.dtype)
+
+ if _type(op2) == _SCALED_IDENTITY:
+ multiplier_2 = op2.multiplier
+ else:
+ multiplier_2 = array_ops.ones(op2.batch_shape_tensor(), dtype=op2.dtype)
+
+ return linear_operator_identity.LinearOperatorScaledIdentity(
+ num_rows=op1.range_dimension_tensor(),
+ multiplier=multiplier_1 + multiplier_2,
+ is_non_singular=hints.is_non_singular,
+ is_self_adjoint=hints.is_self_adjoint,
+ is_positive_definite=hints.is_positive_definite,
+ name=operator_name)
+
+
+class _AddAndReturnDiag(_Adder):
+ """Handles additions resulting in a Diag operator."""
+
+ def can_add(self, op1, op2):
+ types = {_type(op1), _type(op2)}
+ return not types.difference(_DIAG_LIKE)
+
+ def _add(self, op1, op2, operator_name, hints):
+ return linear_operator_diag.LinearOperatorDiag(
+ diag=op1.diag_part() + op2.diag_part(),
+ is_non_singular=hints.is_non_singular,
+ is_self_adjoint=hints.is_self_adjoint,
+ is_positive_definite=hints.is_positive_definite,
+ name=operator_name)
+
+
+class _AddAndReturnTriL(_Adder):
+ """Handles additions resulting in a TriL operator."""
+
+ def can_add(self, op1, op2):
+ types = {_type(op1), _type(op2)}
+ return not types.difference(_DIAG_LIKE.union({_TRIL}))
+
+ def _add(self, op1, op2, operator_name, hints):
+ if _type(op1) in _EFFICIENT_ADD_TO_TENSOR:
+ op_add_to_tensor, op_other = op1, op2
+ else:
+ op_add_to_tensor, op_other = op2, op1
+
+ return linear_operator_lower_triangular.LinearOperatorLowerTriangular(
+ tril=op_add_to_tensor.add_to_tensor(op_other.to_dense()),
+ is_non_singular=hints.is_non_singular,
+ is_self_adjoint=hints.is_self_adjoint,
+ is_positive_definite=hints.is_positive_definite,
+ name=operator_name)
+
+
+class _AddAndReturnMatrix(_Adder):
+ """"Handles additions resulting in a `LinearOperatorFullMatrix`."""
+
+ def can_add(self, op1, op2): # pylint: disable=unused-argument
+ return isinstance(op1, linear_operator.LinearOperator) and isinstance(
+ op2, linear_operator.LinearOperator)
+
+ def _add(self, op1, op2, operator_name, hints):
+ if _type(op1) in _EFFICIENT_ADD_TO_TENSOR:
+ op_add_to_tensor, op_other = op1, op2
+ else:
+ op_add_to_tensor, op_other = op2, op1
+ return linear_operator_full_matrix.LinearOperatorFullMatrix(
+ matrix=op_add_to_tensor.add_to_tensor(op_other.to_dense()),
+ is_non_singular=hints.is_non_singular,
+ is_self_adjoint=hints.is_self_adjoint,
+ is_positive_definite=hints.is_positive_definite,
+ name=operator_name)
+
+
+################################################################################
+# Constants designating types of LinearOperators
+################################################################################
+
+# Type name constants for LinearOperator classes.
+_IDENTITY = "identity"
+_SCALED_IDENTITY = "scaled_identity"
+_DIAG = "diag"
+_TRIL = "tril"
+_MATRIX = "matrix"
+
+# Groups of operators.
+_DIAG_LIKE = {_DIAG, _IDENTITY, _SCALED_IDENTITY}
+_IDENTITY_FAMILY = {_IDENTITY, _SCALED_IDENTITY}
+# operators with an efficient .add_to_tensor() method.
+_EFFICIENT_ADD_TO_TENSOR = _DIAG_LIKE
+
+
+def _type(operator):
+ """Returns the type name constant (e.g. _TRIL) for operator."""
+ if isinstance(operator, linear_operator_diag.LinearOperatorDiag):
+ return _DIAG
+ if isinstance(operator,
+ linear_operator_lower_triangular.LinearOperatorLowerTriangular):
+ return _TRIL
+ if isinstance(operator, linear_operator_full_matrix.LinearOperatorFullMatrix):
+ return _MATRIX
+ if isinstance(operator, linear_operator_identity.LinearOperatorIdentity):
+ return _IDENTITY
+ if isinstance(operator,
+ linear_operator_identity.LinearOperatorScaledIdentity):
+ return _SCALED_IDENTITY
+ raise TypeError("Operator type unknown: %s" % operator)
+
+
+################################################################################
+# Addition tiers:
+# We attempt to use Adders in tier K before K+1.
+#
+# Organize tiers to
+# (i) reduce O(..) complexity of forming final operator, and
+# (ii) produce the "most efficient" final operator.
+# Dev notes:
+# * Results of addition at tier K will be added at tier K or higher.
+# * Tiers may change, and we warn the user that it may change.
+################################################################################
+
+# Note that the final tier, _AddAndReturnMatrix, will convert everything to a
+# dense matrix. So it is sometimes very inefficient.
+_DEFAULT_ADDITION_TIERS = [
+ [_AddAndReturnScaledIdentity()],
+ [_AddAndReturnDiag()],
+ [_AddAndReturnTriL()],
+ [_AddAndReturnMatrix()],
+]
diff --git a/tensorflow/python/ops/linalg/linear_operator_circulant.py b/tensorflow/python/ops/linalg/linear_operator_circulant.py
index c367ed25ad..021ef47383 100644
--- a/tensorflow/python/ops/linalg/linear_operator_circulant.py
+++ b/tensorflow/python/ops/linalg/linear_operator_circulant.py
@@ -160,20 +160,20 @@ class _BaseLinearOperatorCirculant(linear_operator.LinearOperator):
`block_depth = 1` means `A` is symmetric circulant. For example,
```
- A = |x y z y|
- |y x y z|
- |z y x y|
- |y z y x|
+ A = |w z y x|
+ |x w z y|
+ |y x w z|
+ |z y x w|
```
`block_depth = 2` means `A` is block symmetric circulant with symemtric
- circulant blocks. For example, with `X`, `Y`, `Z` symmetric circulant,
+ circulant blocks. For example, with `W`, `X`, `Y`, `Z` symmetric circulant,
```
- A = |X Y Z Y|
- |Y X Y Z|
- |Z Y X Y|
- |Y Z Y X|
+ A = |W Z Y X|
+ |X W Z Y|
+ |Y X W Z|
+ |Z Y X W|
```
`block_depth = 3` means `A` is block symmetric circulant with block
diff --git a/tensorflow/python/ops/linalg/linear_operator_test_util.py b/tensorflow/python/ops/linalg/linear_operator_test_util.py
index 78c85db557..76d659f109 100644
--- a/tensorflow/python/ops/linalg/linear_operator_test_util.py
+++ b/tensorflow/python/ops/linalg/linear_operator_test_util.py
@@ -184,7 +184,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for use_placeholder in self._use_placeholder_options:
for build_info in self._operator_build_infos:
for dtype in self._dtypes_to_test:
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
@@ -199,7 +199,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for use_placeholder in self._use_placeholder_options:
for build_info in self._operator_build_infos:
for dtype in self._dtypes_to_test:
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
@@ -215,7 +215,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for use_placeholder in self._use_placeholder_options:
for build_info in self._operator_build_infos:
for dtype in self._dtypes_to_test:
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
@@ -240,7 +240,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for dtype in self._dtypes_to_test:
for adjoint in self._adjoint_options:
for adjoint_arg in self._adjoint_arg_options:
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
@@ -283,7 +283,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for dtype in self._dtypes_to_test:
for adjoint in self._adjoint_options:
for adjoint_arg in self._adjoint_arg_options:
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
@@ -319,7 +319,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for use_placeholder in self._use_placeholder_options:
for build_info in self._operator_build_infos:
for dtype in self._dtypes_to_test:
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
@@ -335,7 +335,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for use_placeholder in self._use_placeholder_options:
for build_info in self._operator_build_infos:
for dtype in self._dtypes_to_test:
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
@@ -353,7 +353,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for use_placeholder in self._use_placeholder_options:
for build_info in self._operator_build_infos:
for dtype in self._dtypes_to_test:
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
diff --git a/tensorflow/python/ops/logging_ops.py b/tensorflow/python/ops/logging_ops.py
index df41933f8a..4c53f33af1 100644
--- a/tensorflow/python/ops/logging_ops.py
+++ b/tensorflow/python/ops/logging_ops.py
@@ -19,13 +19,24 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import pprint
+import random
+import sys
+
+import six
+
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import gen_logging_ops
+from tensorflow.python.ops import string_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_logging_ops import *
# pylint: enable=wildcard-import
+from tensorflow.python.platform import tf_logging
+from tensorflow.python.util import nest
from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.tf_export import tf_export
@@ -40,7 +51,32 @@ from tensorflow.python.util.tf_export import tf_export
# For users with Python 3 or Python 2.7
# with `from __future__ import print_function`, we could also allow lowercase.
# See https://github.com/tensorflow/tensorflow/issues/18053
-@tf_export("Print")
+
+
+# pylint: disable=invalid-name
+@deprecated("2018-08-20", "Use tf.print instead of tf.Print. Note that "
+ "tf.print returns a no-output operator that directly "
+ "prints the output. Outside of defuns or eager mode, "
+ "this operator will not be executed unless it is "
+ "directly specified in session.run or used as a "
+ "control dependency for other operators. This is "
+ "only a concern in graph mode. Below is an example "
+ "of how to ensure tf.print executes in graph mode:\n"
+ """```python
+ sess = tf.Session()
+ with sess.as_default():
+ tensor = tf.range(10)
+ print_op = tf.print(tensor)
+ with tf.control_dependencies([print_op]):
+ out = tf.add(tensor, tensor)
+ sess.run(out)
+ ```
+Additionally, to use tf.print in python 2.7, users must make sure to import
+the following:
+
+ `from __future__ import print_function`
+""")
+@tf_export(v1=["Print"])
def Print(input_, data, message=None, first_n=None, summarize=None,
name=None):
"""Prints a list of tensors.
@@ -66,6 +102,228 @@ def Print(input_, data, message=None, first_n=None, summarize=None,
A `Tensor`. Has the same type and contents as `input_`.
"""
return gen_logging_ops._print(input_, data, message, first_n, summarize, name)
+# pylint: enable=invalid-name
+
+
+def _generate_placeholder_string(x, default_placeholder="{}"):
+ """Generate and return a string that does not appear in `x`."""
+ placeholder = default_placeholder
+ rng = random.Random(5)
+ while placeholder in x:
+ placeholder = placeholder + str(rng.randint(0, 9))
+ return placeholder
+
+
+# Temporarily disable pylint g-doc-args error to allow giving more context
+# about what the kwargs are.
+# Because we are using arbitrary-length positional arguments, python 2
+# does not support explicitly specifying the keyword arguments in the
+# function definition.
+# pylint: disable=g-doc-args
+@tf_export("print")
+def print_v2(*inputs, **kwargs):
+ """Print the specified inputs.
+
+ Returns an operator that prints the specified inputs to a desired
+ output stream or logging level. The inputs may be dense or sparse Tensors,
+ primitive python objects, data structures that contain Tensors, and printable
+ python objects. Printed tensors will recursively show the first and last
+ `summarize` elements of each dimension.
+
+ With eager execution enabled and/or inside a `tf.contrib.eager.defun` this
+ operator will automatically execute, and users only need to call `tf.print`
+ without using the return value. When constructing graphs outside of a
+ `tf.contrib.eager.defun`, one must either include the returned op
+ in the input to `session.run`, or use the operator as a control dependency for
+ executed ops by specifying `with tf.control_dependencies([print_op])`.
+
+ @compatibility(python2)
+ In python 2.7, make sure to import the following:
+ `from __future__ import print_function`
+ @end_compatibility
+
+ Example:
+ Single-input usage:
+ ```python
+ tf.enable_eager_execution()
+ tensor = tf.range(10)
+ tf.print(tensor, output_stream=sys.stderr)
+ ```
+ (This prints "[0 1 2 ... 7 8 9]" to sys.stderr)
+
+ Multi-input usage:
+ ```python
+ tf.enable_eager_execution()
+ tensor = tf.range(10)
+ tf.print("tensors:", tensor, {2: tensor * 2}, output_stream=sys.stdout)
+ ```
+ (This prints "tensors: [0 1 2 ... 7 8 9] {2: [0 2 4 ... 14 16 18]}" to
+ sys.stdout)
+
+ Usage in a defun:
+ ```python
+ tf.enable_eager_execution()
+
+ @tf.contrib.eager.defun
+ def f():
+ tensor = tf.range(10)
+ tf.print(tensor, output_stream=sys.stderr)
+ return tensor
+
+ range_tensor = f()
+ ```
+ (This prints "[0 1 2 ... 7 8 9]" to sys.stderr)
+
+ Usage when constructing graphs:
+ ```python
+ sess = tf.Session()
+ with sess.as_default():
+ tensor = tf.range(10)
+ print_op = tf.print("tensors:", tensor, {2: tensor * 2},
+ output_stream=sys.stdout)
+ with tf.control_dependencies([print_op]):
+ tripled_tensor = tensor * 3
+ sess.run(tripled_tensor)
+ ```
+ (This prints "tensors: [0 1 2 ... 7 8 9] {2: [0 2 4 ... 14 16 18]}" to
+ sys.stdout)
+
+ Note: This op is only partially compatible with Jupyter notebooks and colabs.
+ Because it prints to the C++ standard out / standard error, this will go
+ in the notebook kernel's console output, not in the notebook cell output.
+
+ Args:
+ *inputs: Positional arguments that are the inputs to print. Inputs in the
+ printed output will be separated by spaces. Inputs may be python
+ primitives, tensors, data structures such as dicts and lists that
+ may contain tensors (with the data structures possibly nested in
+ arbitrary ways), and printable python objects.
+ output_stream: The output stream or logging level to print to. Defaults to
+ sys.stderr, but sys.stdout, tf.logging.info, tf.logging.warning, and
+ tf.logging.error are also supported.
+ summarize: The first and last `summarize` elements within each dimension are
+ recursively printed per Tensor. If None, then the first 3 and last 3
+ elements of each dimension are printed for each tensor. If set to -1, it
+ will print all elements of every tensor.
+ name: A name for the operation (optional).
+
+ Returns:
+ A print operator that prints the specified inputs in the specified output
+ stream or logging level.
+
+ Raises:
+ ValueError: If an unsupported output stream is specified.
+ """
+ # Because we are using arbitrary-length positional arguments, python 2
+ # does not support explicitly specifying the keyword arguments in the
+ # function definition. So, we manually get the keyword arguments w/ default
+ # values here.
+ output_stream = kwargs.pop("output_stream", sys.stderr)
+ name = kwargs.pop("name", None)
+ summarize = kwargs.pop("summarize", 3)
+ if kwargs:
+ raise ValueError("Unrecognized keyword arguments for tf.print: %s" % kwargs)
+ format_name = None
+ if name:
+ format_name = name + "_format"
+
+ # Match the C++ string constants representing the different output streams.
+ # Keep this updated!
+ output_stream_to_constant = {
+ sys.stdout: "stdout",
+ sys.stderr: "stderr",
+ tf_logging.INFO: "log(info)",
+ tf_logging.info: "log(info)",
+ tf_logging.WARN: "log(warning)",
+ tf_logging.warning: "log(warning)",
+ tf_logging.warn: "log(warning)",
+ tf_logging.ERROR: "log(error)",
+ tf_logging.error: "log(error)",
+ }
+
+ output_stream_string = output_stream_to_constant.get(output_stream)
+ if not output_stream_string:
+ raise ValueError(
+ "Unsupported output stream or logging level " +
+ str(output_stream) + ". Supported streams are sys.stdout, "
+ "sys.stderr, tf.logging.info, "
+ "tf.logging.warning, tf.logging.error")
+
+ # If we are only printing a single string scalar, there is no need to format
+ if (len(inputs) == 1 and tensor_util.is_tensor(inputs[0])
+ and (not isinstance(inputs[0], sparse_tensor.SparseTensor))
+ and inputs[0].shape and (inputs[0].dtype == dtypes.string)):
+ formatted_string = inputs[0]
+ # Otherwise, we construct an appropriate template for the tensors we are
+ # printing, and format the template using those tensors.
+ else:
+ # For each input to this print function, we extract any nested tensors,
+ # and construct an appropriate template to format representing the
+ # printed input.
+ templates = []
+ tensors = []
+ tensor_free_structure = nest.map_structure(
+ lambda x: "" if tensor_util.is_tensor(x) else x,
+ inputs)
+ tensor_free_template = " ".join(pprint.pformat(x)
+ for x in tensor_free_structure)
+ placeholder = _generate_placeholder_string(tensor_free_template)
+
+ for input_ in inputs:
+ placeholders = []
+ # Use the nest utilities to flatten & process any nested elements in this
+ # input. The placeholder for a tensor in the template should be the
+ # placeholder string, and the placeholder for a non-tensor can just be
+ # the printed value of the non-tensor itself.
+ for x in nest.flatten(input_):
+ # support sparse tensors
+ if isinstance(x, sparse_tensor.SparseTensor):
+ tensors.extend([x.indices, x.values, x.dense_shape])
+ placeholders.append(
+ "SparseTensor(indices={}, values={}, shape={})".format(
+ placeholder, placeholder, placeholder)
+ )
+ elif tensor_util.is_tensor(x):
+ tensors.append(x)
+ placeholders.append(placeholder)
+ else:
+ placeholders.append(x)
+
+ if isinstance(input_, six.string_types):
+ # If the current input to format/print is a normal string, that string
+ # can act as the template.
+ cur_template = input_
+ else:
+ # We pack the placeholders into a data structure that matches the
+ # input data structure format, then format that data structure
+ # into a string template.
+ #
+ # NOTE: We must use pprint.pformat here for building the template for
+ # unordered data structures such as `dict`, because `str` doesn't
+ # guarantee orderings, while pprint prints in sorted order. pprint
+ # will match the ordering of `nest.flatten`.
+ # This even works when nest.flatten reorders OrderedDicts, because
+ # pprint is printing *after* the OrderedDicts have been reordered.
+ cur_template = pprint.pformat(
+ nest.pack_sequence_as(input_, placeholders))
+ templates.append(cur_template)
+
+ # We join the templates for the various inputs into a single larger
+ # template. We also remove all quotes surrounding the placeholders, so that
+ # the formatted/printed output will not contain quotes around tensors.
+ # (example of where these quotes might appear: if we have added a
+ # placeholder string into a list, then pretty-formatted that list)
+ template = " ".join(templates)
+ template = template.replace("'" + placeholder + "'", placeholder)
+ formatted_string = string_ops.string_format(
+ inputs=tensors, template=template, placeholder=placeholder,
+ summarize=summarize,
+ name=format_name)
+
+ return gen_logging_ops.print_v2(formatted_string,
+ output_stream=output_stream_string,
+ name=name)
+# pylint: enable=g-doc-args
@ops.RegisterGradient("Print")
diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py
index 561a341cf3..5443699ddd 100644
--- a/tensorflow/python/ops/lookup_ops.py
+++ b/tensorflow/python/ops/lookup_ops.py
@@ -422,7 +422,7 @@ class TextFileInitializer(TableInitializerBase):
* `palmer -> 30`
```python
- table = tf.contrib.lookup.HashTable(tf.contrib.lookup.TextFileInitializer(
+ table = tf.lookup.HashTable(tf.lookup.TextFileInitializer(
"test.txt", tf.string, 0, tf.int64, 1, delimiter=" "), -1)
...
table.init.run()
@@ -435,9 +435,9 @@ class TextFileInitializer(TableInitializerBase):
* `palmer 30 -> 2`
```python
- table = tf.contrib.lookup.HashTable(tf.contrib.lookup.TextFileInitializer(
- "test.txt", tf.string, tf.contrib.lookup.TextFileIndex.WHOLE_LINE,
- tf.int64, tf.contrib.lookup.TextFileIndex.LINE_NUMBER, delimiter=" "), -1)
+ table = tf.lookup.HashTable(tf.lookup.TextFileInitializer(
+ "test.txt", tf.string, tf.lookup.TextFileIndex.WHOLE_LINE,
+ tf.int64, tf.lookup.TextFileIndex.LINE_NUMBER, delimiter=" "), -1)
...
table.init.run()
```
@@ -953,7 +953,7 @@ def index_table_from_file(vocabulary_file=None,
```python
features = tf.constant(["emerson", "lake", "and", "palmer"])
- table = tf.contrib.lookup.index_table_from_file(
+ table = tf.lookup.index_table_from_file(
vocabulary_file="test.txt", num_oov_buckets=1)
ids = table.lookup(features)
...
@@ -1054,21 +1054,21 @@ def index_table_from_tensor(vocabulary_list,
Any lookup of an out-of-vocabulary token will return a bucket ID based on its
hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the
- `default_value`.
- The bucket ID range is `[mapping size, mapping size + num_oov_buckets - 1]`.
+ `default_value`. The bucket ID range is
+ `[vocabulary list size, vocabulary list size + num_oov_buckets - 1]`.
The underlying table must be initialized by calling
`tf.tables_initializer.run()` or `table.init.run()` once.
- Elements in `mapping` cannot have duplicates, otherwise when executing the
- table initializer op, it will throw a `FailedPreconditionError`.
+ Elements in `vocabulary_list` cannot have duplicates, otherwise when executing
+ the table initializer op, it will throw a `FailedPreconditionError`.
Sample Usages:
```python
vocabulary_list = tf.constant(["emerson", "lake", "palmer"])
- table = tf.contrib.lookup.index_table_from_tensor(
- mapping=vocabulary_list, num_oov_buckets=1, default_value=-1)
+ table = tf.lookup.index_table_from_tensor(
+ vocabulary_list=vocabulary_list, num_oov_buckets=1, default_value=-1)
features = tf.constant(["emerson", "lake", "and", "palmer"])
ids = table.lookup(features)
...
@@ -1093,7 +1093,7 @@ def index_table_from_tensor(vocabulary_list,
The lookup table to map an input `Tensor` to index `int64` `Tensor`.
Raises:
- ValueError: If `mapping` is invalid.
+ ValueError: If `vocabulary_list` is invalid.
ValueError: If `num_oov_buckets` is negative.
"""
if vocabulary_list is None:
@@ -1185,7 +1185,7 @@ def index_to_string_table_from_file(vocabulary_file,
```python
indices = tf.constant([1, 5], tf.int64)
- table = tf.contrib.lookup.index_to_string_table_from_file(
+ table = tf.lookup.index_to_string_table_from_file(
vocabulary_file="test.txt", default_value="UNKNOWN")
values = table.lookup(indices)
...
@@ -1250,25 +1250,25 @@ def index_to_string_table_from_tensor(vocabulary_list,
"""Returns a lookup table that maps a `Tensor` of indices into strings.
This operation constructs a lookup table to map int64 indices into string
- values. The mapping is initialized from a string `mapping` 1-D `Tensor` where
- each element is a value and the corresponding index within the tensor is the
- key.
+ values. The mapping is initialized from a string `vocabulary_list` 1-D
+ `Tensor` where each element is a value and the corresponding index within the
+ tensor is the key.
- Any input which does not have a corresponding index in 'mapping'
+ Any input which does not have a corresponding index in 'vocabulary_list'
(an out-of-vocabulary entry) is assigned the `default_value`
The underlying table must be initialized by calling
`tf.tables_initializer.run()` or `table.init.run()` once.
- Elements in `mapping` cannot have duplicates, otherwise when executing the
- table initializer op, it will throw a `FailedPreconditionError`.
+ Elements in `vocabulary_list` cannot have duplicates, otherwise when executing
+ the table initializer op, it will throw a `FailedPreconditionError`.
Sample Usages:
```python
vocabulary_list = tf.constant(["emerson", "lake", "palmer"])
indices = tf.constant([1, 5], tf.int64)
- table = tf.contrib.lookup.index_to_string_table_from_tensor(
+ table = tf.lookup.index_to_string_table_from_tensor(
vocabulary_list, default_value="UNKNOWN")
values = table.lookup(indices)
...
diff --git a/tensorflow/python/ops/losses/util_test.py b/tensorflow/python/ops/losses/util_test.py
index 7fa7a41fca..df2e60e2e4 100644
--- a/tensorflow/python/ops/losses/util_test.py
+++ b/tensorflow/python/ops/losses/util_test.py
@@ -28,7 +28,7 @@ class LossesUtilTest(test.TestCase):
def testGetRegularizationLoss(self):
# Empty regularization collection should evaluate to 0.0.
- with self.test_session():
+ with self.cached_session():
self.assertEqual(0.0, util.get_regularization_loss().eval())
# Loss should sum.
@@ -36,14 +36,14 @@ class LossesUtilTest(test.TestCase):
ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(2.0))
ops.add_to_collection(
ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(3.0))
- with self.test_session():
+ with self.cached_session():
self.assertEqual(5.0, util.get_regularization_loss().eval())
# Check scope capture mechanism.
with ops.name_scope('scope1'):
ops.add_to_collection(
ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(-1.0))
- with self.test_session():
+ with self.cached_session():
self.assertEqual(-1.0, util.get_regularization_loss('scope1').eval())
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 33e7a5533b..f57abf6704 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -1088,9 +1088,6 @@ def floordiv(x, y, name=None):
`x // y` floor division in Python 3 and in Python 2.7 with
`from __future__ import division`.
- Note that for efficiency, `floordiv` uses C semantics for negative numbers
- (unlike Python and Numpy).
-
`x` and `y` must have the same type, and the result will have the same type
as well.
@@ -1100,7 +1097,7 @@ def floordiv(x, y, name=None):
name: A name for the operation (optional).
Returns:
- `x / y` rounded down (except possibly towards zero for negative integers).
+ `x / y` rounded down.
Raises:
TypeError: If the inputs are complex.
@@ -2901,21 +2898,23 @@ def tensordot(a, b, axes, name=None):
shape_a = a.get_shape().as_list()
axes = [i if i >= 0 else i + len(shape_a) for i in axes]
free = [i for i in xrange(len(shape_a)) if i not in axes]
- free_dims_static = [shape_a[i] for i in free]
+ axes_dims = [shape_a[i] for i in axes]
+ free_dims = [shape_a[i] for i in free]
+ free_dims_static = free_dims
+ axes = ops.convert_to_tensor(axes, dtype=dtypes.int32, name="axes")
+ free = ops.convert_to_tensor(free, dtype=dtypes.int32, name="free")
+ shape_a = array_ops.shape(a)
else:
free_dims_static = None
- shape_a = array_ops.shape(a)
- rank_a = array_ops.rank(a)
- axes = ops.convert_to_tensor(axes, dtype=dtypes.int32, name="axes")
- axes = cast(axes >= 0, dtypes.int32) * axes + cast(
- axes < 0, dtypes.int32) * (
- axes + rank_a)
- free, _ = array_ops.setdiff1d(range(rank_a), axes)
+ shape_a = array_ops.shape(a)
+ rank_a = array_ops.rank(a)
+ axes = ops.convert_to_tensor(axes, dtype=dtypes.int32, name="axes")
+ axes = array_ops.where(axes >= 0, axes, axes + rank_a)
+ free, _ = array_ops.setdiff1d(range(rank_a), axes)
free_dims = array_ops.gather(shape_a, free)
axes_dims = array_ops.gather(shape_a, axes)
prod_free_dims = reduce_prod(free_dims)
prod_axes_dims = reduce_prod(axes_dims)
- perm = array_ops.concat([axes_dims, free_dims], 0)
if flipped:
perm = array_ops.concat([axes, free], 0)
new_shape = array_ops.stack([prod_axes_dims, prod_free_dims])
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index 2861f40586..3f64f0af9a 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -22,7 +22,6 @@ import numbers
import numpy as np
-from tensorflow.python.compat import compat
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import graph_util
@@ -1672,47 +1671,24 @@ def _softmax(logits, compute_op, dim=-1, name=None):
shape = logits.get_shape()
is_last_dim = (dim is -1) or (dim == shape.ndims - 1)
- # TODO(phawkins): remove after 2018/8/27 and simplify this code.
- softmax_accepts_r1_or_greater = compat.forward_compatible(2018, 8, 27)
- reshape_required = (not softmax_accepts_r1_or_greater) and shape.ndims != 2
if is_last_dim:
- if reshape_required:
- # If dim is the last dimension, simply reshape the logits to a matrix and
- # apply the internal softmax.
- input_shape = array_ops.shape(logits)
- logits = _flatten_outer_dims(logits)
- output = compute_op(logits)
- output = array_ops.reshape(output, input_shape, name=name)
- return output
return compute_op(logits, name=name)
- # If dim is not the last dimension, we have to do a reshape and transpose so
- # that we can still perform softmax on its last dimension.
+ # If dim is not the last dimension, we have to do a transpose so that we can
+ # still perform softmax on its last dimension.
# Swap logits' dimension of dim and its last dimension.
input_rank = array_ops.rank(logits)
dim_axis = dim % shape.ndims
logits = _swap_axis(logits, dim_axis, math_ops.subtract(input_rank, 1))
- shape_after_swap = array_ops.shape(logits)
- if reshape_required:
- # Reshape logits into a matrix.
- logits = _flatten_outer_dims(logits)
-
- # Do the actual softmax on its last dimension.
- output = compute_op(logits)
-
- # Transform back the output tensor.
- output = array_ops.reshape(output, shape_after_swap)
- else:
- # Do the actual softmax on its last dimension.
- output = compute_op(logits)
+ # Do the actual softmax on its last dimension.
+ output = compute_op(logits)
output = _swap_axis(
output, dim_axis, math_ops.subtract(input_rank, 1), name=name)
- # Make shape inference work since reshape and transpose may erase its static
- # shape.
+ # Make shape inference work since transpose may erase its static shape.
output.set_shape(shape)
return output
diff --git a/tensorflow/python/ops/parallel_for/BUILD b/tensorflow/python/ops/parallel_for/BUILD
index 015181af47..07fc9433a2 100644
--- a/tensorflow/python/ops/parallel_for/BUILD
+++ b/tensorflow/python/ops/parallel_for/BUILD
@@ -123,6 +123,8 @@ cuda_py_test(
"//third_party/py/numpy",
"//tensorflow/python:layers",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:functional_ops",
"//tensorflow/python:random_ops",
"//tensorflow/python/ops/losses",
],
diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
index d403b0c61a..6e276dee55 100644
--- a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
+++ b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
@@ -31,6 +31,8 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import bitwise_ops
+from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import gradients as gradient_ops
@@ -300,28 +302,129 @@ class ArrayTest(PForTest):
self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 2)
+class BitwiseTest(PForTest):
+
+ def test_unary_cwise(self):
+ for op in [bitwise_ops.invert]:
+ x = random_ops.random_uniform([7, 3, 5], maxval=10, dtype=dtypes.int32)
+
+ # pylint: disable=cell-var-from-loop
+ def loop_fn(i):
+ x1 = array_ops.gather(x, i)
+ return op(x1)
+ # pylint: enable=cell-var-from-loop
+
+ self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.int32])
+
+ def test_binary_cwise(self):
+ binary_ops = [
+ bitwise_ops.bitwise_and,
+ bitwise_ops.bitwise_or,
+ bitwise_ops.bitwise_xor,
+ bitwise_ops.left_shift,
+ bitwise_ops.right_shift,
+ ]
+ for op in binary_ops:
+ x = random_ops.random_uniform([7, 3, 5], maxval=10, dtype=dtypes.int32)
+ y = random_ops.random_uniform([3, 5], maxval=10, dtype=dtypes.int32)
+
+ output_dtypes = []
+ # pylint: disable=cell-var-from-loop
+ def loop_fn(i):
+ x1 = array_ops.gather(x, i)
+ y1 = array_ops.gather(y, i)
+ outputs = [op(x, y), op(x1, y), op(x, y1), op(x1, y1), op(x1, x1)]
+ del output_dtypes[:]
+ output_dtypes.extend([t.dtype for t in outputs])
+ return outputs
+ # pylint: enable=cell-var-from-loop
+ self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=output_dtypes)
+
+
class MathTest(PForTest):
def test_unary_cwise_ops(self):
- for op in [
- math_ops.tanh, nn.relu, math_ops.sigmoid, math_ops.negative,
- math_ops.square
- ]:
+ complex_ops = [
+ math_ops.angle,
+ math_ops.imag,
+ math_ops.complex_abs,
+ math_ops.real,
+ math_ops.conj,
+ ]
+ real_ops = [
+ lambda x: math_ops.acosh(1 + math_ops.square(x)),
+ math_ops.abs,
+ math_ops.acos,
+ math_ops.asin,
+ math_ops.asinh,
+ math_ops.atan,
+ math_ops.atanh,
+ math_ops.bessel_i0e,
+ math_ops.bessel_i1e,
+ math_ops.cos,
+ math_ops.cosh,
+ math_ops.digamma,
+ math_ops.erf,
+ math_ops.erfc,
+ math_ops.exp,
+ math_ops.expm1,
+ math_ops.inv,
+ math_ops.is_finite,
+ math_ops.is_inf,
+ math_ops.lgamma,
+ math_ops.log,
+ math_ops.log1p,
+ math_ops.neg,
+ math_ops.negative,
+ math_ops.reciprocal,
+ math_ops.rint,
+ math_ops.round,
+ math_ops.rsqrt,
+ math_ops.sigmoid,
+ math_ops.sign,
+ math_ops.sin,
+ math_ops.sinh,
+ math_ops.sqrt,
+ math_ops.square,
+ math_ops.tan,
+ math_ops.tanh,
+ math_ops.tanh,
+ nn.elu,
+ nn.relu,
+ nn.relu6,
+ nn.selu,
+ nn.softplus,
+ nn.softsign,
+ ]
+ for op in complex_ops + real_ops:
x = random_ops.random_uniform([3, 5])
+ if op in complex_ops:
+ y = random_ops.random_uniform([3, 5])
+ x = math_ops.complex(x, y)
# pylint: disable=cell-var-from-loop
+ output_dtypes = []
def loop_fn(i):
x1 = array_ops.gather(x, i)
- y = op(x1)
- loss = math_ops.reduce_sum(y * y)
- return op(x), y, gradient_ops.gradients(loss, x1)
+ y1 = op(x1)
+ outputs = [op(x), y1]
+ if y1.dtype == dtypes.float32:
+ loss = math_ops.reduce_sum(y1 * y1)
+ grad = gradient_ops.gradients(loss, x1)
+ if grad and grad[0] is not None:
+ outputs.extend(grad)
+ del output_dtypes[:]
+ output_dtypes.extend([t.dtype for t in outputs])
+ return outputs
# pylint: enable=cell-var-from-loop
- self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 3)
+ self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=output_dtypes)
def test_unary_cwise_no_grad(self):
- for op in [math_ops.ceil, math_ops.floor, math_ops.logical_not]:
+ for op in [math_ops.ceil,
+ math_ops.floor,
+ math_ops.logical_not]:
x = random_ops.random_uniform([3, 5])
if op == math_ops.logical_not:
x = x > 0
@@ -336,33 +439,80 @@ class MathTest(PForTest):
def test_binary_cwise_ops(self):
logical_ops = [
- math_ops.logical_and, math_ops.logical_or, math_ops.logical_xor
- ]
- bool_ops = [
- math_ops.less, math_ops.less_equal, math_ops.greater,
- math_ops.greater_equal, math_ops.equal, math_ops.not_equal
+ math_ops.logical_and,
+ math_ops.logical_or,
+ math_ops.logical_xor
]
+
+ # Wrapper functions restricting the range of inputs of zeta and polygamma.
+ def safe_polygamma(x, y):
+ return math_ops.polygamma(
+ math_ops.round(clip_ops.clip_by_value(y, 1, 10)),
+ x * x + 1)
+
+ def safe_zeta(x, y):
+ return math_ops.zeta(x * x + 1, y * y)
+
float_ops = [
- math_ops.add, math_ops.subtract, math_ops.multiply, math_ops.divide,
- math_ops.maximum, math_ops.minimum
+ math_ops.add,
+ math_ops.add_v2,
+ math_ops.atan2,
+ math_ops.complex,
+ math_ops.div,
+ math_ops.divide,
+ math_ops.div_no_nan,
+ math_ops.equal,
+ math_ops.floor_div,
+ math_ops.floor_mod,
+ math_ops.greater,
+ math_ops.greater_equal,
+ math_ops.igamma,
+ math_ops.igammac,
+ math_ops.igamma_grad_a,
+ math_ops.less,
+ math_ops.less_equal,
+ math_ops.maximum,
+ math_ops.minimum,
+ math_ops.mod,
+ math_ops.multiply,
+ math_ops.not_equal,
+ math_ops.pow,
+ math_ops.squared_difference,
+ math_ops.subtract,
+ math_ops.truncate_mod,
+ safe_polygamma,
+ safe_zeta,
]
- for op in logical_ops + bool_ops + float_ops:
+ for op in logical_ops + float_ops:
x = random_ops.random_uniform([7, 3, 5])
y = random_ops.random_uniform([3, 5])
if op in logical_ops:
x = x > 0
y = y > 0
+ output_dtypes = []
# pylint: disable=cell-var-from-loop
def loop_fn(i):
x1 = array_ops.gather(x, i)
y1 = array_ops.gather(y, i)
- return op(x, y), op(x1, y), op(x, y1), op(x1, y1), op(x1, x1)
-
+ outputs = [op(x, y), op(x1, y), op(x, y1), op(x1, y1), op(x1, x1)]
+ del output_dtypes[:]
+ output_dtypes.extend([t.dtype for t in outputs])
+ return outputs
# pylint: enable=cell-var-from-loop
- dtype = dtypes.float32 if op in float_ops else dtypes.bool
- self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtype] * 5)
+ self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=output_dtypes)
+
+ def test_approximate_equal(self):
+ x = random_ops.random_uniform([3, 5])
+ y = random_ops.random_uniform([3, 5])
+
+ def loop_fn(i):
+ x1 = array_ops.gather(x, i)
+ y1 = array_ops.gather(y, i)
+ return math_ops.approximate_equal(x1, y1)
+
+ self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.bool])
def test_addn(self):
x = random_ops.random_uniform([2, 3, 5])
diff --git a/tensorflow/python/ops/parallel_for/gradients.py b/tensorflow/python/ops/parallel_for/gradients.py
index 460de0a97f..1f026b3660 100644
--- a/tensorflow/python/ops/parallel_for/gradients.py
+++ b/tensorflow/python/ops/parallel_for/gradients.py
@@ -42,6 +42,7 @@ def jacobian(output, inputs, use_pfor=True):
[y_1, ..., y_n, x_1, ..., x_m].
"""
flat_inputs = nest.flatten(inputs)
+ output_tensor_shape = output.shape
output_shape = array_ops.shape(output)
output = array_ops.reshape(output, [-1])
@@ -65,6 +66,7 @@ def jacobian(output, inputs, use_pfor=True):
new_shape = array_ops.concat(
[output_shape, array_ops.shape(out)[1:]], axis=0)
out = array_ops.reshape(out, new_shape)
+ out.set_shape(output_tensor_shape.concatenate(flat_inputs[i].shape))
pfor_outputs[i] = out
return nest.pack_sequence_as(inputs, pfor_outputs)
diff --git a/tensorflow/python/ops/parallel_for/gradients_test.py b/tensorflow/python/ops/parallel_for/gradients_test.py
index 628c6764cd..5467f55af6 100644
--- a/tensorflow/python/ops/parallel_for/gradients_test.py
+++ b/tensorflow/python/ops/parallel_for/gradients_test.py
@@ -32,6 +32,8 @@ from tensorflow.python.framework import ops
from tensorflow.python.keras.engine import training as keras_training
from tensorflow.python.layers import layers as tf_layers
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops as tf_control_flow_ops
+from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gradients as gradient_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
@@ -355,6 +357,30 @@ class GradientsTest(test.TestCase):
self.run_and_assert_equal(answer, jacobian_pfor)
self.run_and_assert_equal(answer, jacobian_while)
+ def test_jacobian_scan_shape(self):
+ # Shape x: [3, 4]
+ x = random_ops.random_uniform([3, 4])
+ elems = random_ops.random_uniform([6])
+ # Shape y: [6, 3, 4]
+ y = functional_ops.scan(lambda a, e: a + e, elems, initializer=x)
+ jacobian = gradients.jacobian(y, x)
+
+ expected_shape = [6, 3, 4, 3, 4]
+ self.assertAllEqual(expected_shape, jacobian.shape.as_list())
+
+ def test_jacobian_while_loop_shape(self):
+ # Shape x: [3, 4]
+ x = random_ops.random_uniform([3, 4])
+ _, y = tf_control_flow_ops.while_loop(lambda i, a: i > 5.,
+ lambda i, a: (i + 1, a + i),
+ (constant_op.constant(0.), x))
+ # Shape y: [2, 3]
+ y = y[:2, :3]
+ jacobian = gradients.jacobian(y, x)
+
+ expected_shape = [2, 3, 3, 4]
+ self.assertAllEqual(expected_shape, jacobian.shape.as_list())
+
def test_jacobian_unknown_shape(self):
with self.cached_session() as sess:
x = array_ops.placeholder(dtypes.float32, shape=[None, None])
diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py
index f9153b6d7d..e0f6d51881 100644
--- a/tensorflow/python/ops/parallel_for/pfor.py
+++ b/tensorflow/python/ops/parallel_for/pfor.py
@@ -28,6 +28,7 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import bitwise_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
@@ -1922,37 +1923,114 @@ def _convert_cast(pfor_input):
return wrap(math_ops.cast(inp, dtype), True)
-# Note that ops handled here do not have attributes except "T", and hence don't
-# need extra arguments passed to the cwise_op call below.
+@RegisterPForWithArgs("Abs", math_ops.abs)
+@RegisterPForWithArgs("Acosh", math_ops.acosh)
+@RegisterPForWithArgs("Acos", math_ops.acos)
@RegisterPForWithArgs("Add", math_ops.add)
+@RegisterPForWithArgs("AddV2", math_ops.add_v2)
+@RegisterPForWithArgs("Angle", math_ops.angle)
+@RegisterPForWithArgs("Asinh", math_ops.asinh)
+@RegisterPForWithArgs("Asin", math_ops.asin)
+@RegisterPForWithArgs("Atan2", math_ops.atan2)
+@RegisterPForWithArgs("Atanh", math_ops.atanh)
+@RegisterPForWithArgs("Atan", math_ops.atan)
+@RegisterPForWithArgs("BesselI0e", math_ops.bessel_i0e)
+@RegisterPForWithArgs("BesselI1e", math_ops.bessel_i1e)
+@RegisterPForWithArgs("BitwiseAnd", bitwise_ops.bitwise_and)
+@RegisterPForWithArgs("BitwiseOr", bitwise_ops.bitwise_or)
+@RegisterPForWithArgs("BitwiseXor", bitwise_ops.bitwise_xor)
@RegisterPForWithArgs("Ceil", math_ops.ceil)
+@RegisterPForWithArgs("ComplexAbs", math_ops.complex_abs)
+@RegisterPForWithArgs("Complex", math_ops.complex)
+@RegisterPForWithArgs("Conj", math_ops.conj)
+@RegisterPForWithArgs("Cosh", math_ops.cosh)
+@RegisterPForWithArgs("Cos", math_ops.cos)
+@RegisterPForWithArgs("Digamma", math_ops.digamma)
+@RegisterPForWithArgs("Div", math_ops.div)
+@RegisterPForWithArgs("DivNoNan", math_ops.div_no_nan)
+@RegisterPForWithArgs("Elu", nn_ops.elu)
@RegisterPForWithArgs("Equal", math_ops.equal)
-@RegisterPForWithArgs("NotEqual", math_ops.not_equal)
+@RegisterPForWithArgs("Erfc", math_ops.erfc)
+@RegisterPForWithArgs("Erf", math_ops.erf)
+@RegisterPForWithArgs("Expm1", math_ops.expm1)
+@RegisterPForWithArgs("Exp", math_ops.exp)
+@RegisterPForWithArgs("FloorDiv", math_ops.floor_div)
@RegisterPForWithArgs("Floor", math_ops.floor)
-@RegisterPForWithArgs("Greater", math_ops.greater)
+@RegisterPForWithArgs("FloorMod", math_ops.floor_mod)
@RegisterPForWithArgs("GreaterEqual", math_ops.greater_equal)
-@RegisterPForWithArgs("Less", math_ops.less)
+@RegisterPForWithArgs("Greater", math_ops.greater)
+@RegisterPForWithArgs("Igammac", math_ops.igammac)
+@RegisterPForWithArgs("IgammaGradA", math_ops.igamma_grad_a)
+@RegisterPForWithArgs("Igamma", math_ops.igamma)
+@RegisterPForWithArgs("Imag", math_ops.imag)
+@RegisterPForWithArgs("Invert", bitwise_ops.invert)
+@RegisterPForWithArgs("Inv", math_ops.inv)
+@RegisterPForWithArgs("IsFinite", math_ops.is_finite)
+@RegisterPForWithArgs("IsInf", math_ops.is_inf)
+@RegisterPForWithArgs("LeftShift", bitwise_ops.left_shift)
@RegisterPForWithArgs("LessEqual", math_ops.less_equal)
-@RegisterPForWithArgs("LogicalOr", math_ops.logical_or)
+@RegisterPForWithArgs("Less", math_ops.less)
+@RegisterPForWithArgs("Lgamma", math_ops.lgamma)
+@RegisterPForWithArgs("Log1p", math_ops.log1p)
@RegisterPForWithArgs("LogicalAnd", math_ops.logical_and)
@RegisterPForWithArgs("LogicalNot", math_ops.logical_not)
+@RegisterPForWithArgs("LogicalOr", math_ops.logical_or)
@RegisterPForWithArgs("LogicalXor", math_ops.logical_xor)
+@RegisterPForWithArgs("Log", math_ops.log)
@RegisterPForWithArgs("Maximum", math_ops.maximum)
@RegisterPForWithArgs("Minimum", math_ops.minimum)
+@RegisterPForWithArgs("Mod", math_ops.mod)
@RegisterPForWithArgs("Mul", math_ops.multiply)
@RegisterPForWithArgs("Neg", math_ops.negative)
+@RegisterPForWithArgs("NotEqual", math_ops.not_equal)
+@RegisterPForWithArgs("Polygamma", math_ops.polygamma)
+@RegisterPForWithArgs("Pow", math_ops.pow)
@RegisterPForWithArgs("RealDiv", math_ops.divide)
+@RegisterPForWithArgs("Real", math_ops.real)
+@RegisterPForWithArgs("ReciprocalGrad", math_ops.reciprocal_grad)
+@RegisterPForWithArgs("Reciprocal", math_ops.reciprocal)
+@RegisterPForWithArgs("Relu6", nn_ops.relu6)
@RegisterPForWithArgs("Relu", nn_ops.relu)
+@RegisterPForWithArgs("RightShift", bitwise_ops.right_shift)
+@RegisterPForWithArgs("Rint", math_ops.rint)
+@RegisterPForWithArgs("Round", math_ops.round)
+@RegisterPForWithArgs("RsqrtGrad", math_ops.rsqrt_grad)
+@RegisterPForWithArgs("Rsqrt", math_ops.rsqrt)
+@RegisterPForWithArgs("Selu", nn_ops.selu)
@RegisterPForWithArgs("Sigmoid", math_ops.sigmoid)
+@RegisterPForWithArgs("Sign", math_ops.sign)
+@RegisterPForWithArgs("Sinh", math_ops.sinh)
+@RegisterPForWithArgs("Sin", math_ops.sin)
+@RegisterPForWithArgs("Softplus", nn_ops.softplus)
+@RegisterPForWithArgs("Softsign", nn_ops.softsign)
+@RegisterPForWithArgs("SqrtGrad", math_ops.sqrt_grad)
+@RegisterPForWithArgs("Sqrt", math_ops.sqrt)
+@RegisterPForWithArgs("SquaredDifference", math_ops.squared_difference)
@RegisterPForWithArgs("Square", math_ops.square)
@RegisterPForWithArgs("Sub", math_ops.subtract)
@RegisterPForWithArgs("Tanh", math_ops.tanh)
+@RegisterPForWithArgs("Tan", math_ops.tan)
+@RegisterPForWithArgs("TruncateDiv", math_ops.truncate_div)
+@RegisterPForWithArgs("TruncateMod", math_ops.truncate_mod)
+@RegisterPForWithArgs("Zeta", math_ops.zeta)
def _convert_cwise(pfor_input, op_type, op_func):
- del op_type
+ # Note that ops handled here do not have attributes except "T" and "Tout", and
+ # hence don't need extra arguments passed to the cwise_op call below.
+ for attr in pfor_input.op.node_def.attr.keys():
+ assert attr in [u"T", u"Tout"], (op_type, attr)
pfor_input.expanddim_inputs_for_broadcast()
return wrap(op_func(*[x.t for x in pfor_input.inputs]), True)
+@RegisterPFor("ApproximateEqual")
+def _convert_approximate_equal(pfor_input):
+ pfor_input.expanddim_inputs_for_broadcast()
+ x = pfor_input.input(0)[0]
+ y = pfor_input.input(1)[0]
+ tolerance = pfor_input.get_attr("tolerance")
+ return wrap(math_ops.approximate_equal(x, y, tolerance=tolerance), True)
+
+
@RegisterPFor("Shape")
def _convert_shape(pfor_input):
out_type = pfor_input.get_attr("out_type")
@@ -2009,10 +2087,14 @@ def _convert_biasaddgrad(pfor_input):
# Some required ops are not exposed under the tf namespace. Hence relying on
# _create_op to create them.
+@RegisterPForWithArgs("EluGrad")
+@RegisterPForWithArgs("Relu6Grad")
@RegisterPForWithArgs("ReluGrad")
-@RegisterPForWithArgs("TanhGrad")
+@RegisterPForWithArgs("SeluGrad")
@RegisterPForWithArgs("SigmoidGrad")
@RegisterPForWithArgs("SoftplusGrad")
+@RegisterPForWithArgs("SoftsignGrad")
+@RegisterPForWithArgs("TanhGrad")
def _convert_grads(pfor_input, op_type, *args, **kw_args):
del args
del kw_args
diff --git a/tensorflow/python/ops/parsing_ops.py b/tensorflow/python/ops/parsing_ops.py
index 8224097ac4..b3e03a0135 100644
--- a/tensorflow/python/ops/parsing_ops.py
+++ b/tensorflow/python/ops/parsing_ops.py
@@ -981,9 +981,10 @@ def parse_sequence_example(serialized,
name: A name for this operation (optional).
Returns:
- A tuple of two `dict`s, each mapping keys to `Tensor`s and `SparseTensor`s.
- The first dict contains the context key/values.
- The second dict contains the feature_list key/values.
+ A tuple of three `dict`s, each mapping keys to `Tensor`s and
+ `SparseTensor`s. The first dict contains the context key/values,
+ the second dict contains the feature_list key/values, and the final dict
+ contains the lengths of any dense feature_list features.
Raises:
ValueError: if any feature is invalid.
@@ -1584,7 +1585,8 @@ def decode_csv(records,
record_defaults: A list of `Tensor` objects with specific types.
Acceptable types are `float32`, `float64`, `int32`, `int64`, `string`.
One tensor per column of the input record, with either a
- scalar default value for that column or empty if the column is required.
+ scalar default value for that column or an empty vector if the column is
+ required.
field_delim: An optional `string`. Defaults to `","`.
char delimiter to separate fields in a record.
use_quote_delim: An optional `bool`. Defaults to `True`.
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index 55c2eb5fa4..4a126e9d7a 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -48,14 +48,14 @@ def get_resource_handle_data(graph_op):
assert ops._USE_C_SHAPES # pylint: disable=protected-access
assert type(graph_op) == ops.Tensor # pylint: disable=unidiomatic-typecheck
- handle_data = pywrap_tensorflow.GetResourceHandleShapeAndType(
+ handle_data = pywrap_tensorflow.GetHandleShapeAndType(
graph_op.graph._c_graph, graph_op._as_tf_output()) # pylint: disable=protected-access
return cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData.FromString(
compat.as_bytes(handle_data))
-def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode):
+def eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode):
"""Creates a variable handle with information to do shape inference."""
container = ops.get_default_graph()._container # pylint: disable=protected-access
if container is None:
@@ -397,61 +397,33 @@ class ResourceVariable(variables.RefVariable):
# When in eager mode use a uid for the shared_name, to prevent
# accidental sharing.
shared_name = "%s_%d" % (handle_name, ops.uid())
- if init_from_fn:
- # Use attr_scope and device(None) to simulate the behavior of
- # colocate_with when the variable we want to colocate with doesn't
- # yet exist.
- if self._in_graph_mode:
- attr = attr_value_pb2.AttrValue(
- list=attr_value_pb2.AttrValue.ListValue(
- s=[compat.as_bytes("loc:@%s" % handle_name)]))
- with ops.get_default_graph()._attr_scope({"_class": attr}):
- with ops.name_scope("Initializer"), ops.device(None):
- initial_value = ops.convert_to_tensor(
- initial_value(), name="initial_value", dtype=dtype)
- self._handle = _eager_safe_variable_handle(
- shape=initial_value.get_shape(),
- dtype=initial_value.dtype.base_dtype,
- shared_name=shared_name,
- name=name,
- graph_mode=self._in_graph_mode)
- self._shape = initial_value.get_shape()
- else:
- initial_value = initial_value()
- with ops.name_scope("Initializer"):
- initial_value = ops.convert_to_tensor(
- initial_value, name="initial_value", dtype=dtype)
- self._handle = _eager_safe_variable_handle(
- shape=initial_value.get_shape(),
- dtype=initial_value.dtype.base_dtype,
- shared_name=shared_name,
- name=name,
- graph_mode=False)
- self._shape = initial_value.get_shape()
- # pylint: enable=protected-access
-
- # Or get the initial value from a Tensor or Python object.
- else:
- with ops.name_scope("Initializer"):
+ # Use attr_scope and device(None) to simulate the behavior of
+ # colocate_with when the variable we want to colocate with doesn't
+ # yet exist.
+ attr = attr_value_pb2.AttrValue(
+ list=attr_value_pb2.AttrValue.ListValue(
+ s=[compat.as_bytes("loc:@%s" % handle_name)]))
+ with ops.get_default_graph()._attr_scope({"_class": attr}):
+ with ops.name_scope("Initializer"), ops.device(None):
initial_value = ops.convert_to_tensor(
- initial_value, name="initial_value", dtype=dtype)
- # pylint: disable=protected-access
- if (self._in_graph_mode and initial_value is not None and
- initial_value.op._get_control_flow_context() is not None):
- raise ValueError(
- "Initializer for variable %s is from inside a control-flow "
- "construct, such as a loop or conditional. When creating a "
- "variable inside a loop or conditional, use a lambda as the "
- "initializer." % name)
- # pylint: enable=protected-access
- self._handle = _eager_safe_variable_handle(
+ initial_value() if init_from_fn else initial_value,
+ name="initial_value", dtype=dtype)
+ self._handle = eager_safe_variable_handle(
shape=initial_value.get_shape(),
dtype=initial_value.dtype.base_dtype,
shared_name=shared_name,
name=name,
graph_mode=self._in_graph_mode)
- self._shape = initial_value.get_shape()
-
+ self._shape = initial_value.shape
+ # pylint: disable=protected-access
+ if (self._in_graph_mode and initial_value is not None and
+ initial_value.op._get_control_flow_context() is not None):
+ raise ValueError(
+ "Initializer for variable %s is from inside a control-flow "
+ "construct, such as a loop or conditional. When creating a "
+ "variable inside a loop or conditional, use a lambda as the "
+ "initializer." % name)
+ # pylint: enable=protected-access
self._unique_id = shared_name
self._initial_value = initial_value if self._in_graph_mode else None
self._handle_name = handle_name + ":0"
diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py
index 5c00d929bf..5a3a5cc225 100644
--- a/tensorflow/python/ops/rnn.py
+++ b/tensorflow/python/ops/rnn.py
@@ -709,6 +709,10 @@ def _dynamic_rnn_loop(cell,
Raises:
ValueError: If the input depth cannot be inferred via shape inference
from the inputs.
+ ValueError: If time_step is not the same for all the elements in the
+ inputs.
+ ValueError: If batch_size is not the same for all the elements in the
+ inputs.
"""
state = initial_state
assert isinstance(parallel_iterations, int), "parallel_iterations must be int"
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index c11c9ccaae..43cca1a498 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -428,7 +428,7 @@ class BasicRNNCell(LayerRNNCell):
def build(self, inputs_shape):
if inputs_shape[-1] is None:
raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
- % str(input_shape))
+ % str(inputs_shape))
input_depth = inputs_shape[-1]
self._kernel = self.add_variable(
@@ -525,7 +525,7 @@ class GRUCell(LayerRNNCell):
def build(self, inputs_shape):
if inputs_shape[-1] is None:
raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
- % str(input_shape))
+ % str(inputs_shape))
input_depth = inputs_shape[-1]
self._gate_kernel = self.add_variable(
@@ -705,7 +705,7 @@ class BasicLSTMCell(LayerRNNCell):
def build(self, inputs_shape):
if inputs_shape[-1] is None:
raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
- % str(input_shape))
+ % str(inputs_shape))
input_depth = inputs_shape[-1]
h_depth = self._num_units
@@ -908,7 +908,7 @@ class LSTMCell(LayerRNNCell):
def build(self, inputs_shape):
if inputs_shape[-1] is None:
raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
- % str(input_shape))
+ % str(inputs_shape))
input_depth = inputs_shape[-1]
h_depth = self._num_units if self._num_proj is None else self._num_proj
@@ -954,7 +954,7 @@ class LSTMCell(LayerRNNCell):
"""Run one step of LSTM.
Args:
- inputs: input Tensor, 2D, `[batch, num_units].
+ inputs: input Tensor, must be 2-D, `[batch, input_size]`.
state: if `state_is_tuple` is False, this must be a state Tensor,
`2-D, [batch, state_size]`. If `state_is_tuple` is True, this must be a
tuple of state Tensors, both `2-D`, with column sizes `c_state` and
diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py
index 29fefbe3a5..046a48d192 100644
--- a/tensorflow/python/ops/string_ops.py
+++ b/tensorflow/python/ops/string_ops.py
@@ -29,16 +29,19 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_string_ops
from tensorflow.python.ops import math_ops
-from tensorflow.python.util import compat as util_compat
# go/tf-wildcard-import
# pylint: disable=wildcard-import
+# pylint: disable=g-bad-import-order
from tensorflow.python.ops.gen_string_ops import *
+from tensorflow.python.util import compat as util_compat
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
+# pylint: enable=g-bad-import-order
# pylint: enable=wildcard-import
@@ -90,11 +93,6 @@ def regex_replace(source, pattern, rewrite, replace_global=True):
Returns:
string `Tensor` of the same shape as `source` with specified replacements.
"""
- # TODO(b/112455102): Remove compat.forward_compatible once past the horizon.
- if not compat.forward_compatible(2018, 10, 10):
- return gen_string_ops.regex_replace(
- input=source, pattern=pattern,
- rewrite=rewrite, replace_global=replace_global)
if (isinstance(pattern, util_compat.bytes_or_text_types) and
isinstance(rewrite, util_compat.bytes_or_text_types)):
# When `pattern` and `rewrite` are static through the life of the op we can
@@ -108,6 +106,87 @@ def regex_replace(source, pattern, rewrite, replace_global=True):
rewrite=rewrite, replace_global=replace_global)
+@tf_export("strings.format")
+def string_format(template, inputs, placeholder="{}", summarize=3, name=None):
+ r"""Formats a string template using a list of tensors.
+
+ Formats a string template using a list of tensors, abbreviating tensors by
+ only printing the first and last `summarize` elements of each dimension
+ (recursively). If formatting only one tensor into a template, the tensor does
+ not have to be wrapped in a list.
+
+ Example:
+ Formatting a single-tensor template:
+ ```python
+ sess = tf.Session()
+ with sess.as_default():
+ tensor = tf.range(10)
+ formatted = tf.strings.format("tensor: {}, suffix", tensor)
+ out = sess.run(formatted)
+ expected = "tensor: [0 1 2 ... 7 8 9], suffix"
+
+ assert(out.decode() == expected)
+ ```
+
+ Formatting a multi-tensor template:
+ ```python
+ sess = tf.Session()
+ with sess.as_default():
+ tensor_one = tf.reshape(tf.range(100), [10, 10])
+ tensor_two = tf.range(10)
+ formatted = tf.strings.format("first: {}, second: {}, suffix",
+ (tensor_one, tensor_two))
+
+ out = sess.run(formatted)
+ expected = ("first: [[0 1 2 ... 7 8 9]\n"
+ " [10 11 12 ... 17 18 19]\n"
+ " [20 21 22 ... 27 28 29]\n"
+ " ...\n"
+ " [70 71 72 ... 77 78 79]\n"
+ " [80 81 82 ... 87 88 89]\n"
+ " [90 91 92 ... 97 98 99]], second: [0 1 2 ... 7 8 9], suffix")
+
+ assert(out.decode() == expected)
+ ```
+
+ Args:
+ template: A string template to format tensor values into.
+ inputs: A list of `Tensor` objects, or a single Tensor.
+ The list of tensors to format into the template string. If a solitary
+ tensor is passed in, the input tensor will automatically be wrapped as a
+ list.
+ placeholder: An optional `string`. Defaults to `{}`.
+ At each placeholder occurring in the template, a subsequent tensor
+ will be inserted.
+ summarize: An optional `int`. Defaults to `3`.
+ When formatting the tensors, show the first and last `summarize`
+ entries of each tensor dimension (recursively). If set to -1, all
+ elements of the tensor will be shown.
+ name: A name for the operation (optional).
+
+ Returns:
+ A scalar `Tensor` of type `string`.
+
+ Raises:
+ ValueError: if the number of placeholders does not match the number of
+ inputs.
+ """
+ # If there is only one tensor to format, we will automatically wrap it in a
+ # list to simplify the user experience
+ if tensor_util.is_tensor(inputs):
+ inputs = [inputs]
+ if template.count(placeholder) != len(inputs):
+ raise ValueError("%s placeholder(s) in template does not match %s tensor(s)"
+ " provided as input" % (template.count(placeholder),
+ len(inputs)))
+
+ return gen_string_ops.string_format(inputs,
+ template=template,
+ placeholder=placeholder,
+ summarize=summarize,
+ name=name)
+
+
@tf_export("string_split")
def string_split(source, delimiter=" ", skip_empty=True): # pylint: disable=invalid-name
"""Split elements of `source` based on `delimiter` into a `SparseTensor`.
@@ -251,6 +330,17 @@ def reduce_join(inputs, axis=None,
reduce_join.__doc__ = deprecation.rewrite_argument_docstring(
gen_string_ops.reduce_join.__doc__, "reduction_indices", "axis")
+
+# This wrapper provides backwards compatibility for code that predates the
+# unit argument and that passed 'name' as a positional argument.
+@tf_export("strings.length")
+def string_length(input, name=None, unit="BYTE"):
+ return gen_string_ops.string_length(input, unit=unit, name=name)
+
+
+string_length.__doc__ = gen_string_ops.string_length.__doc__
+
+
ops.NotDifferentiable("RegexReplace")
ops.NotDifferentiable("StringToHashBucket")
ops.NotDifferentiable("StringToHashBucketFast")
diff --git a/tensorflow/python/ops/summary_ops_v2.py b/tensorflow/python/ops/summary_ops_v2.py
index 94c7d88b5c..a404507627 100644
--- a/tensorflow/python/ops/summary_ops_v2.py
+++ b/tensorflow/python/ops/summary_ops_v2.py
@@ -234,6 +234,7 @@ def create_file_writer(logdir,
"""
if logdir is None:
return SummaryWriter(None, None)
+ logdir = str(logdir)
with ops.device("cpu:0"):
if max_queue is None:
max_queue = constant_op.constant(10)
diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py
new file mode 100644
index 0000000000..875be31602
--- /dev/null
+++ b/tensorflow/python/ops/while_v2.py
@@ -0,0 +1,580 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""while_v2 and gradient.
+
+This is a version of while_loop that emits a single While op, as well as the
+gradient function for While ops produced by while_loop. This will eventually
+replace the current tf.while_loop implementation once it reaches feature and
+performance parity.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from tensorflow.core.framework import attr_value_pb2
+from tensorflow.python.eager import function
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import function_def_to_graph
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import cond_v2_impl as cond_v2
+from tensorflow.python.ops import control_flow_util
+from tensorflow.python.ops import gen_functional_ops
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import list_ops
+from tensorflow.python.util import nest
+
+# pylint: disable=protected-access
+
+# TODO(b/79881896): Handle external control dependencies. tf.while_loop allows
+# control dependencies on external nodes with at least 1 output.
+# Another idea is to create const nodes outside the loop and add control edges
+# to them and then pass those in as data inputs. This should probably be
+# handled in the CapturingGraph itself.
+
+
+def while_loop(cond, body, loop_vars, name=None):
+ """Like tf.while_loop, except emits a single While op."""
+ if not name:
+ name = "while"
+
+ with ops.name_scope(name) as scope:
+ with ops.name_scope(None):
+ cond_name = _get_unique_name(("%scond" % scope).replace("/", "_"))
+ body_name = _get_unique_name(("%sbody" % scope).replace("/", "_"))
+
+ flattened_loop_vars = nest.flatten(loop_vars)
+ num_outputs = len(flattened_loop_vars)
+
+ # Add loop counter needed for computing gradients.
+ flattened_loop_vars = [constant_op.constant(0., name="loop_counter")
+ ] + flattened_loop_vars
+
+ # Build a `cond` wrapper that can handle the extra counter loop_var.
+ def wrapped_cond(unused_loop_counter, *loop_vars):
+ return cond(*loop_vars)
+
+ cond_graph = function.func_graph_from_py_func(cond_name, wrapped_cond,
+ flattened_loop_vars, {})
+
+ # Add external_captures of cond to the list of loop vars.
+ # Note that external tensors will be treated as loop invariants, i.e.,
+ # the value of that tensor in each iteration is the same as it was at the
+ # beginning of the loop execution.
+ flattened_loop_vars = flattened_loop_vars + cond_graph.external_captures
+
+ def wrapped_body(loop_counter, *args):
+ """Loop body augmented with counter update.
+
+ Args:
+ loop_counter: Loop counter which needs to be incremented in the body.
+ *args: List of args
+ args[:num_outputs] - Args for the original loop body.
+ args[num_outputs:] - External captures of cond. These get passed
+ through as is.
+
+ Returns:
+ A list of tensors the same length as args.
+ """
+ outputs = body(*args[:num_outputs])
+ if not isinstance(outputs, collections.Sequence):
+ outputs = [outputs]
+
+ # Return the external_captures of cond_graph as is, i.e., treat them as
+ # loop invariants.
+ # TODO(srbs): Update lowering code to create _Enter nodes with
+ # is_constant=True for inputs that are directly passed to outputs.
+ return [loop_counter + 1] + list(outputs) + list(args[num_outputs:])
+
+ body_graph = function.func_graph_from_py_func(body_name, wrapped_body,
+ flattened_loop_vars, {})
+ # Add external captures of body to the list of loop vars.
+ # Note that external tensors will be treated as loop invariants, i.e.,
+ # the value of that tensor in each iteration is the same as it was at the
+ # beginning of the loop execution.
+ flattened_loop_vars = flattened_loop_vars + body_graph.external_captures
+ # TODO(srbs): Update lowering code to create _Enter nodes with
+ # is_constant=True for inputs that are directly passed to outputs.
+ body_graph.outputs.extend(body_graph.internal_captures)
+
+ # Capture `external_captures` of `body_graph` in `cond_graph` so that it
+ # expects to receive those as arguments.
+ # TODO(srbs): Dedup tensors that are captured in both the cond and body.
+ # This logic already exists in cond_v2.
+ with cond_graph.as_default():
+ for external_capture in body_graph.external_captures:
+ cond_graph.capture(external_capture)
+
+ # Export all tensors in the loop body that may be needed for gradient
+ # computation. We do this by accumulating the intermediate values in
+ # TensorLists.
+ intermediate_tensors = _get_intermediates(body_graph)
+
+ for intermediate_tensor in intermediate_tensors:
+ # TODO(srbs): Cache and re-use empty tensor lists.
+ tensor_list = list_ops.empty_tensor_list(
+ element_dtype=intermediate_tensor.dtype,
+ element_shape=_get_tensor_convertible_shape(
+ intermediate_tensor.shape))
+ flattened_loop_vars.append(tensor_list)
+ with cond_graph.as_default():
+ # Add a placeholder to cond_graph's inputs corresponding to the
+ # tensor_list.
+ cond_graph.capture(tensor_list)
+ with body_graph.as_default():
+ # Push the intermediate tensor to the tensor list. This captures the
+ # `tensor_list` as well.
+ appended_tensor_list = list_ops.tensor_list_push_back(
+ tensor_list,
+ intermediate_tensor)
+ # Add this modified tensor list to the list of outputs.
+ body_graph.outputs.append(appended_tensor_list)
+
+ outputs = gen_functional_ops._while(
+ flattened_loop_vars,
+ cond_v2._create_new_tf_function(cond_graph),
+ cond_v2._create_new_tf_function(body_graph),
+ name=scope)
+
+ _copy_handle_data(body_graph.outputs, outputs)
+ _maybe_set_lowering_attr(outputs[0].op)
+
+ # First var is loop counter.
+ if num_outputs == 1:
+ return outputs[1]
+ else:
+ return nest.pack_sequence_as(loop_vars, outputs[1:1 + num_outputs])
+
+
+@ops.RegisterGradient("While")
+def _WhileGrad(op, *grads): # pylint: disable=invalid-name
+ """The gradient of a While op produced by while_loop."""
+ body_graph = _get_body_graph(op)
+
+ # Replace None gradients with zeros. This is needed because `grads` could have
+ # None incoming gradients for the TensorLists. If we pass None's through, the
+ # custom gradient of TensorListPopBack will create an EmptyTensorList inside
+ # the FuncGraph which is undesirable.
+ # TODO(b/80444525): There might be an issue with treating no gradient as zero
+ # gradient in certain cases. Consider replacing None gradients with Zeros
+ # for accumulators only.
+ grads = [
+ g if g is not None else array_ops.zeros_like(output)
+ for g, output in zip(grads, op.outputs)
+ ]
+
+ body_grad_graph, args = _create_grad_func(
+ body_graph, grads,
+ _get_unique_name("%s_grad" % body_graph.name), op)
+
+ intermediate_tensors = _get_intermediates(body_grad_graph)
+
+ for intermediate_tensor in intermediate_tensors:
+ tensor_list = list_ops.empty_tensor_list(
+ element_dtype=intermediate_tensor.dtype,
+ element_shape=_get_tensor_convertible_shape(intermediate_tensor.shape))
+ with body_grad_graph.as_default():
+ tensor_list_ph = body_grad_graph.capture(tensor_list, whitelisted=True)
+ # Push the intermediate tensor to the tensor list.
+ appended_tensor_list = list_ops.tensor_list_push_back(tensor_list_ph,
+ intermediate_tensor)
+ # Add this modified tensor list to the list of outputs.
+ body_grad_graph.outputs.append(appended_tensor_list)
+
+ def grad_cond(counter, max_iters, *unused_args):
+ return counter < max_iters
+
+ loop_vars = args + body_grad_graph.external_captures
+ cond_grad_graph = function.func_graph_from_py_func(
+ _get_unique_name("%s_grad_cond" % op.name),
+ grad_cond, loop_vars, {})
+
+ assert len(loop_vars) == len(body_grad_graph.inputs)
+ assert len(loop_vars) == len(body_grad_graph.outputs)
+ assert len(loop_vars) == len(cond_grad_graph.inputs)
+
+ outputs = gen_functional_ops._while(
+ loop_vars,
+ cond_v2._create_new_tf_function(cond_grad_graph),
+ cond_v2._create_new_tf_function(body_grad_graph),
+ name=_get_unique_name("%s_grad" % op.name))
+
+ _copy_handle_data(body_grad_graph.outputs, outputs)
+ _maybe_set_lowering_attr(outputs[0].op)
+
+ # outputs[0] is the loop counter.
+ # outputs[1] is the total number of loop iterations.
+ return outputs[2:2 + len(op.inputs)]
+
+
+# TODO(srbs): Pull this into common utils for cond_v2 and while_v2.
+def _get_body_graph(while_op):
+ """Returns `FuncGraph` for the while body.
+
+ Args:
+ while_op: The While Operation.
+
+ Returns:
+ `FuncGraph` for the while body.
+ """
+ extra_inputs = list(while_op.inputs)
+ input_shapes = [t.shape for t in extra_inputs]
+ func_name = while_op.get_attr("body").name
+ fdef = while_op.graph._get_function(func_name).definition
+ func_graph = function_def_to_graph.function_def_to_graph(fdef, input_shapes)
+ func_graph._while = while_op
+ return func_graph
+
+
+def _create_grad_func(func_graph, grads, name, while_op):
+ """Builds and returns the gradient FuncGraph of `func_graph` and its args.
+
+ The returned grad_func_graph must be called with the returned
+ args + grad_func_graph.captures.
+
+ Args:
+ func_graph: FuncGraph for the forward body function.
+ grads: The incoming grads for `func_graph`'s outputs.
+ name: Name of the returned gradient function.
+ while_op: The forward While op.
+
+ Returns:
+ 2-tuple of (grad_func_graph, args).
+ """
+ assert len(func_graph.outputs) == len(grads)
+
+ loop_counter = constant_op.constant(0.)
+ # TODO(srbs): For nested while loops will need to lookup this value from
+ # the accumulator of the enclosing while loop. For now use as is assuming
+ # there is no nesting.
+ num_iters_t = while_op.outputs[0]
+
+ args = [loop_counter, num_iters_t] + grads
+
+ # Note: The returned function does not have `args` in the list of
+ # `external_captures`.
+ grad_func_graph = function.func_graph_from_py_func(
+ name,
+ lambda *args: _grad_fn(func_graph, args),
+ args, {},
+ func_graph=_WhileBodyGradFuncGraph(name, func_graph))
+
+ # Add the popped accumulators to the list of outputs.
+ for internal_capture in grad_func_graph.internal_captures:
+ grad_func_graph.outputs.append(
+ grad_func_graph.popped_tensor_lists[internal_capture])
+
+ return grad_func_graph, args
+
+
+def _grad_fn(func_graph, args):
+ """Computes the gradient of `func_graph` in the current graph.
+
+ This function builds the gradient graph of the corresponding forward-pass
+ `func_graph` by differentiating `func_graph`'s outputs w.r.t. its inputs.
+
+ Args:
+ func_graph: function.FuncGraph. The corresponding forward-pass function.
+ args: The input arguments. args[0] - Loop counter args[1] - Total number of
+ iterations.
+ args[2:] - Incoming gradients for `func_graph.outputs`.
+
+ Returns:
+ The output gradient Tensors.
+ """
+ xs = func_graph.inputs
+ ys = func_graph.outputs
+ grad_ys = args[2:]
+
+ # Build the gradient graph. Note that this builds the gradient computation of
+ # func_graph in the current graph, which requires capturing tensors from
+ # func_graph. The captured func_graph tensors are resolved to external tensors
+ # in _resolve_grad_inputs.
+ # TODO(srbs): Mark GradientsHelper as public?
+ grad_outs = gradients_impl._GradientsHelper(
+ ys, xs, grad_ys=grad_ys, src_graph=func_graph)
+
+ assert all([g is not None for g in grad_outs])
+ counter = args[0]
+ total_iters = args[1]
+ return [counter + 1, total_iters] + grad_outs
+
+
+def _get_intermediates(func_graph):
+ """Returns all tensors in `func_graph` that should be accumulated."""
+ # We currently accumulate output tensors of most ops in the function and rely
+ # on the pruning pass to get rid of the unused accumulators at runtime.
+ # However, this can bloat the GraphDef and make debugging harder so we perform
+ # some optimizations.
+ #
+ # Optimization we currently perform:
+ # 1. We do not accumulate tensors which already have an accumulator
+ # in the loop body.
+ # 2. We do not accumulate outputs of Identity nodes. When building the
+ # FuncGraph, we add an Identity node for each output (see
+ # `AutomaticControlDependencies.mark_as_return`). Accumulating outputs
+ # of all these nodes bloats the GraphDef quite a bit so we remove those.
+ # Since the gradient of an Identity node does not rely on its forward op's
+ # input this is safe to do.
+ #
+ # Other possible optimizations:
+ # 1. Only accumulate tensors that will be required by the backward pass.
+ # This will require running the gradient pass and hence would increase the
+ # graph building time for the forward pass.
+ # 2. Do not accumulate Const nodes created inside the loop body.
+ # 3. Do not accumulate inputs that are passed as-is, e.g. loop invariants.
+ # TODO(srbs): 2 and 3 may be hard optimizations for the runtime optimizer
+ # since it requires knowledge of the while loop semantics. If so, consider
+ # doing those here.
+ intermediates = []
+
+ for op in func_graph.get_operations():
+ if op.type == "Identity":
+ continue
+ for o in op.outputs:
+ if (o != func_graph.inputs[0] and # Loop counter.
+ _get_accumulator(o) is None): # Has existing accumulator.
+ intermediates.append(o)
+ return intermediates
+
+
+def _get_accumulator(tensor):
+ r"""Returns TensorList if any containing accumulated values of tensor.
+
+ We try to find a pattern of the form:
+
+ input_tl tensor
+ \ /
+ (TensorListPushBack)
+ |
+ output_tl
+
+ which satisfies the following conditions:
+
+ 1. input_tl must be in tensor.graph.inputs.
+ 2. output_tl or Identity(output_tl) must be in tensor.graph.outputs.
+ 3. tensor.graph.input_index(input_tl) == tensor.graph.output_index(output_t).
+
+ output_tl or Identity(output_tl) (whichever is in tensor.graph.outputs) is
+ returned if such a pattern is found else None is returned.
+
+ Args:
+ tensor: The Tensor to be accumulated.
+
+ Returns:
+ A variant tensor in the same graph as `tensor` or None if no accumulator is
+ found.
+ """
+ assert isinstance(tensor.graph, function.FuncGraph)
+
+ def get_func_graph_output(t):
+ """Returns t or Identity(t) whichever exists in graph outputs else None."""
+ if t in tensor.graph.outputs:
+ return t
+ # tf.defun adds an Identity for each output, check whether that is the case.
+ identity_op = t.consumers()[0]
+ if (identity_op.type == "Identity" and
+ identity_op.outputs[0] in tensor.graph.outputs):
+ return identity_op.outputs[0]
+ return None
+
+ for consumer in tensor.consumers():
+ # Find the consumer that is a TensorListPushBack node whose TensorList input
+ # is in the list of function inputs.
+ if (consumer.type != "TensorListPushBack" or
+ consumer.inputs[0] not in tensor.graph.inputs):
+ continue
+
+ output = get_func_graph_output(consumer.outputs[0])
+ if output is None:
+ # The TensorList output of `consumer` is not in the list of function
+ # outputs.
+ continue
+
+ accum_input_idx = tensor.graph.inputs.index(consumer.inputs[0])
+ accum_output_idx = tensor.graph.outputs.index(output)
+ if accum_input_idx == accum_output_idx:
+ return output
+ return None
+
+
+# TODO(srbs): Add to common utils for cond_v2 and while_v2.
+def _get_unique_name(name):
+ """Returns a name that is unique in the root graph of `func_graph`.
+
+ Args:
+ name: String to uniquify.
+
+ Returns:
+ A string.
+ """
+ with ops.init_scope():
+ return ops.get_default_graph().unique_name(name)
+
+
+class _WhileBodyGradFuncGraph(function.FuncGraph):
+ """FuncGraph for the gradient function of the body of a While op.
+
+ Contains the logic for capturing the tensors from the body of the forward
+ While op which is as follows:
+ 1. Find the accumulator for that tensor.
+ 2. Capture the forward While op output tensor corresponding to the
+ accumulator in this FuncGraph.
+ 3. Pop a value from the captured placeholder and use it as the captured value
+ for the forward pass tensor.
+
+ This only allows capturing tensors in the forward graph. A ValueError is
+ raised if an attempt is made to capture a tensor not in the forward graph.
+ To manually capture capture a tensor that is not in the forward graph, call
+ `capture` with `whitelisted=True`.
+
+ Note: The `captures` dict does not contain the forward tensor since it is not
+ directly captured. It contains the accumulator corresponding to this forward
+ tensor.
+
+ Attributes:
+ popped_tensor_lists: Dict from the captured accumulator placeholder to the
+ TensorList obtained after popping the intermediate tensor from it. The
+ values of this dict need to be added to the list of outputs.
+ """
+
+ def __init__(self, name, forward_graph):
+ super(_WhileBodyGradFuncGraph, self).__init__(name)
+ self.popped_tensor_lists = {}
+ # FuncGraph for the body of the forward While op.
+ self._forward_graph = forward_graph
+ # Dict from forward intermediate tensor to the corresponding "popped" tensor
+ # in this graph.
+ self._indirect_captures = {}
+ # Dict from forward graph tensor to the While op output corresponding to its
+ # accumulator.
+ self._tensor_to_accumulator = {}
+
+ def capture(self, tensor, name=None, whitelisted=False):
+ """Selectively captures external tensors.
+
+ If `whitelisted` is False only allows capturing tensors in the
+ `_forward_graph`.
+
+ Args:
+ tensor: Tensor. May be from this FuncGraph or a different graph.
+ name: Optional name if a placeholder is created.
+ whitelisted: If False (default), only allows capturing tensors from the
+ forward graph.
+
+ Returns:
+ The placeholder in this graph for the tensor.
+
+ Raises:
+ ValueError: If attempting to capture an external tensor not in the forward
+ graph with `whitelisted` set to False.
+ """
+ if (not whitelisted and tensor.graph is not self and
+ tensor.graph != self._forward_graph):
+ raise ValueError("Attempting to capture tensor", str(tensor),
+ " which is not in the forward graph but in ",
+ _graph_name(tensor.graph), ".")
+ return super(_WhileBodyGradFuncGraph, self).capture(tensor, name)
+
+ def _capture_helper(self, tensor, name):
+ if tensor.graph is not self._forward_graph:
+ return super(_WhileBodyGradFuncGraph, self)._capture_helper(tensor, name)
+
+ captured_tensor = self._indirect_captures.get(tensor)
+ if captured_tensor is not None:
+ # For GradientTape housekeeping.
+ assert self._tensor_to_accumulator[tensor] in self.captures
+ super(_WhileBodyGradFuncGraph, self)._capture_helper(
+ self._tensor_to_accumulator[tensor], name)
+ return captured_tensor
+
+ assert tensor not in self._tensor_to_accumulator
+
+ accumulator = None
+
+ # Find the TensorList that was used to accumulate the tensors of this
+ # intermediate tensor.
+ accumulator = _get_accumulator(tensor)
+ if accumulator is None:
+ raise ValueError("Reference to un-accumulated intermediate tensor: ",
+ tensor.name)
+ assert accumulator.graph == self._forward_graph
+ # Get the While op output corresponding to the accumulator.
+ accumulator = self._forward_graph._while.outputs[self._forward_graph.outputs
+ .index(accumulator)]
+
+ assert accumulator.graph == self._forward_graph.outer_graph
+ self._tensor_to_accumulator[tensor] = accumulator
+
+ # Capture the `accumulator`.
+ accumulator_ph = super(_WhileBodyGradFuncGraph, self)._capture_helper(
+ accumulator, name)
+ new_tensor_list, captured_tensor = list_ops.tensor_list_pop_back(
+ accumulator_ph, element_dtype=tensor.dtype)
+ self._indirect_captures[tensor] = captured_tensor
+ self.popped_tensor_lists[accumulator_ph] = new_tensor_list
+ return captured_tensor
+
+
+def _copy_handle_data(src_tensors, tgt_tensors):
+ for src_t, tgt_t in zip(src_tensors, tgt_tensors):
+ function._copy_handle_data(src_t, tgt_t)
+
+
+# TODO(srbs): Move to common utils for cond_v2 and while_v2.
+def _maybe_set_lowering_attr(op):
+ """Sets the flag to enable lowering on the `While` op if necessary.
+
+ Lowering allows while_v2 to avoid some of the limitations of Functions,
+ allowing users to specify devices & colocation inside of while_v2
+ branches, and enabling non-strict evaluation & partial pruning of while_v2
+ branches. This brings while_v2 closer to feature parity with
+ tf.while_loop.
+
+ However, we do not lower `While` in the XLA context because it is easier
+ for XLA to apply its own optimizations when dealing with un-lowered
+ `While` operators than with low-level control flow primitives.
+
+ Args:
+ op: The While op.
+ """
+ if not control_flow_util.IsInXLAContext(op):
+ # pylint: disable=protected-access
+ op._set_attr("_lower_using_switch_merge", attr_value_pb2.AttrValue(b=True))
+ # pylint: enable=protected-access
+
+
+def _get_tensor_convertible_shape(shape):
+ assert isinstance(shape, tensor_shape.TensorShape)
+ if shape.is_fully_defined():
+ return shape
+ if not shape: # Unknown shape.
+ return -1
+ # Partially defined shape.
+ shape_list = shape.as_list()
+ shape_list = [s if s is not None else -1 for s in shape_list]
+ return ops.convert_to_tensor(shape_list)
+
+
+def _graph_name(graph):
+ if isinstance(graph, function.FuncGraph):
+ return graph.name
+ return "Base"
+
+
+# pylint: enable=protected-access