aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/embedding_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/embedding_ops.py')
-rw-r--r--tensorflow/python/ops/embedding_ops.py9
1 files changed, 6 insertions, 3 deletions
diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py
index f4561d1a83..8c1ccc6840 100644
--- a/tensorflow/python/ops/embedding_ops.py
+++ b/tensorflow/python/ops/embedding_ops.py
@@ -191,9 +191,12 @@ def _embedding_lookup_and_transform(params,
(flat_ids - extras) // ids_per_partition)
# Emulate a conditional using a boolean indicator tensor
- new_ids = array_ops.where(p_assignments < extras,
- flat_ids % (ids_per_partition + 1),
- (flat_ids - extras) % ids_per_partition)
+ is_in_first_extras_partitions = math_ops.cast(p_assignments < extras,
+ flat_ids.dtype)
+ new_ids = (is_in_first_extras_partitions * (flat_ids %
+ (ids_per_partition + 1)) +
+ (1 - is_in_first_extras_partitions) *
+ ((flat_ids - extras) % ids_per_partition))
else:
raise ValueError("Unrecognized partition strategy: " +
partition_strategy)