diff options
author | Adrian Kuegel <akuegel@google.com> | 2018-09-19 01:55:55 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-19 01:59:05 -0700 |
commit | 0d8942fcbcc9cb3a05be8acc843d1fc4b6dfc9f1 (patch) | |
tree | 28cccc167c27fa13e207860a4f339d9c28cd3e74 | |
parent | 6967287715a097c8b009b52010c53247ab658232 (diff) |
Implement sort op for CPU.
Also don't allow parallelization for the sort op in parallel_task_assignment.
PiperOrigin-RevId: 213592046
11 files changed, 605 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 8cc522a59e..b3e4fab727 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -180,6 +180,7 @@ cc_library( ":runtime_conv2d_mkl", ":runtime_fft", ":runtime_fork_join", + ":runtime_key_value_sort", ":runtime_matmul", ":runtime_matmul_mkl", ":runtime_single_threaded_conv2d", @@ -624,6 +625,18 @@ cc_library( ) cc_library( + name = "runtime_key_value_sort", + srcs = ["runtime_key_value_sort.cc"], + hdrs = ["runtime_key_value_sort.h"], + copts = runtime_copts(), + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:lib", + "//third_party/eigen3", + ], +) + +cc_library( name = "runtime_fork_join", srcs = ["runtime_fork_join.cc"], hdrs = ["runtime_fork_join.h"], diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index 8a44c384bb..7e1590955a 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -74,6 +74,30 @@ extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName = "__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation"; extern const char* const kParallelForkJoinSymbolName = "__xla_cpu_runtime_ParallelForkJoin"; +extern const char* const kKeyValueSortPREDSymbolName = + "__xla_cpu_runtime_KeyValueSortPRED"; +extern const char* const kKeyValueSortS8SymbolName = + "__xla_cpu_runtime_KeyValueSortS8"; +extern const char* const kKeyValueSortU8SymbolName = + "__xla_cpu_runtime_KeyValueSortU8"; +extern const char* const kKeyValueSortS16SymbolName = + "__xla_cpu_runtime_KeyValueSortS16"; +extern const char* const kKeyValueSortU16SymbolName = + "__xla_cpu_runtime_KeyValueSortU16"; +extern const char* const kKeyValueSortF16SymbolName = + "__xla_cpu_runtime_KeyValueSortF16"; +extern const char* const kKeyValueSortS32SymbolName = + "__xla_cpu_runtime_KeyValueSortS32"; +extern const char* const kKeyValueSortU32SymbolName = + "__xla_cpu_runtime_KeyValueSortU32"; +extern const char* const kKeyValueSortF32SymbolName = + "__xla_cpu_runtime_KeyValueSortF32"; +extern const char* const kKeyValueSortS64SymbolName = + "__xla_cpu_runtime_KeyValueSortS64"; +extern const char* const kKeyValueSortU64SymbolName = + "__xla_cpu_runtime_KeyValueSortU64"; +extern const char* const kKeyValueSortF64SymbolName = + "__xla_cpu_runtime_KeyValueSortF64"; extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_"; } // namespace runtime diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h index aa0e967123..e6345e0344 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h @@ -63,6 +63,18 @@ extern const char* const kReleaseInfeedBufferAfterDequeueSymbolName; extern const char* const kAcquireOutfeedBufferForPopulationSymbolName; extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName; extern const char* const kParallelForkJoinSymbolName; +extern const char* const kKeyValueSortPREDSymbolName; +extern const char* const kKeyValueSortS8SymbolName; +extern const char* const kKeyValueSortU8SymbolName; +extern const char* const kKeyValueSortS16SymbolName; +extern const char* const kKeyValueSortU16SymbolName; +extern const char* const kKeyValueSortF16SymbolName; +extern const char* const kKeyValueSortS32SymbolName; +extern const char* const kKeyValueSortU32SymbolName; +extern const char* const kKeyValueSortF32SymbolName; +extern const char* const kKeyValueSortS64SymbolName; +extern const char* const kKeyValueSortU64SymbolName; +extern const char* const kKeyValueSortF64SymbolName; // All symbol names for XLA CPU runtime functions need to start with this // prefix. diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index df8c2a636b..7e82375cc3 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -495,8 +495,149 @@ Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) { } Status IrEmitter::HandleSort(HloInstruction* sort) { - // TODO(b/26783907): Implement sort on CPU. - return Unimplemented("Sort is not implemented on CPU."); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(sort)); + auto keys = sort->operand(0); + auto values = sort->operand_count() > 1 ? sort->operand(1) : nullptr; + ShapeIndex keys_shape_index({}); + ShapeIndex values_shape_index({}); + if (values != nullptr) { + keys_shape_index = ShapeIndex({0}); + values_shape_index = ShapeIndex({1}); + } + auto keys_destination = GetAllocationSlice(*sort, keys_shape_index); + auto keys_destination_address = + EmitBufferPointer(keys_destination, keys->shape()); + auto values_destination = GetAllocationSlice(*sort, values_shape_index); + llvm::Value* values_destination_address = nullptr; + + // The sort is implemented in-place, therefore we first copy the operand + // buffer to the output buffer if they are not the same. + if (keys_destination != GetAllocationSlice(*keys)) { + int64 primitive_type_size = + ShapeUtil::ByteSizeOfPrimitiveType(keys->shape().element_type()); + auto source_buffer = GetEmittedValueFor(keys); + int64 keys_size = ByteSizeOf(keys->shape()); + MemCpy(keys_destination_address, /*DstAlign=*/primitive_type_size, + source_buffer, + /*SrcAlign=*/primitive_type_size, keys_size); + } + if (values != nullptr) { + values_destination_address = + EmitBufferPointer(values_destination, values->shape()); + if (values_destination != GetAllocationSlice(*values)) { + int64 primitive_type_size = + ShapeUtil::ByteSizeOfPrimitiveType(values->shape().element_type()); + auto source_buffer = GetEmittedValueFor(values); + int64 values_size = ByteSizeOf(values->shape()); + MemCpy(values_destination_address, /*DstAlign=*/primitive_type_size, + source_buffer, + /*SrcAlign=*/primitive_type_size, values_size); + } + } + + // Normalize the shape and the dimension to sort. + Shape normalized_keys_shape = + ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( + keys->shape()); + int64 physical_dimension_to_sort = LayoutUtil::MakeLogicalToPhysical( + keys->shape().layout())[sort->dimensions(0)]; + + int64 sort_dimension_elements = + normalized_keys_shape.dimensions(physical_dimension_to_sort); + int64 higher_dimensions = 1; + for (int64 i = 0; i < physical_dimension_to_sort; ++i) { + higher_dimensions *= normalized_keys_shape.dimensions(i); + } + int64 lower_dimensions = 1; + for (int64 i = ShapeUtil::Rank(normalized_keys_shape) - 1; + i > physical_dimension_to_sort; --i) { + lower_dimensions *= normalized_keys_shape.dimensions(i); + } + + PrimitiveType keys_type = keys->shape().element_type(); + const char* fn_name = nullptr; + llvm::Type* keys_native_type = nullptr; + switch (keys_type) { + case PRED: + fn_name = runtime::kKeyValueSortPREDSymbolName; + keys_native_type = b_.getInt8PtrTy(); + break; + case S8: + fn_name = runtime::kKeyValueSortS8SymbolName; + keys_native_type = b_.getInt8PtrTy(); + break; + case U8: + fn_name = runtime::kKeyValueSortU8SymbolName; + keys_native_type = b_.getInt8PtrTy(); + break; + case S16: + fn_name = runtime::kKeyValueSortS16SymbolName; + keys_native_type = b_.getInt16Ty()->getPointerTo(); + break; + case U16: + fn_name = runtime::kKeyValueSortU16SymbolName; + keys_native_type = b_.getInt16Ty()->getPointerTo(); + break; + case F16: + fn_name = runtime::kKeyValueSortF16SymbolName; + keys_native_type = b_.getHalfTy()->getPointerTo(); + break; + case S32: + fn_name = runtime::kKeyValueSortS32SymbolName; + keys_native_type = b_.getInt32Ty()->getPointerTo(); + break; + case U32: + fn_name = runtime::kKeyValueSortU32SymbolName; + keys_native_type = b_.getInt32Ty()->getPointerTo(); + break; + case F32: + fn_name = runtime::kKeyValueSortF32SymbolName; + keys_native_type = b_.getFloatTy()->getPointerTo(); + break; + case S64: + fn_name = runtime::kKeyValueSortS64SymbolName; + keys_native_type = b_.getInt64Ty()->getPointerTo(); + break; + case U64: + fn_name = runtime::kKeyValueSortU64SymbolName; + keys_native_type = b_.getInt64Ty()->getPointerTo(); + break; + case F64: + fn_name = runtime::kKeyValueSortF64SymbolName; + keys_native_type = b_.getDoubleTy()->getPointerTo(); + break; + default: + DLOG(FATAL) << "Element type " << PrimitiveType_Name(keys_type) + << " not supported in the Sort op on CPU."; + } + + llvm::FunctionType* key_value_sort_type = llvm::FunctionType::get( + b_.getVoidTy(), + {keys_native_type, b_.getInt64Ty(), b_.getInt64Ty(), b_.getInt64Ty(), + b_.getInt8PtrTy(), b_.getInt32Ty()}, + /*isVarArg=*/false); + auto* key_value_sort_func = llvm::cast<llvm::Function>( + module_->getOrInsertFunction(fn_name, key_value_sort_type)); + key_value_sort_func->setCallingConv(llvm::CallingConv::C); + key_value_sort_func->setDoesNotThrow(); + key_value_sort_func->setOnlyAccessesArgMemory(); + Call(key_value_sort_func, + {PointerCast(keys_destination_address, keys_native_type), + b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements), + b_.getInt64(lower_dimensions), + values != nullptr + ? PointerCast(values_destination_address, b_.getInt8PtrTy()) + : llvm::Constant::getNullValue(b_.getInt8PtrTy()), + b_.getInt32(values != nullptr ? ShapeUtil::ByteSizeOfPrimitiveType( + values->shape().element_type()) + : 0)}); + + if (values != nullptr) { + llvm_ir::EmitTuple(GetIrArrayFor(sort), + {keys_destination_address, values_destination_address}, + &b_, module_); + } + return Status::OK(); } Status IrEmitter::HandleTuple(HloInstruction* tuple) { diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 3df99464ba..daafef4eb3 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -163,6 +163,12 @@ class IrEmitter : public DfsHloVisitorWithDefault, Status Preprocess(HloInstruction* hlo) override; Status Postprocess(HloInstruction* hlo) override; + // A convenient helper for calling BufferAssignment::GetUniqueSlice. + BufferAllocation::Slice GetAllocationSlice( + const HloInstruction& hlo, const ShapeIndex& index = {}) const { + return assignment_.GetUniqueSlice(&hlo, index).ConsumeValueOrDie(); + } + private: // Private helper to initialize an IR function for the computation. void InitializeIrFunction(const string& function_name); diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index b4c0c09ec0..ede7f433ca 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -142,6 +142,7 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount( opcode == HloOpcode::kGetTupleElement || opcode == HloOpcode::kBitcast || opcode == HloOpcode::kFft || opcode == HloOpcode::kInfeed || opcode == HloOpcode::kOutfeed || opcode == HloOpcode::kRng || + opcode == HloOpcode::kSort || (opcode == HloOpcode::kConvolution && PotentiallyImplementedAsEigenConvolution(*instruction, target_machine_features_)) || diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc new file mode 100644 index 0000000000..cef5420f00 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc @@ -0,0 +1,237 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h" + +#include <algorithm> +#include <cmath> +#include <cstring> +#include <memory> +#include <string> +#include <utility> + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace { +using tensorflow::int16; +using tensorflow::int32; +using tensorflow::int64; +using tensorflow::int8; +using tensorflow::uint16; +using tensorflow::uint32; +using tensorflow::uint64; +using tensorflow::uint8; + +template <typename KeyType> +void KeyValueSort(std::pair<KeyType, int64>* row_to_sort, int64 num_elements) { + std::sort(row_to_sort, row_to_sort + num_elements); +} + +// For floating point numbers, we want a total order comparator. -NaN and NaN +// should appear at the beginning and end of the ordering, and -0.0 should +// appear before 0.0. Also we want to have a stable sort, so if the keys are the +// same, we compare the index values. +template <typename KeyType> +bool LessThan(KeyType lhs, int64 lhs_index, KeyType rhs, int64 rhs_index) { + bool lhs_is_negative = std::signbit(lhs); + bool rhs_is_negative = std::signbit(rhs); + // If the signs are different, we can just compare the signs. + if (lhs_is_negative != rhs_is_negative) { + return lhs_is_negative && !rhs_is_negative; + } + bool lhs_nan = std::isnan(lhs); + bool rhs_nan = std::isnan(rhs); + // Exactly one number is nan? + if (lhs_nan != rhs_nan) { + if (lhs_nan) { + return lhs_is_negative; + } + return !rhs_is_negative; + } + if (lhs != rhs) { + return lhs < rhs; + } + return lhs_index < rhs_index; +} + +template <> +void KeyValueSort(std::pair<double, int64>* row_to_sort, int64 num_elements) { + std::sort(row_to_sort, row_to_sort + num_elements, + [](const std::pair<double, int64>& lhs, + const std::pair<double, int64>& rhs) -> bool { + return LessThan(lhs.first, lhs.second, rhs.first, rhs.second); + }); +} + +template <> +void KeyValueSort(std::pair<float, int64>* row_to_sort, int64 num_elements) { + std::sort(row_to_sort, row_to_sort + num_elements, + [](const std::pair<float, int64>& lhs, + const std::pair<float, int64>& rhs) -> bool { + return LessThan(lhs.first, lhs.second, rhs.first, rhs.second); + }); +} + +template <> +void KeyValueSort(std::pair<Eigen::half, int64>* row_to_sort, + int64 num_elements) { + std::sort(row_to_sort, row_to_sort + num_elements, + [](const std::pair<Eigen::half, int64>& lhs, + const std::pair<Eigen::half, int64>& rhs) -> bool { + return LessThan( + Eigen::half_impl::half_to_float(lhs.first), lhs.second, + Eigen::half_impl::half_to_float(rhs.first), rhs.second); + }); +} + +template <typename KeyType> +void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + // High-level idea of the iteration/sorting logic: + // Conceptually we have a 3-dimensional shape [a, b, c]. b corresponds to the + // dimension to sort, c is the product of the more minor dimensions (set to 1 + // if b is the most minor dimension), and a is the product of the more major + // dimensions (set to 1 if b is the most major dimension). There are a * c + // many rows that we need to sort. We iterate through these, calculate a + // 'base_offset' value which points to the first element in that row, and add + // i * c for accessing the 'i'-th element in that row. + + int64 sort_dimension_elements = b; + int64 num_iteration_elements = a * c; + int64 sort_dimension_offset = c; + + std::unique_ptr<std::pair<KeyType, int64>[]> row_to_sort( + new std::pair<KeyType, int64>[sort_dimension_elements]); + std::unique_ptr<std::string[]> reordered_values( + new std::string[sort_dimension_elements]); + for (int64 index = 0; index < num_iteration_elements; ++index) { + // 'index' can be split into two values which index into the 'c' dimension + // and the 'a' dimension, respectively. 'index' % 'c' is the index into the + // 'c' dimension, 'index' / 'c' is the index into the 'a' dimension. When + // calculating the base offset, we need to multiply the index into the 'a' + // dimension with 'b' * 'c'. + // 'index' / 'c' * 'c' * 'b' = ('index' - 'index' % 'c') * 'b'. + int64 base_offset = + index % sort_dimension_offset + + (index - index % sort_dimension_offset) * sort_dimension_elements; + // TODO(b/26783907): We could define a custom iterator class that references + // both arrays. Then we could avoid the intermediate copy. However this + // would become more complicated, and it is not clear if the benefit is high + // enough. + for (int64 i = 0; i < sort_dimension_elements; ++i) { + row_to_sort[i] = + std::make_pair(keys[base_offset + i * sort_dimension_offset], i); + } + KeyValueSort(row_to_sort.get(), sort_dimension_elements); + for (int64 i = 0; i < sort_dimension_elements; ++i) { + keys[base_offset + i * sort_dimension_offset] = row_to_sort[i].first; + } + if (values == nullptr) { + continue; + } + + // Reorder the values according to the order defined by the keys. + for (int64 i = 0; i < sort_dimension_elements; ++i) { + int64 memory_index = + (base_offset + row_to_sort[i].second * sort_dimension_offset) * + values_primitive_type_size_in_bytes; + + reordered_values[i] = std::string(values + memory_index, + values_primitive_type_size_in_bytes); + } + for (int64 i = 0; i < sort_dimension_elements; ++i) { + int64 memory_index = (base_offset + i * sort_dimension_offset) * + values_primitive_type_size_in_bytes; + memcpy(values + memory_index, reordered_values[i].c_str(), + values_primitive_type_size_in_bytes); + } + } +} +} // namespace + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortPRED( + bool* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS8( + int8* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU8( + uint8* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS16( + int16* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU16( + uint16* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF16( + Eigen::half* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS32( + int32* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU32( + uint32* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF32( + float* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS64( + int64* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU64( + uint64* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF64( + double* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h new file mode 100644 index 0000000000..28e35e82c1 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h @@ -0,0 +1,88 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_ + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/platform/types.h" + +extern "C" { + +// 'keys' represents a 3-dimensional shape with dimensions [a, b, c]. The 'b' +// dimension of 'keys' is sorted into ascending order. 'values' can be nullptr. +// If 'values' is not nullptr, the elements in 'values' are reordered in such a +// way that if the element at index 'i' in 'keys' was moved to index 'j', the +// element at index 'i' in 'values' is also moved to index 'j' (which means that +// the same elements correspond to each other as before). +extern void __xla_cpu_runtime_KeyValueSortPRED( + bool* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c, + char* values, tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortS8( + tensorflow::int8* keys, tensorflow::int64 a, tensorflow::int64 b, + tensorflow::int64 c, char* values, + tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortU8( + tensorflow::uint8* keys, tensorflow::int64 a, tensorflow::int64 b, + tensorflow::int64 c, char* values, + tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortS16( + tensorflow::int16* keys, tensorflow::int64 a, tensorflow::int64 b, + tensorflow::int64 c, char* values, + tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortU16( + tensorflow::uint16* keys, tensorflow::int64 a, tensorflow::int64 b, + tensorflow::int64 c, char* values, + tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortF16( + Eigen::half* keys, tensorflow::int64 a, tensorflow::int64 b, + tensorflow::int64 c, char* values, + tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortS32( + tensorflow::int32* keys, tensorflow::int64 a, tensorflow::int64 b, + tensorflow::int64 c, char* values, + tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortU32( + tensorflow::uint32* keys, tensorflow::int64 a, tensorflow::int64 b, + tensorflow::int64 c, char* values, + tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortF32( + float* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c, + char* values, tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortS64( + tensorflow::int64* keys, tensorflow::int64 a, tensorflow::int64 b, + tensorflow::int64 c, char* values, + tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortU64( + tensorflow::uint64* keys, tensorflow::int64 a, tensorflow::int64 b, + tensorflow::int64 c, char* values, + tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortF64( + double* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c, + char* values, tensorflow::int32 values_primitive_type_size_in_bytes); +} + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_ diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index bf98064647..9ec0c8f657 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/runtime_fft.h" #include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h" #include "tensorflow/compiler/xla/service/cpu/runtime_fp16.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h" @@ -202,6 +203,18 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(ParallelForkJoin); REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue); REGISTER_CPU_RUNTIME_SYMBOL(ReleaseOutfeedBufferAfterPopulation); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortPRED); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS8); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU8); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS16); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU16); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF16); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS32); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU32); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF32); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS64); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU64); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF64); registry->Register("__gnu_f2h_ieee", reinterpret_cast<void*>(__gnu_f2h_ieee)); registry->Register("__gnu_h2f_ieee", reinterpret_cast<void*>(__gnu_h2f_ieee)); diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD index c55206eee7..4b129c95d4 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD @@ -180,3 +180,17 @@ tf_cc_test( "//tensorflow/core:test_main", ], ) + +tf_cc_test( + name = "cpu_key_value_sort_test", + srcs = ["cpu_key_value_sort_test.cc"], + deps = [ + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service/cpu:cpu_compiler", + "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc new file mode 100644 index 0000000000..3934c03a04 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc @@ -0,0 +1,54 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" +#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" + +namespace xla { +namespace cpu { +namespace { +class CpuKeyValueSortTest : public CpuCodegenTest {}; + +TEST_F(CpuKeyValueSortTest, SortR1) { + const string hlo_text = R"( +HloModule KeyValueSort + +ENTRY main { + a = f32[10] parameter(0) + + ROOT result = f32[10] sort(f32[10] a), dimensions={0} +} +)"; + + string filecheck_pattern = R"( +CHECK: call void @__xla_cpu_runtime_KeyValueSort +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseHloString(hlo_text)); + + CpuAotCompilationOptions options{ + /*triple=*/"x86_64", /*cpu_name=*/"", /*features=*/"", + /*entry_point_name=*/"entry", + /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; + + CompileAheadOfTimeAndVerifyIr(std::move(module), options, filecheck_pattern, + /*match_optimized_ir=*/true); +} + +} // namespace +} // namespace cpu +} // namespace xla |