diff options
author | 2018-01-02 13:01:59 -0800 | |
---|---|---|
committer | 2018-01-02 13:05:54 -0800 | |
commit | 5bf26acd87d3d44183fc28cb9576cda10c0255ca (patch) | |
tree | 5486c0ab9077496e30b594d35c089a45f7b7e18c | |
parent | 97a843db78745fe3a8e418b3b1e93ef79fbfff12 (diff) |
Automated g4 rollback of changelist 180000981
PiperOrigin-RevId: 180581912
44 files changed, 1532 insertions, 7 deletions
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 5732e95575..78777f3c96 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -241,6 +241,21 @@ tf_xla_py_test( ) tf_xla_py_test( + name = "fft_test", + size = "medium", + srcs = ["fft_test.py"], + shard_count = 3, + tags = ["optonly"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform_test", + "//tensorflow/python:spectral_ops", + ], +) + +tf_xla_py_test( name = "slice_ops_test", size = "small", srcs = ["slice_ops_test.py"], diff --git a/tensorflow/compiler/tests/fft_test.py b/tensorflow/compiler/tests/fft_test.py new file mode 100644 index 0000000000..bdc38be48c --- /dev/null +++ b/tensorflow/compiler/tests/fft_test.py @@ -0,0 +1,179 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for FFT via the XLA JIT.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import spectral_ops +from tensorflow.python.platform import googletest + +BATCH_DIMS = (3, 5) +RTOL = 0.02 # Eigen/cuFFT differ widely from np, especially for FFT3D +ATOL = 1e-3 + + +def pick_10(x): + x = list(x) + np.random.seed(123) + np.random.shuffle(x) + return x[:10] + + +def to_32bit(x): + if x.dtype == np.complex128: + return x.astype(np.complex64) + if x.dtype == np.float64: + return x.astype(np.float32) + return x + + +POWS_OF_2 = 2**np.arange(3, 12) +INNER_DIMS_1D = list((x,) for x in POWS_OF_2) +POWS_OF_2 = 2**np.arange(3, 8) # To avoid OOM on GPU. +INNER_DIMS_2D = pick_10(itertools.product(POWS_OF_2, POWS_OF_2)) +INNER_DIMS_3D = pick_10(itertools.product(POWS_OF_2, POWS_OF_2, POWS_OF_2)) + + +class FFTTest(XLATestCase): + + def _VerifyFftMethod(self, inner_dims, complex_to_input, input_to_expected, + tf_method): + for indims in inner_dims: + print("nfft =", indims) + shape = BATCH_DIMS + indims + data = np.arange(np.prod(shape) * 2) / np.prod(indims) + np.random.seed(123) + np.random.shuffle(data) + data = np.reshape(data.astype(np.float32).view(np.complex64), shape) + data = to_32bit(complex_to_input(data)) + expected = to_32bit(input_to_expected(data)) + with self.test_session() as sess: + with self.test_scope(): + ph = array_ops.placeholder( + dtypes.as_dtype(data.dtype), shape=data.shape) + out = tf_method(ph) + value = sess.run(out, {ph: data}) + self.assertAllClose(expected, value, rtol=RTOL, atol=ATOL) + + def testFFT(self): + self._VerifyFftMethod(INNER_DIMS_1D, lambda x: x, np.fft.fft, + spectral_ops.fft) + + def testFFT2D(self): + self._VerifyFftMethod(INNER_DIMS_2D, lambda x: x, np.fft.fft2, + spectral_ops.fft2d) + + def testFFT3D(self): + self._VerifyFftMethod(INNER_DIMS_3D, lambda x: x, + lambda x: np.fft.fftn(x, axes=(-3, -2, -1)), + spectral_ops.fft3d) + + def testIFFT(self): + self._VerifyFftMethod(INNER_DIMS_1D, lambda x: x, np.fft.ifft, + spectral_ops.ifft) + + def testIFFT2D(self): + self._VerifyFftMethod(INNER_DIMS_2D, lambda x: x, np.fft.ifft2, + spectral_ops.ifft2d) + + def testIFFT3D(self): + self._VerifyFftMethod(INNER_DIMS_3D, lambda x: x, + lambda x: np.fft.ifftn(x, axes=(-3, -2, -1)), + spectral_ops.ifft3d) + + def testRFFT(self): + self._VerifyFftMethod( + INNER_DIMS_1D, np.real, lambda x: np.fft.rfft(x, n=x.shape[-1]), + lambda x: spectral_ops.rfft(x, fft_length=[x.shape[-1].value])) + + def testRFFT2D(self): + + def _tf_fn(x): + return spectral_ops.rfft2d( + x, fft_length=[x.shape[-2].value, x.shape[-1].value]) + + self._VerifyFftMethod( + INNER_DIMS_2D, np.real, + lambda x: np.fft.rfft2(x, s=[x.shape[-2], x.shape[-1]]), _tf_fn) + + def testRFFT3D(self): + + def _to_expected(x): + return np.fft.rfftn( + x, axes=(-3, -2, -1), s=[x.shape[-3], x.shape[-2], x.shape[-1]]) + + def _tf_fn(x): + return spectral_ops.rfft3d( + x, + fft_length=[x.shape[-3].value, x.shape[-2].value, x.shape[-1].value]) + + self._VerifyFftMethod(INNER_DIMS_3D, np.real, _to_expected, _tf_fn) + + def testIRFFT(self): + + def _tf_fn(x): + return spectral_ops.irfft(x, fft_length=[2 * (x.shape[-1].value - 1)]) + + self._VerifyFftMethod( + INNER_DIMS_1D, lambda x: np.fft.rfft(np.real(x), n=x.shape[-1]), + lambda x: np.fft.irfft(x, n=2 * (x.shape[-1] - 1)), _tf_fn) + + def testIRFFT2D(self): + + def _tf_fn(x): + return spectral_ops.irfft2d( + x, fft_length=[x.shape[-2].value, 2 * (x.shape[-1].value - 1)]) + + self._VerifyFftMethod( + INNER_DIMS_2D, + lambda x: np.fft.rfft2(np.real(x), s=[x.shape[-2], x.shape[-1]]), + lambda x: np.fft.irfft2(x, s=[x.shape[-2], 2 * (x.shape[-1] - 1)]), + _tf_fn) + + def testIRFFT3D(self): + + def _to_input(x): + return np.fft.rfftn( + np.real(x), + axes=(-3, -2, -1), + s=[x.shape[-3], x.shape[-2], x.shape[-1]]) + + def _to_expected(x): + return np.fft.irfftn( + x, + axes=(-3, -2, -1), + s=[x.shape[-3], x.shape[-2], 2 * (x.shape[-1] - 1)]) + + def _tf_fn(x): + return spectral_ops.irfft3d( + x, + fft_length=[ + x.shape[-3].value, x.shape[-2].value, 2 * (x.shape[-1].value - 1) + ]) + + self._VerifyFftMethod(INNER_DIMS_3D, _to_input, _to_expected, _tf_fn) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index 2ab013b7fa..e72dd4eea9 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -93,11 +93,11 @@ class OpTestBuilder { public: explicit OpTestBuilder(const string& op_name); - // Adds an input 'tensor'. + // Adds an input 'tensor' as a Placeholder node. OpTestBuilder& Input(const Tensor& tensor); - // Adds a random input tensor with 'type'. If 'dims' is not provided, - // RandomDims() is used. + // Adds a random input tensor with 'type' as a Placeholder node. + // If 'dims' is not provided, RandomDims() is used. OpTestBuilder& RandomInput(DataType type); OpTestBuilder& RandomInput(DataType type, std::vector<int64> dims); @@ -1375,6 +1375,121 @@ TEST_F(OpTest, Conj) { }); } +TEST_F(OpTest, FFT) { + Repeatedly([this]() { + std::vector<int64> dims = RandomDims(1, kDefaultMaxRank); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("FFT").RandomInput(DT_COMPLEX64, dims)); + }); +} + +TEST_F(OpTest, FFT2D) { + Repeatedly([this]() { + std::vector<int64> dims = RandomDims(2, kDefaultMaxRank); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("FFT2D").RandomInput(DT_COMPLEX64, dims)); + }); +} + +TEST_F(OpTest, FFT3D) { + Repeatedly([this]() { + std::vector<int64> dims = RandomDims(3, kDefaultMaxRank); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("FFT3D").RandomInput(DT_COMPLEX64, dims)); + }); +} + +TEST_F(OpTest, IFFT) { + Repeatedly([this]() { + std::vector<int64> dims = RandomDims(1, kDefaultMaxRank); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("IFFT").RandomInput(DT_COMPLEX64, dims)); + }); +} + +TEST_F(OpTest, IFFT2D) { + Repeatedly([this]() { + std::vector<int64> dims = RandomDims(2, kDefaultMaxRank); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("IFFT2D").RandomInput(DT_COMPLEX64, dims)); + }); +} + +TEST_F(OpTest, IFFT3D) { + Repeatedly([this]() { + std::vector<int64> dims = RandomDims(3, kDefaultMaxRank); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("IFFT3D").RandomInput(DT_COMPLEX64, dims)); + }); +} + +TEST_F(OpTest, RFFT) { + Repeatedly([this]() { + std::vector<int64> dims = RandomDims(1, kDefaultMaxRank, 3); + Tensor fft_shape = test::AsTensor<int32>(AsInt32s({dims[dims.size() - 1]})); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("RFFT").RandomInput(DT_FLOAT, dims).Input(fft_shape)); + }); +} + +TEST_F(OpTest, RFFT2D) { + Repeatedly([this]() { + std::vector<int64> dims = RandomDims(2, kDefaultMaxRank, 3); + Tensor fft_shape = test::AsTensor<int32>( + AsInt32s({dims[dims.size() - 2], dims[dims.size() - 1]})); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("RFFT2D").RandomInput(DT_FLOAT, dims).Input(fft_shape)); + }); +} + +TEST_F(OpTest, RFFT3D) { + Repeatedly([this]() { + std::vector<int64> dims = RandomDims(3, kDefaultMaxRank, 3); + Tensor fft_shape = test::AsTensor<int32>(AsInt32s( + {dims[dims.size() - 3], dims[dims.size() - 2], dims[dims.size() - 1]})); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("RFFT3D").RandomInput(DT_FLOAT, dims).Input(fft_shape)); + }); +} + +TEST_F(OpTest, IRFFT) { + Repeatedly([this]() { + std::vector<int64> dims = RandomDims(1, kDefaultMaxRank, 3); + int64 orig_size = dims[dims.size() - 1]; + dims[dims.size() - 1] = dims[dims.size() - 1] / 2 + 1; + Tensor fft_shape = test::AsTensor<int32>(AsInt32s({orig_size})); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("IRFFT") + .RandomInput(DT_COMPLEX64, dims) + .Input(fft_shape)); + }); +} + +TEST_F(OpTest, IRFFT2D) { + Repeatedly([this]() { + std::vector<int64> dims = RandomDims(2, kDefaultMaxRank, 3); + std::vector<int64> orig_size = {dims[dims.size() - 2], + dims[dims.size() - 1]}; + dims[dims.size() - 1] = dims[dims.size() - 1] / 2 + 1; + Tensor fft_shape = test::AsTensor<int32>(AsInt32s({orig_size})); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("IRFFT2D") + .RandomInput(DT_COMPLEX64, dims) + .Input(fft_shape)); + }); +} + +TEST_F(OpTest, IRFFT3D) { + Repeatedly([this]() { + std::vector<int64> dims = RandomDims(3, kDefaultMaxRank, 3); + std::vector<int64> orig_size = { + dims[dims.size() - 3], dims[dims.size() - 2], dims[dims.size() - 1]}; + dims[dims.size() - 1] = dims[dims.size() - 1] / 2 + 1; + Tensor fft_shape = test::AsTensor<int32>(AsInt32s({orig_size})); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("IRFFT3D") + .RandomInput(DT_COMPLEX64, dims) + .Input(fft_shape)); + }); +} + TEST_F(OpTest, Conv2D) { Repeatedly([this]() { WindowedSpatialDims d = ChooseWindowedSpatialDims(2); diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 092d852fe3..0dd95de0ce 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -30,6 +30,7 @@ tf_kernel_library( "diag_op.cc", "dynamic_stitch_op.cc", "elu_op.cc", + "fft_ops.cc", "fill_op.cc", "function_ops.cc", "gather_op.cc", @@ -105,6 +106,7 @@ tf_kernel_library( "//tensorflow/core:lib", "//tensorflow/core:linalg_ops_op_lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:spectral_ops_op_lib", "//tensorflow/core:stateless_random_ops_op_lib", "//tensorflow/core/kernels:bounds_check", "//tensorflow/core/kernels:concat_lib", diff --git a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc new file mode 100644 index 0000000000..a4f3c1c3ad --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc @@ -0,0 +1,122 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA-specific Ops for FFT. + +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_slice.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/kernels/conv_grad_ops.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { + +namespace { + +using xla::FftType; + +class GenericFftOp : public XlaOpKernel { + public: + explicit GenericFftOp(OpKernelConstruction* ctx, FftType fft_type, + int fft_rank) + : XlaOpKernel(ctx), fft_type_(fft_type), fft_rank_(fft_rank) {} + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape input_shape = ctx->InputShape(0); + OP_REQUIRES( + ctx, TensorShapeUtils::IsVectorOrHigher(input_shape), + errors::InvalidArgument("input must be at least 1 dimensional")); + + std::vector<int64> fft_length; + if (fft_type_ == FftType::RFFT || fft_type_ == FftType::IRFFT) { + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &fft_length)); + OP_REQUIRES(ctx, fft_length.size() == fft_rank_, + errors::InvalidArgument("fft_length must be length ", + fft_rank_, " vector")); + } else { + // Innermost axis provides the FFT length. + for (int i = 0; i < fft_rank_; i++) { + fft_length.push_back( + input_shape.dim_size(input_shape.dims() - fft_rank_ + i)); + } + } + + xla::ComputationBuilder* b = ctx->builder(); + xla::ComputationDataHandle fft = + b->Fft(ctx->Input(0), fft_type_, fft_length); + ctx->SetOutput(0, fft); + } + + protected: + const FftType fft_type_; + const int fft_rank_; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(GenericFftOp); +}; + +template <int FFTRank> +class FFTOp : public GenericFftOp { + public: + explicit FFTOp(OpKernelConstruction* ctx) + : GenericFftOp(ctx, /*fft_type=*/FftType::FFT, /*fft_rank=*/FFTRank) {} +}; +REGISTER_XLA_OP(Name("FFT"), FFTOp<1>); +REGISTER_XLA_OP(Name("FFT2D"), FFTOp<2>); +REGISTER_XLA_OP(Name("FFT3D"), FFTOp<3>); + +template <int FFTRank> +class IFFTOp : public GenericFftOp { + public: + explicit IFFTOp(OpKernelConstruction* ctx) + : GenericFftOp(ctx, /*fft_type=*/FftType::IFFT, /*fft_rank=*/FFTRank) {} +}; +REGISTER_XLA_OP(Name("IFFT"), IFFTOp<1>); +REGISTER_XLA_OP(Name("IFFT2D"), IFFTOp<2>); +REGISTER_XLA_OP(Name("IFFT3D"), IFFTOp<3>); + +template <int FFTRank> +class RFFTOp : public GenericFftOp { + public: + explicit RFFTOp(OpKernelConstruction* ctx) + : GenericFftOp(ctx, /*fft_type=*/FftType::RFFT, /*fft_rank=*/FFTRank) {} +}; +REGISTER_XLA_OP(Name("RFFT").CompileTimeConstInput("fft_length"), RFFTOp<1>); +REGISTER_XLA_OP(Name("RFFT2D").CompileTimeConstInput("fft_length"), RFFTOp<2>); +REGISTER_XLA_OP(Name("RFFT3D").CompileTimeConstInput("fft_length"), RFFTOp<3>); + +template <int FFTRank> +class IRFFTOp : public GenericFftOp { + public: + explicit IRFFTOp(OpKernelConstruction* ctx) + : GenericFftOp(ctx, /*fft_type=*/FftType::IRFFT, /*fft_rank=*/FFTRank) {} +}; +REGISTER_XLA_OP(Name("IRFFT").CompileTimeConstInput("fft_length"), IRFFTOp<1>); +REGISTER_XLA_OP(Name("IRFFT2D").CompileTimeConstInput("fft_length"), + IRFFTOp<2>); +REGISTER_XLA_OP(Name("IRFFT3D").CompileTimeConstInput("fft_length"), + IRFFTOp<3>); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index 97bb100fb1..0dde6a986c 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -161,7 +161,14 @@ void XlaOpRegistry::RegisterCompilationKernels() { const string& op_name = op.first; const std::unique_ptr<OpRegistration>& op_registration = op.second; const OpDef* op_def; - TF_CHECK_OK(op_registry->LookUpOpDef(op_name, &op_def)); + Status lookup_status = op_registry->LookUpOpDef(op_name, &op_def); + if (!lookup_status.ok()) { + LOG(ERROR) << lookup_status.error_message(); + XLA_LOG_LINES( + ERROR, "Ops registered: \n" + + dynamic_cast<OpRegistry*>(op_registry)->DebugString(true)); + } + TF_CHECK_OK(lookup_status); std::unordered_set<string> type_attrs; for (const OpDef::AttrDef& attr_def : op_def->attr()) { diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc index 317dcb4e41..1c0669c1d4 100644 --- a/tensorflow/compiler/xla/client/computation_builder.cc +++ b/tensorflow/compiler/xla/client/computation_builder.cc @@ -855,6 +855,31 @@ ComputationDataHandle ComputationBuilder::ConvGeneralDilated( return ParseOpResponse(s, &response); } +ComputationDataHandle ComputationBuilder::Fft( + const ComputationDataHandle& operand, const FftType fft_type, + const tensorflow::gtl::ArraySlice<int64> fft_length) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + + FftRequest request; + *request.mutable_operand() = operand; + request.set_fft_type(fft_type); + for (int64 dim_len : fft_length) { + request.add_fft_length(dim_len); + } + OpRequest op_request; + *op_request.mutable_computation() = computation_.handle(); + *op_request.mutable_fft_request() = request; + AddCommonFieldsToOpRequest(&op_request); + OpResponse response; + + VLOG(2) << "making fft op request"; + Status s = client_->stub()->Op(&op_request, &response); + + return ParseOpResponse(s, &response); +} + ComputationDataHandle ComputationBuilder::Infeed(const Shape& shape, const string& config) { if (!first_error_.ok() || !PrepareComputation().ok()) { diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index 7293b35c0f..1a3a54d91e 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -410,6 +410,12 @@ class ComputationBuilder { tensorflow::gtl::ArraySlice<int64> rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers); + // Enqueues an FFT instruction onto the computation, of the given type and + // with the given FFT length. + ComputationDataHandle Fft(const ComputationDataHandle& operand, + FftType fft_type, + tensorflow::gtl::ArraySlice<int64> fft_length); + // Enqueues an infeed instruction onto the computation, which writes data of // the given shape to the infeed buffer of the device. ComputationDataHandle Infeed(const Shape& shape, const string& config = ""); diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index e35d947525..9754f8d9af 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -165,6 +165,7 @@ cc_library( ":external_constant_pool", ":orc_jit_memory_mapper", ":runtime_conv2d", + ":runtime_fft", ":runtime_fork_join", ":runtime_matmul", ":runtime_single_threaded_conv2d", @@ -504,6 +505,24 @@ cc_library( ) cc_library( + name = "runtime_fft", + srcs = [ + "runtime_fft.cc", + "runtime_fft_impl.h", + ], + hdrs = ["runtime_fft.h"], + copts = runtime_copts(), + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:framework", + "//tensorflow/core:framework_lite", + "//third_party/eigen3", + ], +) + +cc_library( name = "runtime_matvec", srcs = ["runtime_matvec.cc"], hdrs = ["runtime_matvec.h"], diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index ba208d7249..9a5e146b5a 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -705,6 +705,8 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend( ir_module_string = llvm_ir::DumpModuleToString(*llvm_module); } + XLA_VLOG_LINES(2, "LLVM IR:\n" + llvm_ir::DumpModuleToString(*llvm_module)); + // JIT compile the LLVM IR module to in-memory machine code. jit->AddModule(std::move(llvm_module)); cpu_executable.reset(new CpuExecutable( diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index 7908dc173d..1ef45dbec3 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -37,6 +37,7 @@ extern const char* const kEigenMatMulF64SymbolName = "__xla_cpu_runtime_EigenMatMulF64"; extern const char* const kEigenConvF32SymbolName = "__xla_cpu_runtime_EigenConvF32"; +extern const char* const kEigenFftSymbolName = "__xla_cpu_runtime_EigenFft"; extern const char* const kEigenSingleThreadedMatMulF32SymbolName = "__xla_cpu_runtime_EigenSingleThreadedMatMulF32"; extern const char* const kEigenSingleThreadedMatMulF64SymbolName = diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h index 2ade455b8a..3e1f080711 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h @@ -44,6 +44,7 @@ namespace runtime { extern const char* const kEigenMatMulF32SymbolName; extern const char* const kEigenMatMulF64SymbolName; extern const char* const kEigenConvF32SymbolName; +extern const char* const kEigenFftSymbolName; extern const char* const kEigenSingleThreadedMatMulF32SymbolName; extern const char* const kEigenSingleThreadedMatMulF64SymbolName; extern const char* const kEigenSingleThreadedConvF32SymbolName; diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 4165f920d2..26bd9ad326 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -1135,6 +1135,55 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { }); } +Status IrEmitter::HandleFft(HloInstruction* fft) { + auto operand = fft->operand(0); + TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( + /*instruction=*/*fft, /*operands=*/{operand}, + /*supported_types=*/{F32, C64})); + TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(operand->shape().layout())); + TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(fft->shape().layout())); + VLOG(3) << "operand=" << ShapeUtil::HumanStringWithLayout(operand->shape()); + VLOG(3) << "fft=" << ShapeUtil::HumanStringWithLayout(fft->shape()); + + llvm::Value* operand_address = GetEmittedValueFor(operand); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fft)); + + const std::vector<int64>& fft_length = fft->fft_length(); + int64 input_batch = 1; + for (int i = 0; i < fft->shape().dimensions_size() - fft_length.size(); i++) { + input_batch *= fft->shape().dimensions(i); + } + + // Args have been computed, make the call. + llvm::Type* int8_ptr_type = ir_builder_.getInt8Ty()->getPointerTo(); + llvm::Type* int32_type = ir_builder_.getInt32Ty(); + llvm::Type* int64_type = ir_builder_.getInt64Ty(); + llvm::FunctionType* fft_type = llvm::FunctionType::get( + ir_builder_.getVoidTy(), + {int8_ptr_type, int8_ptr_type, int8_ptr_type, int32_type, int32_type, + int64_type, int64_type, int64_type, int64_type}, + /*isVarArg=*/false); + const char* fn_name = runtime::kEigenFftSymbolName; + llvm::Function* fft_func = llvm::cast<llvm::Function>( + module_->getOrInsertFunction(fn_name, fft_type)); + fft_func->setCallingConv(llvm::CallingConv::C); + fft_func->setDoesNotThrow(); + fft_func->setOnlyAccessesInaccessibleMemOrArgMem(); + const int fft_rank = fft_length.size(); + ir_builder_.CreateCall( + fft_func, + {GetExecutableRunOptionsArgument(), + ir_builder_.CreateBitCast(GetEmittedValueFor(fft), int8_ptr_type), + ir_builder_.CreateBitCast(operand_address, int8_ptr_type), + ir_builder_.getInt32(fft->fft_type()), ir_builder_.getInt32(fft_rank), + ir_builder_.getInt64(input_batch), + ir_builder_.getInt64(fft_rank > 0 ? fft_length[0] : 0), + ir_builder_.getInt64(fft_rank > 1 ? fft_length[1] : 0), + ir_builder_.getInt64(fft_rank > 2 ? fft_length[2] : 0)}); + + return Status::OK(); +} + Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) { if (hlo_module_config_.replica_count() == 1) { // When there is a single replica, a cross replica sum is the identity diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 2341e3ea72..b8d71eba18 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -158,6 +158,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleSelect(HloInstruction* select) override; Status HandleDot(HloInstruction* dot) override; Status HandleConvolution(HloInstruction* convolution) override; + Status HandleFft(HloInstruction* fft) override; Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override; Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override; Status HandleCrossReplicaSum(HloInstruction* crs) override; diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index 4b44ac8941..deb21bf4ef 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -126,7 +126,7 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount( HloInstruction* instruction) { // Currently, we do not assign parallel tasks to instructions with at least // one of the following properties: - // *) Internal threading (library calls to kConv, kDot, and kCustomCall). + // *) Internal threading (library calls to kConv, kDot, kFft, kCustomCall). // *) Emit custom loops (kSelectAndScatter, FusionKind::kTransposeDot). // *) Tuple-shaped. // TODO(b/27458679) Parallelize instructions which are skipped here. @@ -137,6 +137,7 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount( instruction->opcode() == HloOpcode::kSelectAndScatter || instruction->opcode() == HloOpcode::kGetTupleElement || instruction->opcode() == HloOpcode::kBitcast || + instruction->opcode() == HloOpcode::kFft || (instruction->opcode() == HloOpcode::kConvolution && PotentiallyImplementedAsEigenConvolution(*instruction)) || PotentiallyImplementedAsEigenDot(*instruction) || diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fft.cc b/tensorflow/compiler/xla/service/cpu/runtime_fft.cc new file mode 100644 index 0000000000..848d2d2241 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_fft.cc @@ -0,0 +1,37 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_fft.h" + +#define EIGEN_USE_THREADS + +#include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h" +#include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/types.h" + +using tensorflow::int32; +using tensorflow::int64; + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenFft( + const void* run_options_ptr, void* out, void* operand, int32 fft_type, + int32 fft_rank, int64 input_batch, int64 fft_length0, int64 fft_length1, + int64 fft_length2) { + const xla::ExecutableRunOptions* run_options = + static_cast<const xla::ExecutableRunOptions*>(run_options_ptr); + tensorflow::xla::EigenFftImpl(*run_options->intra_op_thread_pool(), out, + operand, fft_type, fft_rank, input_batch, + fft_length0, fft_length1, fft_length2); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fft.h b/tensorflow/compiler/xla/service/cpu/runtime_fft.h new file mode 100644 index 0000000000..f20c5aa0aa --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_fft.h @@ -0,0 +1,31 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_H_ + +#include "tensorflow/core/platform/types.h" + +extern "C" { + +extern void __xla_cpu_runtime_EigenFft( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, void* out, + void* operand, tensorflow::int32 fft_type, tensorflow::int32 fft_rank, + tensorflow::int64 input_batch, tensorflow::int64 fft_length0, + tensorflow::int64 fft_length1, tensorflow::int64 fft_length2); + +} // extern "C" + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_H_ diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h b/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h new file mode 100644 index 0000000000..c7c701ca97 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h @@ -0,0 +1,238 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_IMPL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_IMPL_H_ + +#include "third_party/eigen3/Eigen/Core" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/numeric_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/types.h" + +// 'tensorflow' namespace is used so that int64 and other types don't require +// qualification. +namespace tensorflow { +namespace xla { + +namespace internal { + +// Computes either a forward or reverse complex-to-complex FFT. +template <bool Forward, int FFTRank, typename EigenDevice> +void EigenFftC2C(const EigenDevice& device, complex64* out, complex64* operand, + int64 input_batch, int64 fft_length0, int64 fft_length1, + int64 fft_length2) { + // Create the axes (which are always trailing). + const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank); + constexpr auto direction = Forward ? Eigen::FFT_FORWARD : Eigen::FFT_REVERSE; + + const std::array<int64, 3> fft_shape = { + {fft_length0, fft_length1, fft_length2}}; + + Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> dims; + dims[0] = input_batch; + for (int i = 0; i < FFTRank; i++) { + dims[i + 1] = fft_shape[i]; + } + const Eigen::TensorMap<Eigen::Tensor<complex64, FFTRank + 1, Eigen::RowMajor>, + Eigen::Aligned> + input(operand, dims); + Eigen::TensorMap<Eigen::Tensor<complex64, FFTRank + 1, Eigen::RowMajor>, + Eigen::Aligned> + output(out, dims); + output.device(device) = input.template fft<Eigen::BothParts, direction>(axes); +} + +// Computes a forward real->complex FFT, slicing out redundant negative +// frequencies from the innermost dimension. +template <int FFTRank, typename EigenDevice> +void EigenFftR2C(const EigenDevice& device, complex64* out, float* operand, + int64 input_batch, int64 fft_length0, int64 fft_length1, + int64 fft_length2) { + const std::array<int64, 3> fft_shape = { + {fft_length0, fft_length1, fft_length2}}; + + Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> in_dims; + in_dims[0] = input_batch; + Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> out_dims; + out_dims[0] = input_batch; + TensorShape temp_shape{input_batch}; + for (int i = 0; i < FFTRank; i++) { + in_dims[i + 1] = fft_shape[i]; + out_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i]; + temp_shape.AddDim(fft_shape[i]); + } + const Eigen::TensorMap<Eigen::Tensor<float, FFTRank + 1, Eigen::RowMajor>, + Eigen::Aligned> + input(operand, in_dims); + Eigen::TensorMap<Eigen::Tensor<complex64, FFTRank + 1, Eigen::RowMajor>, + Eigen::Aligned> + output(out, out_dims); + + // Create the axes (which are always trailing). + const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank); + + // Compute the full FFT using a temporary tensor. + Tensor temp(DataTypeToEnum<complex64>::v(), temp_shape); + auto full_fft = temp.flat_inner_dims<complex64, FFTRank + 1>(); + const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> zero_start_indices; + full_fft.device(device) = + input.template fft<Eigen::BothParts, Eigen::FFT_FORWARD>(axes); + + // Slice away the negative frequency components. + output.device(device) = full_fft.slice(zero_start_indices, out_dims); +} + +// Computes a reverse complex->real FFT, reconstructing redundant negative +// frequencies using reverse conjugate on innermost dimension after doing IFFT +// on outer dimensions. +template <int FFTRank, typename EigenDevice> +void EigenFftC2R(const EigenDevice& device, float* out, complex64* operand, + int64 input_batch, int64 fft_length0, int64 fft_length1, + int64 fft_length2) { + const std::array<int64, 3> fft_shape = { + {fft_length0, fft_length1, fft_length2}}; + + Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> in_dims; + in_dims[0] = input_batch; + Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> out_dims; + out_dims[0] = input_batch; + TensorShape temp_shape{input_batch}; + for (int i = 0; i < FFTRank; i++) { + in_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i]; + out_dims[i + 1] = fft_shape[i]; + temp_shape.AddDim(fft_shape[i]); + } + const Eigen::TensorMap<Eigen::Tensor<complex64, FFTRank + 1, Eigen::RowMajor>, + Eigen::Aligned> + input(operand, in_dims); + Eigen::TensorMap<Eigen::Tensor<float, FFTRank + 1, Eigen::RowMajor>, + Eigen::Aligned> + output(out, out_dims); + + // Calculate the shape of the temporary tensor for the full FFT and the + // region we will slice from input given fft_shape. We slice input to + // fft_shape on its inner-most dimensions, except the last (which we + // slice to fft_shape[-1] / 2 + 1). + Tensor temp(DataTypeToEnum<complex64>::v(), temp_shape); + auto full_fft = temp.flat_inner_dims<complex64, FFTRank + 1>(); + + // Calculate the starting point and range of the source of + // negative frequency part. + auto neg_sizes = in_dims; + neg_sizes[FFTRank] = fft_shape[FFTRank - 1] - in_dims[FFTRank]; + Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> neg_target_indices; + neg_target_indices[FFTRank] = in_dims[FFTRank]; + + const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> zero_start_indices; + Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> neg_start_indices; + neg_start_indices[FFTRank] = 1; + + full_fft.slice(zero_start_indices, in_dims).device(device) = input; + + // First, conduct IFFTs on outer dimensions. We save computation (and + // avoid touching uninitialized memory) by slicing full_fft to the + // subregion we wrote input to. + if (FFTRank > 1) { + const auto outer_axes = + Eigen::ArrayXi::LinSpaced(FFTRank - 1, 1, FFTRank - 1); + full_fft.slice(zero_start_indices, in_dims).device(device) = + full_fft.slice(zero_start_indices, in_dims) + .template fft<Eigen::BothParts, Eigen::FFT_REVERSE>(outer_axes); + } + + // Reconstruct the full FFT by appending reversed and conjugated + // spectrum as the negative frequency part. + Eigen::array<bool, FFTRank + 1> reverse_last_axis; + for (auto i = 0; i <= FFTRank; i++) { + reverse_last_axis[i] = i == FFTRank; + } + + if (neg_sizes[FFTRank] != 0) { + full_fft.slice(neg_target_indices, neg_sizes).device(device) = + full_fft.slice(neg_start_indices, neg_sizes) + .reverse(reverse_last_axis) + .conjugate(); + } + + auto inner_axis = Eigen::array<int, 1>{FFTRank}; + output.device(device) = + full_fft.template fft<Eigen::RealPart, Eigen::FFT_REVERSE>(inner_axis); +} + +template <int FFTRank, typename EigenDevice> +void EigenFftWithRank(const EigenDevice& device, void* out, void* operand, + int32 fft_type, int64 input_batch, int64 fft_length0, + int64 fft_length1, int64 fft_length2) { + CHECK(::xla::FftType_IsValid(fft_type)) << fft_type; + switch (fft_type) { + case ::xla::FftType::FFT: + EigenFftC2C<true, FFTRank, EigenDevice>( + device, static_cast<complex64*>(out), + static_cast<complex64*>(operand), input_batch, fft_length0, + fft_length1, fft_length2); + break; + case ::xla::FftType::IFFT: + EigenFftC2C<false, FFTRank, EigenDevice>( + device, static_cast<complex64*>(out), + static_cast<complex64*>(operand), input_batch, fft_length0, + fft_length1, fft_length2); + break; + case ::xla::FftType::RFFT: + EigenFftR2C<FFTRank, EigenDevice>( + device, static_cast<complex64*>(out), static_cast<float*>(operand), + input_batch, fft_length0, fft_length1, fft_length2); + break; + case ::xla::FftType::IRFFT: + EigenFftC2R<FFTRank, EigenDevice>( + device, static_cast<float*>(out), static_cast<complex64*>(operand), + input_batch, fft_length0, fft_length1, fft_length2); + break; + default: + LOG(FATAL) << "Unsupported FFT type: " << fft_type; + } +} + +} // namespace internal + +template <typename EigenDevice> +void EigenFftImpl(const EigenDevice& device, void* out, void* operand, + int32 fft_type, int32 fft_rank, int64 input_batch, + int64 fft_length0, int64 fft_length1, int64 fft_length2) { + switch (fft_rank) { + case 1: + internal::EigenFftWithRank<1, EigenDevice>( + device, out, operand, fft_type, input_batch, fft_length0, 0, 0); + break; + case 2: + internal::EigenFftWithRank<2, EigenDevice>(device, out, operand, fft_type, + input_batch, fft_length0, + fft_length1, 0); + break; + case 3: + internal::EigenFftWithRank<3, EigenDevice>(device, out, operand, fft_type, + input_batch, fft_length0, + fft_length1, fft_length2); + break; + default: + LOG(FATAL) << "Unsupported FFT rank " << fft_rank; + } +} + +} // namespace xla +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_IMPL_H_ diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index 65da61805a..43ab2ec524 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" #include "tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h" #include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_fft.h" #include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h" @@ -208,6 +209,7 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(AcquireInfeedBufferForDequeue); REGISTER_CPU_RUNTIME_SYMBOL(AcquireOutfeedBufferForPopulation); REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF32); + REGISTER_CPU_RUNTIME_SYMBOL(EigenFft); REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF64); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF32); diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 0d54e325e6..a803b3171f 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -103,6 +103,7 @@ class DfsHloVisitorBase { return HandleElementwiseBinary(hlo); } virtual Status HandleConvolution(HloInstructionPtr hlo) = 0; + virtual Status HandleFft(HloInstructionPtr fft) = 0; virtual Status HandleCrossReplicaSum(HloInstructionPtr hlo) = 0; virtual Status HandleCompare(HloInstructionPtr hlo) { return HandleElementwiseBinary(hlo); diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 133aa25094..170adb3d24 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -85,6 +85,9 @@ class DfsHloVisitorWithDefaultBase Status HandleConvolution(HloInstructionPtr convolution) override { return DefaultAction(convolution); } + Status HandleFft(HloInstructionPtr fft) override { + return DefaultAction(fft); + } Status HandleCrossReplicaSum(HloInstructionPtr crs) override { return DefaultAction(crs); } diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index d5d89de6d4..e4832b2ee6 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -220,6 +220,7 @@ cc_library( "convolution_thunk.cc", "copy_thunk.cc", "cudnn_batchnorm_thunk.cc", + "fft_thunk.cc", "for_thunk.cc", "gemm_thunk.cc", "gpu_executable.cc", @@ -234,6 +235,7 @@ cc_library( "convolution_thunk.h", "copy_thunk.h", "cudnn_batchnorm_thunk.h", + "fft_thunk.h", "for_thunk.h", "gemm_thunk.h", "gpu_executable.h", @@ -272,6 +274,7 @@ cc_library( "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core/platform/default/build_config:cublas_plugin", "//tensorflow/core/platform/default/build_config:cudnn_plugin", + "//tensorflow/core/platform/default/build_config:cufft_plugin", "//tensorflow/core/platform/default/build_config:stream_executor_cuda", # build_cleaner: keep ], ) diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc new file mode 100644 index 0000000000..66931bdc8b --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc @@ -0,0 +1,234 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/fft_thunk.h" + +#include <string> + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace se = ::perftools::gputools; + +namespace xla { +namespace gpu { + +FftScratchAllocator::FftScratchAllocator( + int device_ordinal, DeviceMemoryAllocator* memory_allocator) + : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {} + +FftScratchAllocator::~FftScratchAllocator() { + for (auto& allocated_buffer : allocated_buffers_) { + if (!memory_allocator_->Deallocate(device_ordinal_, &allocated_buffer) + .ok()) { + // The program can still continue with failed deallocation. + LOG(ERROR) << "Failed to deallocate the allocated buffer: " + << allocated_buffer.opaque(); + } + } +} + +int64 FftScratchAllocator::GetMemoryLimitInBytes(se::Stream* stream) { + constexpr int64 kFftScratchSize = 1LL << 32; // 4GB by default. + return kFftScratchSize; +} + +se::port::StatusOr<se::DeviceMemory<uint8>> FftScratchAllocator::AllocateBytes( + se::Stream* stream, int64 byte_size) { + CHECK_GE(byte_size, 0) << "byte_size must be positive."; + if (byte_size > GetMemoryLimitInBytes(stream)) { + return se::port::Status( + se::port::error::RESOURCE_EXHAUSTED, + tensorflow::strings::Printf( + "Allocating %lld bytes exceeds the memory limit of %lld bytes.", + byte_size, GetMemoryLimitInBytes(stream))); + } + + auto status_or_memory = + memory_allocator_->Allocate(device_ordinal_, byte_size, + /*retry_on_failure=*/false); + if (!status_or_memory.ok()) { + return tensorflow::errors::ResourceExhausted( + "Failed to allocate %lld bytes on device %d.", byte_size, + device_ordinal_); + } + se::DeviceMemoryBase allocated_buffer = status_or_memory.ValueOrDie(); + allocated_buffers_.push_back(allocated_buffer); + total_allocated_bytes_ += byte_size; + return se::DeviceMemory<uint8>(allocated_buffer); +} + +namespace { + +se::fft::Type FftTypeToSeType(FftType type) { + switch (type) { + case FftType::FFT: + return se::fft::Type::kC2CForward; + case FftType::IFFT: + return se::fft::Type::kC2CInverse; + case FftType::IRFFT: + return se::fft::Type::kC2R; + case FftType::RFFT: + return se::fft::Type::kR2C; + default: + LOG(FATAL) << "unsupported fft type"; + } +} + +string FftTypeToString(se::fft::Type type) { + switch (type) { + case se::fft::Type::kC2CForward: + return "FFT"; + case se::fft::Type::kC2CInverse: + return "IFFT"; + case se::fft::Type::kC2R: + return "IRFFT"; + case se::fft::Type::kR2C: + return "RFFT"; + default: + LOG(FATAL) << "unknown fft type"; + } +} + +} // namespace + +FftThunk::FftThunk(FftType fft_type, + tensorflow::gtl::ArraySlice<int64> fft_length, + const BufferAllocation::Slice& input_buffer, + const BufferAllocation::Slice& output_buffer, + const Shape& input_shape, const Shape& output_shape, + const HloInstruction* hlo) + : Thunk(Kind::kFft, hlo), + fft_type_(FftTypeToSeType(fft_type)), + fft_length_(fft_length.begin(), fft_length.end()), + scale_factor_(1.0f), + input_buffer_(input_buffer), + output_buffer_(output_buffer), + input_shape_(input_shape), + output_shape_(output_shape) {} + +tensorflow::Status FftThunk::ExecuteOnStream( + const BufferAllocations& buffer_allocations, se::Stream* stream) { + VLOG(3) << "FFT type: " << FftTypeToString(fft_type_); + VLOG(3) << "Input shape: " << ShapeUtil::HumanStringWithLayout(input_shape_); + VLOG(3) << "Output shape: " + << ShapeUtil::HumanStringWithLayout(output_shape_); + + FftScratchAllocator scratch_allocator(buffer_allocations.device_ordinal(), + buffer_allocations.memory_allocator()); + + if (fft_plan_ == nullptr) { + const int64 fft_rank = fft_length_.size(); + CHECK_LE(fft_rank, 3); + int batch_size = 1; + for (int i = 0; i < input_shape_.dimensions_size() - fft_rank; ++i) { + batch_size *= input_shape_.dimensions(i); + } + uint64 fft_length[3]; + uint64 input_embed[3]; + const uint64 input_stride = 1; + uint64 input_distance = 1; + uint64 output_embed[3]; + const uint64 output_stride = 1; + uint64 output_distance = 1; + + for (int i = 0; i < fft_rank; ++i) { + auto dim_offset = input_shape_.dimensions_size() - fft_rank + i; + fft_length[i] = static_cast<uint64>(fft_length_[i]); + input_embed[i] = input_shape_.dimensions(dim_offset); + input_distance *= input_shape_.dimensions(dim_offset); + output_embed[i] = output_shape_.dimensions(dim_offset); + output_distance *= output_shape_.dimensions(dim_offset); + } + + constexpr bool kInPlaceFft = false; + fft_plan_ = + stream->parent()->AsFft()->CreateBatchedPlanWithScratchAllocator( + stream, fft_rank, fft_length, input_embed, input_stride, + input_distance, output_embed, output_stride, output_distance, + fft_type_, kInPlaceFft, batch_size, &scratch_allocator); + scale_factor_ = 1.0f / output_distance; + } else { + stream->parent()->AsFft()->UpdatePlanWithScratchAllocator( + stream, fft_plan_.get(), &scratch_allocator); + } + + bool launch_ok; + switch (fft_type_) { + case se::fft::Type::kC2CForward: { + se::DeviceMemory<complex64> input_data( + buffer_allocations.GetDeviceAddress(input_buffer_)); + se::DeviceMemory<complex64> output_data( + buffer_allocations.GetDeviceAddress(output_buffer_)); + launch_ok = + stream->ThenFft(fft_plan_.get(), input_data, &output_data).ok(); + break; + } + case se::fft::Type::kC2CInverse: { + se::DeviceMemory<complex64> input_data( + buffer_allocations.GetDeviceAddress(input_buffer_)); + se::DeviceMemory<complex64> output_data( + buffer_allocations.GetDeviceAddress(output_buffer_)); + launch_ok = + stream->ThenFft(fft_plan_.get(), input_data, &output_data).ok(); + if (launch_ok) { + launch_ok = + stream + ->ThenBlasScal(ShapeUtil::ElementsIn(output_shape_), + complex64(scale_factor_), &output_data, 1) + .ok(); + } + break; + } + case se::fft::Type::kR2C: { + se::DeviceMemory<float> input_data( + buffer_allocations.GetDeviceAddress(input_buffer_)); + se::DeviceMemory<complex64> output_data( + buffer_allocations.GetDeviceAddress(output_buffer_)); + launch_ok = + stream->ThenFft(fft_plan_.get(), input_data, &output_data).ok(); + break; + } + case se::fft::Type::kC2R: { + se::DeviceMemory<complex64> input_data( + buffer_allocations.GetDeviceAddress(input_buffer_)); + se::DeviceMemory<float> output_data( + buffer_allocations.GetDeviceAddress(output_buffer_)); + launch_ok = + stream->ThenFft(fft_plan_.get(), input_data, &output_data).ok(); + if (launch_ok) { + launch_ok = stream + ->ThenBlasScal(ShapeUtil::ElementsIn(output_shape_), + scale_factor_, &output_data, 1) + .ok(); + } + break; + } + default: + LOG(FATAL) << "unsupported fft type"; + } + if (launch_ok) { + return tensorflow::Status::OK(); + } + return InternalError("Unable to launch fft for thunk %p with type %s", this, + FftTypeToString(fft_type_).c_str()); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.h b/tensorflow/compiler/xla/service/gpu/fft_thunk.h new file mode 100644 index 0000000000..52fb8c376d --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.h @@ -0,0 +1,98 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FFT_THUNK_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FFT_THUNK_H_ + +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/optional.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +// A one-time scratch allocator for FFT. The scratch buffers allocated are +// released on destruction. +// +// Not thread-safe in that AllocateBytes, destructor are not locked. +class FftScratchAllocator : public perftools::gputools::ScratchAllocator { + public: + FftScratchAllocator(int device_ordinal, + DeviceMemoryAllocator* memory_allocator); + + ~FftScratchAllocator() override; + + int64 GetMemoryLimitInBytes(perftools::gputools::Stream* stream) override; + + int64 TotalAllocatedBytes() { return total_allocated_bytes_; } + + perftools::gputools::port::StatusOr<perftools::gputools::DeviceMemory<uint8>> + AllocateBytes(perftools::gputools::Stream* stream, int64 byte_size) override; + + private: + const int device_ordinal_; + DeviceMemoryAllocator* memory_allocator_; + std::vector<perftools::gputools::DeviceMemoryBase> allocated_buffers_; + int64 total_allocated_bytes_ = 0; +}; + +// This class stores everything that StreamExecutor needs to launch an FFT. +// It is generated by IrEmitter. +// +// This is thread-compatible. +class FftThunk : public Thunk { + public: + // Constructs a thunk for launching an FFT on a stream. + // Semantics of null hlo_instruction argument are as in Thunk. + FftThunk(FftType fft_type, tensorflow::gtl::ArraySlice<int64> fft_length, + const BufferAllocation::Slice& input_buffer, + const BufferAllocation::Slice& output_buffer, + const Shape& input_shape, const Shape& output_shape, + const HloInstruction* hlo); + + FftThunk(const FftThunk&) = delete; // Cannot share fft_plan_ + FftThunk& operator=(const FftThunk&) = delete; // Cannot share fft_plan_ + + // Does the FFT for the thunk on "stream". + tensorflow::Status ExecuteOnStream( + const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) override; + + private: + const perftools::gputools::fft::Type fft_type_; + const std::vector<int64> fft_length_; + + float scale_factor_; + + std::unique_ptr<perftools::gputools::fft::Plan> fft_plan_; + + const BufferAllocation::Slice input_buffer_; + const BufferAllocation::Slice output_buffer_; + + const Shape input_shape_; + const Shape output_shape_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FFT_THUNK_H_ diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index d85d6aae12..d6d0e1e116 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -605,6 +605,14 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { "Hit a case for convolution that is not implemented on GPU."); } +Status IrEmitter::HandleFft(HloInstruction* fft) { + if (ShapeUtil::HasZeroElements(fft->shape())) { + // Emit no code for an empty output. + return Status::OK(); + } + return Unimplemented("Hit a case for fft that is not implemented on GPU."); +} + Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) { // TODO(b/33011107): Support cross replica sum on GPU. return Unimplemented( diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index 41d013c13d..af43895c23 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -79,6 +79,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleDot(HloInstruction* dot) override; Status HandleConvolution(HloInstruction* convolution) override; + Status HandleFft(HloInstruction* fft) override; Status HandleCrossReplicaSum(HloInstruction* crs) override; Status HandleInfeed(HloInstruction* infeed) override; Status HandleOutfeed(HloInstruction* outfeed) override; @@ -242,6 +243,7 @@ class IrEmitterUnnested : public IrEmitter { Status HandleConvolution(HloInstruction* convolution) override; Status HandleCustomCall(HloInstruction* custom_call) override; Status HandleDot(HloInstruction* dot) override; + Status HandleFft(HloInstruction* fft) override; Status HandleFusion(HloInstruction* fusion) override; Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleReduce(HloInstruction* reduce) override; @@ -332,6 +334,9 @@ class IrEmitterUnnested : public IrEmitter { // Returns a ConvolutionThunk that calls DNN to implement `inst`. std::unique_ptr<Thunk> BuildConvolutionThunk(const HloInstruction* inst); + // Returns a FftThunk that calls cuFFT to implement `inst`. + std::unique_ptr<Thunk> BuildFftThunk(const HloInstruction* inst); + // Returns a GemmThunk that calls gemm to implement `inst`. The caller needs // to make sure `inst` outlives the lifetime of the returned Thunk object. std::unique_ptr<Thunk> BuildGemmThunk(const HloInstruction* inst); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 1be3473f52..1aa506a3a9 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/fft_thunk.h" #include "tensorflow/compiler/xla/service/gpu/for_thunk.h" #include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h" #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h" @@ -373,6 +374,14 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { return IrEmitter::HandleCustomCall(custom_call); } +Status IrEmitterUnnested::HandleFft(HloInstruction* fft) { + TF_RET_CHECK( + LayoutUtil::IsMonotonicWithDim0Major(fft->operand(0)->shape().layout())); + TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(fft->shape().layout())); + thunk_sequence_->emplace_back(BuildFftThunk(fft)); + return Status::OK(); +} + Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { HloInstruction* root = fusion->fused_expression_root(); // HandleFusion specializes reduction from a multi-dimensional array to a 1D @@ -1855,6 +1864,16 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildConvolutionThunk( } } +std::unique_ptr<Thunk> IrEmitterUnnested::BuildFftThunk( + const HloInstruction* inst) { + const HloInstruction* operand = inst->operand(0); + return MakeUnique<FftThunk>(inst->fft_type(), inst->fft_length(), + /*input_buffer=*/GetAllocationSlice(*operand), + /*output_buffer=*/GetAllocationSlice(*inst), + /*input_shape=*/operand->shape(), + /*output_shape=*/inst->shape(), inst); +} + Status IrEmitterUnnested::EmitInitializer(const HloInstruction* hlo, KernelThunk* thunk) { bool fused = HloOpcode::kFusion == hlo->opcode(); diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h index 191c7675c6..625c3f8bea 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.h +++ b/tensorflow/compiler/xla/service/gpu/thunk.h @@ -46,6 +46,7 @@ class Thunk { kCudnnBatchNormBackward, kCudnnBatchNormForwardInference, kCudnnBatchNormForwardTraining, + kFft, kGemm, kInfeed, kKernel, diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index e4aed7593c..0e9a852788 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -123,6 +123,12 @@ message HloInstructionProto { // Describes the dimension numbers used for a dot operation xla.DotDimensionNumbers dot_dimension_numbers = 30; + + // FFT type (FFT, IFFT, etc). + xla.FftType fft_type = 31; + + // FFT length. + repeated int64 fft_length = 32; } // Serialization of HloComputation. diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index b933695b82..cd54eb74d1 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -392,6 +392,21 @@ Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) { return Status::OK(); } +Status HloCostAnalysis::HandleFft(const HloInstruction* fft) { + auto real_shape = + ShapeUtil::IsTuple(fft->operand(0)->shape()) + ? ShapeUtil::GetTupleElementShape(fft->operand(0)->shape(), 0) + : fft->operand(0)->shape(); + constexpr int kFmaPerComplexMul = 4; + int64 log_factors = 1; + for (int64 dim : fft->fft_length()) { + log_factors *= tensorflow::Log2Floor(dim); + } + current_properties_[kFlopsKey] = kFmaFlops * kFmaPerComplexMul * log_factors * + ShapeUtil::ElementsIn(real_shape); + return Status::OK(); +} + Status HloCostAnalysis::HandleCrossReplicaSum(const HloInstruction* crs) { // We assume 2 replicas, so that each output element is the sum of two input // elements. diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index fade19522c..e5783539e5 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -67,6 +67,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleCopy(const HloInstruction* copy) override; Status HandleDot(const HloInstruction* dot) override; Status HandleConvolution(const HloInstruction* convolution) override; + Status HandleFft(const HloInstruction* fft) override; Status HandleCrossReplicaSum(const HloInstruction* crs) override; Status HandleInfeed(const HloInstruction* infeed) override; Status HandleOutfeed(const HloInstruction* outfeed) override; diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 023aec96ec..f7c6435002 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -961,6 +961,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { return kGreen; case HloOpcode::kConvolution: case HloOpcode::kDot: + case HloOpcode::kFft: return kDarkBlue; case HloOpcode::kReducePrecision: return kRed; diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 9805818e4c..138fb9a190 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -144,6 +144,10 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( instruction->infeed_config_ = proto.infeed_config(); instruction->custom_call_target_ = proto.custom_call_target(); instruction->outfeed_shape_ = proto.outfeed_shape(); + instruction->fft_type_ = proto.fft_type(); + for (int64 fft_len : proto.fft_length()) { + instruction->fft_length_.push_back(fft_len); + } return std::move(instruction); } @@ -334,6 +338,16 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, return instruction; } +/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFft( + const Shape& shape, HloInstruction* operand, FftType fft_type, + tensorflow::gtl::ArraySlice<int64> fft_length) { + auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFft, shape)); + instruction->AppendOperand(operand); + instruction->fft_type_ = fft_type; + instruction->fft_length_.assign(fft_length.begin(), fft_length.end()); + return instruction; +} + /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDot( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, const DotDimensionNumbers& dimension_numbers) { @@ -1167,6 +1181,9 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( clone = CreateDot(shape, new_operands[0], new_operands[1], *dot_dimension_numbers_); break; + case HloOpcode::kFft: + CHECK_EQ(new_operands.size(), 1); + return CreateFft(shape, new_operands[0], fft_type_, fft_length_); case HloOpcode::kCrossReplicaSum: clone = CreateCrossReplicaSum(shape, new_operands); break; @@ -1631,6 +1648,11 @@ bool HloInstruction::IdenticalSlowPath( return protobuf_util::ProtobufEquals(dot_dimension_numbers(), other.dot_dimension_numbers()); + // FFT has various types & lengths. + case HloOpcode::kFft: + return fft_type() == other.fft_type() && + fft_length() == other.fft_length(); + // Reduction results are determined by the reduction dimension and the // reduction computation. case HloOpcode::kReduce: @@ -2055,6 +2077,10 @@ std::vector<string> HloInstruction::ExtraAttributesToString( if (dot_dimension_numbers_ != nullptr) { extra.push_back(DotDimensionNumbersToString()); } + if (opcode() == HloOpcode::kFft) { + extra.push_back(StrCat("fft_type=", FftType_Name(fft_type()))); + extra.push_back(StrCat("fft_length={", Join(fft_length(), ","), "}")); + } if (options.print_subcomputation_references()) { if (opcode() == HloOpcode::kWhile) { @@ -2206,6 +2232,10 @@ HloInstructionProto HloInstruction::ToProto() const { proto.set_infeed_config(infeed_config_); proto.set_custom_call_target(custom_call_target_); *proto.mutable_outfeed_shape() = outfeed_shape_; + proto.set_fft_type(fft_type_); + for (int64 fft_len : fft_length_) { + proto.add_fft_length(fft_len); + } return proto; } @@ -2421,6 +2451,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) { return visitor->HandleSelect(this); case HloOpcode::kConvolution: return visitor->HandleConvolution(this); + case HloOpcode::kFft: + return visitor->HandleFft(this); case HloOpcode::kCrossReplicaSum: return visitor->HandleCrossReplicaSum(this); case HloOpcode::kTuple: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index d455cfc3f1..c5cab92ac9 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -261,6 +261,11 @@ class HloInstruction { const Window& window, const ConvolutionDimensionNumbers& dimension_numbers); + // Creates an FFT op, of the type indicated by fft_type. + static std::unique_ptr<HloInstruction> CreateFft( + const Shape& shape, HloInstruction* operand, FftType fft_type, + tensorflow::gtl::ArraySlice<int64> fft_length); + // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch // dimensions specified in 'dimension_numbers'. static std::unique_ptr<HloInstruction> CreateDot( @@ -1031,6 +1036,16 @@ class HloInstruction { return *convolution_dimension_numbers_; } + FftType fft_type() const { + CHECK_EQ(HloOpcode::kFft, opcode_); + return fft_type_; + } + + const std::vector<int64>& fft_length() const { + CHECK_EQ(HloOpcode::kFft, opcode_); + return fft_length_; + } + // Returns the dump string of the convolution dimension numbers. string ConvolutionDimensionNumbersToString() const; @@ -1303,6 +1318,12 @@ class HloInstruction { // Describes the dimension numbers used for a dot. std::unique_ptr<DotDimensionNumbers> dot_dimension_numbers_; + // Describes FFT type for an FFT instruction. + FftType fft_type_ = FftType::FFT; + + // Indicates the FFT length for an FFT instruction. + std::vector<int64> fft_length_; + // Describes the [begin, end) index range for a slice. std::vector<int64> slice_starts_; std::vector<int64> slice_limits_; diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index f3f7935758..3d64523a79 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -73,6 +73,7 @@ namespace xla { V(kDynamicUpdateSlice, "dynamic-update-slice") \ V(kEq, "equal-to", kHloOpcodeIsComparison) \ V(kExp, "exponential") \ + V(kFft, "fft") \ V(kFloor, "floor") \ V(kFusion, "fusion", kHloOpcodeIsVariadic) \ V(kGe, "greater-than-or-equal-to", kHloOpcodeIsComparison) \ diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index d963a8a2f4..9d5ca6673a 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -92,6 +92,14 @@ class ShapeVerifier : public DfsHloVisitor { return CheckShape(convolution, expected); } + Status HandleFft(HloInstruction* fft) override { + TF_ASSIGN_OR_RETURN( + const Shape expected, + ShapeInference::InferFftShape(fft->operand(0)->shape(), fft->fft_type(), + fft->fft_length())); + return CheckShape(fft, expected); + } + Status HandleCrossReplicaSum(HloInstruction* crs) override { std::vector<const Shape*> operand_shapes; for (const HloInstruction* operand : crs->operands()) { diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index ba901b99e4..90e1f0acdc 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -100,6 +100,7 @@ namespace xla { case HloOpcode::kDivide: case HloOpcode::kDot: case HloOpcode::kExp: + case HloOpcode::kFft: case HloOpcode::kFusion: case HloOpcode::kLog: case HloOpcode::kMap: diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 0b98714168..40613cc75b 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -1406,6 +1406,9 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { handle_status = computation->AddDynamicUpdateSliceInstruction( arg->dynamic_update_slice_request()); break; + case OpRequest::kFftRequest: + handle_status = computation->AddFftInstruction(arg->fft_request()); + break; case OpRequest::kGetTupleElementRequest: handle_status = computation->AddGetTupleElementInstruction( arg->get_tuple_element_request()); @@ -1518,8 +1521,6 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { handle_status = computation->AddRecvInstruction(arg->recv_request()); break; } - case OpRequest::kFftRequest: - return Unimplemented("FftRequest not implemented in XLA service."); case OpRequest::OP_NOT_SET: return InvalidArgument("XLA service received OpRequest with OP_NOT_SET"); default: diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 9c1b951d01..6dc49ffe4c 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -1715,6 +1715,78 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return ShapeUtil::MakeShape(lhs.element_type(), dimensions); } +/* static */ StatusOr<Shape> ShapeInference::InferFftShape( + const Shape& in, const FftType fft_type, + const tensorflow::gtl::ArraySlice<int64> fft_length) { + const int64 fft_rank = fft_length.size(); + if (fft_rank < 1 || fft_rank > 3) { + return InvalidArgument("FFT only supports ranks 1-3, but got %lld", + fft_rank); + } + switch (fft_type) { + case FFT: + case IFFT: + if (in.element_type() != C64) { + return InvalidArgument("%s requires C64 input type, found %s", + FftType_Name(fft_type).c_str(), + PrimitiveType_Name(in.element_type()).c_str()); + } + return in; + case RFFT: { + if (in.element_type() != F32) { + return InvalidArgument("RFFT requires F32 input type, found %s", + PrimitiveType_Name(in.element_type()).c_str()); + } + for (int i = 0; i < fft_rank; i++) { + if (in.dimensions(in.dimensions_size() - fft_rank + i) != + fft_length[i]) { + return InvalidArgument( + "RFFT requires innermost dimensions match fft_length but " + "dimension %lld is %lld and should be %lld", + in.dimensions_size() - fft_rank + i, + in.dimensions(in.dimensions_size() - fft_rank + i), + fft_length[i]); + } + } + Shape result = ShapeUtil::ChangeElementType(in, C64); + result.set_dimensions(result.dimensions_size() - 1, + fft_length[fft_rank - 1] / 2 + 1); + return result; + } + case IRFFT: { + if (in.element_type() != C64) { + return InvalidArgument("IRFFT requires C64 input type, found %s", + PrimitiveType_Name(in.element_type()).c_str()); + } + Shape result = ShapeUtil::ChangeElementType(in, F32); + for (int i = 0; i < fft_rank - 1; i++) { + if (in.dimensions(in.dimensions_size() - fft_rank + i) != + fft_length[i]) { + return InvalidArgument( + "IRFFT requires all but one innermost dimensions match " + "fft_length, but dimension %lld is %lld and should be %lld", + in.dimensions_size() - fft_rank + i, + in.dimensions(in.dimensions_size() - fft_rank + i), + fft_length[i]); + } + } + if (in.dimensions(in.dimensions_size() - 1) != + fft_length[fft_rank - 1] / 2 + 1) { + return InvalidArgument( + "IRFFT requires innermost dimension matches fft_length/2+1, but " + "dimension %d is %lld and should be %lld", + in.dimensions_size() - 1, in.dimensions(in.dimensions_size() - 1), + fft_length[fft_rank - 1] / 2 + 1); + } + result.set_dimensions(result.dimensions_size() - 1, + fft_length[fft_rank - 1]); + return result; + } + default: + LOG(FATAL) << "Unexpected fft_type: " << fft_type; + } +} + /* static */ StatusOr<Shape> ShapeInference::InferCrossReplicaSumShape( tensorflow::gtl::ArraySlice<const Shape*> operand_shapes) { for (const Shape* operand_shape : operand_shapes) { diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index c06340d2d5..b39151ebbc 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -109,6 +109,11 @@ class ShapeInference { const Shape& lhs, const Shape& rhs, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers); + // Infers the shape produced by the given FFT type on the given operand. + static StatusOr<Shape> InferFftShape( + const Shape& in, FftType fft_type, + tensorflow::gtl::ArraySlice<int64> fft_length); + // Infers the shape produced a cross replica sum with the given operand // shapes. static StatusOr<Shape> InferCrossReplicaSumShape( diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index e42cbfa976..9d941fb770 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -1121,6 +1121,31 @@ StatusOr<ComputationDataHandle> UserComputation::AddConvolveInstruction( return handle; } +StatusOr<ComputationDataHandle> UserComputation::AddFftInstruction( + const FftRequest& fft_request) { + tensorflow::mutex_lock lock(mutex_); + + TF_ASSIGN_OR_RETURN(const OperationRequest* operand, + LookUpRequest(fft_request.operand())); + TF_ASSIGN_OR_RETURN(Shape shape, + ShapeInference::InferFftShape( + operand->output_shape(), fft_request.fft_type(), + AsInt64Slice(fft_request.fft_length()))); + + const ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = shape; + *request.mutable_request()->mutable_fft_request() = fft_request; + + VLOG(1) << "AddFftInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << fft_request.ShortDebugString(); + return handle; +} + StatusOr<ComputationDataHandle> UserComputation::AddCrossReplicaSumInstruction( const CrossReplicaSumRequest& cross_replica_sum_request) { tensorflow::mutex_lock lock(mutex_); @@ -1675,6 +1700,13 @@ void PureFunctionalVisitor(const SessionComputation& session_computation, break; } + case OpRequest::kFftRequest: { + const FftRequest& fft_request = request.request().fft_request(); + PureFunctionalVisitor(session_computation, fft_request.operand(), + num_parameters, visited, is_functional); + break; + } + case OpRequest::kCrossReplicaSumRequest: { // TODO(b/33009255): Implmement constant folding for cross replica sum. *is_functional = false; @@ -2406,6 +2438,12 @@ static void ForEachOperand( break; } + case OpRequest::kFftRequest: { + const FftRequest& fft_request = request.request().fft_request(); + apply(fft_request.operand()); + break; + } + case OpRequest::kBatchNormTrainingRequest: { const BatchNormTrainingRequest& batch_norm_training_request = request.request().batch_norm_training_request(); @@ -2880,6 +2918,15 @@ void ComputationLowerer::Visit( break; } + case OpRequest::kFftRequest: { + const FftRequest& fft_request = request.request().fft_request(); + HloInstruction* operand = lookup_instruction(fft_request.operand()); + hlo_instruction = add_instruction(HloInstruction::CreateFft( + request.output_shape(), operand, fft_request.fft_type(), + AsInt64Slice(fft_request.fft_length()))); + break; + } + case OpRequest::kDotRequest: { const DotRequest& dot_request = request.request().dot_request(); HloInstruction* lhs = lookup_instruction(dot_request.lhs()); diff --git a/tensorflow/compiler/xla/service/user_computation.h b/tensorflow/compiler/xla/service/user_computation.h index 8a78d520e1..8be639c784 100644 --- a/tensorflow/compiler/xla/service/user_computation.h +++ b/tensorflow/compiler/xla/service/user_computation.h @@ -133,6 +133,10 @@ class UserComputation { StatusOr<ComputationDataHandle> AddConvolveInstruction( const ConvolveRequest& convolve_request); + // Enqueues an FFT instruction onto this user computation. + StatusOr<ComputationDataHandle> AddFftInstruction( + const FftRequest& fft_request); + // Enqueues a cross replica sum instruction onto this user computation. StatusOr<ComputationDataHandle> AddCrossReplicaSumInstruction( const CrossReplicaSumRequest& cross_replica_sum_request); diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc index 58139ee0c1..3caa465769 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -99,6 +99,7 @@ class HloParser { kString, kBracedInt64List, kHloComputation, + kFftType, kWindow, kConvolutionDimensionNumbers, kSharding, @@ -178,6 +179,7 @@ class HloParser { bool ParseString(string* result); bool ParseShape(Shape* result); bool ParseOpcode(HloOpcode* result); + bool ParseFftType(FftType* result); bool ParseFusionKind(HloInstruction::FusionKind* result); bool ParseRandomDistribution(RandomDistribution* result); bool ParseInt64(int64* result); @@ -685,6 +687,20 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, shape, /*lhs=*/operands[0], /*rhs=*/operands[1], *window, *dnums)); break; } + case HloOpcode::kFft: { + optional<FftType> fft_type; + optional<std::vector<int64>> fft_length; + attrs["fft_type"] = {/*required=*/true, AttrTy::kFftType, &fft_type}; + attrs["fft_length"] = {/*required=*/true, AttrTy::kBracedInt64List, + &fft_length}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateFft( + shape, operands[0], *fft_type, *fft_length)); + break; + } case HloOpcode::kBroadcast: { optional<std::vector<int64>> broadcast_dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, @@ -1672,6 +1688,14 @@ bool HloParser::ParseAttributeHelper( static_cast<optional<HloComputation*>*>(attr_out_ptr)->emplace(result); return true; } + case AttrTy::kFftType: { + FftType result; + if (!ParseFftType(&result)) { + return false; + } + static_cast<optional<FftType>*>(attr_out_ptr)->emplace(result); + return true; + } case AttrTy::kWindow: { Window result; if (!ParseWindow(&result)) { @@ -2289,6 +2313,19 @@ bool HloParser::ParseOpcode(HloOpcode* result) { return true; } +bool HloParser::ParseFftType(FftType* result) { + VLOG(1) << "ParseFftType"; + if (lexer_.GetKind() != TokKind::kIdent) { + return TokenError("expects fft type"); + } + string val = lexer_.GetStrVal(); + if (!FftType_Parse(val, result) || !FftType_IsValid(*result)) { + return TokenError(Printf("expects fft type but sees: %s", val.c_str())); + } + lexer_.Lex(); + return true; +} + bool HloParser::ParseFusionKind(HloInstruction::FusionKind* result) { VLOG(1) << "ParseFusionKind"; if (lexer_.GetKind() != TokKind::kIdent) { diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc index 69c59ad6c7..ce11d2b43d 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc @@ -582,6 +582,54 @@ ENTRY %BatchNormGrad.v4 (input: f32[2,2,2,2], scale: f32[2], mean: f32[2], varia )" }, +// fft +{ +"Fft", +R"(HloModule Fft_module + +ENTRY %Fft (input: c64[8,32]) -> c64[8,32] { + %input = c64[8,32]{1,0} parameter(0) + ROOT %fft = c64[8,32]{1,0} fft(c64[8,32]{1,0} %input), fft_type=FFT, fft_length={32} +} + +)" +}, +// ifft +{ +"Ifft2d", +R"(HloModule Ifft2d_module + +ENTRY %Ifft2d (input: c64[5,8,32]) -> c64[5,8,32] { + %input = c64[5,8,32]{2,1,0} parameter(0) + ROOT %fft = c64[5,8,32]{2,1,0} fft(c64[5,8,32]{2,1,0} %input), fft_type=IFFT, fft_length={8,32} +} + +)" +}, +// rfft2d +{ +"Rfft2d", +R"(HloModule Rfft2d_module + +ENTRY %Rfft2d (input: f32[5,64,32]) -> c64[5,64,17] { + %input = f32[5,64,32]{2,1,0} parameter(0) + ROOT %fft = c64[5,64,17]{2,1,0} fft(f32[5,64,32]{2,1,0} %input), fft_type=RFFT, fft_length={64,32} +} + +)" +}, +// irfft3d +{ +"Irfft3d", +R"(HloModule Irfft3d_module + +ENTRY %Irfft3d (input: c64[5,64,128,33]) -> f32[5,64,128,64] { + %input = c64[5,64,128,33]{3,2,1,0} parameter(0) + ROOT %fft = f32[5,64,128,64]{3,2,1,0} fft(c64[5,64,128,33]{3,2,1,0} %input), fft_type=IRFFT, fft_length={64,128,64} +} + +)" +}, // pad { "Pad", |