diff options
Diffstat (limited to 'tensorflow/python/ops/embedding_ops.py')
-rw-r--r-- | tensorflow/python/ops/embedding_ops.py | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py index f4561d1a83..8c1ccc6840 100644 --- a/tensorflow/python/ops/embedding_ops.py +++ b/tensorflow/python/ops/embedding_ops.py @@ -191,9 +191,12 @@ def _embedding_lookup_and_transform(params, (flat_ids - extras) // ids_per_partition) # Emulate a conditional using a boolean indicator tensor - new_ids = array_ops.where(p_assignments < extras, - flat_ids % (ids_per_partition + 1), - (flat_ids - extras) % ids_per_partition) + is_in_first_extras_partitions = math_ops.cast(p_assignments < extras, + flat_ids.dtype) + new_ids = (is_in_first_extras_partitions * (flat_ids % + (ids_per_partition + 1)) + + (1 - is_in_first_extras_partitions) * + ((flat_ids - extras) % ids_per_partition)) else: raise ValueError("Unrecognized partition strategy: " + partition_strategy) |