aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Michael Kuperstein <mkuper@google.com>2018-06-29 11:38:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-29 11:44:44 -0700
commitc8e967357ef0bf040e85e1fb1aa85af54e8d5689 (patch)
treee07674ae8413dda738633fe3d26e8577c79e058f
parente19d4924f20131b4e95cee711535125ee7902dba (diff)
[TF:XLA] A more generic TopK.
Use Sort to implement R1 TopK for an arbitrary dimension size, and more types. PiperOrigin-RevId: 202681175
-rw-r--r--tensorflow/compiler/tests/sort_ops_test.py25
-rw-r--r--tensorflow/compiler/tf2xla/kernels/topk_op.cc108
2 files changed, 34 insertions, 99 deletions
diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py
index 8ae579abda..9e2ef964a1 100644
--- a/tensorflow/compiler/tests/sort_ops_test.py
+++ b/tensorflow/compiler/tests/sort_ops_test.py
@@ -64,20 +64,29 @@ class XlaSortOpTest(xla_test.XLATestCase):
if self.device in ["XLA_CPU", "XLA_GPU"]:
return
- # Only bfloat16 is implemented.
- bfloat16 = dtypes.bfloat16.as_numpy_dtype
- if bfloat16 in self.numeric_types:
- for x in [np.arange(20)]:
+ supported_types = set(
+ [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32])
+ for dtype in supported_types.intersection(self.numeric_types):
+ # Use small input size for bfloat16. Otherwise, we'll get duplicate values
+ # after conversion to bfloat16, so the possible resulting index array is
+ # no longer unique.
+ if dtype == dtypes.bfloat16.as_numpy_dtype:
+ array_size = 20
+ k_options = [0, 1, 2, 10, 20]
+ else:
+ array_size = 200 * 1000
+ k_options = [0, 1, 2, 10, 20, 100, 1000, 200 * 1000]
+ for x in [np.arange(array_size)]:
np.random.shuffle(x)
- for k in [0, 1, 2, 10, 20]:
+ for k in k_options:
indices = x.argsort()[::-1][:k]
def topk(v, k=k):
return nn_ops.top_k(v, k=k, sorted=True)
self._assertOpOutputMatchesExpected(
- topk, [x.astype(bfloat16)],
- expected=[x[indices].astype(bfloat16), indices])
+ topk, [x.astype(dtype)],
+ expected=[x[indices].astype(dtype), indices])
def testTopKZeros(self):
"""Tests that positive and negative zeros sort correctly."""
@@ -99,7 +108,7 @@ class XlaSortOpTest(xla_test.XLATestCase):
{p: np.array([0., -0., 0., 3., -0., -4., 0., -0.], dtype=bfloat16)})
self.assertAllEqual(
np.array([3., 0., 0., 0.], dtype=bfloat16), results[0])
- self.assertEqual(list([3, 0, 1, 2]), list(results[1]))
+ self.assertEqual(list([3, 0, 2, 6]), list(results[1]))
def testTopKInfinities(self):
"""Tests that positive and negative infinity sort correctly."""
diff --git a/tensorflow/compiler/tf2xla/kernels/topk_op.cc b/tensorflow/compiler/tf2xla/kernels/topk_op.cc
index 8a1377fc38..9962f1207d 100644
--- a/tensorflow/compiler/tf2xla/kernels/topk_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/topk_op.cc
@@ -52,107 +52,33 @@ class TopKOp : public XlaOpKernel {
errors::Unimplemented("TopK is implemented for 1-D inputs, got shape ",
input_shape.DebugString()));
- const int64 n = input_shape.dim_size(0);
- OP_REQUIRES(context, n < (1 << 16),
- errors::Unimplemented(
- "TopK is implemented for sizes up to 2**16, got shape ",
- input_shape.DebugString()));
-
xla::XlaBuilder* const b = context->builder();
if (input_shape.dim_size(0) < k) {
k = input_shape.dim_size(0);
}
- const xla::XlaOp input_bf16 = context->Input(0);
- xla::XlaOp iota_s32 = xla::Iota(b, xla::S32, n);
-
- // TODO(b/73891930): add a key-value sort to HLO, rather than using
- // bit-packing tricks here.
-
- xla::XlaOp zero = xla::ConstantR0<int32>(b, 0);
-
- // max can either be 0x7FFFFFFF or 0x8000000. Neither choice is totally
- // ideal. The implications of the choice are:
- //
- // 0x7FFFFFFF
- // 1. +0.0 > -0.0
- // 2. The elements of the inputs and outputs are bitwise identical.
- // 3. The sort is unstable since a later +0.0 will appear before an earlier
- // -0.0.
- //
- // 0x8000000
- // 1. +0.0 == -0.0
- // 2. All -0.0 in the input are replaced with +0.0 in the output.
- // 3. The sort is stable.
- xla::XlaOp max = xla::ConstantR0<int32>(b, 0x80000000);
- xla::XlaOp index_mask = xla::ConstantR0<int32>(b, 0x0000FFFF);
- xla::XlaOp value_mask = xla::ConstantR0<int32>(b, 0xFFFF0000);
-
- // Convert to from bf16 to f32. The lower 16-bits are zero due to the
- // definition of bf16.
- xla::XlaOp input_f32 = xla::ConvertElementType(input_bf16, xla::F32);
-
- // Negate the input to reverse sort it. The lower 16-bits are zero, because
- // negating a float is just inverting the high-bit.
- xla::XlaOp negative_input_f32 = xla::Neg(input_f32);
-
- // Convert to a sign magnitude integer. The lower 16-bits are zero, since
- // bitcast convert doesn't change any bits.
- xla::XlaOp negative_input_sm32 =
- xla::BitcastConvertType(negative_input_f32, xla::S32);
-
- // Convert from sign magnitude integer to two's complement integer. The
- // lower 16-bits are zero on both sides of the select. On the false side,
- // the value is unchanged, and on the true side, the lower 16-bits of max
- // are all zero, so the lower 16-bits of the result of the subtraction will
- // also be zero.
- xla::XlaOp negative_input_s32 =
- xla::Select(xla::Lt(negative_input_sm32, zero),
- xla::Sub(max, negative_input_sm32), negative_input_sm32);
-
- // In order for the Or with iota_s32 to to work properly, the lower 16-bits
- // of negative_input_32 must be zero.
-
- // Pack elements as:
- // * upper 16 bits are the value
- // * lower 16 bits are the index.
- xla::XlaOp packed_s32 = xla::Or(negative_input_s32, iota_s32);
-
- // TODO(phawkins): use a more efficient algorithm that does not require a
- // full sort.
- xla::XlaOp sorted_s32 = xla::Slice(xla::Sort(packed_s32),
- /*start_indices=*/{0},
- /*limit_indices=*/{k},
- /*strides=*/{1});
-
- // Unpack the value/index.
- xla::XlaOp indices_s32 = xla::And(sorted_s32, index_mask);
- xla::XlaOp negative_values_s32 = xla::And(sorted_s32, value_mask);
-
- // Convert from two's complement integer to sign magnitude integer.
- xla::XlaOp negative_values_sm32 =
- xla::Select(xla::Lt(negative_values_s32, zero),
- xla::Sub(max, negative_values_s32), negative_values_s32);
-
- xla::XlaOp negative_values_f32 =
- xla::BitcastConvertType(negative_values_sm32, xla::F32);
-
- // Negate the values to get back the original inputs.
- xla::XlaOp values_f32 = xla::Neg(negative_values_f32);
-
- // Convert from f32 to bf16.
- xla::XlaOp values_bf16 = xla::ConvertElementType(values_f32, xla::BF16);
-
- context->SetOutput(0, values_bf16);
- context->SetOutput(1, indices_s32);
+ const xla::XlaOp input = context->Input(0);
+ xla::XlaOp iota_s32 = xla::Iota(b, xla::S32, input_shape.dim_size(0));
+ xla::XlaOp sort_result = xla::Sort(xla::Neg(input), iota_s32);
+ xla::XlaOp values =
+ xla::Neg(xla::Slice(xla::GetTupleElement(sort_result, 0),
+ /*start_indices=*/{0},
+ /*limit_indices=*/{k},
+ /*strides=*/{1}));
+ xla::XlaOp indices = xla::Slice(xla::GetTupleElement(sort_result, 1),
+ /*start_indices=*/{0},
+ /*limit_indices=*/{k},
+ /*strides=*/{1});
+ context->SetOutput(0, values);
+ context->SetOutput(1, indices);
}
private:
bool sorted_;
};
-REGISTER_XLA_OP(
- Name("TopKV2").CompileTimeConstInput("k").TypeConstraint("T", DT_BFLOAT16),
- TopKOp);
+REGISTER_XLA_OP(Name("TopKV2").CompileTimeConstInput("k").TypeConstraint(
+ "T", {DT_UINT32, DT_INT32, DT_FLOAT, DT_BFLOAT16}),
+ TopKOp);
} // namespace
} // namespace tensorflow