aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Adrian Kuegel <akuegel@google.com>2018-09-19 01:55:55 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-19 01:59:05 -0700
commit0d8942fcbcc9cb3a05be8acc843d1fc4b6dfc9f1 (patch)
tree28cccc167c27fa13e207860a4f339d9c28cd3e74
parent6967287715a097c8b009b52010c53247ab658232 (diff)
Implement sort op for CPU.
Also don't allow parallelization for the sort op in parallel_task_assignment. PiperOrigin-RevId: 213592046
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD13
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_runtime.cc24
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_runtime.h12
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc145
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h6
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc1
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc237
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h88
-rw-r--r--tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc13
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/BUILD14
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc54
11 files changed, 605 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 8cc522a59e..b3e4fab727 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -180,6 +180,7 @@ cc_library(
":runtime_conv2d_mkl",
":runtime_fft",
":runtime_fork_join",
+ ":runtime_key_value_sort",
":runtime_matmul",
":runtime_matmul_mkl",
":runtime_single_threaded_conv2d",
@@ -624,6 +625,18 @@ cc_library(
)
cc_library(
+ name = "runtime_key_value_sort",
+ srcs = ["runtime_key_value_sort.cc"],
+ hdrs = ["runtime_key_value_sort.h"],
+ copts = runtime_copts(),
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:lib",
+ "//third_party/eigen3",
+ ],
+)
+
+cc_library(
name = "runtime_fork_join",
srcs = ["runtime_fork_join.cc"],
hdrs = ["runtime_fork_join.h"],
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
index 8a44c384bb..7e1590955a 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
@@ -74,6 +74,30 @@ extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName =
"__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation";
extern const char* const kParallelForkJoinSymbolName =
"__xla_cpu_runtime_ParallelForkJoin";
+extern const char* const kKeyValueSortPREDSymbolName =
+ "__xla_cpu_runtime_KeyValueSortPRED";
+extern const char* const kKeyValueSortS8SymbolName =
+ "__xla_cpu_runtime_KeyValueSortS8";
+extern const char* const kKeyValueSortU8SymbolName =
+ "__xla_cpu_runtime_KeyValueSortU8";
+extern const char* const kKeyValueSortS16SymbolName =
+ "__xla_cpu_runtime_KeyValueSortS16";
+extern const char* const kKeyValueSortU16SymbolName =
+ "__xla_cpu_runtime_KeyValueSortU16";
+extern const char* const kKeyValueSortF16SymbolName =
+ "__xla_cpu_runtime_KeyValueSortF16";
+extern const char* const kKeyValueSortS32SymbolName =
+ "__xla_cpu_runtime_KeyValueSortS32";
+extern const char* const kKeyValueSortU32SymbolName =
+ "__xla_cpu_runtime_KeyValueSortU32";
+extern const char* const kKeyValueSortF32SymbolName =
+ "__xla_cpu_runtime_KeyValueSortF32";
+extern const char* const kKeyValueSortS64SymbolName =
+ "__xla_cpu_runtime_KeyValueSortS64";
+extern const char* const kKeyValueSortU64SymbolName =
+ "__xla_cpu_runtime_KeyValueSortU64";
+extern const char* const kKeyValueSortF64SymbolName =
+ "__xla_cpu_runtime_KeyValueSortF64";
extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_";
} // namespace runtime
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h
index aa0e967123..e6345e0344 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h
@@ -63,6 +63,18 @@ extern const char* const kReleaseInfeedBufferAfterDequeueSymbolName;
extern const char* const kAcquireOutfeedBufferForPopulationSymbolName;
extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName;
extern const char* const kParallelForkJoinSymbolName;
+extern const char* const kKeyValueSortPREDSymbolName;
+extern const char* const kKeyValueSortS8SymbolName;
+extern const char* const kKeyValueSortU8SymbolName;
+extern const char* const kKeyValueSortS16SymbolName;
+extern const char* const kKeyValueSortU16SymbolName;
+extern const char* const kKeyValueSortF16SymbolName;
+extern const char* const kKeyValueSortS32SymbolName;
+extern const char* const kKeyValueSortU32SymbolName;
+extern const char* const kKeyValueSortF32SymbolName;
+extern const char* const kKeyValueSortS64SymbolName;
+extern const char* const kKeyValueSortU64SymbolName;
+extern const char* const kKeyValueSortF64SymbolName;
// All symbol names for XLA CPU runtime functions need to start with this
// prefix.
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index df8c2a636b..7e82375cc3 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -495,8 +495,149 @@ Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) {
}
Status IrEmitter::HandleSort(HloInstruction* sort) {
- // TODO(b/26783907): Implement sort on CPU.
- return Unimplemented("Sort is not implemented on CPU.");
+ TF_RETURN_IF_ERROR(EmitTargetAddressForOp(sort));
+ auto keys = sort->operand(0);
+ auto values = sort->operand_count() > 1 ? sort->operand(1) : nullptr;
+ ShapeIndex keys_shape_index({});
+ ShapeIndex values_shape_index({});
+ if (values != nullptr) {
+ keys_shape_index = ShapeIndex({0});
+ values_shape_index = ShapeIndex({1});
+ }
+ auto keys_destination = GetAllocationSlice(*sort, keys_shape_index);
+ auto keys_destination_address =
+ EmitBufferPointer(keys_destination, keys->shape());
+ auto values_destination = GetAllocationSlice(*sort, values_shape_index);
+ llvm::Value* values_destination_address = nullptr;
+
+ // The sort is implemented in-place, therefore we first copy the operand
+ // buffer to the output buffer if they are not the same.
+ if (keys_destination != GetAllocationSlice(*keys)) {
+ int64 primitive_type_size =
+ ShapeUtil::ByteSizeOfPrimitiveType(keys->shape().element_type());
+ auto source_buffer = GetEmittedValueFor(keys);
+ int64 keys_size = ByteSizeOf(keys->shape());
+ MemCpy(keys_destination_address, /*DstAlign=*/primitive_type_size,
+ source_buffer,
+ /*SrcAlign=*/primitive_type_size, keys_size);
+ }
+ if (values != nullptr) {
+ values_destination_address =
+ EmitBufferPointer(values_destination, values->shape());
+ if (values_destination != GetAllocationSlice(*values)) {
+ int64 primitive_type_size =
+ ShapeUtil::ByteSizeOfPrimitiveType(values->shape().element_type());
+ auto source_buffer = GetEmittedValueFor(values);
+ int64 values_size = ByteSizeOf(values->shape());
+ MemCpy(values_destination_address, /*DstAlign=*/primitive_type_size,
+ source_buffer,
+ /*SrcAlign=*/primitive_type_size, values_size);
+ }
+ }
+
+ // Normalize the shape and the dimension to sort.
+ Shape normalized_keys_shape =
+ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
+ keys->shape());
+ int64 physical_dimension_to_sort = LayoutUtil::MakeLogicalToPhysical(
+ keys->shape().layout())[sort->dimensions(0)];
+
+ int64 sort_dimension_elements =
+ normalized_keys_shape.dimensions(physical_dimension_to_sort);
+ int64 higher_dimensions = 1;
+ for (int64 i = 0; i < physical_dimension_to_sort; ++i) {
+ higher_dimensions *= normalized_keys_shape.dimensions(i);
+ }
+ int64 lower_dimensions = 1;
+ for (int64 i = ShapeUtil::Rank(normalized_keys_shape) - 1;
+ i > physical_dimension_to_sort; --i) {
+ lower_dimensions *= normalized_keys_shape.dimensions(i);
+ }
+
+ PrimitiveType keys_type = keys->shape().element_type();
+ const char* fn_name = nullptr;
+ llvm::Type* keys_native_type = nullptr;
+ switch (keys_type) {
+ case PRED:
+ fn_name = runtime::kKeyValueSortPREDSymbolName;
+ keys_native_type = b_.getInt8PtrTy();
+ break;
+ case S8:
+ fn_name = runtime::kKeyValueSortS8SymbolName;
+ keys_native_type = b_.getInt8PtrTy();
+ break;
+ case U8:
+ fn_name = runtime::kKeyValueSortU8SymbolName;
+ keys_native_type = b_.getInt8PtrTy();
+ break;
+ case S16:
+ fn_name = runtime::kKeyValueSortS16SymbolName;
+ keys_native_type = b_.getInt16Ty()->getPointerTo();
+ break;
+ case U16:
+ fn_name = runtime::kKeyValueSortU16SymbolName;
+ keys_native_type = b_.getInt16Ty()->getPointerTo();
+ break;
+ case F16:
+ fn_name = runtime::kKeyValueSortF16SymbolName;
+ keys_native_type = b_.getHalfTy()->getPointerTo();
+ break;
+ case S32:
+ fn_name = runtime::kKeyValueSortS32SymbolName;
+ keys_native_type = b_.getInt32Ty()->getPointerTo();
+ break;
+ case U32:
+ fn_name = runtime::kKeyValueSortU32SymbolName;
+ keys_native_type = b_.getInt32Ty()->getPointerTo();
+ break;
+ case F32:
+ fn_name = runtime::kKeyValueSortF32SymbolName;
+ keys_native_type = b_.getFloatTy()->getPointerTo();
+ break;
+ case S64:
+ fn_name = runtime::kKeyValueSortS64SymbolName;
+ keys_native_type = b_.getInt64Ty()->getPointerTo();
+ break;
+ case U64:
+ fn_name = runtime::kKeyValueSortU64SymbolName;
+ keys_native_type = b_.getInt64Ty()->getPointerTo();
+ break;
+ case F64:
+ fn_name = runtime::kKeyValueSortF64SymbolName;
+ keys_native_type = b_.getDoubleTy()->getPointerTo();
+ break;
+ default:
+ DLOG(FATAL) << "Element type " << PrimitiveType_Name(keys_type)
+ << " not supported in the Sort op on CPU.";
+ }
+
+ llvm::FunctionType* key_value_sort_type = llvm::FunctionType::get(
+ b_.getVoidTy(),
+ {keys_native_type, b_.getInt64Ty(), b_.getInt64Ty(), b_.getInt64Ty(),
+ b_.getInt8PtrTy(), b_.getInt32Ty()},
+ /*isVarArg=*/false);
+ auto* key_value_sort_func = llvm::cast<llvm::Function>(
+ module_->getOrInsertFunction(fn_name, key_value_sort_type));
+ key_value_sort_func->setCallingConv(llvm::CallingConv::C);
+ key_value_sort_func->setDoesNotThrow();
+ key_value_sort_func->setOnlyAccessesArgMemory();
+ Call(key_value_sort_func,
+ {PointerCast(keys_destination_address, keys_native_type),
+ b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements),
+ b_.getInt64(lower_dimensions),
+ values != nullptr
+ ? PointerCast(values_destination_address, b_.getInt8PtrTy())
+ : llvm::Constant::getNullValue(b_.getInt8PtrTy()),
+ b_.getInt32(values != nullptr ? ShapeUtil::ByteSizeOfPrimitiveType(
+ values->shape().element_type())
+ : 0)});
+
+ if (values != nullptr) {
+ llvm_ir::EmitTuple(GetIrArrayFor(sort),
+ {keys_destination_address, values_destination_address},
+ &b_, module_);
+ }
+ return Status::OK();
}
Status IrEmitter::HandleTuple(HloInstruction* tuple) {
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index 3df99464ba..daafef4eb3 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -163,6 +163,12 @@ class IrEmitter : public DfsHloVisitorWithDefault,
Status Preprocess(HloInstruction* hlo) override;
Status Postprocess(HloInstruction* hlo) override;
+ // A convenient helper for calling BufferAssignment::GetUniqueSlice.
+ BufferAllocation::Slice GetAllocationSlice(
+ const HloInstruction& hlo, const ShapeIndex& index = {}) const {
+ return assignment_.GetUniqueSlice(&hlo, index).ConsumeValueOrDie();
+ }
+
private:
// Private helper to initialize an IR function for the computation.
void InitializeIrFunction(const string& function_name);
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
index b4c0c09ec0..ede7f433ca 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
+++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
@@ -142,6 +142,7 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount(
opcode == HloOpcode::kGetTupleElement || opcode == HloOpcode::kBitcast ||
opcode == HloOpcode::kFft || opcode == HloOpcode::kInfeed ||
opcode == HloOpcode::kOutfeed || opcode == HloOpcode::kRng ||
+ opcode == HloOpcode::kSort ||
(opcode == HloOpcode::kConvolution &&
PotentiallyImplementedAsEigenConvolution(*instruction,
target_machine_features_)) ||
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc
new file mode 100644
index 0000000000..cef5420f00
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc
@@ -0,0 +1,237 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h"
+
+#include <algorithm>
+#include <cmath>
+#include <cstring>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/platform/dynamic_annotations.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace {
+using tensorflow::int16;
+using tensorflow::int32;
+using tensorflow::int64;
+using tensorflow::int8;
+using tensorflow::uint16;
+using tensorflow::uint32;
+using tensorflow::uint64;
+using tensorflow::uint8;
+
+template <typename KeyType>
+void KeyValueSort(std::pair<KeyType, int64>* row_to_sort, int64 num_elements) {
+ std::sort(row_to_sort, row_to_sort + num_elements);
+}
+
+// For floating point numbers, we want a total order comparator. -NaN and NaN
+// should appear at the beginning and end of the ordering, and -0.0 should
+// appear before 0.0. Also we want to have a stable sort, so if the keys are the
+// same, we compare the index values.
+template <typename KeyType>
+bool LessThan(KeyType lhs, int64 lhs_index, KeyType rhs, int64 rhs_index) {
+ bool lhs_is_negative = std::signbit(lhs);
+ bool rhs_is_negative = std::signbit(rhs);
+ // If the signs are different, we can just compare the signs.
+ if (lhs_is_negative != rhs_is_negative) {
+ return lhs_is_negative && !rhs_is_negative;
+ }
+ bool lhs_nan = std::isnan(lhs);
+ bool rhs_nan = std::isnan(rhs);
+ // Exactly one number is nan?
+ if (lhs_nan != rhs_nan) {
+ if (lhs_nan) {
+ return lhs_is_negative;
+ }
+ return !rhs_is_negative;
+ }
+ if (lhs != rhs) {
+ return lhs < rhs;
+ }
+ return lhs_index < rhs_index;
+}
+
+template <>
+void KeyValueSort(std::pair<double, int64>* row_to_sort, int64 num_elements) {
+ std::sort(row_to_sort, row_to_sort + num_elements,
+ [](const std::pair<double, int64>& lhs,
+ const std::pair<double, int64>& rhs) -> bool {
+ return LessThan(lhs.first, lhs.second, rhs.first, rhs.second);
+ });
+}
+
+template <>
+void KeyValueSort(std::pair<float, int64>* row_to_sort, int64 num_elements) {
+ std::sort(row_to_sort, row_to_sort + num_elements,
+ [](const std::pair<float, int64>& lhs,
+ const std::pair<float, int64>& rhs) -> bool {
+ return LessThan(lhs.first, lhs.second, rhs.first, rhs.second);
+ });
+}
+
+template <>
+void KeyValueSort(std::pair<Eigen::half, int64>* row_to_sort,
+ int64 num_elements) {
+ std::sort(row_to_sort, row_to_sort + num_elements,
+ [](const std::pair<Eigen::half, int64>& lhs,
+ const std::pair<Eigen::half, int64>& rhs) -> bool {
+ return LessThan(
+ Eigen::half_impl::half_to_float(lhs.first), lhs.second,
+ Eigen::half_impl::half_to_float(rhs.first), rhs.second);
+ });
+}
+
+template <typename KeyType>
+void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ // High-level idea of the iteration/sorting logic:
+ // Conceptually we have a 3-dimensional shape [a, b, c]. b corresponds to the
+ // dimension to sort, c is the product of the more minor dimensions (set to 1
+ // if b is the most minor dimension), and a is the product of the more major
+ // dimensions (set to 1 if b is the most major dimension). There are a * c
+ // many rows that we need to sort. We iterate through these, calculate a
+ // 'base_offset' value which points to the first element in that row, and add
+ // i * c for accessing the 'i'-th element in that row.
+
+ int64 sort_dimension_elements = b;
+ int64 num_iteration_elements = a * c;
+ int64 sort_dimension_offset = c;
+
+ std::unique_ptr<std::pair<KeyType, int64>[]> row_to_sort(
+ new std::pair<KeyType, int64>[sort_dimension_elements]);
+ std::unique_ptr<std::string[]> reordered_values(
+ new std::string[sort_dimension_elements]);
+ for (int64 index = 0; index < num_iteration_elements; ++index) {
+ // 'index' can be split into two values which index into the 'c' dimension
+ // and the 'a' dimension, respectively. 'index' % 'c' is the index into the
+ // 'c' dimension, 'index' / 'c' is the index into the 'a' dimension. When
+ // calculating the base offset, we need to multiply the index into the 'a'
+ // dimension with 'b' * 'c'.
+ // 'index' / 'c' * 'c' * 'b' = ('index' - 'index' % 'c') * 'b'.
+ int64 base_offset =
+ index % sort_dimension_offset +
+ (index - index % sort_dimension_offset) * sort_dimension_elements;
+ // TODO(b/26783907): We could define a custom iterator class that references
+ // both arrays. Then we could avoid the intermediate copy. However this
+ // would become more complicated, and it is not clear if the benefit is high
+ // enough.
+ for (int64 i = 0; i < sort_dimension_elements; ++i) {
+ row_to_sort[i] =
+ std::make_pair(keys[base_offset + i * sort_dimension_offset], i);
+ }
+ KeyValueSort(row_to_sort.get(), sort_dimension_elements);
+ for (int64 i = 0; i < sort_dimension_elements; ++i) {
+ keys[base_offset + i * sort_dimension_offset] = row_to_sort[i].first;
+ }
+ if (values == nullptr) {
+ continue;
+ }
+
+ // Reorder the values according to the order defined by the keys.
+ for (int64 i = 0; i < sort_dimension_elements; ++i) {
+ int64 memory_index =
+ (base_offset + row_to_sort[i].second * sort_dimension_offset) *
+ values_primitive_type_size_in_bytes;
+
+ reordered_values[i] = std::string(values + memory_index,
+ values_primitive_type_size_in_bytes);
+ }
+ for (int64 i = 0; i < sort_dimension_elements; ++i) {
+ int64 memory_index = (base_offset + i * sort_dimension_offset) *
+ values_primitive_type_size_in_bytes;
+ memcpy(values + memory_index, reordered_values[i].c_str(),
+ values_primitive_type_size_in_bytes);
+ }
+ }
+}
+} // namespace
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortPRED(
+ bool* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS8(
+ int8* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU8(
+ uint8* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS16(
+ int16* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU16(
+ uint16* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF16(
+ Eigen::half* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS32(
+ int32* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU32(
+ uint32* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF32(
+ float* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS64(
+ int64* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU64(
+ uint64* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF64(
+ double* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h
new file mode 100644
index 0000000000..28e35e82c1
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h
@@ -0,0 +1,88 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/platform/types.h"
+
+extern "C" {
+
+// 'keys' represents a 3-dimensional shape with dimensions [a, b, c]. The 'b'
+// dimension of 'keys' is sorted into ascending order. 'values' can be nullptr.
+// If 'values' is not nullptr, the elements in 'values' are reordered in such a
+// way that if the element at index 'i' in 'keys' was moved to index 'j', the
+// element at index 'i' in 'values' is also moved to index 'j' (which means that
+// the same elements correspond to each other as before).
+extern void __xla_cpu_runtime_KeyValueSortPRED(
+ bool* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c,
+ char* values, tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortS8(
+ tensorflow::int8* keys, tensorflow::int64 a, tensorflow::int64 b,
+ tensorflow::int64 c, char* values,
+ tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortU8(
+ tensorflow::uint8* keys, tensorflow::int64 a, tensorflow::int64 b,
+ tensorflow::int64 c, char* values,
+ tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortS16(
+ tensorflow::int16* keys, tensorflow::int64 a, tensorflow::int64 b,
+ tensorflow::int64 c, char* values,
+ tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortU16(
+ tensorflow::uint16* keys, tensorflow::int64 a, tensorflow::int64 b,
+ tensorflow::int64 c, char* values,
+ tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortF16(
+ Eigen::half* keys, tensorflow::int64 a, tensorflow::int64 b,
+ tensorflow::int64 c, char* values,
+ tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortS32(
+ tensorflow::int32* keys, tensorflow::int64 a, tensorflow::int64 b,
+ tensorflow::int64 c, char* values,
+ tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortU32(
+ tensorflow::uint32* keys, tensorflow::int64 a, tensorflow::int64 b,
+ tensorflow::int64 c, char* values,
+ tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortF32(
+ float* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c,
+ char* values, tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortS64(
+ tensorflow::int64* keys, tensorflow::int64 a, tensorflow::int64 b,
+ tensorflow::int64 c, char* values,
+ tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortU64(
+ tensorflow::uint64* keys, tensorflow::int64 a, tensorflow::int64 b,
+ tensorflow::int64 c, char* values,
+ tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortF64(
+ double* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c,
+ char* values, tensorflow::int32 values_primitive_type_size_in_bytes);
+}
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_
diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
index bf98064647..9ec0c8f657 100644
--- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
+++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
@@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/runtime_fft.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_fp16.h"
+#include "tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h"
@@ -202,6 +203,18 @@ bool RegisterKnownJITSymbols() {
REGISTER_CPU_RUNTIME_SYMBOL(ParallelForkJoin);
REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue);
REGISTER_CPU_RUNTIME_SYMBOL(ReleaseOutfeedBufferAfterPopulation);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortPRED);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS8);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU8);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS16);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU16);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF16);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS32);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU32);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF32);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS64);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU64);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF64);
registry->Register("__gnu_f2h_ieee", reinterpret_cast<void*>(__gnu_f2h_ieee));
registry->Register("__gnu_h2f_ieee", reinterpret_cast<void*>(__gnu_h2f_ieee));
diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD
index c55206eee7..4b129c95d4 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD
@@ -180,3 +180,17 @@ tf_cc_test(
"//tensorflow/core:test_main",
],
)
+
+tf_cc_test(
+ name = "cpu_key_value_sort_test",
+ srcs = ["cpu_key_value_sort_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_parser",
+ "//tensorflow/compiler/xla/service/cpu:cpu_compiler",
+ "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc
new file mode 100644
index 0000000000..3934c03a04
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc
@@ -0,0 +1,54 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
+#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
+
+namespace xla {
+namespace cpu {
+namespace {
+class CpuKeyValueSortTest : public CpuCodegenTest {};
+
+TEST_F(CpuKeyValueSortTest, SortR1) {
+ const string hlo_text = R"(
+HloModule KeyValueSort
+
+ENTRY main {
+ a = f32[10] parameter(0)
+
+ ROOT result = f32[10] sort(f32[10] a), dimensions={0}
+}
+)";
+
+ string filecheck_pattern = R"(
+CHECK: call void @__xla_cpu_runtime_KeyValueSort
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(hlo_text));
+
+ CpuAotCompilationOptions options{
+ /*triple=*/"x86_64", /*cpu_name=*/"", /*features=*/"",
+ /*entry_point_name=*/"entry",
+ /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static};
+
+ CompileAheadOfTimeAndVerifyIr(std::move(module), options, filecheck_pattern,
+ /*match_optimized_ir=*/true);
+}
+
+} // namespace
+} // namespace cpu
+} // namespace xla