aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h')
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h20
1 files changed, 7 insertions, 13 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h b/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h
index 984cb0616e..0bf693edd0 100644
--- a/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h
+++ b/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h
@@ -21,8 +21,6 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/numeric_types.h"
-#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/platform/types.h"
// 'tensorflow' namespace is used so that int64 and other types don't require
@@ -71,11 +69,9 @@ void EigenFftR2C(const EigenDevice& device, complex64* out, float* operand,
in_dims[0] = input_batch;
Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> out_dims;
out_dims[0] = input_batch;
- TensorShape temp_shape{input_batch};
for (int i = 0; i < FFTRank; i++) {
in_dims[i + 1] = fft_shape[i];
out_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i];
- temp_shape.AddDim(fft_shape[i]);
}
const Eigen::TensorMap<Eigen::Tensor<float, FFTRank + 1, Eigen::RowMajor>,
Eigen::Aligned>
@@ -88,8 +84,8 @@ void EigenFftR2C(const EigenDevice& device, complex64* out, float* operand,
const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank);
// Compute the full FFT using a temporary tensor.
- Tensor temp(DataTypeToEnum<complex64>::v(), temp_shape);
- auto full_fft = temp.flat_inner_dims<complex64, FFTRank + 1>();
+ Eigen::Tensor<complex64, FFTRank + 1, Eigen::RowMajor> full_fft(in_dims);
+
const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> zero_start_indices;
full_fft.device(device) =
input.template fft<Eigen::BothParts, Eigen::FFT_FORWARD>(axes);
@@ -112,11 +108,9 @@ void EigenFftC2R(const EigenDevice& device, float* out, complex64* operand,
in_dims[0] = input_batch;
Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> out_dims;
out_dims[0] = input_batch;
- TensorShape temp_shape{input_batch};
for (int i = 0; i < FFTRank; i++) {
in_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i];
out_dims[i + 1] = fft_shape[i];
- temp_shape.AddDim(fft_shape[i]);
}
const Eigen::TensorMap<Eigen::Tensor<complex64, FFTRank + 1, Eigen::RowMajor>,
Eigen::Aligned>
@@ -129,8 +123,7 @@ void EigenFftC2R(const EigenDevice& device, float* out, complex64* operand,
// region we will slice from input given fft_shape. We slice input to
// fft_shape on its inner-most dimensions, except the last (which we
// slice to fft_shape[-1] / 2 + 1).
- Tensor temp(DataTypeToEnum<complex64>::v(), temp_shape);
- auto full_fft = temp.flat_inner_dims<complex64, FFTRank + 1>();
+ Eigen::Tensor<complex64, FFTRank + 1, Eigen::RowMajor> full_fft(out_dims);
// Calculate the starting point and range of the source of
// negative frequency part.
@@ -179,7 +172,6 @@ template <int FFTRank, typename EigenDevice>
void EigenFftWithRank(const EigenDevice& device, void* out, void* operand,
int32 fft_type, int64 input_batch, int64 fft_length0,
int64 fft_length1, int64 fft_length2) {
- CHECK(::xla::FftType_IsValid(fft_type)) << fft_type;
switch (fft_type) {
case ::xla::FftType::FFT:
EigenFftC2C<true, FFTRank, EigenDevice>(
@@ -204,7 +196,8 @@ void EigenFftWithRank(const EigenDevice& device, void* out, void* operand,
input_batch, fft_length0, fft_length1, fft_length2);
break;
default:
- LOG(FATAL) << "Unsupported FFT type: " << fft_type;
+ // Unsupported FFT type
+ abort();
}
}
@@ -230,7 +223,8 @@ void EigenFftImpl(const EigenDevice& device, void* out, void* operand,
fft_length1, fft_length2);
break;
default:
- LOG(FATAL) << "Unsupported FFT rank " << fft_rank;
+ // Unsupported FFT rank
+ abort();
}
}