aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-02 13:01:59 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-02 13:05:54 -0800
commit5bf26acd87d3d44183fc28cb9576cda10c0255ca (patch)
tree5486c0ab9077496e30b594d35c089a45f7b7e18c
parent97a843db78745fe3a8e418b3b1e93ef79fbfff12 (diff)
Automated g4 rollback of changelist 180000981
PiperOrigin-RevId: 180581912
-rw-r--r--tensorflow/compiler/tests/BUILD15
-rw-r--r--tensorflow/compiler/tests/fft_test.py179
-rw-r--r--tensorflow/compiler/tests/randomized_tests.cc121
-rw-r--r--tensorflow/compiler/tf2xla/kernels/BUILD2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/fft_ops.cc122
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_registry.cc9
-rw-r--r--tensorflow/compiler/xla/client/computation_builder.cc25
-rw-r--r--tensorflow/compiler/xla/client/computation_builder.h6
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD19
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc2
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_runtime.cc1
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_runtime.h1
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc49
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h1
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc3
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_fft.cc37
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_fft.h31
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h238
-rw-r--r--tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc2
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor.h1
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h3
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD3
-rw-r--r--tensorflow/compiler/xla/service/gpu/fft_thunk.cc234
-rw-r--r--tensorflow/compiler/xla/service/gpu/fft_thunk.h98
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.cc8
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.h5
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc19
-rw-r--r--tensorflow/compiler/xla/service/gpu/thunk.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo.proto6
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.cc15
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc1
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc32
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h21
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc8
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc1
-rw-r--r--tensorflow/compiler/xla/service/service.cc5
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc72
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.h5
-rw-r--r--tensorflow/compiler/xla/service/user_computation.cc47
-rw-r--r--tensorflow/compiler/xla/service/user_computation.h4
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_parser.cc37
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc48
44 files changed, 1532 insertions, 7 deletions
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 5732e95575..78777f3c96 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -241,6 +241,21 @@ tf_xla_py_test(
)
tf_xla_py_test(
+ name = "fft_test",
+ size = "medium",
+ srcs = ["fft_test.py"],
+ shard_count = 3,
+ tags = ["optonly"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:spectral_ops",
+ ],
+)
+
+tf_xla_py_test(
name = "slice_ops_test",
size = "small",
srcs = ["slice_ops_test.py"],
diff --git a/tensorflow/compiler/tests/fft_test.py b/tensorflow/compiler/tests/fft_test.py
new file mode 100644
index 0000000000..bdc38be48c
--- /dev/null
+++ b/tensorflow/compiler/tests/fft_test.py
@@ -0,0 +1,179 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for FFT via the XLA JIT."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import itertools
+
+import numpy as np
+
+from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import spectral_ops
+from tensorflow.python.platform import googletest
+
+BATCH_DIMS = (3, 5)
+RTOL = 0.02 # Eigen/cuFFT differ widely from np, especially for FFT3D
+ATOL = 1e-3
+
+
+def pick_10(x):
+ x = list(x)
+ np.random.seed(123)
+ np.random.shuffle(x)
+ return x[:10]
+
+
+def to_32bit(x):
+ if x.dtype == np.complex128:
+ return x.astype(np.complex64)
+ if x.dtype == np.float64:
+ return x.astype(np.float32)
+ return x
+
+
+POWS_OF_2 = 2**np.arange(3, 12)
+INNER_DIMS_1D = list((x,) for x in POWS_OF_2)
+POWS_OF_2 = 2**np.arange(3, 8) # To avoid OOM on GPU.
+INNER_DIMS_2D = pick_10(itertools.product(POWS_OF_2, POWS_OF_2))
+INNER_DIMS_3D = pick_10(itertools.product(POWS_OF_2, POWS_OF_2, POWS_OF_2))
+
+
+class FFTTest(XLATestCase):
+
+ def _VerifyFftMethod(self, inner_dims, complex_to_input, input_to_expected,
+ tf_method):
+ for indims in inner_dims:
+ print("nfft =", indims)
+ shape = BATCH_DIMS + indims
+ data = np.arange(np.prod(shape) * 2) / np.prod(indims)
+ np.random.seed(123)
+ np.random.shuffle(data)
+ data = np.reshape(data.astype(np.float32).view(np.complex64), shape)
+ data = to_32bit(complex_to_input(data))
+ expected = to_32bit(input_to_expected(data))
+ with self.test_session() as sess:
+ with self.test_scope():
+ ph = array_ops.placeholder(
+ dtypes.as_dtype(data.dtype), shape=data.shape)
+ out = tf_method(ph)
+ value = sess.run(out, {ph: data})
+ self.assertAllClose(expected, value, rtol=RTOL, atol=ATOL)
+
+ def testFFT(self):
+ self._VerifyFftMethod(INNER_DIMS_1D, lambda x: x, np.fft.fft,
+ spectral_ops.fft)
+
+ def testFFT2D(self):
+ self._VerifyFftMethod(INNER_DIMS_2D, lambda x: x, np.fft.fft2,
+ spectral_ops.fft2d)
+
+ def testFFT3D(self):
+ self._VerifyFftMethod(INNER_DIMS_3D, lambda x: x,
+ lambda x: np.fft.fftn(x, axes=(-3, -2, -1)),
+ spectral_ops.fft3d)
+
+ def testIFFT(self):
+ self._VerifyFftMethod(INNER_DIMS_1D, lambda x: x, np.fft.ifft,
+ spectral_ops.ifft)
+
+ def testIFFT2D(self):
+ self._VerifyFftMethod(INNER_DIMS_2D, lambda x: x, np.fft.ifft2,
+ spectral_ops.ifft2d)
+
+ def testIFFT3D(self):
+ self._VerifyFftMethod(INNER_DIMS_3D, lambda x: x,
+ lambda x: np.fft.ifftn(x, axes=(-3, -2, -1)),
+ spectral_ops.ifft3d)
+
+ def testRFFT(self):
+ self._VerifyFftMethod(
+ INNER_DIMS_1D, np.real, lambda x: np.fft.rfft(x, n=x.shape[-1]),
+ lambda x: spectral_ops.rfft(x, fft_length=[x.shape[-1].value]))
+
+ def testRFFT2D(self):
+
+ def _tf_fn(x):
+ return spectral_ops.rfft2d(
+ x, fft_length=[x.shape[-2].value, x.shape[-1].value])
+
+ self._VerifyFftMethod(
+ INNER_DIMS_2D, np.real,
+ lambda x: np.fft.rfft2(x, s=[x.shape[-2], x.shape[-1]]), _tf_fn)
+
+ def testRFFT3D(self):
+
+ def _to_expected(x):
+ return np.fft.rfftn(
+ x, axes=(-3, -2, -1), s=[x.shape[-3], x.shape[-2], x.shape[-1]])
+
+ def _tf_fn(x):
+ return spectral_ops.rfft3d(
+ x,
+ fft_length=[x.shape[-3].value, x.shape[-2].value, x.shape[-1].value])
+
+ self._VerifyFftMethod(INNER_DIMS_3D, np.real, _to_expected, _tf_fn)
+
+ def testIRFFT(self):
+
+ def _tf_fn(x):
+ return spectral_ops.irfft(x, fft_length=[2 * (x.shape[-1].value - 1)])
+
+ self._VerifyFftMethod(
+ INNER_DIMS_1D, lambda x: np.fft.rfft(np.real(x), n=x.shape[-1]),
+ lambda x: np.fft.irfft(x, n=2 * (x.shape[-1] - 1)), _tf_fn)
+
+ def testIRFFT2D(self):
+
+ def _tf_fn(x):
+ return spectral_ops.irfft2d(
+ x, fft_length=[x.shape[-2].value, 2 * (x.shape[-1].value - 1)])
+
+ self._VerifyFftMethod(
+ INNER_DIMS_2D,
+ lambda x: np.fft.rfft2(np.real(x), s=[x.shape[-2], x.shape[-1]]),
+ lambda x: np.fft.irfft2(x, s=[x.shape[-2], 2 * (x.shape[-1] - 1)]),
+ _tf_fn)
+
+ def testIRFFT3D(self):
+
+ def _to_input(x):
+ return np.fft.rfftn(
+ np.real(x),
+ axes=(-3, -2, -1),
+ s=[x.shape[-3], x.shape[-2], x.shape[-1]])
+
+ def _to_expected(x):
+ return np.fft.irfftn(
+ x,
+ axes=(-3, -2, -1),
+ s=[x.shape[-3], x.shape[-2], 2 * (x.shape[-1] - 1)])
+
+ def _tf_fn(x):
+ return spectral_ops.irfft3d(
+ x,
+ fft_length=[
+ x.shape[-3].value, x.shape[-2].value, 2 * (x.shape[-1].value - 1)
+ ])
+
+ self._VerifyFftMethod(INNER_DIMS_3D, _to_input, _to_expected, _tf_fn)
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc
index 2ab013b7fa..e72dd4eea9 100644
--- a/tensorflow/compiler/tests/randomized_tests.cc
+++ b/tensorflow/compiler/tests/randomized_tests.cc
@@ -93,11 +93,11 @@ class OpTestBuilder {
public:
explicit OpTestBuilder(const string& op_name);
- // Adds an input 'tensor'.
+ // Adds an input 'tensor' as a Placeholder node.
OpTestBuilder& Input(const Tensor& tensor);
- // Adds a random input tensor with 'type'. If 'dims' is not provided,
- // RandomDims() is used.
+ // Adds a random input tensor with 'type' as a Placeholder node.
+ // If 'dims' is not provided, RandomDims() is used.
OpTestBuilder& RandomInput(DataType type);
OpTestBuilder& RandomInput(DataType type, std::vector<int64> dims);
@@ -1375,6 +1375,121 @@ TEST_F(OpTest, Conj) {
});
}
+TEST_F(OpTest, FFT) {
+ Repeatedly([this]() {
+ std::vector<int64> dims = RandomDims(1, kDefaultMaxRank);
+ return ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("FFT").RandomInput(DT_COMPLEX64, dims));
+ });
+}
+
+TEST_F(OpTest, FFT2D) {
+ Repeatedly([this]() {
+ std::vector<int64> dims = RandomDims(2, kDefaultMaxRank);
+ return ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("FFT2D").RandomInput(DT_COMPLEX64, dims));
+ });
+}
+
+TEST_F(OpTest, FFT3D) {
+ Repeatedly([this]() {
+ std::vector<int64> dims = RandomDims(3, kDefaultMaxRank);
+ return ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("FFT3D").RandomInput(DT_COMPLEX64, dims));
+ });
+}
+
+TEST_F(OpTest, IFFT) {
+ Repeatedly([this]() {
+ std::vector<int64> dims = RandomDims(1, kDefaultMaxRank);
+ return ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("IFFT").RandomInput(DT_COMPLEX64, dims));
+ });
+}
+
+TEST_F(OpTest, IFFT2D) {
+ Repeatedly([this]() {
+ std::vector<int64> dims = RandomDims(2, kDefaultMaxRank);
+ return ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("IFFT2D").RandomInput(DT_COMPLEX64, dims));
+ });
+}
+
+TEST_F(OpTest, IFFT3D) {
+ Repeatedly([this]() {
+ std::vector<int64> dims = RandomDims(3, kDefaultMaxRank);
+ return ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("IFFT3D").RandomInput(DT_COMPLEX64, dims));
+ });
+}
+
+TEST_F(OpTest, RFFT) {
+ Repeatedly([this]() {
+ std::vector<int64> dims = RandomDims(1, kDefaultMaxRank, 3);
+ Tensor fft_shape = test::AsTensor<int32>(AsInt32s({dims[dims.size() - 1]}));
+ return ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("RFFT").RandomInput(DT_FLOAT, dims).Input(fft_shape));
+ });
+}
+
+TEST_F(OpTest, RFFT2D) {
+ Repeatedly([this]() {
+ std::vector<int64> dims = RandomDims(2, kDefaultMaxRank, 3);
+ Tensor fft_shape = test::AsTensor<int32>(
+ AsInt32s({dims[dims.size() - 2], dims[dims.size() - 1]}));
+ return ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("RFFT2D").RandomInput(DT_FLOAT, dims).Input(fft_shape));
+ });
+}
+
+TEST_F(OpTest, RFFT3D) {
+ Repeatedly([this]() {
+ std::vector<int64> dims = RandomDims(3, kDefaultMaxRank, 3);
+ Tensor fft_shape = test::AsTensor<int32>(AsInt32s(
+ {dims[dims.size() - 3], dims[dims.size() - 2], dims[dims.size() - 1]}));
+ return ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("RFFT3D").RandomInput(DT_FLOAT, dims).Input(fft_shape));
+ });
+}
+
+TEST_F(OpTest, IRFFT) {
+ Repeatedly([this]() {
+ std::vector<int64> dims = RandomDims(1, kDefaultMaxRank, 3);
+ int64 orig_size = dims[dims.size() - 1];
+ dims[dims.size() - 1] = dims[dims.size() - 1] / 2 + 1;
+ Tensor fft_shape = test::AsTensor<int32>(AsInt32s({orig_size}));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("IRFFT")
+ .RandomInput(DT_COMPLEX64, dims)
+ .Input(fft_shape));
+ });
+}
+
+TEST_F(OpTest, IRFFT2D) {
+ Repeatedly([this]() {
+ std::vector<int64> dims = RandomDims(2, kDefaultMaxRank, 3);
+ std::vector<int64> orig_size = {dims[dims.size() - 2],
+ dims[dims.size() - 1]};
+ dims[dims.size() - 1] = dims[dims.size() - 1] / 2 + 1;
+ Tensor fft_shape = test::AsTensor<int32>(AsInt32s({orig_size}));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("IRFFT2D")
+ .RandomInput(DT_COMPLEX64, dims)
+ .Input(fft_shape));
+ });
+}
+
+TEST_F(OpTest, IRFFT3D) {
+ Repeatedly([this]() {
+ std::vector<int64> dims = RandomDims(3, kDefaultMaxRank, 3);
+ std::vector<int64> orig_size = {
+ dims[dims.size() - 3], dims[dims.size() - 2], dims[dims.size() - 1]};
+ dims[dims.size() - 1] = dims[dims.size() - 1] / 2 + 1;
+ Tensor fft_shape = test::AsTensor<int32>(AsInt32s({orig_size}));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("IRFFT3D")
+ .RandomInput(DT_COMPLEX64, dims)
+ .Input(fft_shape));
+ });
+}
+
TEST_F(OpTest, Conv2D) {
Repeatedly([this]() {
WindowedSpatialDims d = ChooseWindowedSpatialDims(2);
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 092d852fe3..0dd95de0ce 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -30,6 +30,7 @@ tf_kernel_library(
"diag_op.cc",
"dynamic_stitch_op.cc",
"elu_op.cc",
+ "fft_ops.cc",
"fill_op.cc",
"function_ops.cc",
"gather_op.cc",
@@ -105,6 +106,7 @@ tf_kernel_library(
"//tensorflow/core:lib",
"//tensorflow/core:linalg_ops_op_lib",
"//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:spectral_ops_op_lib",
"//tensorflow/core:stateless_random_ops_op_lib",
"//tensorflow/core/kernels:bounds_check",
"//tensorflow/core/kernels:concat_lib",
diff --git a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc
new file mode 100644
index 0000000000..a4f3c1c3ad
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc
@@ -0,0 +1,122 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// XLA-specific Ops for FFT.
+
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_slice.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/conv_grad_ops.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/util/padding.h"
+#include "tensorflow/core/util/tensor_format.h"
+
+namespace tensorflow {
+
+namespace {
+
+using xla::FftType;
+
+class GenericFftOp : public XlaOpKernel {
+ public:
+ explicit GenericFftOp(OpKernelConstruction* ctx, FftType fft_type,
+ int fft_rank)
+ : XlaOpKernel(ctx), fft_type_(fft_type), fft_rank_(fft_rank) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const TensorShape input_shape = ctx->InputShape(0);
+ OP_REQUIRES(
+ ctx, TensorShapeUtils::IsVectorOrHigher(input_shape),
+ errors::InvalidArgument("input must be at least 1 dimensional"));
+
+ std::vector<int64> fft_length;
+ if (fft_type_ == FftType::RFFT || fft_type_ == FftType::IRFFT) {
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &fft_length));
+ OP_REQUIRES(ctx, fft_length.size() == fft_rank_,
+ errors::InvalidArgument("fft_length must be length ",
+ fft_rank_, " vector"));
+ } else {
+ // Innermost axis provides the FFT length.
+ for (int i = 0; i < fft_rank_; i++) {
+ fft_length.push_back(
+ input_shape.dim_size(input_shape.dims() - fft_rank_ + i));
+ }
+ }
+
+ xla::ComputationBuilder* b = ctx->builder();
+ xla::ComputationDataHandle fft =
+ b->Fft(ctx->Input(0), fft_type_, fft_length);
+ ctx->SetOutput(0, fft);
+ }
+
+ protected:
+ const FftType fft_type_;
+ const int fft_rank_;
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(GenericFftOp);
+};
+
+template <int FFTRank>
+class FFTOp : public GenericFftOp {
+ public:
+ explicit FFTOp(OpKernelConstruction* ctx)
+ : GenericFftOp(ctx, /*fft_type=*/FftType::FFT, /*fft_rank=*/FFTRank) {}
+};
+REGISTER_XLA_OP(Name("FFT"), FFTOp<1>);
+REGISTER_XLA_OP(Name("FFT2D"), FFTOp<2>);
+REGISTER_XLA_OP(Name("FFT3D"), FFTOp<3>);
+
+template <int FFTRank>
+class IFFTOp : public GenericFftOp {
+ public:
+ explicit IFFTOp(OpKernelConstruction* ctx)
+ : GenericFftOp(ctx, /*fft_type=*/FftType::IFFT, /*fft_rank=*/FFTRank) {}
+};
+REGISTER_XLA_OP(Name("IFFT"), IFFTOp<1>);
+REGISTER_XLA_OP(Name("IFFT2D"), IFFTOp<2>);
+REGISTER_XLA_OP(Name("IFFT3D"), IFFTOp<3>);
+
+template <int FFTRank>
+class RFFTOp : public GenericFftOp {
+ public:
+ explicit RFFTOp(OpKernelConstruction* ctx)
+ : GenericFftOp(ctx, /*fft_type=*/FftType::RFFT, /*fft_rank=*/FFTRank) {}
+};
+REGISTER_XLA_OP(Name("RFFT").CompileTimeConstInput("fft_length"), RFFTOp<1>);
+REGISTER_XLA_OP(Name("RFFT2D").CompileTimeConstInput("fft_length"), RFFTOp<2>);
+REGISTER_XLA_OP(Name("RFFT3D").CompileTimeConstInput("fft_length"), RFFTOp<3>);
+
+template <int FFTRank>
+class IRFFTOp : public GenericFftOp {
+ public:
+ explicit IRFFTOp(OpKernelConstruction* ctx)
+ : GenericFftOp(ctx, /*fft_type=*/FftType::IRFFT, /*fft_rank=*/FFTRank) {}
+};
+REGISTER_XLA_OP(Name("IRFFT").CompileTimeConstInput("fft_length"), IRFFTOp<1>);
+REGISTER_XLA_OP(Name("IRFFT2D").CompileTimeConstInput("fft_length"),
+ IRFFTOp<2>);
+REGISTER_XLA_OP(Name("IRFFT3D").CompileTimeConstInput("fft_length"),
+ IRFFTOp<3>);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc
index 97bb100fb1..0dde6a986c 100644
--- a/tensorflow/compiler/tf2xla/xla_op_registry.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc
@@ -161,7 +161,14 @@ void XlaOpRegistry::RegisterCompilationKernels() {
const string& op_name = op.first;
const std::unique_ptr<OpRegistration>& op_registration = op.second;
const OpDef* op_def;
- TF_CHECK_OK(op_registry->LookUpOpDef(op_name, &op_def));
+ Status lookup_status = op_registry->LookUpOpDef(op_name, &op_def);
+ if (!lookup_status.ok()) {
+ LOG(ERROR) << lookup_status.error_message();
+ XLA_LOG_LINES(
+ ERROR, "Ops registered: \n" +
+ dynamic_cast<OpRegistry*>(op_registry)->DebugString(true));
+ }
+ TF_CHECK_OK(lookup_status);
std::unordered_set<string> type_attrs;
for (const OpDef::AttrDef& attr_def : op_def->attr()) {
diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc
index 317dcb4e41..1c0669c1d4 100644
--- a/tensorflow/compiler/xla/client/computation_builder.cc
+++ b/tensorflow/compiler/xla/client/computation_builder.cc
@@ -855,6 +855,31 @@ ComputationDataHandle ComputationBuilder::ConvGeneralDilated(
return ParseOpResponse(s, &response);
}
+ComputationDataHandle ComputationBuilder::Fft(
+ const ComputationDataHandle& operand, const FftType fft_type,
+ const tensorflow::gtl::ArraySlice<int64> fft_length) {
+ if (!first_error_.ok() || !PrepareComputation().ok()) {
+ return ComputationDataHandle();
+ }
+
+ FftRequest request;
+ *request.mutable_operand() = operand;
+ request.set_fft_type(fft_type);
+ for (int64 dim_len : fft_length) {
+ request.add_fft_length(dim_len);
+ }
+ OpRequest op_request;
+ *op_request.mutable_computation() = computation_.handle();
+ *op_request.mutable_fft_request() = request;
+ AddCommonFieldsToOpRequest(&op_request);
+ OpResponse response;
+
+ VLOG(2) << "making fft op request";
+ Status s = client_->stub()->Op(&op_request, &response);
+
+ return ParseOpResponse(s, &response);
+}
+
ComputationDataHandle ComputationBuilder::Infeed(const Shape& shape,
const string& config) {
if (!first_error_.ok() || !PrepareComputation().ok()) {
diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h
index 7293b35c0f..1a3a54d91e 100644
--- a/tensorflow/compiler/xla/client/computation_builder.h
+++ b/tensorflow/compiler/xla/client/computation_builder.h
@@ -410,6 +410,12 @@ class ComputationBuilder {
tensorflow::gtl::ArraySlice<int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers);
+ // Enqueues an FFT instruction onto the computation, of the given type and
+ // with the given FFT length.
+ ComputationDataHandle Fft(const ComputationDataHandle& operand,
+ FftType fft_type,
+ tensorflow::gtl::ArraySlice<int64> fft_length);
+
// Enqueues an infeed instruction onto the computation, which writes data of
// the given shape to the infeed buffer of the device.
ComputationDataHandle Infeed(const Shape& shape, const string& config = "");
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index e35d947525..9754f8d9af 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -165,6 +165,7 @@ cc_library(
":external_constant_pool",
":orc_jit_memory_mapper",
":runtime_conv2d",
+ ":runtime_fft",
":runtime_fork_join",
":runtime_matmul",
":runtime_single_threaded_conv2d",
@@ -504,6 +505,24 @@ cc_library(
)
cc_library(
+ name = "runtime_fft",
+ srcs = [
+ "runtime_fft.cc",
+ "runtime_fft_impl.h",
+ ],
+ hdrs = ["runtime_fft.h"],
+ copts = runtime_copts(),
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/compiler/xla:executable_run_options",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_lite",
+ "//third_party/eigen3",
+ ],
+)
+
+cc_library(
name = "runtime_matvec",
srcs = ["runtime_matvec.cc"],
hdrs = ["runtime_matvec.h"],
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index ba208d7249..9a5e146b5a 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -705,6 +705,8 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
ir_module_string = llvm_ir::DumpModuleToString(*llvm_module);
}
+ XLA_VLOG_LINES(2, "LLVM IR:\n" + llvm_ir::DumpModuleToString(*llvm_module));
+
// JIT compile the LLVM IR module to in-memory machine code.
jit->AddModule(std::move(llvm_module));
cpu_executable.reset(new CpuExecutable(
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
index 7908dc173d..1ef45dbec3 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
@@ -37,6 +37,7 @@ extern const char* const kEigenMatMulF64SymbolName =
"__xla_cpu_runtime_EigenMatMulF64";
extern const char* const kEigenConvF32SymbolName =
"__xla_cpu_runtime_EigenConvF32";
+extern const char* const kEigenFftSymbolName = "__xla_cpu_runtime_EigenFft";
extern const char* const kEigenSingleThreadedMatMulF32SymbolName =
"__xla_cpu_runtime_EigenSingleThreadedMatMulF32";
extern const char* const kEigenSingleThreadedMatMulF64SymbolName =
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h
index 2ade455b8a..3e1f080711 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h
@@ -44,6 +44,7 @@ namespace runtime {
extern const char* const kEigenMatMulF32SymbolName;
extern const char* const kEigenMatMulF64SymbolName;
extern const char* const kEigenConvF32SymbolName;
+extern const char* const kEigenFftSymbolName;
extern const char* const kEigenSingleThreadedMatMulF32SymbolName;
extern const char* const kEigenSingleThreadedMatMulF64SymbolName;
extern const char* const kEigenSingleThreadedConvF32SymbolName;
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 4165f920d2..26bd9ad326 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -1135,6 +1135,55 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
});
}
+Status IrEmitter::HandleFft(HloInstruction* fft) {
+ auto operand = fft->operand(0);
+ TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
+ /*instruction=*/*fft, /*operands=*/{operand},
+ /*supported_types=*/{F32, C64}));
+ TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(operand->shape().layout()));
+ TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(fft->shape().layout()));
+ VLOG(3) << "operand=" << ShapeUtil::HumanStringWithLayout(operand->shape());
+ VLOG(3) << "fft=" << ShapeUtil::HumanStringWithLayout(fft->shape());
+
+ llvm::Value* operand_address = GetEmittedValueFor(operand);
+ TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fft));
+
+ const std::vector<int64>& fft_length = fft->fft_length();
+ int64 input_batch = 1;
+ for (int i = 0; i < fft->shape().dimensions_size() - fft_length.size(); i++) {
+ input_batch *= fft->shape().dimensions(i);
+ }
+
+ // Args have been computed, make the call.
+ llvm::Type* int8_ptr_type = ir_builder_.getInt8Ty()->getPointerTo();
+ llvm::Type* int32_type = ir_builder_.getInt32Ty();
+ llvm::Type* int64_type = ir_builder_.getInt64Ty();
+ llvm::FunctionType* fft_type = llvm::FunctionType::get(
+ ir_builder_.getVoidTy(),
+ {int8_ptr_type, int8_ptr_type, int8_ptr_type, int32_type, int32_type,
+ int64_type, int64_type, int64_type, int64_type},
+ /*isVarArg=*/false);
+ const char* fn_name = runtime::kEigenFftSymbolName;
+ llvm::Function* fft_func = llvm::cast<llvm::Function>(
+ module_->getOrInsertFunction(fn_name, fft_type));
+ fft_func->setCallingConv(llvm::CallingConv::C);
+ fft_func->setDoesNotThrow();
+ fft_func->setOnlyAccessesInaccessibleMemOrArgMem();
+ const int fft_rank = fft_length.size();
+ ir_builder_.CreateCall(
+ fft_func,
+ {GetExecutableRunOptionsArgument(),
+ ir_builder_.CreateBitCast(GetEmittedValueFor(fft), int8_ptr_type),
+ ir_builder_.CreateBitCast(operand_address, int8_ptr_type),
+ ir_builder_.getInt32(fft->fft_type()), ir_builder_.getInt32(fft_rank),
+ ir_builder_.getInt64(input_batch),
+ ir_builder_.getInt64(fft_rank > 0 ? fft_length[0] : 0),
+ ir_builder_.getInt64(fft_rank > 1 ? fft_length[1] : 0),
+ ir_builder_.getInt64(fft_rank > 2 ? fft_length[2] : 0)});
+
+ return Status::OK();
+}
+
Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) {
if (hlo_module_config_.replica_count() == 1) {
// When there is a single replica, a cross replica sum is the identity
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index 2341e3ea72..b8d71eba18 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -158,6 +158,7 @@ class IrEmitter : public DfsHloVisitorWithDefault {
Status HandleSelect(HloInstruction* select) override;
Status HandleDot(HloInstruction* dot) override;
Status HandleConvolution(HloInstruction* convolution) override;
+ Status HandleFft(HloInstruction* fft) override;
Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override;
Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override;
Status HandleCrossReplicaSum(HloInstruction* crs) override;
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
index 4b44ac8941..deb21bf4ef 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
+++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
@@ -126,7 +126,7 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount(
HloInstruction* instruction) {
// Currently, we do not assign parallel tasks to instructions with at least
// one of the following properties:
- // *) Internal threading (library calls to kConv, kDot, and kCustomCall).
+ // *) Internal threading (library calls to kConv, kDot, kFft, kCustomCall).
// *) Emit custom loops (kSelectAndScatter, FusionKind::kTransposeDot).
// *) Tuple-shaped.
// TODO(b/27458679) Parallelize instructions which are skipped here.
@@ -137,6 +137,7 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount(
instruction->opcode() == HloOpcode::kSelectAndScatter ||
instruction->opcode() == HloOpcode::kGetTupleElement ||
instruction->opcode() == HloOpcode::kBitcast ||
+ instruction->opcode() == HloOpcode::kFft ||
(instruction->opcode() == HloOpcode::kConvolution &&
PotentiallyImplementedAsEigenConvolution(*instruction)) ||
PotentiallyImplementedAsEigenDot(*instruction) ||
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fft.cc b/tensorflow/compiler/xla/service/cpu/runtime_fft.cc
new file mode 100644
index 0000000000..848d2d2241
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/runtime_fft.cc
@@ -0,0 +1,37 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/cpu/runtime_fft.h"
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/compiler/xla/executable_run_options.h"
+#include "tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h"
+#include "tensorflow/core/platform/dynamic_annotations.h"
+#include "tensorflow/core/platform/types.h"
+
+using tensorflow::int32;
+using tensorflow::int64;
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenFft(
+ const void* run_options_ptr, void* out, void* operand, int32 fft_type,
+ int32 fft_rank, int64 input_batch, int64 fft_length0, int64 fft_length1,
+ int64 fft_length2) {
+ const xla::ExecutableRunOptions* run_options =
+ static_cast<const xla::ExecutableRunOptions*>(run_options_ptr);
+ tensorflow::xla::EigenFftImpl(*run_options->intra_op_thread_pool(), out,
+ operand, fft_type, fft_rank, input_batch,
+ fft_length0, fft_length1, fft_length2);
+}
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fft.h b/tensorflow/compiler/xla/service/cpu/runtime_fft.h
new file mode 100644
index 0000000000..f20c5aa0aa
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/runtime_fft.h
@@ -0,0 +1,31 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_H_
+
+#include "tensorflow/core/platform/types.h"
+
+extern "C" {
+
+extern void __xla_cpu_runtime_EigenFft(
+ const void* /* xla::ExecutableRunOptions* */ run_options_ptr, void* out,
+ void* operand, tensorflow::int32 fft_type, tensorflow::int32 fft_rank,
+ tensorflow::int64 input_batch, tensorflow::int64 fft_length0,
+ tensorflow::int64 fft_length1, tensorflow::int64 fft_length2);
+
+} // extern "C"
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_H_
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h b/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h
new file mode 100644
index 0000000000..c7c701ca97
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h
@@ -0,0 +1,238 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_IMPL_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_IMPL_H_
+
+#include "third_party/eigen3/Eigen/Core"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/framework/numeric_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/platform/types.h"
+
+// 'tensorflow' namespace is used so that int64 and other types don't require
+// qualification.
+namespace tensorflow {
+namespace xla {
+
+namespace internal {
+
+// Computes either a forward or reverse complex-to-complex FFT.
+template <bool Forward, int FFTRank, typename EigenDevice>
+void EigenFftC2C(const EigenDevice& device, complex64* out, complex64* operand,
+ int64 input_batch, int64 fft_length0, int64 fft_length1,
+ int64 fft_length2) {
+ // Create the axes (which are always trailing).
+ const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank);
+ constexpr auto direction = Forward ? Eigen::FFT_FORWARD : Eigen::FFT_REVERSE;
+
+ const std::array<int64, 3> fft_shape = {
+ {fft_length0, fft_length1, fft_length2}};
+
+ Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> dims;
+ dims[0] = input_batch;
+ for (int i = 0; i < FFTRank; i++) {
+ dims[i + 1] = fft_shape[i];
+ }
+ const Eigen::TensorMap<Eigen::Tensor<complex64, FFTRank + 1, Eigen::RowMajor>,
+ Eigen::Aligned>
+ input(operand, dims);
+ Eigen::TensorMap<Eigen::Tensor<complex64, FFTRank + 1, Eigen::RowMajor>,
+ Eigen::Aligned>
+ output(out, dims);
+ output.device(device) = input.template fft<Eigen::BothParts, direction>(axes);
+}
+
+// Computes a forward real->complex FFT, slicing out redundant negative
+// frequencies from the innermost dimension.
+template <int FFTRank, typename EigenDevice>
+void EigenFftR2C(const EigenDevice& device, complex64* out, float* operand,
+ int64 input_batch, int64 fft_length0, int64 fft_length1,
+ int64 fft_length2) {
+ const std::array<int64, 3> fft_shape = {
+ {fft_length0, fft_length1, fft_length2}};
+
+ Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> in_dims;
+ in_dims[0] = input_batch;
+ Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> out_dims;
+ out_dims[0] = input_batch;
+ TensorShape temp_shape{input_batch};
+ for (int i = 0; i < FFTRank; i++) {
+ in_dims[i + 1] = fft_shape[i];
+ out_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i];
+ temp_shape.AddDim(fft_shape[i]);
+ }
+ const Eigen::TensorMap<Eigen::Tensor<float, FFTRank + 1, Eigen::RowMajor>,
+ Eigen::Aligned>
+ input(operand, in_dims);
+ Eigen::TensorMap<Eigen::Tensor<complex64, FFTRank + 1, Eigen::RowMajor>,
+ Eigen::Aligned>
+ output(out, out_dims);
+
+ // Create the axes (which are always trailing).
+ const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank);
+
+ // Compute the full FFT using a temporary tensor.
+ Tensor temp(DataTypeToEnum<complex64>::v(), temp_shape);
+ auto full_fft = temp.flat_inner_dims<complex64, FFTRank + 1>();
+ const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> zero_start_indices;
+ full_fft.device(device) =
+ input.template fft<Eigen::BothParts, Eigen::FFT_FORWARD>(axes);
+
+ // Slice away the negative frequency components.
+ output.device(device) = full_fft.slice(zero_start_indices, out_dims);
+}
+
+// Computes a reverse complex->real FFT, reconstructing redundant negative
+// frequencies using reverse conjugate on innermost dimension after doing IFFT
+// on outer dimensions.
+template <int FFTRank, typename EigenDevice>
+void EigenFftC2R(const EigenDevice& device, float* out, complex64* operand,
+ int64 input_batch, int64 fft_length0, int64 fft_length1,
+ int64 fft_length2) {
+ const std::array<int64, 3> fft_shape = {
+ {fft_length0, fft_length1, fft_length2}};
+
+ Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> in_dims;
+ in_dims[0] = input_batch;
+ Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> out_dims;
+ out_dims[0] = input_batch;
+ TensorShape temp_shape{input_batch};
+ for (int i = 0; i < FFTRank; i++) {
+ in_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i];
+ out_dims[i + 1] = fft_shape[i];
+ temp_shape.AddDim(fft_shape[i]);
+ }
+ const Eigen::TensorMap<Eigen::Tensor<complex64, FFTRank + 1, Eigen::RowMajor>,
+ Eigen::Aligned>
+ input(operand, in_dims);
+ Eigen::TensorMap<Eigen::Tensor<float, FFTRank + 1, Eigen::RowMajor>,
+ Eigen::Aligned>
+ output(out, out_dims);
+
+ // Calculate the shape of the temporary tensor for the full FFT and the
+ // region we will slice from input given fft_shape. We slice input to
+ // fft_shape on its inner-most dimensions, except the last (which we
+ // slice to fft_shape[-1] / 2 + 1).
+ Tensor temp(DataTypeToEnum<complex64>::v(), temp_shape);
+ auto full_fft = temp.flat_inner_dims<complex64, FFTRank + 1>();
+
+ // Calculate the starting point and range of the source of
+ // negative frequency part.
+ auto neg_sizes = in_dims;
+ neg_sizes[FFTRank] = fft_shape[FFTRank - 1] - in_dims[FFTRank];
+ Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> neg_target_indices;
+ neg_target_indices[FFTRank] = in_dims[FFTRank];
+
+ const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> zero_start_indices;
+ Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> neg_start_indices;
+ neg_start_indices[FFTRank] = 1;
+
+ full_fft.slice(zero_start_indices, in_dims).device(device) = input;
+
+ // First, conduct IFFTs on outer dimensions. We save computation (and
+ // avoid touching uninitialized memory) by slicing full_fft to the
+ // subregion we wrote input to.
+ if (FFTRank > 1) {
+ const auto outer_axes =
+ Eigen::ArrayXi::LinSpaced(FFTRank - 1, 1, FFTRank - 1);
+ full_fft.slice(zero_start_indices, in_dims).device(device) =
+ full_fft.slice(zero_start_indices, in_dims)
+ .template fft<Eigen::BothParts, Eigen::FFT_REVERSE>(outer_axes);
+ }
+
+ // Reconstruct the full FFT by appending reversed and conjugated
+ // spectrum as the negative frequency part.
+ Eigen::array<bool, FFTRank + 1> reverse_last_axis;
+ for (auto i = 0; i <= FFTRank; i++) {
+ reverse_last_axis[i] = i == FFTRank;
+ }
+
+ if (neg_sizes[FFTRank] != 0) {
+ full_fft.slice(neg_target_indices, neg_sizes).device(device) =
+ full_fft.slice(neg_start_indices, neg_sizes)
+ .reverse(reverse_last_axis)
+ .conjugate();
+ }
+
+ auto inner_axis = Eigen::array<int, 1>{FFTRank};
+ output.device(device) =
+ full_fft.template fft<Eigen::RealPart, Eigen::FFT_REVERSE>(inner_axis);
+}
+
+template <int FFTRank, typename EigenDevice>
+void EigenFftWithRank(const EigenDevice& device, void* out, void* operand,
+ int32 fft_type, int64 input_batch, int64 fft_length0,
+ int64 fft_length1, int64 fft_length2) {
+ CHECK(::xla::FftType_IsValid(fft_type)) << fft_type;
+ switch (fft_type) {
+ case ::xla::FftType::FFT:
+ EigenFftC2C<true, FFTRank, EigenDevice>(
+ device, static_cast<complex64*>(out),
+ static_cast<complex64*>(operand), input_batch, fft_length0,
+ fft_length1, fft_length2);
+ break;
+ case ::xla::FftType::IFFT:
+ EigenFftC2C<false, FFTRank, EigenDevice>(
+ device, static_cast<complex64*>(out),
+ static_cast<complex64*>(operand), input_batch, fft_length0,
+ fft_length1, fft_length2);
+ break;
+ case ::xla::FftType::RFFT:
+ EigenFftR2C<FFTRank, EigenDevice>(
+ device, static_cast<complex64*>(out), static_cast<float*>(operand),
+ input_batch, fft_length0, fft_length1, fft_length2);
+ break;
+ case ::xla::FftType::IRFFT:
+ EigenFftC2R<FFTRank, EigenDevice>(
+ device, static_cast<float*>(out), static_cast<complex64*>(operand),
+ input_batch, fft_length0, fft_length1, fft_length2);
+ break;
+ default:
+ LOG(FATAL) << "Unsupported FFT type: " << fft_type;
+ }
+}
+
+} // namespace internal
+
+template <typename EigenDevice>
+void EigenFftImpl(const EigenDevice& device, void* out, void* operand,
+ int32 fft_type, int32 fft_rank, int64 input_batch,
+ int64 fft_length0, int64 fft_length1, int64 fft_length2) {
+ switch (fft_rank) {
+ case 1:
+ internal::EigenFftWithRank<1, EigenDevice>(
+ device, out, operand, fft_type, input_batch, fft_length0, 0, 0);
+ break;
+ case 2:
+ internal::EigenFftWithRank<2, EigenDevice>(device, out, operand, fft_type,
+ input_batch, fft_length0,
+ fft_length1, 0);
+ break;
+ case 3:
+ internal::EigenFftWithRank<3, EigenDevice>(device, out, operand, fft_type,
+ input_batch, fft_length0,
+ fft_length1, fft_length2);
+ break;
+ default:
+ LOG(FATAL) << "Unsupported FFT rank " << fft_rank;
+ }
+}
+
+} // namespace xla
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_IMPL_H_
diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
index 65da61805a..43ab2ec524 100644
--- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
+++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
@@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
#include "tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h"
+#include "tensorflow/compiler/xla/service/cpu/runtime_fft.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h"
@@ -208,6 +209,7 @@ bool RegisterKnownJITSymbols() {
REGISTER_CPU_RUNTIME_SYMBOL(AcquireInfeedBufferForDequeue);
REGISTER_CPU_RUNTIME_SYMBOL(AcquireOutfeedBufferForPopulation);
REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF32);
+ REGISTER_CPU_RUNTIME_SYMBOL(EigenFft);
REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF32);
REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF64);
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF32);
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
index 0d54e325e6..a803b3171f 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
@@ -103,6 +103,7 @@ class DfsHloVisitorBase {
return HandleElementwiseBinary(hlo);
}
virtual Status HandleConvolution(HloInstructionPtr hlo) = 0;
+ virtual Status HandleFft(HloInstructionPtr fft) = 0;
virtual Status HandleCrossReplicaSum(HloInstructionPtr hlo) = 0;
virtual Status HandleCompare(HloInstructionPtr hlo) {
return HandleElementwiseBinary(hlo);
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
index 133aa25094..170adb3d24 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
@@ -85,6 +85,9 @@ class DfsHloVisitorWithDefaultBase
Status HandleConvolution(HloInstructionPtr convolution) override {
return DefaultAction(convolution);
}
+ Status HandleFft(HloInstructionPtr fft) override {
+ return DefaultAction(fft);
+ }
Status HandleCrossReplicaSum(HloInstructionPtr crs) override {
return DefaultAction(crs);
}
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index d5d89de6d4..e4832b2ee6 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -220,6 +220,7 @@ cc_library(
"convolution_thunk.cc",
"copy_thunk.cc",
"cudnn_batchnorm_thunk.cc",
+ "fft_thunk.cc",
"for_thunk.cc",
"gemm_thunk.cc",
"gpu_executable.cc",
@@ -234,6 +235,7 @@ cc_library(
"convolution_thunk.h",
"copy_thunk.h",
"cudnn_batchnorm_thunk.h",
+ "fft_thunk.h",
"for_thunk.h",
"gemm_thunk.h",
"gpu_executable.h",
@@ -272,6 +274,7 @@ cc_library(
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core/platform/default/build_config:cublas_plugin",
"//tensorflow/core/platform/default/build_config:cudnn_plugin",
+ "//tensorflow/core/platform/default/build_config:cufft_plugin",
"//tensorflow/core/platform/default/build_config:stream_executor_cuda", # build_cleaner: keep
],
)
diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc
new file mode 100644
index 0000000000..66931bdc8b
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc
@@ -0,0 +1,234 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/fft_thunk.h"
+
+#include <string>
+
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace se = ::perftools::gputools;
+
+namespace xla {
+namespace gpu {
+
+FftScratchAllocator::FftScratchAllocator(
+ int device_ordinal, DeviceMemoryAllocator* memory_allocator)
+ : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {}
+
+FftScratchAllocator::~FftScratchAllocator() {
+ for (auto& allocated_buffer : allocated_buffers_) {
+ if (!memory_allocator_->Deallocate(device_ordinal_, &allocated_buffer)
+ .ok()) {
+ // The program can still continue with failed deallocation.
+ LOG(ERROR) << "Failed to deallocate the allocated buffer: "
+ << allocated_buffer.opaque();
+ }
+ }
+}
+
+int64 FftScratchAllocator::GetMemoryLimitInBytes(se::Stream* stream) {
+ constexpr int64 kFftScratchSize = 1LL << 32; // 4GB by default.
+ return kFftScratchSize;
+}
+
+se::port::StatusOr<se::DeviceMemory<uint8>> FftScratchAllocator::AllocateBytes(
+ se::Stream* stream, int64 byte_size) {
+ CHECK_GE(byte_size, 0) << "byte_size must be positive.";
+ if (byte_size > GetMemoryLimitInBytes(stream)) {
+ return se::port::Status(
+ se::port::error::RESOURCE_EXHAUSTED,
+ tensorflow::strings::Printf(
+ "Allocating %lld bytes exceeds the memory limit of %lld bytes.",
+ byte_size, GetMemoryLimitInBytes(stream)));
+ }
+
+ auto status_or_memory =
+ memory_allocator_->Allocate(device_ordinal_, byte_size,
+ /*retry_on_failure=*/false);
+ if (!status_or_memory.ok()) {
+ return tensorflow::errors::ResourceExhausted(
+ "Failed to allocate %lld bytes on device %d.", byte_size,
+ device_ordinal_);
+ }
+ se::DeviceMemoryBase allocated_buffer = status_or_memory.ValueOrDie();
+ allocated_buffers_.push_back(allocated_buffer);
+ total_allocated_bytes_ += byte_size;
+ return se::DeviceMemory<uint8>(allocated_buffer);
+}
+
+namespace {
+
+se::fft::Type FftTypeToSeType(FftType type) {
+ switch (type) {
+ case FftType::FFT:
+ return se::fft::Type::kC2CForward;
+ case FftType::IFFT:
+ return se::fft::Type::kC2CInverse;
+ case FftType::IRFFT:
+ return se::fft::Type::kC2R;
+ case FftType::RFFT:
+ return se::fft::Type::kR2C;
+ default:
+ LOG(FATAL) << "unsupported fft type";
+ }
+}
+
+string FftTypeToString(se::fft::Type type) {
+ switch (type) {
+ case se::fft::Type::kC2CForward:
+ return "FFT";
+ case se::fft::Type::kC2CInverse:
+ return "IFFT";
+ case se::fft::Type::kC2R:
+ return "IRFFT";
+ case se::fft::Type::kR2C:
+ return "RFFT";
+ default:
+ LOG(FATAL) << "unknown fft type";
+ }
+}
+
+} // namespace
+
+FftThunk::FftThunk(FftType fft_type,
+ tensorflow::gtl::ArraySlice<int64> fft_length,
+ const BufferAllocation::Slice& input_buffer,
+ const BufferAllocation::Slice& output_buffer,
+ const Shape& input_shape, const Shape& output_shape,
+ const HloInstruction* hlo)
+ : Thunk(Kind::kFft, hlo),
+ fft_type_(FftTypeToSeType(fft_type)),
+ fft_length_(fft_length.begin(), fft_length.end()),
+ scale_factor_(1.0f),
+ input_buffer_(input_buffer),
+ output_buffer_(output_buffer),
+ input_shape_(input_shape),
+ output_shape_(output_shape) {}
+
+tensorflow::Status FftThunk::ExecuteOnStream(
+ const BufferAllocations& buffer_allocations, se::Stream* stream) {
+ VLOG(3) << "FFT type: " << FftTypeToString(fft_type_);
+ VLOG(3) << "Input shape: " << ShapeUtil::HumanStringWithLayout(input_shape_);
+ VLOG(3) << "Output shape: "
+ << ShapeUtil::HumanStringWithLayout(output_shape_);
+
+ FftScratchAllocator scratch_allocator(buffer_allocations.device_ordinal(),
+ buffer_allocations.memory_allocator());
+
+ if (fft_plan_ == nullptr) {
+ const int64 fft_rank = fft_length_.size();
+ CHECK_LE(fft_rank, 3);
+ int batch_size = 1;
+ for (int i = 0; i < input_shape_.dimensions_size() - fft_rank; ++i) {
+ batch_size *= input_shape_.dimensions(i);
+ }
+ uint64 fft_length[3];
+ uint64 input_embed[3];
+ const uint64 input_stride = 1;
+ uint64 input_distance = 1;
+ uint64 output_embed[3];
+ const uint64 output_stride = 1;
+ uint64 output_distance = 1;
+
+ for (int i = 0; i < fft_rank; ++i) {
+ auto dim_offset = input_shape_.dimensions_size() - fft_rank + i;
+ fft_length[i] = static_cast<uint64>(fft_length_[i]);
+ input_embed[i] = input_shape_.dimensions(dim_offset);
+ input_distance *= input_shape_.dimensions(dim_offset);
+ output_embed[i] = output_shape_.dimensions(dim_offset);
+ output_distance *= output_shape_.dimensions(dim_offset);
+ }
+
+ constexpr bool kInPlaceFft = false;
+ fft_plan_ =
+ stream->parent()->AsFft()->CreateBatchedPlanWithScratchAllocator(
+ stream, fft_rank, fft_length, input_embed, input_stride,
+ input_distance, output_embed, output_stride, output_distance,
+ fft_type_, kInPlaceFft, batch_size, &scratch_allocator);
+ scale_factor_ = 1.0f / output_distance;
+ } else {
+ stream->parent()->AsFft()->UpdatePlanWithScratchAllocator(
+ stream, fft_plan_.get(), &scratch_allocator);
+ }
+
+ bool launch_ok;
+ switch (fft_type_) {
+ case se::fft::Type::kC2CForward: {
+ se::DeviceMemory<complex64> input_data(
+ buffer_allocations.GetDeviceAddress(input_buffer_));
+ se::DeviceMemory<complex64> output_data(
+ buffer_allocations.GetDeviceAddress(output_buffer_));
+ launch_ok =
+ stream->ThenFft(fft_plan_.get(), input_data, &output_data).ok();
+ break;
+ }
+ case se::fft::Type::kC2CInverse: {
+ se::DeviceMemory<complex64> input_data(
+ buffer_allocations.GetDeviceAddress(input_buffer_));
+ se::DeviceMemory<complex64> output_data(
+ buffer_allocations.GetDeviceAddress(output_buffer_));
+ launch_ok =
+ stream->ThenFft(fft_plan_.get(), input_data, &output_data).ok();
+ if (launch_ok) {
+ launch_ok =
+ stream
+ ->ThenBlasScal(ShapeUtil::ElementsIn(output_shape_),
+ complex64(scale_factor_), &output_data, 1)
+ .ok();
+ }
+ break;
+ }
+ case se::fft::Type::kR2C: {
+ se::DeviceMemory<float> input_data(
+ buffer_allocations.GetDeviceAddress(input_buffer_));
+ se::DeviceMemory<complex64> output_data(
+ buffer_allocations.GetDeviceAddress(output_buffer_));
+ launch_ok =
+ stream->ThenFft(fft_plan_.get(), input_data, &output_data).ok();
+ break;
+ }
+ case se::fft::Type::kC2R: {
+ se::DeviceMemory<complex64> input_data(
+ buffer_allocations.GetDeviceAddress(input_buffer_));
+ se::DeviceMemory<float> output_data(
+ buffer_allocations.GetDeviceAddress(output_buffer_));
+ launch_ok =
+ stream->ThenFft(fft_plan_.get(), input_data, &output_data).ok();
+ if (launch_ok) {
+ launch_ok = stream
+ ->ThenBlasScal(ShapeUtil::ElementsIn(output_shape_),
+ scale_factor_, &output_data, 1)
+ .ok();
+ }
+ break;
+ }
+ default:
+ LOG(FATAL) << "unsupported fft type";
+ }
+ if (launch_ok) {
+ return tensorflow::Status::OK();
+ }
+ return InternalError("Unable to launch fft for thunk %p with type %s", this,
+ FftTypeToString(fft_type_).c_str());
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.h b/tensorflow/compiler/xla/service/gpu/fft_thunk.h
new file mode 100644
index 0000000000..52fb8c376d
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.h
@@ -0,0 +1,98 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FFT_THUNK_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FFT_THUNK_H_
+
+#include "tensorflow/compiler/xla/service/buffer_assignment.h"
+#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
+#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
+#include "tensorflow/compiler/xla/service/gpu/thunk.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/optional.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace xla {
+namespace gpu {
+
+// A one-time scratch allocator for FFT. The scratch buffers allocated are
+// released on destruction.
+//
+// Not thread-safe in that AllocateBytes, destructor are not locked.
+class FftScratchAllocator : public perftools::gputools::ScratchAllocator {
+ public:
+ FftScratchAllocator(int device_ordinal,
+ DeviceMemoryAllocator* memory_allocator);
+
+ ~FftScratchAllocator() override;
+
+ int64 GetMemoryLimitInBytes(perftools::gputools::Stream* stream) override;
+
+ int64 TotalAllocatedBytes() { return total_allocated_bytes_; }
+
+ perftools::gputools::port::StatusOr<perftools::gputools::DeviceMemory<uint8>>
+ AllocateBytes(perftools::gputools::Stream* stream, int64 byte_size) override;
+
+ private:
+ const int device_ordinal_;
+ DeviceMemoryAllocator* memory_allocator_;
+ std::vector<perftools::gputools::DeviceMemoryBase> allocated_buffers_;
+ int64 total_allocated_bytes_ = 0;
+};
+
+// This class stores everything that StreamExecutor needs to launch an FFT.
+// It is generated by IrEmitter.
+//
+// This is thread-compatible.
+class FftThunk : public Thunk {
+ public:
+ // Constructs a thunk for launching an FFT on a stream.
+ // Semantics of null hlo_instruction argument are as in Thunk.
+ FftThunk(FftType fft_type, tensorflow::gtl::ArraySlice<int64> fft_length,
+ const BufferAllocation::Slice& input_buffer,
+ const BufferAllocation::Slice& output_buffer,
+ const Shape& input_shape, const Shape& output_shape,
+ const HloInstruction* hlo);
+
+ FftThunk(const FftThunk&) = delete; // Cannot share fft_plan_
+ FftThunk& operator=(const FftThunk&) = delete; // Cannot share fft_plan_
+
+ // Does the FFT for the thunk on "stream".
+ tensorflow::Status ExecuteOnStream(
+ const BufferAllocations& buffer_allocations,
+ perftools::gputools::Stream* stream) override;
+
+ private:
+ const perftools::gputools::fft::Type fft_type_;
+ const std::vector<int64> fft_length_;
+
+ float scale_factor_;
+
+ std::unique_ptr<perftools::gputools::fft::Plan> fft_plan_;
+
+ const BufferAllocation::Slice input_buffer_;
+ const BufferAllocation::Slice output_buffer_;
+
+ const Shape input_shape_;
+ const Shape output_shape_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FFT_THUNK_H_
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
index d85d6aae12..d6d0e1e116 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
@@ -605,6 +605,14 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
"Hit a case for convolution that is not implemented on GPU.");
}
+Status IrEmitter::HandleFft(HloInstruction* fft) {
+ if (ShapeUtil::HasZeroElements(fft->shape())) {
+ // Emit no code for an empty output.
+ return Status::OK();
+ }
+ return Unimplemented("Hit a case for fft that is not implemented on GPU.");
+}
+
Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) {
// TODO(b/33011107): Support cross replica sum on GPU.
return Unimplemented(
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
index 41d013c13d..af43895c23 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
@@ -79,6 +79,7 @@ class IrEmitter : public DfsHloVisitorWithDefault {
Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
Status HandleDot(HloInstruction* dot) override;
Status HandleConvolution(HloInstruction* convolution) override;
+ Status HandleFft(HloInstruction* fft) override;
Status HandleCrossReplicaSum(HloInstruction* crs) override;
Status HandleInfeed(HloInstruction* infeed) override;
Status HandleOutfeed(HloInstruction* outfeed) override;
@@ -242,6 +243,7 @@ class IrEmitterUnnested : public IrEmitter {
Status HandleConvolution(HloInstruction* convolution) override;
Status HandleCustomCall(HloInstruction* custom_call) override;
Status HandleDot(HloInstruction* dot) override;
+ Status HandleFft(HloInstruction* fft) override;
Status HandleFusion(HloInstruction* fusion) override;
Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
Status HandleReduce(HloInstruction* reduce) override;
@@ -332,6 +334,9 @@ class IrEmitterUnnested : public IrEmitter {
// Returns a ConvolutionThunk that calls DNN to implement `inst`.
std::unique_ptr<Thunk> BuildConvolutionThunk(const HloInstruction* inst);
+ // Returns a FftThunk that calls cuFFT to implement `inst`.
+ std::unique_ptr<Thunk> BuildFftThunk(const HloInstruction* inst);
+
// Returns a GemmThunk that calls gemm to implement `inst`. The caller needs
// to make sure `inst` outlives the lifetime of the returned Thunk object.
std::unique_ptr<Thunk> BuildGemmThunk(const HloInstruction* inst);
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 1be3473f52..1aa506a3a9 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/copy_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h"
+#include "tensorflow/compiler/xla/service/gpu/fft_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/for_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
@@ -373,6 +374,14 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
return IrEmitter::HandleCustomCall(custom_call);
}
+Status IrEmitterUnnested::HandleFft(HloInstruction* fft) {
+ TF_RET_CHECK(
+ LayoutUtil::IsMonotonicWithDim0Major(fft->operand(0)->shape().layout()));
+ TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(fft->shape().layout()));
+ thunk_sequence_->emplace_back(BuildFftThunk(fft));
+ return Status::OK();
+}
+
Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
HloInstruction* root = fusion->fused_expression_root();
// HandleFusion specializes reduction from a multi-dimensional array to a 1D
@@ -1855,6 +1864,16 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildConvolutionThunk(
}
}
+std::unique_ptr<Thunk> IrEmitterUnnested::BuildFftThunk(
+ const HloInstruction* inst) {
+ const HloInstruction* operand = inst->operand(0);
+ return MakeUnique<FftThunk>(inst->fft_type(), inst->fft_length(),
+ /*input_buffer=*/GetAllocationSlice(*operand),
+ /*output_buffer=*/GetAllocationSlice(*inst),
+ /*input_shape=*/operand->shape(),
+ /*output_shape=*/inst->shape(), inst);
+}
+
Status IrEmitterUnnested::EmitInitializer(const HloInstruction* hlo,
KernelThunk* thunk) {
bool fused = HloOpcode::kFusion == hlo->opcode();
diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h
index 191c7675c6..625c3f8bea 100644
--- a/tensorflow/compiler/xla/service/gpu/thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/thunk.h
@@ -46,6 +46,7 @@ class Thunk {
kCudnnBatchNormBackward,
kCudnnBatchNormForwardInference,
kCudnnBatchNormForwardTraining,
+ kFft,
kGemm,
kInfeed,
kKernel,
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index e4aed7593c..0e9a852788 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -123,6 +123,12 @@ message HloInstructionProto {
// Describes the dimension numbers used for a dot operation
xla.DotDimensionNumbers dot_dimension_numbers = 30;
+
+ // FFT type (FFT, IFFT, etc).
+ xla.FftType fft_type = 31;
+
+ // FFT length.
+ repeated int64 fft_length = 32;
}
// Serialization of HloComputation.
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
index b933695b82..cd54eb74d1 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
@@ -392,6 +392,21 @@ Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) {
return Status::OK();
}
+Status HloCostAnalysis::HandleFft(const HloInstruction* fft) {
+ auto real_shape =
+ ShapeUtil::IsTuple(fft->operand(0)->shape())
+ ? ShapeUtil::GetTupleElementShape(fft->operand(0)->shape(), 0)
+ : fft->operand(0)->shape();
+ constexpr int kFmaPerComplexMul = 4;
+ int64 log_factors = 1;
+ for (int64 dim : fft->fft_length()) {
+ log_factors *= tensorflow::Log2Floor(dim);
+ }
+ current_properties_[kFlopsKey] = kFmaFlops * kFmaPerComplexMul * log_factors *
+ ShapeUtil::ElementsIn(real_shape);
+ return Status::OK();
+}
+
Status HloCostAnalysis::HandleCrossReplicaSum(const HloInstruction* crs) {
// We assume 2 replicas, so that each output element is the sum of two input
// elements.
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
index fade19522c..e5783539e5 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
@@ -67,6 +67,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
Status HandleCopy(const HloInstruction* copy) override;
Status HandleDot(const HloInstruction* dot) override;
Status HandleConvolution(const HloInstruction* convolution) override;
+ Status HandleFft(const HloInstruction* fft) override;
Status HandleCrossReplicaSum(const HloInstruction* crs) override;
Status HandleInfeed(const HloInstruction* infeed) override;
Status HandleOutfeed(const HloInstruction* outfeed) override;
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index 023aec96ec..f7c6435002 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -961,6 +961,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
return kGreen;
case HloOpcode::kConvolution:
case HloOpcode::kDot:
+ case HloOpcode::kFft:
return kDarkBlue;
case HloOpcode::kReducePrecision:
return kRed;
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 9805818e4c..138fb9a190 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -144,6 +144,10 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
instruction->infeed_config_ = proto.infeed_config();
instruction->custom_call_target_ = proto.custom_call_target();
instruction->outfeed_shape_ = proto.outfeed_shape();
+ instruction->fft_type_ = proto.fft_type();
+ for (int64 fft_len : proto.fft_length()) {
+ instruction->fft_length_.push_back(fft_len);
+ }
return std::move(instruction);
}
@@ -334,6 +338,16 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
return instruction;
}
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFft(
+ const Shape& shape, HloInstruction* operand, FftType fft_type,
+ tensorflow::gtl::ArraySlice<int64> fft_length) {
+ auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFft, shape));
+ instruction->AppendOperand(operand);
+ instruction->fft_type_ = fft_type;
+ instruction->fft_length_.assign(fft_length.begin(), fft_length.end());
+ return instruction;
+}
+
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDot(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
const DotDimensionNumbers& dimension_numbers) {
@@ -1167,6 +1181,9 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
clone = CreateDot(shape, new_operands[0], new_operands[1],
*dot_dimension_numbers_);
break;
+ case HloOpcode::kFft:
+ CHECK_EQ(new_operands.size(), 1);
+ return CreateFft(shape, new_operands[0], fft_type_, fft_length_);
case HloOpcode::kCrossReplicaSum:
clone = CreateCrossReplicaSum(shape, new_operands);
break;
@@ -1631,6 +1648,11 @@ bool HloInstruction::IdenticalSlowPath(
return protobuf_util::ProtobufEquals(dot_dimension_numbers(),
other.dot_dimension_numbers());
+ // FFT has various types & lengths.
+ case HloOpcode::kFft:
+ return fft_type() == other.fft_type() &&
+ fft_length() == other.fft_length();
+
// Reduction results are determined by the reduction dimension and the
// reduction computation.
case HloOpcode::kReduce:
@@ -2055,6 +2077,10 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
if (dot_dimension_numbers_ != nullptr) {
extra.push_back(DotDimensionNumbersToString());
}
+ if (opcode() == HloOpcode::kFft) {
+ extra.push_back(StrCat("fft_type=", FftType_Name(fft_type())));
+ extra.push_back(StrCat("fft_length={", Join(fft_length(), ","), "}"));
+ }
if (options.print_subcomputation_references()) {
if (opcode() == HloOpcode::kWhile) {
@@ -2206,6 +2232,10 @@ HloInstructionProto HloInstruction::ToProto() const {
proto.set_infeed_config(infeed_config_);
proto.set_custom_call_target(custom_call_target_);
*proto.mutable_outfeed_shape() = outfeed_shape_;
+ proto.set_fft_type(fft_type_);
+ for (int64 fft_len : fft_length_) {
+ proto.add_fft_length(fft_len);
+ }
return proto;
}
@@ -2421,6 +2451,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
return visitor->HandleSelect(this);
case HloOpcode::kConvolution:
return visitor->HandleConvolution(this);
+ case HloOpcode::kFft:
+ return visitor->HandleFft(this);
case HloOpcode::kCrossReplicaSum:
return visitor->HandleCrossReplicaSum(this);
case HloOpcode::kTuple:
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index d455cfc3f1..c5cab92ac9 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -261,6 +261,11 @@ class HloInstruction {
const Window& window,
const ConvolutionDimensionNumbers& dimension_numbers);
+ // Creates an FFT op, of the type indicated by fft_type.
+ static std::unique_ptr<HloInstruction> CreateFft(
+ const Shape& shape, HloInstruction* operand, FftType fft_type,
+ tensorflow::gtl::ArraySlice<int64> fft_length);
+
// Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch
// dimensions specified in 'dimension_numbers'.
static std::unique_ptr<HloInstruction> CreateDot(
@@ -1031,6 +1036,16 @@ class HloInstruction {
return *convolution_dimension_numbers_;
}
+ FftType fft_type() const {
+ CHECK_EQ(HloOpcode::kFft, opcode_);
+ return fft_type_;
+ }
+
+ const std::vector<int64>& fft_length() const {
+ CHECK_EQ(HloOpcode::kFft, opcode_);
+ return fft_length_;
+ }
+
// Returns the dump string of the convolution dimension numbers.
string ConvolutionDimensionNumbersToString() const;
@@ -1303,6 +1318,12 @@ class HloInstruction {
// Describes the dimension numbers used for a dot.
std::unique_ptr<DotDimensionNumbers> dot_dimension_numbers_;
+ // Describes FFT type for an FFT instruction.
+ FftType fft_type_ = FftType::FFT;
+
+ // Indicates the FFT length for an FFT instruction.
+ std::vector<int64> fft_length_;
+
// Describes the [begin, end) index range for a slice.
std::vector<int64> slice_starts_;
std::vector<int64> slice_limits_;
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h
index f3f7935758..3d64523a79 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.h
+++ b/tensorflow/compiler/xla/service/hlo_opcode.h
@@ -73,6 +73,7 @@ namespace xla {
V(kDynamicUpdateSlice, "dynamic-update-slice") \
V(kEq, "equal-to", kHloOpcodeIsComparison) \
V(kExp, "exponential") \
+ V(kFft, "fft") \
V(kFloor, "floor") \
V(kFusion, "fusion", kHloOpcodeIsVariadic) \
V(kGe, "greater-than-or-equal-to", kHloOpcodeIsComparison) \
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index d963a8a2f4..9d5ca6673a 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -92,6 +92,14 @@ class ShapeVerifier : public DfsHloVisitor {
return CheckShape(convolution, expected);
}
+ Status HandleFft(HloInstruction* fft) override {
+ TF_ASSIGN_OR_RETURN(
+ const Shape expected,
+ ShapeInference::InferFftShape(fft->operand(0)->shape(), fft->fft_type(),
+ fft->fft_length()));
+ return CheckShape(fft, expected);
+ }
+
Status HandleCrossReplicaSum(HloInstruction* crs) override {
std::vector<const Shape*> operand_shapes;
for (const HloInstruction* operand : crs->operands()) {
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index ba901b99e4..90e1f0acdc 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -100,6 +100,7 @@ namespace xla {
case HloOpcode::kDivide:
case HloOpcode::kDot:
case HloOpcode::kExp:
+ case HloOpcode::kFft:
case HloOpcode::kFusion:
case HloOpcode::kLog:
case HloOpcode::kMap:
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index 0b98714168..40613cc75b 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -1406,6 +1406,9 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) {
handle_status = computation->AddDynamicUpdateSliceInstruction(
arg->dynamic_update_slice_request());
break;
+ case OpRequest::kFftRequest:
+ handle_status = computation->AddFftInstruction(arg->fft_request());
+ break;
case OpRequest::kGetTupleElementRequest:
handle_status = computation->AddGetTupleElementInstruction(
arg->get_tuple_element_request());
@@ -1518,8 +1521,6 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) {
handle_status = computation->AddRecvInstruction(arg->recv_request());
break;
}
- case OpRequest::kFftRequest:
- return Unimplemented("FftRequest not implemented in XLA service.");
case OpRequest::OP_NOT_SET:
return InvalidArgument("XLA service received OpRequest with OP_NOT_SET");
default:
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 9c1b951d01..6dc49ffe4c 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -1715,6 +1715,78 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
return ShapeUtil::MakeShape(lhs.element_type(), dimensions);
}
+/* static */ StatusOr<Shape> ShapeInference::InferFftShape(
+ const Shape& in, const FftType fft_type,
+ const tensorflow::gtl::ArraySlice<int64> fft_length) {
+ const int64 fft_rank = fft_length.size();
+ if (fft_rank < 1 || fft_rank > 3) {
+ return InvalidArgument("FFT only supports ranks 1-3, but got %lld",
+ fft_rank);
+ }
+ switch (fft_type) {
+ case FFT:
+ case IFFT:
+ if (in.element_type() != C64) {
+ return InvalidArgument("%s requires C64 input type, found %s",
+ FftType_Name(fft_type).c_str(),
+ PrimitiveType_Name(in.element_type()).c_str());
+ }
+ return in;
+ case RFFT: {
+ if (in.element_type() != F32) {
+ return InvalidArgument("RFFT requires F32 input type, found %s",
+ PrimitiveType_Name(in.element_type()).c_str());
+ }
+ for (int i = 0; i < fft_rank; i++) {
+ if (in.dimensions(in.dimensions_size() - fft_rank + i) !=
+ fft_length[i]) {
+ return InvalidArgument(
+ "RFFT requires innermost dimensions match fft_length but "
+ "dimension %lld is %lld and should be %lld",
+ in.dimensions_size() - fft_rank + i,
+ in.dimensions(in.dimensions_size() - fft_rank + i),
+ fft_length[i]);
+ }
+ }
+ Shape result = ShapeUtil::ChangeElementType(in, C64);
+ result.set_dimensions(result.dimensions_size() - 1,
+ fft_length[fft_rank - 1] / 2 + 1);
+ return result;
+ }
+ case IRFFT: {
+ if (in.element_type() != C64) {
+ return InvalidArgument("IRFFT requires C64 input type, found %s",
+ PrimitiveType_Name(in.element_type()).c_str());
+ }
+ Shape result = ShapeUtil::ChangeElementType(in, F32);
+ for (int i = 0; i < fft_rank - 1; i++) {
+ if (in.dimensions(in.dimensions_size() - fft_rank + i) !=
+ fft_length[i]) {
+ return InvalidArgument(
+ "IRFFT requires all but one innermost dimensions match "
+ "fft_length, but dimension %lld is %lld and should be %lld",
+ in.dimensions_size() - fft_rank + i,
+ in.dimensions(in.dimensions_size() - fft_rank + i),
+ fft_length[i]);
+ }
+ }
+ if (in.dimensions(in.dimensions_size() - 1) !=
+ fft_length[fft_rank - 1] / 2 + 1) {
+ return InvalidArgument(
+ "IRFFT requires innermost dimension matches fft_length/2+1, but "
+ "dimension %d is %lld and should be %lld",
+ in.dimensions_size() - 1, in.dimensions(in.dimensions_size() - 1),
+ fft_length[fft_rank - 1] / 2 + 1);
+ }
+ result.set_dimensions(result.dimensions_size() - 1,
+ fft_length[fft_rank - 1]);
+ return result;
+ }
+ default:
+ LOG(FATAL) << "Unexpected fft_type: " << fft_type;
+ }
+}
+
/* static */ StatusOr<Shape> ShapeInference::InferCrossReplicaSumShape(
tensorflow::gtl::ArraySlice<const Shape*> operand_shapes) {
for (const Shape* operand_shape : operand_shapes) {
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h
index c06340d2d5..b39151ebbc 100644
--- a/tensorflow/compiler/xla/service/shape_inference.h
+++ b/tensorflow/compiler/xla/service/shape_inference.h
@@ -109,6 +109,11 @@ class ShapeInference {
const Shape& lhs, const Shape& rhs, const Window& window,
const ConvolutionDimensionNumbers& dimension_numbers);
+ // Infers the shape produced by the given FFT type on the given operand.
+ static StatusOr<Shape> InferFftShape(
+ const Shape& in, FftType fft_type,
+ tensorflow::gtl::ArraySlice<int64> fft_length);
+
// Infers the shape produced a cross replica sum with the given operand
// shapes.
static StatusOr<Shape> InferCrossReplicaSumShape(
diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc
index e42cbfa976..9d941fb770 100644
--- a/tensorflow/compiler/xla/service/user_computation.cc
+++ b/tensorflow/compiler/xla/service/user_computation.cc
@@ -1121,6 +1121,31 @@ StatusOr<ComputationDataHandle> UserComputation::AddConvolveInstruction(
return handle;
}
+StatusOr<ComputationDataHandle> UserComputation::AddFftInstruction(
+ const FftRequest& fft_request) {
+ tensorflow::mutex_lock lock(mutex_);
+
+ TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
+ LookUpRequest(fft_request.operand()));
+ TF_ASSIGN_OR_RETURN(Shape shape,
+ ShapeInference::InferFftShape(
+ operand->output_shape(), fft_request.fft_type(),
+ AsInt64Slice(fft_request.fft_length())));
+
+ const ComputationDataHandle handle = CreateComputationDataHandle();
+
+ OperationRequest& request =
+ (*session_computation_.mutable_requests())[handle.handle()];
+ *request.mutable_output_handle() = handle;
+ *request.mutable_output_shape() = shape;
+ *request.mutable_request()->mutable_fft_request() = fft_request;
+
+ VLOG(1) << "AddFftInstruction (" << GetVersionedHandleInternal()
+ << "), data handle " << handle.handle() << ": "
+ << fft_request.ShortDebugString();
+ return handle;
+}
+
StatusOr<ComputationDataHandle> UserComputation::AddCrossReplicaSumInstruction(
const CrossReplicaSumRequest& cross_replica_sum_request) {
tensorflow::mutex_lock lock(mutex_);
@@ -1675,6 +1700,13 @@ void PureFunctionalVisitor(const SessionComputation& session_computation,
break;
}
+ case OpRequest::kFftRequest: {
+ const FftRequest& fft_request = request.request().fft_request();
+ PureFunctionalVisitor(session_computation, fft_request.operand(),
+ num_parameters, visited, is_functional);
+ break;
+ }
+
case OpRequest::kCrossReplicaSumRequest: {
// TODO(b/33009255): Implmement constant folding for cross replica sum.
*is_functional = false;
@@ -2406,6 +2438,12 @@ static void ForEachOperand(
break;
}
+ case OpRequest::kFftRequest: {
+ const FftRequest& fft_request = request.request().fft_request();
+ apply(fft_request.operand());
+ break;
+ }
+
case OpRequest::kBatchNormTrainingRequest: {
const BatchNormTrainingRequest& batch_norm_training_request =
request.request().batch_norm_training_request();
@@ -2880,6 +2918,15 @@ void ComputationLowerer::Visit(
break;
}
+ case OpRequest::kFftRequest: {
+ const FftRequest& fft_request = request.request().fft_request();
+ HloInstruction* operand = lookup_instruction(fft_request.operand());
+ hlo_instruction = add_instruction(HloInstruction::CreateFft(
+ request.output_shape(), operand, fft_request.fft_type(),
+ AsInt64Slice(fft_request.fft_length())));
+ break;
+ }
+
case OpRequest::kDotRequest: {
const DotRequest& dot_request = request.request().dot_request();
HloInstruction* lhs = lookup_instruction(dot_request.lhs());
diff --git a/tensorflow/compiler/xla/service/user_computation.h b/tensorflow/compiler/xla/service/user_computation.h
index 8a78d520e1..8be639c784 100644
--- a/tensorflow/compiler/xla/service/user_computation.h
+++ b/tensorflow/compiler/xla/service/user_computation.h
@@ -133,6 +133,10 @@ class UserComputation {
StatusOr<ComputationDataHandle> AddConvolveInstruction(
const ConvolveRequest& convolve_request);
+ // Enqueues an FFT instruction onto this user computation.
+ StatusOr<ComputationDataHandle> AddFftInstruction(
+ const FftRequest& fft_request);
+
// Enqueues a cross replica sum instruction onto this user computation.
StatusOr<ComputationDataHandle> AddCrossReplicaSumInstruction(
const CrossReplicaSumRequest& cross_replica_sum_request);
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
index 58139ee0c1..3caa465769 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
@@ -99,6 +99,7 @@ class HloParser {
kString,
kBracedInt64List,
kHloComputation,
+ kFftType,
kWindow,
kConvolutionDimensionNumbers,
kSharding,
@@ -178,6 +179,7 @@ class HloParser {
bool ParseString(string* result);
bool ParseShape(Shape* result);
bool ParseOpcode(HloOpcode* result);
+ bool ParseFftType(FftType* result);
bool ParseFusionKind(HloInstruction::FusionKind* result);
bool ParseRandomDistribution(RandomDistribution* result);
bool ParseInt64(int64* result);
@@ -685,6 +687,20 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
shape, /*lhs=*/operands[0], /*rhs=*/operands[1], *window, *dnums));
break;
}
+ case HloOpcode::kFft: {
+ optional<FftType> fft_type;
+ optional<std::vector<int64>> fft_length;
+ attrs["fft_type"] = {/*required=*/true, AttrTy::kFftType, &fft_type};
+ attrs["fft_length"] = {/*required=*/true, AttrTy::kBracedInt64List,
+ &fft_length};
+ if (!ParseOperands(&operands, /*expected_size=*/1) ||
+ !ParseAttributes(attrs)) {
+ return false;
+ }
+ instruction = builder->AddInstruction(HloInstruction::CreateFft(
+ shape, operands[0], *fft_type, *fft_length));
+ break;
+ }
case HloOpcode::kBroadcast: {
optional<std::vector<int64>> broadcast_dimensions;
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
@@ -1672,6 +1688,14 @@ bool HloParser::ParseAttributeHelper(
static_cast<optional<HloComputation*>*>(attr_out_ptr)->emplace(result);
return true;
}
+ case AttrTy::kFftType: {
+ FftType result;
+ if (!ParseFftType(&result)) {
+ return false;
+ }
+ static_cast<optional<FftType>*>(attr_out_ptr)->emplace(result);
+ return true;
+ }
case AttrTy::kWindow: {
Window result;
if (!ParseWindow(&result)) {
@@ -2289,6 +2313,19 @@ bool HloParser::ParseOpcode(HloOpcode* result) {
return true;
}
+bool HloParser::ParseFftType(FftType* result) {
+ VLOG(1) << "ParseFftType";
+ if (lexer_.GetKind() != TokKind::kIdent) {
+ return TokenError("expects fft type");
+ }
+ string val = lexer_.GetStrVal();
+ if (!FftType_Parse(val, result) || !FftType_IsValid(*result)) {
+ return TokenError(Printf("expects fft type but sees: %s", val.c_str()));
+ }
+ lexer_.Lex();
+ return true;
+}
+
bool HloParser::ParseFusionKind(HloInstruction::FusionKind* result) {
VLOG(1) << "ParseFusionKind";
if (lexer_.GetKind() != TokKind::kIdent) {
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
index 69c59ad6c7..ce11d2b43d 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
@@ -582,6 +582,54 @@ ENTRY %BatchNormGrad.v4 (input: f32[2,2,2,2], scale: f32[2], mean: f32[2], varia
)"
},
+// fft
+{
+"Fft",
+R"(HloModule Fft_module
+
+ENTRY %Fft (input: c64[8,32]) -> c64[8,32] {
+ %input = c64[8,32]{1,0} parameter(0)
+ ROOT %fft = c64[8,32]{1,0} fft(c64[8,32]{1,0} %input), fft_type=FFT, fft_length={32}
+}
+
+)"
+},
+// ifft
+{
+"Ifft2d",
+R"(HloModule Ifft2d_module
+
+ENTRY %Ifft2d (input: c64[5,8,32]) -> c64[5,8,32] {
+ %input = c64[5,8,32]{2,1,0} parameter(0)
+ ROOT %fft = c64[5,8,32]{2,1,0} fft(c64[5,8,32]{2,1,0} %input), fft_type=IFFT, fft_length={8,32}
+}
+
+)"
+},
+// rfft2d
+{
+"Rfft2d",
+R"(HloModule Rfft2d_module
+
+ENTRY %Rfft2d (input: f32[5,64,32]) -> c64[5,64,17] {
+ %input = f32[5,64,32]{2,1,0} parameter(0)
+ ROOT %fft = c64[5,64,17]{2,1,0} fft(f32[5,64,32]{2,1,0} %input), fft_type=RFFT, fft_length={64,32}
+}
+
+)"
+},
+// irfft3d
+{
+"Irfft3d",
+R"(HloModule Irfft3d_module
+
+ENTRY %Irfft3d (input: c64[5,64,128,33]) -> f32[5,64,128,64] {
+ %input = c64[5,64,128,33]{3,2,1,0} parameter(0)
+ ROOT %fft = f32[5,64,128,64]{3,2,1,0} fft(c64[5,64,128,33]{3,2,1,0} %input), fft_type=IRFFT, fft_length={64,128,64}
+}
+
+)"
+},
// pad
{
"Pad",