aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-03-26 15:39:54 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-26 15:42:46 -0700
commitc83a54adcface7d4bb666d7c4fd3968ba980a50d (patch)
tree5010155c5e2a46ef47bb9f9933e3bbf0b4628a7e
parent290632966fae0619db30c1ba777634db9a43b757 (diff)
Makes tf.gather not silently snapshot resource variables.
PiperOrigin-RevId: 190537320
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py29
-rw-r--r--tensorflow/python/ops/array_ops.py17
-rw-r--r--tensorflow/python/ops/embedding_ops.py29
3 files changed, 33 insertions, 42 deletions
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
index c4139dde49..07b3ad71d4 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
@@ -785,26 +785,31 @@ class AttentionWrapperTest(test.TestCase):
wrapper.BahdanauAttention, wrapper.LuongAttention)
expected_final_output = BasicDecoderOutput(
- rnn_output=ResultSummary(
- shape=(5, 3, 20), dtype=dtype('float32'), mean=0.11798714846372604),
- sample_id=ResultSummary(
- shape=(5, 3), dtype=dtype('int32'), mean=7.933333333333334))
+ rnn_output=ResultSummary(shape=(5, 3, 20),
+ dtype=dtype('float32'),
+ mean=0.11723966),
+ sample_id=ResultSummary(shape=(5, 3),
+ dtype=dtype('int32'),
+ mean=9.2666666666666675))
expected_final_state = AttentionWrapperState(
cell_state=LSTMStateTuple(
- c=ResultSummary(
- shape=(5, 9), dtype=dtype('float32'), mean=-0.0036486709),
- h=ResultSummary(
- shape=(5, 9), dtype=dtype('float32'), mean=-0.0018835809)),
- attention=ResultSummary(
- shape=(5, 20), dtype=dtype('float32'), mean=0.11798714846372604),
+ c=ResultSummary(shape=(5, 9),
+ dtype=dtype('float32'),
+ mean=-0.003545674),
+ h=ResultSummary(shape=(5, 9),
+ dtype=dtype('float32'),
+ mean=-0.0018327223)),
+ attention=ResultSummary(shape=(5, 20),
+ dtype=dtype('float32'),
+ mean=0.11728073),
time=3,
alignments=(
ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125),
ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125)),
+ alignment_history=(),
attention_state=(
ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125),
- ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125)),
- alignment_history=())
+ ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125)))
expected_final_alignment_history = (
ResultSummary(shape=(3, 5, 8), dtype=dtype('float32'), mean=0.125),
ResultSummary(shape=(3, 5, 8), dtype=dtype('float32'), mean=0.125))
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index ec7c14f7d8..9106461c60 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -2691,12 +2691,17 @@ reverse_sequence.__doc__ = deprecation.rewrite_argument_docstring(
@tf_export("gather")
def gather(params, indices, validate_indices=None, name=None, axis=0):
- # TODO(rjryan): Remove "Gather" creation in favor of GatherV2 once the forward
- # compatibility 3 week period has passed.
- if axis == 0:
- return gen_array_ops.gather(
- params, indices, validate_indices=validate_indices, name=name)
- else:
+ del validate_indices
+ if axis != 0:
+ # Note that we do a sparse_read here to avoid snapshotting the entire
+ # resource variable and doing a gather, which can be inefficient and lead to
+ # subtle race conditions. TODO(apassos) implement axis != 0 on sparse_read
+ return gen_array_ops.gather_v2(params, indices, axis, name=name)
+ try:
+ # TODO(apassos) find a less bad way of detecting resource variables without
+ # introducing a circular dependency.
+ return params.sparse_read(indices, name=name)
+ except AttributeError:
return gen_array_ops.gather_v2(params, indices, axis, name=name)
diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py
index 20e4a28b9c..f0120f2957 100644
--- a/tensorflow/python/ops/embedding_ops.py
+++ b/tensorflow/python/ops/embedding_ops.py
@@ -35,34 +35,14 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util.tf_export import tf_export
-def _gather(params, ids, name=None):
- """Helper function for _embedding_lookup_and_transform.
-
- This function gathers embeddings from a single tensor. The gather deals with
- resource variables specially.
-
- Args:
- params: A `Tensor` of embeddings.
- ids: A `Tensor` indexing the embeddings to be retrieved from `params`.
- name: A name for the operation (optional).
-
- Returns:
- A `Tensor` with the same type as `params`.
- """
- if isinstance(params, resource_variable_ops.ResourceVariable):
- return params.sparse_read(ids, name=name)
- else:
- return array_ops.gather(params, ids, name=name)
-
-
def _clip(params, ids, max_norm):
"""Helper function for _embedding_lookup_and_transform.
This function optionally clips embeddings to an l2-norm of max_norm.
Args:
- params: A `Tensor` of embeddings retrieved by `_gather`.
- ids: The `ids` argument that was passed to `_gather`.
+ params: A `Tensor` of embeddings retrieved by `gather`.
+ ids: The `ids` argument that was passed to `gather`.
max_norm: If provided, the embeddings are l2-normalized to the value of
max_norm.
@@ -148,7 +128,8 @@ def _embedding_lookup_and_transform(params,
ids = ops.convert_to_tensor(ids, name="ids")
if np == 1 and (not transform_fn or ids.get_shape().ndims == 1):
with ops.colocate_with(params[0]):
- result = _clip(_gather(params[0], ids, name=name), ids, max_norm)
+ result = _clip(array_ops.gather(params[0], ids, name=name),
+ ids, max_norm)
if transform_fn:
result = transform_fn(result)
return result
@@ -212,7 +193,7 @@ def _embedding_lookup_and_transform(params,
for p in xrange(np):
pids = gather_ids[p]
with ops.colocate_with(params[p]):
- result = _gather(params[p], pids)
+ result = array_ops.gather(params[p], pids)
if transform_fn:
# If transform_fn is provided, the clip_by_norm precedes
# the transform and hence must be co-located. See below