diff options
-rw-r--r-- | tensorflow/compiler/tests/sort_ops_test.py | 25 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/topk_op.cc | 108 |
2 files changed, 34 insertions, 99 deletions
diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py index 8ae579abda..9e2ef964a1 100644 --- a/tensorflow/compiler/tests/sort_ops_test.py +++ b/tensorflow/compiler/tests/sort_ops_test.py @@ -64,20 +64,29 @@ class XlaSortOpTest(xla_test.XLATestCase): if self.device in ["XLA_CPU", "XLA_GPU"]: return - # Only bfloat16 is implemented. - bfloat16 = dtypes.bfloat16.as_numpy_dtype - if bfloat16 in self.numeric_types: - for x in [np.arange(20)]: + supported_types = set( + [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32]) + for dtype in supported_types.intersection(self.numeric_types): + # Use small input size for bfloat16. Otherwise, we'll get duplicate values + # after conversion to bfloat16, so the possible resulting index array is + # no longer unique. + if dtype == dtypes.bfloat16.as_numpy_dtype: + array_size = 20 + k_options = [0, 1, 2, 10, 20] + else: + array_size = 200 * 1000 + k_options = [0, 1, 2, 10, 20, 100, 1000, 200 * 1000] + for x in [np.arange(array_size)]: np.random.shuffle(x) - for k in [0, 1, 2, 10, 20]: + for k in k_options: indices = x.argsort()[::-1][:k] def topk(v, k=k): return nn_ops.top_k(v, k=k, sorted=True) self._assertOpOutputMatchesExpected( - topk, [x.astype(bfloat16)], - expected=[x[indices].astype(bfloat16), indices]) + topk, [x.astype(dtype)], + expected=[x[indices].astype(dtype), indices]) def testTopKZeros(self): """Tests that positive and negative zeros sort correctly.""" @@ -99,7 +108,7 @@ class XlaSortOpTest(xla_test.XLATestCase): {p: np.array([0., -0., 0., 3., -0., -4., 0., -0.], dtype=bfloat16)}) self.assertAllEqual( np.array([3., 0., 0., 0.], dtype=bfloat16), results[0]) - self.assertEqual(list([3, 0, 1, 2]), list(results[1])) + self.assertEqual(list([3, 0, 2, 6]), list(results[1])) def testTopKInfinities(self): """Tests that positive and negative infinity sort correctly.""" diff --git a/tensorflow/compiler/tf2xla/kernels/topk_op.cc b/tensorflow/compiler/tf2xla/kernels/topk_op.cc index 8a1377fc38..9962f1207d 100644 --- a/tensorflow/compiler/tf2xla/kernels/topk_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/topk_op.cc @@ -52,107 +52,33 @@ class TopKOp : public XlaOpKernel { errors::Unimplemented("TopK is implemented for 1-D inputs, got shape ", input_shape.DebugString())); - const int64 n = input_shape.dim_size(0); - OP_REQUIRES(context, n < (1 << 16), - errors::Unimplemented( - "TopK is implemented for sizes up to 2**16, got shape ", - input_shape.DebugString())); - xla::XlaBuilder* const b = context->builder(); if (input_shape.dim_size(0) < k) { k = input_shape.dim_size(0); } - const xla::XlaOp input_bf16 = context->Input(0); - xla::XlaOp iota_s32 = xla::Iota(b, xla::S32, n); - - // TODO(b/73891930): add a key-value sort to HLO, rather than using - // bit-packing tricks here. - - xla::XlaOp zero = xla::ConstantR0<int32>(b, 0); - - // max can either be 0x7FFFFFFF or 0x8000000. Neither choice is totally - // ideal. The implications of the choice are: - // - // 0x7FFFFFFF - // 1. +0.0 > -0.0 - // 2. The elements of the inputs and outputs are bitwise identical. - // 3. The sort is unstable since a later +0.0 will appear before an earlier - // -0.0. - // - // 0x8000000 - // 1. +0.0 == -0.0 - // 2. All -0.0 in the input are replaced with +0.0 in the output. - // 3. The sort is stable. - xla::XlaOp max = xla::ConstantR0<int32>(b, 0x80000000); - xla::XlaOp index_mask = xla::ConstantR0<int32>(b, 0x0000FFFF); - xla::XlaOp value_mask = xla::ConstantR0<int32>(b, 0xFFFF0000); - - // Convert to from bf16 to f32. The lower 16-bits are zero due to the - // definition of bf16. - xla::XlaOp input_f32 = xla::ConvertElementType(input_bf16, xla::F32); - - // Negate the input to reverse sort it. The lower 16-bits are zero, because - // negating a float is just inverting the high-bit. - xla::XlaOp negative_input_f32 = xla::Neg(input_f32); - - // Convert to a sign magnitude integer. The lower 16-bits are zero, since - // bitcast convert doesn't change any bits. - xla::XlaOp negative_input_sm32 = - xla::BitcastConvertType(negative_input_f32, xla::S32); - - // Convert from sign magnitude integer to two's complement integer. The - // lower 16-bits are zero on both sides of the select. On the false side, - // the value is unchanged, and on the true side, the lower 16-bits of max - // are all zero, so the lower 16-bits of the result of the subtraction will - // also be zero. - xla::XlaOp negative_input_s32 = - xla::Select(xla::Lt(negative_input_sm32, zero), - xla::Sub(max, negative_input_sm32), negative_input_sm32); - - // In order for the Or with iota_s32 to to work properly, the lower 16-bits - // of negative_input_32 must be zero. - - // Pack elements as: - // * upper 16 bits are the value - // * lower 16 bits are the index. - xla::XlaOp packed_s32 = xla::Or(negative_input_s32, iota_s32); - - // TODO(phawkins): use a more efficient algorithm that does not require a - // full sort. - xla::XlaOp sorted_s32 = xla::Slice(xla::Sort(packed_s32), - /*start_indices=*/{0}, - /*limit_indices=*/{k}, - /*strides=*/{1}); - - // Unpack the value/index. - xla::XlaOp indices_s32 = xla::And(sorted_s32, index_mask); - xla::XlaOp negative_values_s32 = xla::And(sorted_s32, value_mask); - - // Convert from two's complement integer to sign magnitude integer. - xla::XlaOp negative_values_sm32 = - xla::Select(xla::Lt(negative_values_s32, zero), - xla::Sub(max, negative_values_s32), negative_values_s32); - - xla::XlaOp negative_values_f32 = - xla::BitcastConvertType(negative_values_sm32, xla::F32); - - // Negate the values to get back the original inputs. - xla::XlaOp values_f32 = xla::Neg(negative_values_f32); - - // Convert from f32 to bf16. - xla::XlaOp values_bf16 = xla::ConvertElementType(values_f32, xla::BF16); - - context->SetOutput(0, values_bf16); - context->SetOutput(1, indices_s32); + const xla::XlaOp input = context->Input(0); + xla::XlaOp iota_s32 = xla::Iota(b, xla::S32, input_shape.dim_size(0)); + xla::XlaOp sort_result = xla::Sort(xla::Neg(input), iota_s32); + xla::XlaOp values = + xla::Neg(xla::Slice(xla::GetTupleElement(sort_result, 0), + /*start_indices=*/{0}, + /*limit_indices=*/{k}, + /*strides=*/{1})); + xla::XlaOp indices = xla::Slice(xla::GetTupleElement(sort_result, 1), + /*start_indices=*/{0}, + /*limit_indices=*/{k}, + /*strides=*/{1}); + context->SetOutput(0, values); + context->SetOutput(1, indices); } private: bool sorted_; }; -REGISTER_XLA_OP( - Name("TopKV2").CompileTimeConstInput("k").TypeConstraint("T", DT_BFLOAT16), - TopKOp); +REGISTER_XLA_OP(Name("TopKV2").CompileTimeConstInput("k").TypeConstraint( + "T", {DT_UINT32, DT_INT32, DT_FLOAT, DT_BFLOAT16}), + TopKOp); } // namespace } // namespace tensorflow |