aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-02 13:01:59 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-02 13:05:54 -0800
commit5bf26acd87d3d44183fc28cb9576cda10c0255ca (patch)
tree5486c0ab9077496e30b594d35c089a45f7b7e18c /tensorflow/compiler/xla/service
parent97a843db78745fe3a8e418b3b1e93ef79fbfff12 (diff)
Automated g4 rollback of changelist 180000981
PiperOrigin-RevId: 180581912
Diffstat (limited to 'tensorflow/compiler/xla/service')
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD19
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc2
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_runtime.cc1
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_runtime.h1
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc49
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h1
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc3
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_fft.cc37
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_fft.h31
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h238
-rw-r--r--tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc2
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor.h1
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h3
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD3
-rw-r--r--tensorflow/compiler/xla/service/gpu/fft_thunk.cc234
-rw-r--r--tensorflow/compiler/xla/service/gpu/fft_thunk.h98
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.cc8
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.h5
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc19
-rw-r--r--tensorflow/compiler/xla/service/gpu/thunk.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo.proto6
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.cc15
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc1
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc32
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h21
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc8
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc1
-rw-r--r--tensorflow/compiler/xla/service/service.cc5
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc72
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.h5
-rw-r--r--tensorflow/compiler/xla/service/user_computation.cc47
-rw-r--r--tensorflow/compiler/xla/service/user_computation.h4
34 files changed, 972 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index e35d947525..9754f8d9af 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -165,6 +165,7 @@ cc_library(
":external_constant_pool",
":orc_jit_memory_mapper",
":runtime_conv2d",
+ ":runtime_fft",
":runtime_fork_join",
":runtime_matmul",
":runtime_single_threaded_conv2d",
@@ -504,6 +505,24 @@ cc_library(
)
cc_library(
+ name = "runtime_fft",
+ srcs = [
+ "runtime_fft.cc",
+ "runtime_fft_impl.h",
+ ],
+ hdrs = ["runtime_fft.h"],
+ copts = runtime_copts(),
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/compiler/xla:executable_run_options",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_lite",
+ "//third_party/eigen3",
+ ],
+)
+
+cc_library(
name = "runtime_matvec",
srcs = ["runtime_matvec.cc"],
hdrs = ["runtime_matvec.h"],
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index ba208d7249..9a5e146b5a 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -705,6 +705,8 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
ir_module_string = llvm_ir::DumpModuleToString(*llvm_module);
}
+ XLA_VLOG_LINES(2, "LLVM IR:\n" + llvm_ir::DumpModuleToString(*llvm_module));
+
// JIT compile the LLVM IR module to in-memory machine code.
jit->AddModule(std::move(llvm_module));
cpu_executable.reset(new CpuExecutable(
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
index 7908dc173d..1ef45dbec3 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
@@ -37,6 +37,7 @@ extern const char* const kEigenMatMulF64SymbolName =
"__xla_cpu_runtime_EigenMatMulF64";
extern const char* const kEigenConvF32SymbolName =
"__xla_cpu_runtime_EigenConvF32";
+extern const char* const kEigenFftSymbolName = "__xla_cpu_runtime_EigenFft";
extern const char* const kEigenSingleThreadedMatMulF32SymbolName =
"__xla_cpu_runtime_EigenSingleThreadedMatMulF32";
extern const char* const kEigenSingleThreadedMatMulF64SymbolName =
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h
index 2ade455b8a..3e1f080711 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h
@@ -44,6 +44,7 @@ namespace runtime {
extern const char* const kEigenMatMulF32SymbolName;
extern const char* const kEigenMatMulF64SymbolName;
extern const char* const kEigenConvF32SymbolName;
+extern const char* const kEigenFftSymbolName;
extern const char* const kEigenSingleThreadedMatMulF32SymbolName;
extern const char* const kEigenSingleThreadedMatMulF64SymbolName;
extern const char* const kEigenSingleThreadedConvF32SymbolName;
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 4165f920d2..26bd9ad326 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -1135,6 +1135,55 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
});
}
+Status IrEmitter::HandleFft(HloInstruction* fft) {
+ auto operand = fft->operand(0);
+ TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
+ /*instruction=*/*fft, /*operands=*/{operand},
+ /*supported_types=*/{F32, C64}));
+ TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(operand->shape().layout()));
+ TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(fft->shape().layout()));
+ VLOG(3) << "operand=" << ShapeUtil::HumanStringWithLayout(operand->shape());
+ VLOG(3) << "fft=" << ShapeUtil::HumanStringWithLayout(fft->shape());
+
+ llvm::Value* operand_address = GetEmittedValueFor(operand);
+ TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fft));
+
+ const std::vector<int64>& fft_length = fft->fft_length();
+ int64 input_batch = 1;
+ for (int i = 0; i < fft->shape().dimensions_size() - fft_length.size(); i++) {
+ input_batch *= fft->shape().dimensions(i);
+ }
+
+ // Args have been computed, make the call.
+ llvm::Type* int8_ptr_type = ir_builder_.getInt8Ty()->getPointerTo();
+ llvm::Type* int32_type = ir_builder_.getInt32Ty();
+ llvm::Type* int64_type = ir_builder_.getInt64Ty();
+ llvm::FunctionType* fft_type = llvm::FunctionType::get(
+ ir_builder_.getVoidTy(),
+ {int8_ptr_type, int8_ptr_type, int8_ptr_type, int32_type, int32_type,
+ int64_type, int64_type, int64_type, int64_type},
+ /*isVarArg=*/false);
+ const char* fn_name = runtime::kEigenFftSymbolName;
+ llvm::Function* fft_func = llvm::cast<llvm::Function>(
+ module_->getOrInsertFunction(fn_name, fft_type));
+ fft_func->setCallingConv(llvm::CallingConv::C);
+ fft_func->setDoesNotThrow();
+ fft_func->setOnlyAccessesInaccessibleMemOrArgMem();
+ const int fft_rank = fft_length.size();
+ ir_builder_.CreateCall(
+ fft_func,
+ {GetExecutableRunOptionsArgument(),
+ ir_builder_.CreateBitCast(GetEmittedValueFor(fft), int8_ptr_type),
+ ir_builder_.CreateBitCast(operand_address, int8_ptr_type),
+ ir_builder_.getInt32(fft->fft_type()), ir_builder_.getInt32(fft_rank),
+ ir_builder_.getInt64(input_batch),
+ ir_builder_.getInt64(fft_rank > 0 ? fft_length[0] : 0),
+ ir_builder_.getInt64(fft_rank > 1 ? fft_length[1] : 0),
+ ir_builder_.getInt64(fft_rank > 2 ? fft_length[2] : 0)});
+
+ return Status::OK();
+}
+
Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) {
if (hlo_module_config_.replica_count() == 1) {
// When there is a single replica, a cross replica sum is the identity
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index 2341e3ea72..b8d71eba18 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -158,6 +158,7 @@ class IrEmitter : public DfsHloVisitorWithDefault {
Status HandleSelect(HloInstruction* select) override;
Status HandleDot(HloInstruction* dot) override;
Status HandleConvolution(HloInstruction* convolution) override;
+ Status HandleFft(HloInstruction* fft) override;
Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override;
Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override;
Status HandleCrossReplicaSum(HloInstruction* crs) override;
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
index 4b44ac8941..deb21bf4ef 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
+++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
@@ -126,7 +126,7 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount(
HloInstruction* instruction) {
// Currently, we do not assign parallel tasks to instructions with at least
// one of the following properties:
- // *) Internal threading (library calls to kConv, kDot, and kCustomCall).
+ // *) Internal threading (library calls to kConv, kDot, kFft, kCustomCall).
// *) Emit custom loops (kSelectAndScatter, FusionKind::kTransposeDot).
// *) Tuple-shaped.
// TODO(b/27458679) Parallelize instructions which are skipped here.
@@ -137,6 +137,7 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount(
instruction->opcode() == HloOpcode::kSelectAndScatter ||
instruction->opcode() == HloOpcode::kGetTupleElement ||
instruction->opcode() == HloOpcode::kBitcast ||
+ instruction->opcode() == HloOpcode::kFft ||
(instruction->opcode() == HloOpcode::kConvolution &&
PotentiallyImplementedAsEigenConvolution(*instruction)) ||
PotentiallyImplementedAsEigenDot(*instruction) ||
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fft.cc b/tensorflow/compiler/xla/service/cpu/runtime_fft.cc
new file mode 100644
index 0000000000..848d2d2241
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/runtime_fft.cc
@@ -0,0 +1,37 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/cpu/runtime_fft.h"
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/compiler/xla/executable_run_options.h"
+#include "tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h"
+#include "tensorflow/core/platform/dynamic_annotations.h"
+#include "tensorflow/core/platform/types.h"
+
+using tensorflow::int32;
+using tensorflow::int64;
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenFft(
+ const void* run_options_ptr, void* out, void* operand, int32 fft_type,
+ int32 fft_rank, int64 input_batch, int64 fft_length0, int64 fft_length1,
+ int64 fft_length2) {
+ const xla::ExecutableRunOptions* run_options =
+ static_cast<const xla::ExecutableRunOptions*>(run_options_ptr);
+ tensorflow::xla::EigenFftImpl(*run_options->intra_op_thread_pool(), out,
+ operand, fft_type, fft_rank, input_batch,
+ fft_length0, fft_length1, fft_length2);
+}
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fft.h b/tensorflow/compiler/xla/service/cpu/runtime_fft.h
new file mode 100644
index 0000000000..f20c5aa0aa
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/runtime_fft.h
@@ -0,0 +1,31 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_H_
+
+#include "tensorflow/core/platform/types.h"
+
+extern "C" {
+
+extern void __xla_cpu_runtime_EigenFft(
+ const void* /* xla::ExecutableRunOptions* */ run_options_ptr, void* out,
+ void* operand, tensorflow::int32 fft_type, tensorflow::int32 fft_rank,
+ tensorflow::int64 input_batch, tensorflow::int64 fft_length0,
+ tensorflow::int64 fft_length1, tensorflow::int64 fft_length2);
+
+} // extern "C"
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_H_
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h b/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h
new file mode 100644
index 0000000000..c7c701ca97
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h
@@ -0,0 +1,238 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_IMPL_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_IMPL_H_
+
+#include "third_party/eigen3/Eigen/Core"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/framework/numeric_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/platform/types.h"
+
+// 'tensorflow' namespace is used so that int64 and other types don't require
+// qualification.
+namespace tensorflow {
+namespace xla {
+
+namespace internal {
+
+// Computes either a forward or reverse complex-to-complex FFT.
+template <bool Forward, int FFTRank, typename EigenDevice>
+void EigenFftC2C(const EigenDevice& device, complex64* out, complex64* operand,
+ int64 input_batch, int64 fft_length0, int64 fft_length1,
+ int64 fft_length2) {
+ // Create the axes (which are always trailing).
+ const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank);
+ constexpr auto direction = Forward ? Eigen::FFT_FORWARD : Eigen::FFT_REVERSE;
+
+ const std::array<int64, 3> fft_shape = {
+ {fft_length0, fft_length1, fft_length2}};
+
+ Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> dims;
+ dims[0] = input_batch;
+ for (int i = 0; i < FFTRank; i++) {
+ dims[i + 1] = fft_shape[i];
+ }
+ const Eigen::TensorMap<Eigen::Tensor<complex64, FFTRank + 1, Eigen::RowMajor>,
+ Eigen::Aligned>
+ input(operand, dims);
+ Eigen::TensorMap<Eigen::Tensor<complex64, FFTRank + 1, Eigen::RowMajor>,
+ Eigen::Aligned>
+ output(out, dims);
+ output.device(device) = input.template fft<Eigen::BothParts, direction>(axes);
+}
+
+// Computes a forward real->complex FFT, slicing out redundant negative
+// frequencies from the innermost dimension.
+template <int FFTRank, typename EigenDevice>
+void EigenFftR2C(const EigenDevice& device, complex64* out, float* operand,
+ int64 input_batch, int64 fft_length0, int64 fft_length1,
+ int64 fft_length2) {
+ const std::array<int64, 3> fft_shape = {
+ {fft_length0, fft_length1, fft_length2}};
+
+ Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> in_dims;
+ in_dims[0] = input_batch;
+ Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> out_dims;
+ out_dims[0] = input_batch;
+ TensorShape temp_shape{input_batch};
+ for (int i = 0; i < FFTRank; i++) {
+ in_dims[i + 1] = fft_shape[i];
+ out_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i];
+ temp_shape.AddDim(fft_shape[i]);
+ }
+ const Eigen::TensorMap<Eigen::Tensor<float, FFTRank + 1, Eigen::RowMajor>,
+ Eigen::Aligned>
+ input(operand, in_dims);
+ Eigen::TensorMap<Eigen::Tensor<complex64, FFTRank + 1, Eigen::RowMajor>,
+ Eigen::Aligned>
+ output(out, out_dims);
+
+ // Create the axes (which are always trailing).
+ const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank);
+
+ // Compute the full FFT using a temporary tensor.
+ Tensor temp(DataTypeToEnum<complex64>::v(), temp_shape);
+ auto full_fft = temp.flat_inner_dims<complex64, FFTRank + 1>();
+ const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> zero_start_indices;
+ full_fft.device(device) =
+ input.template fft<Eigen::BothParts, Eigen::FFT_FORWARD>(axes);
+
+ // Slice away the negative frequency components.
+ output.device(device) = full_fft.slice(zero_start_indices, out_dims);
+}
+
+// Computes a reverse complex->real FFT, reconstructing redundant negative
+// frequencies using reverse conjugate on innermost dimension after doing IFFT
+// on outer dimensions.
+template <int FFTRank, typename EigenDevice>
+void EigenFftC2R(const EigenDevice& device, float* out, complex64* operand,
+ int64 input_batch, int64 fft_length0, int64 fft_length1,
+ int64 fft_length2) {
+ const std::array<int64, 3> fft_shape = {
+ {fft_length0, fft_length1, fft_length2}};
+
+ Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> in_dims;
+ in_dims[0] = input_batch;
+ Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> out_dims;
+ out_dims[0] = input_batch;
+ TensorShape temp_shape{input_batch};
+ for (int i = 0; i < FFTRank; i++) {
+ in_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i];
+ out_dims[i + 1] = fft_shape[i];
+ temp_shape.AddDim(fft_shape[i]);
+ }
+ const Eigen::TensorMap<Eigen::Tensor<complex64, FFTRank + 1, Eigen::RowMajor>,
+ Eigen::Aligned>
+ input(operand, in_dims);
+ Eigen::TensorMap<Eigen::Tensor<float, FFTRank + 1, Eigen::RowMajor>,
+ Eigen::Aligned>
+ output(out, out_dims);
+
+ // Calculate the shape of the temporary tensor for the full FFT and the
+ // region we will slice from input given fft_shape. We slice input to
+ // fft_shape on its inner-most dimensions, except the last (which we
+ // slice to fft_shape[-1] / 2 + 1).
+ Tensor temp(DataTypeToEnum<complex64>::v(), temp_shape);
+ auto full_fft = temp.flat_inner_dims<complex64, FFTRank + 1>();
+
+ // Calculate the starting point and range of the source of
+ // negative frequency part.
+ auto neg_sizes = in_dims;
+ neg_sizes[FFTRank] = fft_shape[FFTRank - 1] - in_dims[FFTRank];
+ Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> neg_target_indices;
+ neg_target_indices[FFTRank] = in_dims[FFTRank];
+
+ const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> zero_start_indices;
+ Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> neg_start_indices;
+ neg_start_indices[FFTRank] = 1;
+
+ full_fft.slice(zero_start_indices, in_dims).device(device) = input;
+
+ // First, conduct IFFTs on outer dimensions. We save computation (and
+ // avoid touching uninitialized memory) by slicing full_fft to the
+ // subregion we wrote input to.
+ if (FFTRank > 1) {
+ const auto outer_axes =
+ Eigen::ArrayXi::LinSpaced(FFTRank - 1, 1, FFTRank - 1);
+ full_fft.slice(zero_start_indices, in_dims).device(device) =
+ full_fft.slice(zero_start_indices, in_dims)
+ .template fft<Eigen::BothParts, Eigen::FFT_REVERSE>(outer_axes);
+ }
+
+ // Reconstruct the full FFT by appending reversed and conjugated
+ // spectrum as the negative frequency part.
+ Eigen::array<bool, FFTRank + 1> reverse_last_axis;
+ for (auto i = 0; i <= FFTRank; i++) {
+ reverse_last_axis[i] = i == FFTRank;
+ }
+
+ if (neg_sizes[FFTRank] != 0) {
+ full_fft.slice(neg_target_indices, neg_sizes).device(device) =
+ full_fft.slice(neg_start_indices, neg_sizes)
+ .reverse(reverse_last_axis)
+ .conjugate();
+ }
+
+ auto inner_axis = Eigen::array<int, 1>{FFTRank};
+ output.device(device) =
+ full_fft.template fft<Eigen::RealPart, Eigen::FFT_REVERSE>(inner_axis);
+}
+
+template <int FFTRank, typename EigenDevice>
+void EigenFftWithRank(const EigenDevice& device, void* out, void* operand,
+ int32 fft_type, int64 input_batch, int64 fft_length0,
+ int64 fft_length1, int64 fft_length2) {
+ CHECK(::xla::FftType_IsValid(fft_type)) << fft_type;
+ switch (fft_type) {
+ case ::xla::FftType::FFT:
+ EigenFftC2C<true, FFTRank, EigenDevice>(
+ device, static_cast<complex64*>(out),
+ static_cast<complex64*>(operand), input_batch, fft_length0,
+ fft_length1, fft_length2);
+ break;
+ case ::xla::FftType::IFFT:
+ EigenFftC2C<false, FFTRank, EigenDevice>(
+ device, static_cast<complex64*>(out),
+ static_cast<complex64*>(operand), input_batch, fft_length0,
+ fft_length1, fft_length2);
+ break;
+ case ::xla::FftType::RFFT:
+ EigenFftR2C<FFTRank, EigenDevice>(
+ device, static_cast<complex64*>(out), static_cast<float*>(operand),
+ input_batch, fft_length0, fft_length1, fft_length2);
+ break;
+ case ::xla::FftType::IRFFT:
+ EigenFftC2R<FFTRank, EigenDevice>(
+ device, static_cast<float*>(out), static_cast<complex64*>(operand),
+ input_batch, fft_length0, fft_length1, fft_length2);
+ break;
+ default:
+ LOG(FATAL) << "Unsupported FFT type: " << fft_type;
+ }
+}
+
+} // namespace internal
+
+template <typename EigenDevice>
+void EigenFftImpl(const EigenDevice& device, void* out, void* operand,
+ int32 fft_type, int32 fft_rank, int64 input_batch,
+ int64 fft_length0, int64 fft_length1, int64 fft_length2) {
+ switch (fft_rank) {
+ case 1:
+ internal::EigenFftWithRank<1, EigenDevice>(
+ device, out, operand, fft_type, input_batch, fft_length0, 0, 0);
+ break;
+ case 2:
+ internal::EigenFftWithRank<2, EigenDevice>(device, out, operand, fft_type,
+ input_batch, fft_length0,
+ fft_length1, 0);
+ break;
+ case 3:
+ internal::EigenFftWithRank<3, EigenDevice>(device, out, operand, fft_type,
+ input_batch, fft_length0,
+ fft_length1, fft_length2);
+ break;
+ default:
+ LOG(FATAL) << "Unsupported FFT rank " << fft_rank;
+ }
+}
+
+} // namespace xla
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_IMPL_H_
diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
index 65da61805a..43ab2ec524 100644
--- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
+++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
@@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
#include "tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h"
+#include "tensorflow/compiler/xla/service/cpu/runtime_fft.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h"
@@ -208,6 +209,7 @@ bool RegisterKnownJITSymbols() {
REGISTER_CPU_RUNTIME_SYMBOL(AcquireInfeedBufferForDequeue);
REGISTER_CPU_RUNTIME_SYMBOL(AcquireOutfeedBufferForPopulation);
REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF32);
+ REGISTER_CPU_RUNTIME_SYMBOL(EigenFft);
REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF32);
REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF64);
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF32);
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
index 0d54e325e6..a803b3171f 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
@@ -103,6 +103,7 @@ class DfsHloVisitorBase {
return HandleElementwiseBinary(hlo);
}
virtual Status HandleConvolution(HloInstructionPtr hlo) = 0;
+ virtual Status HandleFft(HloInstructionPtr fft) = 0;
virtual Status HandleCrossReplicaSum(HloInstructionPtr hlo) = 0;
virtual Status HandleCompare(HloInstructionPtr hlo) {
return HandleElementwiseBinary(hlo);
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
index 133aa25094..170adb3d24 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
@@ -85,6 +85,9 @@ class DfsHloVisitorWithDefaultBase
Status HandleConvolution(HloInstructionPtr convolution) override {
return DefaultAction(convolution);
}
+ Status HandleFft(HloInstructionPtr fft) override {
+ return DefaultAction(fft);
+ }
Status HandleCrossReplicaSum(HloInstructionPtr crs) override {
return DefaultAction(crs);
}
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index d5d89de6d4..e4832b2ee6 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -220,6 +220,7 @@ cc_library(
"convolution_thunk.cc",
"copy_thunk.cc",
"cudnn_batchnorm_thunk.cc",
+ "fft_thunk.cc",
"for_thunk.cc",
"gemm_thunk.cc",
"gpu_executable.cc",
@@ -234,6 +235,7 @@ cc_library(
"convolution_thunk.h",
"copy_thunk.h",
"cudnn_batchnorm_thunk.h",
+ "fft_thunk.h",
"for_thunk.h",
"gemm_thunk.h",
"gpu_executable.h",
@@ -272,6 +274,7 @@ cc_library(
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core/platform/default/build_config:cublas_plugin",
"//tensorflow/core/platform/default/build_config:cudnn_plugin",
+ "//tensorflow/core/platform/default/build_config:cufft_plugin",
"//tensorflow/core/platform/default/build_config:stream_executor_cuda", # build_cleaner: keep
],
)
diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc
new file mode 100644
index 0000000000..66931bdc8b
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc
@@ -0,0 +1,234 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/fft_thunk.h"
+
+#include <string>
+
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace se = ::perftools::gputools;
+
+namespace xla {
+namespace gpu {
+
+FftScratchAllocator::FftScratchAllocator(
+ int device_ordinal, DeviceMemoryAllocator* memory_allocator)
+ : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {}
+
+FftScratchAllocator::~FftScratchAllocator() {
+ for (auto& allocated_buffer : allocated_buffers_) {
+ if (!memory_allocator_->Deallocate(device_ordinal_, &allocated_buffer)
+ .ok()) {
+ // The program can still continue with failed deallocation.
+ LOG(ERROR) << "Failed to deallocate the allocated buffer: "
+ << allocated_buffer.opaque();
+ }
+ }
+}
+
+int64 FftScratchAllocator::GetMemoryLimitInBytes(se::Stream* stream) {
+ constexpr int64 kFftScratchSize = 1LL << 32; // 4GB by default.
+ return kFftScratchSize;
+}
+
+se::port::StatusOr<se::DeviceMemory<uint8>> FftScratchAllocator::AllocateBytes(
+ se::Stream* stream, int64 byte_size) {
+ CHECK_GE(byte_size, 0) << "byte_size must be positive.";
+ if (byte_size > GetMemoryLimitInBytes(stream)) {
+ return se::port::Status(
+ se::port::error::RESOURCE_EXHAUSTED,
+ tensorflow::strings::Printf(
+ "Allocating %lld bytes exceeds the memory limit of %lld bytes.",
+ byte_size, GetMemoryLimitInBytes(stream)));
+ }
+
+ auto status_or_memory =
+ memory_allocator_->Allocate(device_ordinal_, byte_size,
+ /*retry_on_failure=*/false);
+ if (!status_or_memory.ok()) {
+ return tensorflow::errors::ResourceExhausted(
+ "Failed to allocate %lld bytes on device %d.", byte_size,
+ device_ordinal_);
+ }
+ se::DeviceMemoryBase allocated_buffer = status_or_memory.ValueOrDie();
+ allocated_buffers_.push_back(allocated_buffer);
+ total_allocated_bytes_ += byte_size;
+ return se::DeviceMemory<uint8>(allocated_buffer);
+}
+
+namespace {
+
+se::fft::Type FftTypeToSeType(FftType type) {
+ switch (type) {
+ case FftType::FFT:
+ return se::fft::Type::kC2CForward;
+ case FftType::IFFT:
+ return se::fft::Type::kC2CInverse;
+ case FftType::IRFFT:
+ return se::fft::Type::kC2R;
+ case FftType::RFFT:
+ return se::fft::Type::kR2C;
+ default:
+ LOG(FATAL) << "unsupported fft type";
+ }
+}
+
+string FftTypeToString(se::fft::Type type) {
+ switch (type) {
+ case se::fft::Type::kC2CForward:
+ return "FFT";
+ case se::fft::Type::kC2CInverse:
+ return "IFFT";
+ case se::fft::Type::kC2R:
+ return "IRFFT";
+ case se::fft::Type::kR2C:
+ return "RFFT";
+ default:
+ LOG(FATAL) << "unknown fft type";
+ }
+}
+
+} // namespace
+
+FftThunk::FftThunk(FftType fft_type,
+ tensorflow::gtl::ArraySlice<int64> fft_length,
+ const BufferAllocation::Slice& input_buffer,
+ const BufferAllocation::Slice& output_buffer,
+ const Shape& input_shape, const Shape& output_shape,
+ const HloInstruction* hlo)
+ : Thunk(Kind::kFft, hlo),
+ fft_type_(FftTypeToSeType(fft_type)),
+ fft_length_(fft_length.begin(), fft_length.end()),
+ scale_factor_(1.0f),
+ input_buffer_(input_buffer),
+ output_buffer_(output_buffer),
+ input_shape_(input_shape),
+ output_shape_(output_shape) {}
+
+tensorflow::Status FftThunk::ExecuteOnStream(
+ const BufferAllocations& buffer_allocations, se::Stream* stream) {
+ VLOG(3) << "FFT type: " << FftTypeToString(fft_type_);
+ VLOG(3) << "Input shape: " << ShapeUtil::HumanStringWithLayout(input_shape_);
+ VLOG(3) << "Output shape: "
+ << ShapeUtil::HumanStringWithLayout(output_shape_);
+
+ FftScratchAllocator scratch_allocator(buffer_allocations.device_ordinal(),
+ buffer_allocations.memory_allocator());
+
+ if (fft_plan_ == nullptr) {
+ const int64 fft_rank = fft_length_.size();
+ CHECK_LE(fft_rank, 3);
+ int batch_size = 1;
+ for (int i = 0; i < input_shape_.dimensions_size() - fft_rank; ++i) {
+ batch_size *= input_shape_.dimensions(i);
+ }
+ uint64 fft_length[3];
+ uint64 input_embed[3];
+ const uint64 input_stride = 1;
+ uint64 input_distance = 1;
+ uint64 output_embed[3];
+ const uint64 output_stride = 1;
+ uint64 output_distance = 1;
+
+ for (int i = 0; i < fft_rank; ++i) {
+ auto dim_offset = input_shape_.dimensions_size() - fft_rank + i;
+ fft_length[i] = static_cast<uint64>(fft_length_[i]);
+ input_embed[i] = input_shape_.dimensions(dim_offset);
+ input_distance *= input_shape_.dimensions(dim_offset);
+ output_embed[i] = output_shape_.dimensions(dim_offset);
+ output_distance *= output_shape_.dimensions(dim_offset);
+ }
+
+ constexpr bool kInPlaceFft = false;
+ fft_plan_ =
+ stream->parent()->AsFft()->CreateBatchedPlanWithScratchAllocator(
+ stream, fft_rank, fft_length, input_embed, input_stride,
+ input_distance, output_embed, output_stride, output_distance,
+ fft_type_, kInPlaceFft, batch_size, &scratch_allocator);
+ scale_factor_ = 1.0f / output_distance;
+ } else {
+ stream->parent()->AsFft()->UpdatePlanWithScratchAllocator(
+ stream, fft_plan_.get(), &scratch_allocator);
+ }
+
+ bool launch_ok;
+ switch (fft_type_) {
+ case se::fft::Type::kC2CForward: {
+ se::DeviceMemory<complex64> input_data(
+ buffer_allocations.GetDeviceAddress(input_buffer_));
+ se::DeviceMemory<complex64> output_data(
+ buffer_allocations.GetDeviceAddress(output_buffer_));
+ launch_ok =
+ stream->ThenFft(fft_plan_.get(), input_data, &output_data).ok();
+ break;
+ }
+ case se::fft::Type::kC2CInverse: {
+ se::DeviceMemory<complex64> input_data(
+ buffer_allocations.GetDeviceAddress(input_buffer_));
+ se::DeviceMemory<complex64> output_data(
+ buffer_allocations.GetDeviceAddress(output_buffer_));
+ launch_ok =
+ stream->ThenFft(fft_plan_.get(), input_data, &output_data).ok();
+ if (launch_ok) {
+ launch_ok =
+ stream
+ ->ThenBlasScal(ShapeUtil::ElementsIn(output_shape_),
+ complex64(scale_factor_), &output_data, 1)
+ .ok();
+ }
+ break;
+ }
+ case se::fft::Type::kR2C: {
+ se::DeviceMemory<float> input_data(
+ buffer_allocations.GetDeviceAddress(input_buffer_));
+ se::DeviceMemory<complex64> output_data(
+ buffer_allocations.GetDeviceAddress(output_buffer_));
+ launch_ok =
+ stream->ThenFft(fft_plan_.get(), input_data, &output_data).ok();
+ break;
+ }
+ case se::fft::Type::kC2R: {
+ se::DeviceMemory<complex64> input_data(
+ buffer_allocations.GetDeviceAddress(input_buffer_));
+ se::DeviceMemory<float> output_data(
+ buffer_allocations.GetDeviceAddress(output_buffer_));
+ launch_ok =
+ stream->ThenFft(fft_plan_.get(), input_data, &output_data).ok();
+ if (launch_ok) {
+ launch_ok = stream
+ ->ThenBlasScal(ShapeUtil::ElementsIn(output_shape_),
+ scale_factor_, &output_data, 1)
+ .ok();
+ }
+ break;
+ }
+ default:
+ LOG(FATAL) << "unsupported fft type";
+ }
+ if (launch_ok) {
+ return tensorflow::Status::OK();
+ }
+ return InternalError("Unable to launch fft for thunk %p with type %s", this,
+ FftTypeToString(fft_type_).c_str());
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.h b/tensorflow/compiler/xla/service/gpu/fft_thunk.h
new file mode 100644
index 0000000000..52fb8c376d
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.h
@@ -0,0 +1,98 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FFT_THUNK_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FFT_THUNK_H_
+
+#include "tensorflow/compiler/xla/service/buffer_assignment.h"
+#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
+#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
+#include "tensorflow/compiler/xla/service/gpu/thunk.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/optional.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace xla {
+namespace gpu {
+
+// A one-time scratch allocator for FFT. The scratch buffers allocated are
+// released on destruction.
+//
+// Not thread-safe in that AllocateBytes, destructor are not locked.
+class FftScratchAllocator : public perftools::gputools::ScratchAllocator {
+ public:
+ FftScratchAllocator(int device_ordinal,
+ DeviceMemoryAllocator* memory_allocator);
+
+ ~FftScratchAllocator() override;
+
+ int64 GetMemoryLimitInBytes(perftools::gputools::Stream* stream) override;
+
+ int64 TotalAllocatedBytes() { return total_allocated_bytes_; }
+
+ perftools::gputools::port::StatusOr<perftools::gputools::DeviceMemory<uint8>>
+ AllocateBytes(perftools::gputools::Stream* stream, int64 byte_size) override;
+
+ private:
+ const int device_ordinal_;
+ DeviceMemoryAllocator* memory_allocator_;
+ std::vector<perftools::gputools::DeviceMemoryBase> allocated_buffers_;
+ int64 total_allocated_bytes_ = 0;
+};
+
+// This class stores everything that StreamExecutor needs to launch an FFT.
+// It is generated by IrEmitter.
+//
+// This is thread-compatible.
+class FftThunk : public Thunk {
+ public:
+ // Constructs a thunk for launching an FFT on a stream.
+ // Semantics of null hlo_instruction argument are as in Thunk.
+ FftThunk(FftType fft_type, tensorflow::gtl::ArraySlice<int64> fft_length,
+ const BufferAllocation::Slice& input_buffer,
+ const BufferAllocation::Slice& output_buffer,
+ const Shape& input_shape, const Shape& output_shape,
+ const HloInstruction* hlo);
+
+ FftThunk(const FftThunk&) = delete; // Cannot share fft_plan_
+ FftThunk& operator=(const FftThunk&) = delete; // Cannot share fft_plan_
+
+ // Does the FFT for the thunk on "stream".
+ tensorflow::Status ExecuteOnStream(
+ const BufferAllocations& buffer_allocations,
+ perftools::gputools::Stream* stream) override;
+
+ private:
+ const perftools::gputools::fft::Type fft_type_;
+ const std::vector<int64> fft_length_;
+
+ float scale_factor_;
+
+ std::unique_ptr<perftools::gputools::fft::Plan> fft_plan_;
+
+ const BufferAllocation::Slice input_buffer_;
+ const BufferAllocation::Slice output_buffer_;
+
+ const Shape input_shape_;
+ const Shape output_shape_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FFT_THUNK_H_
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
index d85d6aae12..d6d0e1e116 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
@@ -605,6 +605,14 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
"Hit a case for convolution that is not implemented on GPU.");
}
+Status IrEmitter::HandleFft(HloInstruction* fft) {
+ if (ShapeUtil::HasZeroElements(fft->shape())) {
+ // Emit no code for an empty output.
+ return Status::OK();
+ }
+ return Unimplemented("Hit a case for fft that is not implemented on GPU.");
+}
+
Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) {
// TODO(b/33011107): Support cross replica sum on GPU.
return Unimplemented(
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
index 41d013c13d..af43895c23 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
@@ -79,6 +79,7 @@ class IrEmitter : public DfsHloVisitorWithDefault {
Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
Status HandleDot(HloInstruction* dot) override;
Status HandleConvolution(HloInstruction* convolution) override;
+ Status HandleFft(HloInstruction* fft) override;
Status HandleCrossReplicaSum(HloInstruction* crs) override;
Status HandleInfeed(HloInstruction* infeed) override;
Status HandleOutfeed(HloInstruction* outfeed) override;
@@ -242,6 +243,7 @@ class IrEmitterUnnested : public IrEmitter {
Status HandleConvolution(HloInstruction* convolution) override;
Status HandleCustomCall(HloInstruction* custom_call) override;
Status HandleDot(HloInstruction* dot) override;
+ Status HandleFft(HloInstruction* fft) override;
Status HandleFusion(HloInstruction* fusion) override;
Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
Status HandleReduce(HloInstruction* reduce) override;
@@ -332,6 +334,9 @@ class IrEmitterUnnested : public IrEmitter {
// Returns a ConvolutionThunk that calls DNN to implement `inst`.
std::unique_ptr<Thunk> BuildConvolutionThunk(const HloInstruction* inst);
+ // Returns a FftThunk that calls cuFFT to implement `inst`.
+ std::unique_ptr<Thunk> BuildFftThunk(const HloInstruction* inst);
+
// Returns a GemmThunk that calls gemm to implement `inst`. The caller needs
// to make sure `inst` outlives the lifetime of the returned Thunk object.
std::unique_ptr<Thunk> BuildGemmThunk(const HloInstruction* inst);
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 1be3473f52..1aa506a3a9 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/copy_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h"
+#include "tensorflow/compiler/xla/service/gpu/fft_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/for_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
@@ -373,6 +374,14 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
return IrEmitter::HandleCustomCall(custom_call);
}
+Status IrEmitterUnnested::HandleFft(HloInstruction* fft) {
+ TF_RET_CHECK(
+ LayoutUtil::IsMonotonicWithDim0Major(fft->operand(0)->shape().layout()));
+ TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(fft->shape().layout()));
+ thunk_sequence_->emplace_back(BuildFftThunk(fft));
+ return Status::OK();
+}
+
Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
HloInstruction* root = fusion->fused_expression_root();
// HandleFusion specializes reduction from a multi-dimensional array to a 1D
@@ -1855,6 +1864,16 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildConvolutionThunk(
}
}
+std::unique_ptr<Thunk> IrEmitterUnnested::BuildFftThunk(
+ const HloInstruction* inst) {
+ const HloInstruction* operand = inst->operand(0);
+ return MakeUnique<FftThunk>(inst->fft_type(), inst->fft_length(),
+ /*input_buffer=*/GetAllocationSlice(*operand),
+ /*output_buffer=*/GetAllocationSlice(*inst),
+ /*input_shape=*/operand->shape(),
+ /*output_shape=*/inst->shape(), inst);
+}
+
Status IrEmitterUnnested::EmitInitializer(const HloInstruction* hlo,
KernelThunk* thunk) {
bool fused = HloOpcode::kFusion == hlo->opcode();
diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h
index 191c7675c6..625c3f8bea 100644
--- a/tensorflow/compiler/xla/service/gpu/thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/thunk.h
@@ -46,6 +46,7 @@ class Thunk {
kCudnnBatchNormBackward,
kCudnnBatchNormForwardInference,
kCudnnBatchNormForwardTraining,
+ kFft,
kGemm,
kInfeed,
kKernel,
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index e4aed7593c..0e9a852788 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -123,6 +123,12 @@ message HloInstructionProto {
// Describes the dimension numbers used for a dot operation
xla.DotDimensionNumbers dot_dimension_numbers = 30;
+
+ // FFT type (FFT, IFFT, etc).
+ xla.FftType fft_type = 31;
+
+ // FFT length.
+ repeated int64 fft_length = 32;
}
// Serialization of HloComputation.
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
index b933695b82..cd54eb74d1 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
@@ -392,6 +392,21 @@ Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) {
return Status::OK();
}
+Status HloCostAnalysis::HandleFft(const HloInstruction* fft) {
+ auto real_shape =
+ ShapeUtil::IsTuple(fft->operand(0)->shape())
+ ? ShapeUtil::GetTupleElementShape(fft->operand(0)->shape(), 0)
+ : fft->operand(0)->shape();
+ constexpr int kFmaPerComplexMul = 4;
+ int64 log_factors = 1;
+ for (int64 dim : fft->fft_length()) {
+ log_factors *= tensorflow::Log2Floor(dim);
+ }
+ current_properties_[kFlopsKey] = kFmaFlops * kFmaPerComplexMul * log_factors *
+ ShapeUtil::ElementsIn(real_shape);
+ return Status::OK();
+}
+
Status HloCostAnalysis::HandleCrossReplicaSum(const HloInstruction* crs) {
// We assume 2 replicas, so that each output element is the sum of two input
// elements.
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
index fade19522c..e5783539e5 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
@@ -67,6 +67,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
Status HandleCopy(const HloInstruction* copy) override;
Status HandleDot(const HloInstruction* dot) override;
Status HandleConvolution(const HloInstruction* convolution) override;
+ Status HandleFft(const HloInstruction* fft) override;
Status HandleCrossReplicaSum(const HloInstruction* crs) override;
Status HandleInfeed(const HloInstruction* infeed) override;
Status HandleOutfeed(const HloInstruction* outfeed) override;
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index 023aec96ec..f7c6435002 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -961,6 +961,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
return kGreen;
case HloOpcode::kConvolution:
case HloOpcode::kDot:
+ case HloOpcode::kFft:
return kDarkBlue;
case HloOpcode::kReducePrecision:
return kRed;
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 9805818e4c..138fb9a190 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -144,6 +144,10 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
instruction->infeed_config_ = proto.infeed_config();
instruction->custom_call_target_ = proto.custom_call_target();
instruction->outfeed_shape_ = proto.outfeed_shape();
+ instruction->fft_type_ = proto.fft_type();
+ for (int64 fft_len : proto.fft_length()) {
+ instruction->fft_length_.push_back(fft_len);
+ }
return std::move(instruction);
}
@@ -334,6 +338,16 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
return instruction;
}
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFft(
+ const Shape& shape, HloInstruction* operand, FftType fft_type,
+ tensorflow::gtl::ArraySlice<int64> fft_length) {
+ auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFft, shape));
+ instruction->AppendOperand(operand);
+ instruction->fft_type_ = fft_type;
+ instruction->fft_length_.assign(fft_length.begin(), fft_length.end());
+ return instruction;
+}
+
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDot(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
const DotDimensionNumbers& dimension_numbers) {
@@ -1167,6 +1181,9 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
clone = CreateDot(shape, new_operands[0], new_operands[1],
*dot_dimension_numbers_);
break;
+ case HloOpcode::kFft:
+ CHECK_EQ(new_operands.size(), 1);
+ return CreateFft(shape, new_operands[0], fft_type_, fft_length_);
case HloOpcode::kCrossReplicaSum:
clone = CreateCrossReplicaSum(shape, new_operands);
break;
@@ -1631,6 +1648,11 @@ bool HloInstruction::IdenticalSlowPath(
return protobuf_util::ProtobufEquals(dot_dimension_numbers(),
other.dot_dimension_numbers());
+ // FFT has various types & lengths.
+ case HloOpcode::kFft:
+ return fft_type() == other.fft_type() &&
+ fft_length() == other.fft_length();
+
// Reduction results are determined by the reduction dimension and the
// reduction computation.
case HloOpcode::kReduce:
@@ -2055,6 +2077,10 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
if (dot_dimension_numbers_ != nullptr) {
extra.push_back(DotDimensionNumbersToString());
}
+ if (opcode() == HloOpcode::kFft) {
+ extra.push_back(StrCat("fft_type=", FftType_Name(fft_type())));
+ extra.push_back(StrCat("fft_length={", Join(fft_length(), ","), "}"));
+ }
if (options.print_subcomputation_references()) {
if (opcode() == HloOpcode::kWhile) {
@@ -2206,6 +2232,10 @@ HloInstructionProto HloInstruction::ToProto() const {
proto.set_infeed_config(infeed_config_);
proto.set_custom_call_target(custom_call_target_);
*proto.mutable_outfeed_shape() = outfeed_shape_;
+ proto.set_fft_type(fft_type_);
+ for (int64 fft_len : fft_length_) {
+ proto.add_fft_length(fft_len);
+ }
return proto;
}
@@ -2421,6 +2451,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
return visitor->HandleSelect(this);
case HloOpcode::kConvolution:
return visitor->HandleConvolution(this);
+ case HloOpcode::kFft:
+ return visitor->HandleFft(this);
case HloOpcode::kCrossReplicaSum:
return visitor->HandleCrossReplicaSum(this);
case HloOpcode::kTuple:
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index d455cfc3f1..c5cab92ac9 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -261,6 +261,11 @@ class HloInstruction {
const Window& window,
const ConvolutionDimensionNumbers& dimension_numbers);
+ // Creates an FFT op, of the type indicated by fft_type.
+ static std::unique_ptr<HloInstruction> CreateFft(
+ const Shape& shape, HloInstruction* operand, FftType fft_type,
+ tensorflow::gtl::ArraySlice<int64> fft_length);
+
// Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch
// dimensions specified in 'dimension_numbers'.
static std::unique_ptr<HloInstruction> CreateDot(
@@ -1031,6 +1036,16 @@ class HloInstruction {
return *convolution_dimension_numbers_;
}
+ FftType fft_type() const {
+ CHECK_EQ(HloOpcode::kFft, opcode_);
+ return fft_type_;
+ }
+
+ const std::vector<int64>& fft_length() const {
+ CHECK_EQ(HloOpcode::kFft, opcode_);
+ return fft_length_;
+ }
+
// Returns the dump string of the convolution dimension numbers.
string ConvolutionDimensionNumbersToString() const;
@@ -1303,6 +1318,12 @@ class HloInstruction {
// Describes the dimension numbers used for a dot.
std::unique_ptr<DotDimensionNumbers> dot_dimension_numbers_;
+ // Describes FFT type for an FFT instruction.
+ FftType fft_type_ = FftType::FFT;
+
+ // Indicates the FFT length for an FFT instruction.
+ std::vector<int64> fft_length_;
+
// Describes the [begin, end) index range for a slice.
std::vector<int64> slice_starts_;
std::vector<int64> slice_limits_;
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h
index f3f7935758..3d64523a79 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.h
+++ b/tensorflow/compiler/xla/service/hlo_opcode.h
@@ -73,6 +73,7 @@ namespace xla {
V(kDynamicUpdateSlice, "dynamic-update-slice") \
V(kEq, "equal-to", kHloOpcodeIsComparison) \
V(kExp, "exponential") \
+ V(kFft, "fft") \
V(kFloor, "floor") \
V(kFusion, "fusion", kHloOpcodeIsVariadic) \
V(kGe, "greater-than-or-equal-to", kHloOpcodeIsComparison) \
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index d963a8a2f4..9d5ca6673a 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -92,6 +92,14 @@ class ShapeVerifier : public DfsHloVisitor {
return CheckShape(convolution, expected);
}
+ Status HandleFft(HloInstruction* fft) override {
+ TF_ASSIGN_OR_RETURN(
+ const Shape expected,
+ ShapeInference::InferFftShape(fft->operand(0)->shape(), fft->fft_type(),
+ fft->fft_length()));
+ return CheckShape(fft, expected);
+ }
+
Status HandleCrossReplicaSum(HloInstruction* crs) override {
std::vector<const Shape*> operand_shapes;
for (const HloInstruction* operand : crs->operands()) {
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index ba901b99e4..90e1f0acdc 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -100,6 +100,7 @@ namespace xla {
case HloOpcode::kDivide:
case HloOpcode::kDot:
case HloOpcode::kExp:
+ case HloOpcode::kFft:
case HloOpcode::kFusion:
case HloOpcode::kLog:
case HloOpcode::kMap:
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index 0b98714168..40613cc75b 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -1406,6 +1406,9 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) {
handle_status = computation->AddDynamicUpdateSliceInstruction(
arg->dynamic_update_slice_request());
break;
+ case OpRequest::kFftRequest:
+ handle_status = computation->AddFftInstruction(arg->fft_request());
+ break;
case OpRequest::kGetTupleElementRequest:
handle_status = computation->AddGetTupleElementInstruction(
arg->get_tuple_element_request());
@@ -1518,8 +1521,6 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) {
handle_status = computation->AddRecvInstruction(arg->recv_request());
break;
}
- case OpRequest::kFftRequest:
- return Unimplemented("FftRequest not implemented in XLA service.");
case OpRequest::OP_NOT_SET:
return InvalidArgument("XLA service received OpRequest with OP_NOT_SET");
default:
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 9c1b951d01..6dc49ffe4c 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -1715,6 +1715,78 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
return ShapeUtil::MakeShape(lhs.element_type(), dimensions);
}
+/* static */ StatusOr<Shape> ShapeInference::InferFftShape(
+ const Shape& in, const FftType fft_type,
+ const tensorflow::gtl::ArraySlice<int64> fft_length) {
+ const int64 fft_rank = fft_length.size();
+ if (fft_rank < 1 || fft_rank > 3) {
+ return InvalidArgument("FFT only supports ranks 1-3, but got %lld",
+ fft_rank);
+ }
+ switch (fft_type) {
+ case FFT:
+ case IFFT:
+ if (in.element_type() != C64) {
+ return InvalidArgument("%s requires C64 input type, found %s",
+ FftType_Name(fft_type).c_str(),
+ PrimitiveType_Name(in.element_type()).c_str());
+ }
+ return in;
+ case RFFT: {
+ if (in.element_type() != F32) {
+ return InvalidArgument("RFFT requires F32 input type, found %s",
+ PrimitiveType_Name(in.element_type()).c_str());
+ }
+ for (int i = 0; i < fft_rank; i++) {
+ if (in.dimensions(in.dimensions_size() - fft_rank + i) !=
+ fft_length[i]) {
+ return InvalidArgument(
+ "RFFT requires innermost dimensions match fft_length but "
+ "dimension %lld is %lld and should be %lld",
+ in.dimensions_size() - fft_rank + i,
+ in.dimensions(in.dimensions_size() - fft_rank + i),
+ fft_length[i]);
+ }
+ }
+ Shape result = ShapeUtil::ChangeElementType(in, C64);
+ result.set_dimensions(result.dimensions_size() - 1,
+ fft_length[fft_rank - 1] / 2 + 1);
+ return result;
+ }
+ case IRFFT: {
+ if (in.element_type() != C64) {
+ return InvalidArgument("IRFFT requires C64 input type, found %s",
+ PrimitiveType_Name(in.element_type()).c_str());
+ }
+ Shape result = ShapeUtil::ChangeElementType(in, F32);
+ for (int i = 0; i < fft_rank - 1; i++) {
+ if (in.dimensions(in.dimensions_size() - fft_rank + i) !=
+ fft_length[i]) {
+ return InvalidArgument(
+ "IRFFT requires all but one innermost dimensions match "
+ "fft_length, but dimension %lld is %lld and should be %lld",
+ in.dimensions_size() - fft_rank + i,
+ in.dimensions(in.dimensions_size() - fft_rank + i),
+ fft_length[i]);
+ }
+ }
+ if (in.dimensions(in.dimensions_size() - 1) !=
+ fft_length[fft_rank - 1] / 2 + 1) {
+ return InvalidArgument(
+ "IRFFT requires innermost dimension matches fft_length/2+1, but "
+ "dimension %d is %lld and should be %lld",
+ in.dimensions_size() - 1, in.dimensions(in.dimensions_size() - 1),
+ fft_length[fft_rank - 1] / 2 + 1);
+ }
+ result.set_dimensions(result.dimensions_size() - 1,
+ fft_length[fft_rank - 1]);
+ return result;
+ }
+ default:
+ LOG(FATAL) << "Unexpected fft_type: " << fft_type;
+ }
+}
+
/* static */ StatusOr<Shape> ShapeInference::InferCrossReplicaSumShape(
tensorflow::gtl::ArraySlice<const Shape*> operand_shapes) {
for (const Shape* operand_shape : operand_shapes) {
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h
index c06340d2d5..b39151ebbc 100644
--- a/tensorflow/compiler/xla/service/shape_inference.h
+++ b/tensorflow/compiler/xla/service/shape_inference.h
@@ -109,6 +109,11 @@ class ShapeInference {
const Shape& lhs, const Shape& rhs, const Window& window,
const ConvolutionDimensionNumbers& dimension_numbers);
+ // Infers the shape produced by the given FFT type on the given operand.
+ static StatusOr<Shape> InferFftShape(
+ const Shape& in, FftType fft_type,
+ tensorflow::gtl::ArraySlice<int64> fft_length);
+
// Infers the shape produced a cross replica sum with the given operand
// shapes.
static StatusOr<Shape> InferCrossReplicaSumShape(
diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc
index e42cbfa976..9d941fb770 100644
--- a/tensorflow/compiler/xla/service/user_computation.cc
+++ b/tensorflow/compiler/xla/service/user_computation.cc
@@ -1121,6 +1121,31 @@ StatusOr<ComputationDataHandle> UserComputation::AddConvolveInstruction(
return handle;
}
+StatusOr<ComputationDataHandle> UserComputation::AddFftInstruction(
+ const FftRequest& fft_request) {
+ tensorflow::mutex_lock lock(mutex_);
+
+ TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
+ LookUpRequest(fft_request.operand()));
+ TF_ASSIGN_OR_RETURN(Shape shape,
+ ShapeInference::InferFftShape(
+ operand->output_shape(), fft_request.fft_type(),
+ AsInt64Slice(fft_request.fft_length())));
+
+ const ComputationDataHandle handle = CreateComputationDataHandle();
+
+ OperationRequest& request =
+ (*session_computation_.mutable_requests())[handle.handle()];
+ *request.mutable_output_handle() = handle;
+ *request.mutable_output_shape() = shape;
+ *request.mutable_request()->mutable_fft_request() = fft_request;
+
+ VLOG(1) << "AddFftInstruction (" << GetVersionedHandleInternal()
+ << "), data handle " << handle.handle() << ": "
+ << fft_request.ShortDebugString();
+ return handle;
+}
+
StatusOr<ComputationDataHandle> UserComputation::AddCrossReplicaSumInstruction(
const CrossReplicaSumRequest& cross_replica_sum_request) {
tensorflow::mutex_lock lock(mutex_);
@@ -1675,6 +1700,13 @@ void PureFunctionalVisitor(const SessionComputation& session_computation,
break;
}
+ case OpRequest::kFftRequest: {
+ const FftRequest& fft_request = request.request().fft_request();
+ PureFunctionalVisitor(session_computation, fft_request.operand(),
+ num_parameters, visited, is_functional);
+ break;
+ }
+
case OpRequest::kCrossReplicaSumRequest: {
// TODO(b/33009255): Implmement constant folding for cross replica sum.
*is_functional = false;
@@ -2406,6 +2438,12 @@ static void ForEachOperand(
break;
}
+ case OpRequest::kFftRequest: {
+ const FftRequest& fft_request = request.request().fft_request();
+ apply(fft_request.operand());
+ break;
+ }
+
case OpRequest::kBatchNormTrainingRequest: {
const BatchNormTrainingRequest& batch_norm_training_request =
request.request().batch_norm_training_request();
@@ -2880,6 +2918,15 @@ void ComputationLowerer::Visit(
break;
}
+ case OpRequest::kFftRequest: {
+ const FftRequest& fft_request = request.request().fft_request();
+ HloInstruction* operand = lookup_instruction(fft_request.operand());
+ hlo_instruction = add_instruction(HloInstruction::CreateFft(
+ request.output_shape(), operand, fft_request.fft_type(),
+ AsInt64Slice(fft_request.fft_length())));
+ break;
+ }
+
case OpRequest::kDotRequest: {
const DotRequest& dot_request = request.request().dot_request();
HloInstruction* lhs = lookup_instruction(dot_request.lhs());
diff --git a/tensorflow/compiler/xla/service/user_computation.h b/tensorflow/compiler/xla/service/user_computation.h
index 8a78d520e1..8be639c784 100644
--- a/tensorflow/compiler/xla/service/user_computation.h
+++ b/tensorflow/compiler/xla/service/user_computation.h
@@ -133,6 +133,10 @@ class UserComputation {
StatusOr<ComputationDataHandle> AddConvolveInstruction(
const ConvolveRequest& convolve_request);
+ // Enqueues an FFT instruction onto this user computation.
+ StatusOr<ComputationDataHandle> AddFftInstruction(
+ const FftRequest& fft_request);
+
// Enqueues a cross replica sum instruction onto this user computation.
StatusOr<ComputationDataHandle> AddCrossReplicaSumInstruction(
const CrossReplicaSumRequest& cross_replica_sum_request);