aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar RJ Ryan <rjryan@google.com>2018-05-02 17:57:27 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-02 18:00:06 -0700
commit8f0a90b711480c12716d1a3b1094cc8b34939f2d (patch)
treefc13cfb0c8bbd942a5cc46be8cd426f9f3a32b02
parent7833890a0da5226e4c409b1020155f1718c0edb2 (diff)
Add complex128 support to FFT, FFT2D, FFT3D, IFFT, IFFT2D, and IFFT3D.
NumPy automatically upcasts to complex128 when computing FFTs, leading to issues like: #10749 This change allows users to choose between 32-bit and 64-bit precision FFTs on CPU and GPU. PiperOrigin-RevId: 195183206
-rw-r--r--tensorflow/compiler/tf2xla/kernels/fft_ops.cc17
-rw-r--r--tensorflow/core/kernels/fft_ops.cc78
-rw-r--r--tensorflow/core/ops/spectral_ops.cc30
-rw-r--r--tensorflow/python/kernel_tests/fft_ops_test.py145
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py6
-rw-r--r--tensorflow/python/ops/spectral_grad.py30
6 files changed, 196 insertions, 110 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc
index fcb927dab0..933924cad1 100644
--- a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc
@@ -81,9 +81,11 @@ class FFTOp : public GenericFftOp {
explicit FFTOp(OpKernelConstruction* ctx)
: GenericFftOp(ctx, /*fft_type=*/FftType::FFT, /*fft_rank=*/FFTRank) {}
};
-REGISTER_XLA_OP(Name("FFT"), FFTOp<1>);
-REGISTER_XLA_OP(Name("FFT2D"), FFTOp<2>);
-REGISTER_XLA_OP(Name("FFT3D"), FFTOp<3>);
+REGISTER_XLA_OP(Name("FFT").TypeConstraint("Tcomplex", DT_COMPLEX64), FFTOp<1>);
+REGISTER_XLA_OP(Name("FFT2D").TypeConstraint("Tcomplex", DT_COMPLEX64),
+ FFTOp<2>);
+REGISTER_XLA_OP(Name("FFT3D").TypeConstraint("Tcomplex", DT_COMPLEX64),
+ FFTOp<3>);
template <int FFTRank>
class IFFTOp : public GenericFftOp {
@@ -91,9 +93,12 @@ class IFFTOp : public GenericFftOp {
explicit IFFTOp(OpKernelConstruction* ctx)
: GenericFftOp(ctx, /*fft_type=*/FftType::IFFT, /*fft_rank=*/FFTRank) {}
};
-REGISTER_XLA_OP(Name("IFFT"), IFFTOp<1>);
-REGISTER_XLA_OP(Name("IFFT2D"), IFFTOp<2>);
-REGISTER_XLA_OP(Name("IFFT3D"), IFFTOp<3>);
+REGISTER_XLA_OP(Name("IFFT").TypeConstraint("Tcomplex", DT_COMPLEX64),
+ IFFTOp<1>);
+REGISTER_XLA_OP(Name("IFFT2D").TypeConstraint("Tcomplex", DT_COMPLEX64),
+ IFFTOp<2>);
+REGISTER_XLA_OP(Name("IFFT3D").TypeConstraint("Tcomplex", DT_COMPLEX64),
+ IFFTOp<3>);
template <int FFTRank>
class RFFTOp : public GenericFftOp {
diff --git a/tensorflow/core/kernels/fft_ops.cc b/tensorflow/core/kernels/fft_ops.cc
index 661bf5fc5f..d7105a71bb 100644
--- a/tensorflow/core/kernels/fft_ops.cc
+++ b/tensorflow/core/kernels/fft_ops.cc
@@ -129,13 +129,23 @@ class FFTCPU : public FFTBase {
auto device = ctx->eigen_device<CPUDevice>();
if (!IsReal()) {
- auto input = Tensor(in).flat_inner_dims<complex64, FFTRank + 1>();
- // Compute the FFT using eigen.
- auto output = out->flat_inner_dims<complex64, FFTRank + 1>();
+ // Compute the FFT using Eigen.
constexpr auto direction =
Forward ? Eigen::FFT_FORWARD : Eigen::FFT_REVERSE;
- output.device(device) =
- input.template fft<Eigen::BothParts, direction>(axes);
+ if (in.dtype() == DT_COMPLEX64) {
+ DCHECK_EQ(out->dtype(), DT_COMPLEX64);
+ auto input = Tensor(in).flat_inner_dims<complex64, FFTRank + 1>();
+ auto output = out->flat_inner_dims<complex64, FFTRank + 1>();
+ output.device(device) =
+ input.template fft<Eigen::BothParts, direction>(axes);
+ } else {
+ DCHECK_EQ(DT_COMPLEX128, in.dtype());
+ DCHECK_EQ(DT_COMPLEX128, out->dtype());
+ auto input = Tensor(in).flat_inner_dims<complex128, FFTRank + 1>();
+ auto output = out->flat_inner_dims<complex128, FFTRank + 1>();
+ output.device(device) =
+ input.template fft<Eigen::BothParts, direction>(axes);
+ }
} else {
if (IsForward()) {
auto input = Tensor(in).flat_inner_dims<float, FFTRank + 1>();
@@ -392,10 +402,16 @@ class FFTGPUBase : public FFTBase {
}
constexpr bool kInPlaceFft = false;
+ const bool is_complex128 = in.dtype() == DT_COMPLEX128;
+ // complex128 real FFT is not supported yet.
+ DCHECK(!IsReal() || !is_complex128);
+
const auto kFftType =
IsReal() ? (IsForward() ? se::fft::Type::kR2C : se::fft::Type::kC2R)
- : (IsForward() ? se::fft::Type::kC2CForward
- : se::fft::Type::kC2CInverse);
+ : (IsForward() ? (is_complex128 ? se::fft::Type::kZ2ZForward
+ : se::fft::Type::kC2CForward)
+ : (is_complex128 ? se::fft::Type::kZ2ZInverse
+ : se::fft::Type::kC2CInverse));
CufftScratchAllocator scratch_allocator(CufftScratchSize, ctx);
auto plan =
@@ -428,20 +444,42 @@ class FFTGPUBase : public FFTBase {
input_shape.DebugString()));
}
} else {
- auto src = AsDeviceMemory<complex64>(in.flat<complex64>().data());
- auto dst = AsDeviceMemory<complex64>(out->flat<complex64>().data());
- OP_REQUIRES(
- ctx, stream->ThenFft(plan.get(), src, &dst).ok(),
- errors::Internal("fft failed : type=", static_cast<int>(kFftType),
- " in.shape=", input_shape.DebugString()));
- if (!IsForward()) {
- auto alpha = complex64(1.f / output_distance);
+ if (!is_complex128) {
+ DCHECK_EQ(in.dtype(), DT_COMPLEX64);
+ DCHECK_EQ(out->dtype(), DT_COMPLEX64);
+ auto src = AsDeviceMemory<complex64>(in.flat<complex64>().data());
+ auto dst = AsDeviceMemory<complex64>(out->flat<complex64>().data());
OP_REQUIRES(
- ctx,
- stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1)
- .ok(),
- errors::Internal("BlasScal failed : in.shape=",
- input_shape.DebugString()));
+ ctx, stream->ThenFft(plan.get(), src, &dst).ok(),
+ errors::Internal("fft failed : type=", static_cast<int>(kFftType),
+ " in.shape=", input_shape.DebugString()));
+ if (!IsForward()) {
+ float alpha = 1.f / output_distance;
+ OP_REQUIRES(
+ ctx,
+ stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1)
+ .ok(),
+ errors::Internal("BlasScal failed : in.shape=",
+ input_shape.DebugString()));
+ }
+ } else {
+ DCHECK_EQ(in.dtype(), DT_COMPLEX128);
+ DCHECK_EQ(out->dtype(), DT_COMPLEX128);
+ auto src = AsDeviceMemory<complex128>(in.flat<complex128>().data());
+ auto dst = AsDeviceMemory<complex128>(out->flat<complex128>().data());
+ OP_REQUIRES(
+ ctx, stream->ThenFft(plan.get(), src, &dst).ok(),
+ errors::Internal("fft failed : type=", static_cast<int>(kFftType),
+ " in.shape=", input_shape.DebugString()));
+ if (!IsForward()) {
+ double alpha = 1.0 / output_distance;
+ OP_REQUIRES(
+ ctx,
+ stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1)
+ .ok(),
+ errors::Internal("BlasScal failed : in.shape=",
+ input_shape.DebugString()));
+ }
}
}
}
diff --git a/tensorflow/core/ops/spectral_ops.cc b/tensorflow/core/ops/spectral_ops.cc
index 2790aee37e..b1ae7040f0 100644
--- a/tensorflow/core/ops/spectral_ops.cc
+++ b/tensorflow/core/ops/spectral_ops.cc
@@ -25,43 +25,49 @@ using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
REGISTER_OP("FFT")
- .Input("input: complex64")
- .Output("output: complex64")
+ .Input("input: Tcomplex")
+ .Output("output: Tcomplex")
+ .Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64")
.SetShapeFn([](InferenceContext* c) {
return shape_inference::UnchangedShapeWithRankAtLeast(c, 1);
});
REGISTER_OP("IFFT")
- .Input("input: complex64")
- .Output("output: complex64")
+ .Input("input: Tcomplex")
+ .Output("output: Tcomplex")
+ .Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64")
.SetShapeFn([](InferenceContext* c) {
return shape_inference::UnchangedShapeWithRankAtLeast(c, 1);
});
REGISTER_OP("FFT2D")
- .Input("input: complex64")
- .Output("output: complex64")
+ .Input("input: Tcomplex")
+ .Output("output: Tcomplex")
+ .Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64")
.SetShapeFn([](InferenceContext* c) {
return shape_inference::UnchangedShapeWithRankAtLeast(c, 2);
});
REGISTER_OP("IFFT2D")
- .Input("input: complex64")
- .Output("output: complex64")
+ .Input("input: Tcomplex")
+ .Output("output: Tcomplex")
+ .Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64")
.SetShapeFn([](InferenceContext* c) {
return shape_inference::UnchangedShapeWithRankAtLeast(c, 2);
});
REGISTER_OP("FFT3D")
- .Input("input: complex64")
- .Output("output: complex64")
+ .Input("input: Tcomplex")
+ .Output("output: Tcomplex")
+ .Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64")
.SetShapeFn([](InferenceContext* c) {
return shape_inference::UnchangedShapeWithRankAtLeast(c, 3);
});
REGISTER_OP("IFFT3D")
- .Input("input: complex64")
- .Output("output: complex64")
+ .Input("input: Tcomplex")
+ .Output("output: Tcomplex")
+ .Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64")
.SetShapeFn([](InferenceContext* c) {
return shape_inference::UnchangedShapeWithRankAtLeast(c, 3);
});
diff --git a/tensorflow/python/kernel_tests/fft_ops_test.py b/tensorflow/python/kernel_tests/fft_ops_test.py
index b9e2aa1f3a..629acedda5 100644
--- a/tensorflow/python/kernel_tests/fft_ops_test.py
+++ b/tensorflow/python/kernel_tests/fft_ops_test.py
@@ -38,11 +38,13 @@ VALID_FFT_RANKS = (1, 2, 3)
class BaseFFTOpsTest(test.TestCase):
- def _compare(self, x, rank, fft_length=None, use_placeholder=False):
- self._compareForward(x, rank, fft_length, use_placeholder)
- self._compareBackward(x, rank, fft_length, use_placeholder)
+ def _compare(self, x, rank, fft_length=None, use_placeholder=False,
+ rtol=1e-4, atol=1e-4):
+ self._compareForward(x, rank, fft_length, use_placeholder, rtol, atol)
+ self._compareBackward(x, rank, fft_length, use_placeholder, rtol, atol)
- def _compareForward(self, x, rank, fft_length=None, use_placeholder=False):
+ def _compareForward(self, x, rank, fft_length=None, use_placeholder=False,
+ rtol=1e-4, atol=1e-4):
x_np = self._npFFT(x, rank, fft_length)
if use_placeholder:
x_ph = array_ops.placeholder(dtype=dtypes.as_dtype(x.dtype))
@@ -50,9 +52,10 @@ class BaseFFTOpsTest(test.TestCase):
else:
x_tf = self._tfFFT(x, rank, fft_length)
- self.assertAllClose(x_np, x_tf, rtol=1e-4, atol=1e-4)
+ self.assertAllClose(x_np, x_tf, rtol=rtol, atol=atol)
- def _compareBackward(self, x, rank, fft_length=None, use_placeholder=False):
+ def _compareBackward(self, x, rank, fft_length=None, use_placeholder=False,
+ rtol=1e-4, atol=1e-4):
x_np = self._npIFFT(x, rank, fft_length)
if use_placeholder:
x_ph = array_ops.placeholder(dtype=dtypes.as_dtype(x.dtype))
@@ -60,7 +63,7 @@ class BaseFFTOpsTest(test.TestCase):
else:
x_tf = self._tfIFFT(x, rank, fft_length)
- self.assertAllClose(x_np, x_tf, rtol=1e-4, atol=1e-4)
+ self.assertAllClose(x_np, x_tf, rtol=rtol, atol=atol)
def _checkMemoryFail(self, x, rank):
config = config_pb2.ConfigProto()
@@ -68,7 +71,8 @@ class BaseFFTOpsTest(test.TestCase):
with self.test_session(config=config, force_gpu=True):
self._tfFFT(x, rank, fft_length=None)
- def _checkGradComplex(self, func, x, y, result_is_complex=True):
+ def _checkGradComplex(self, func, x, y, result_is_complex=True,
+ rtol=1e-2, atol=1e-2):
with self.test_session(use_gpu=True):
inx = ops.convert_to_tensor(x)
iny = ops.convert_to_tensor(y)
@@ -85,10 +89,10 @@ class BaseFFTOpsTest(test.TestCase):
x_init_value=[x, y],
delta=1e-2)
- self.assertAllClose(x_jacob_t, x_jacob_n, rtol=1e-2, atol=1e-2)
- self.assertAllClose(y_jacob_t, y_jacob_n, rtol=1e-2, atol=1e-2)
+ self.assertAllClose(x_jacob_t, x_jacob_n, rtol=rtol, atol=atol)
+ self.assertAllClose(y_jacob_t, y_jacob_n, rtol=rtol, atol=atol)
- def _checkGradReal(self, func, x):
+ def _checkGradReal(self, func, x, rtol=1e-2, atol=1e-2):
with self.test_session(use_gpu=True):
inx = ops.convert_to_tensor(x)
# func is a forward RFFT function (batched or unbatched).
@@ -98,7 +102,7 @@ class BaseFFTOpsTest(test.TestCase):
x_jacob_t, x_jacob_n = test.compute_gradient(
inx, list(x.shape), loss, [1], x_init_value=x, delta=1e-2)
- self.assertAllClose(x_jacob_t, x_jacob_n, rtol=1e-2, atol=1e-2)
+ self.assertAllClose(x_jacob_t, x_jacob_n, rtol=rtol, atol=atol)
class FFTOpsTest(BaseFFTOpsTest):
@@ -155,27 +159,30 @@ class FFTOpsTest(BaseFFTOpsTest):
def testEmpty(self):
with spectral_ops_test_util.fft_kernel_label_map():
- for rank in VALID_FFT_RANKS:
- for dims in xrange(rank, rank + 3):
- x = np.zeros((0,) * dims).astype(np.complex64)
- self.assertEqual(x.shape, self._tfFFT(x, rank).shape)
- self.assertEqual(x.shape, self._tfIFFT(x, rank).shape)
+ for np_type in (np.complex64, np.complex128):
+ for rank in VALID_FFT_RANKS:
+ for dims in xrange(rank, rank + 3):
+ x = np.zeros((0,) * dims).astype(np_type)
+ self.assertEqual(x.shape, self._tfFFT(x, rank).shape)
+ self.assertEqual(x.shape, self._tfIFFT(x, rank).shape)
def testBasic(self):
with spectral_ops_test_util.fft_kernel_label_map():
- for rank in VALID_FFT_RANKS:
- for dims in xrange(rank, rank + 3):
- self._compare(
- np.mod(np.arange(np.power(4, dims)), 10).reshape(
- (4,) * dims).astype(np.complex64), rank)
+ for np_type, tol in ((np.complex64, 1e-4), (np.complex128, 1e-8)):
+ for rank in VALID_FFT_RANKS:
+ for dims in xrange(rank, rank + 3):
+ self._compare(
+ np.mod(np.arange(np.power(4, dims)), 10).reshape(
+ (4,) * dims).astype(np_type), rank, rtol=tol, atol=tol)
def testLargeBatch(self):
if test.is_gpu_available(cuda_only=True):
rank = 1
for dims in xrange(rank, rank + 3):
- self._compare(
- np.mod(np.arange(np.power(128, dims)), 10).reshape(
- (128,) * dims).astype(np.complex64), rank)
+ for np_type, tol in ((np.complex64, 1e-4), (np.complex128, 1e-5)):
+ self._compare(
+ np.mod(np.arange(np.power(128, dims)), 10).reshape(
+ (128,) * dims).astype(np_type), rank, rtol=tol, atol=tol)
# TODO(yangzihao): Disable before we can figure out a way to
# properly test memory fail for large batch fft.
@@ -189,27 +196,49 @@ class FFTOpsTest(BaseFFTOpsTest):
def testBasicPlaceholder(self):
with spectral_ops_test_util.fft_kernel_label_map():
- for rank in VALID_FFT_RANKS:
- for dims in xrange(rank, rank + 3):
- self._compare(
- np.mod(np.arange(np.power(4, dims)), 10).reshape(
- (4,) * dims).astype(np.complex64),
- rank,
- use_placeholder=True)
+ for np_type, tol in ((np.complex64, 1e-4), (np.complex128, 1e-8)):
+ for rank in VALID_FFT_RANKS:
+ for dims in xrange(rank, rank + 3):
+ self._compare(
+ np.mod(np.arange(np.power(4, dims)), 10).reshape(
+ (4,) * dims).astype(np_type),
+ rank, use_placeholder=True, rtol=tol, atol=tol)
def testRandom(self):
with spectral_ops_test_util.fft_kernel_label_map():
- np.random.seed(12345)
+ for np_type, tol in ((np.complex64, 1e-4), (np.complex128, 5e-6)):
+ def gen(shape):
+ n = np.prod(shape)
+ re = np.random.uniform(size=n)
+ im = np.random.uniform(size=n)
+ return (re + im * 1j).reshape(shape)
- def gen(shape):
- n = np.prod(shape)
- re = np.random.uniform(size=n)
- im = np.random.uniform(size=n)
- return (re + im * 1j).reshape(shape)
+ for rank in VALID_FFT_RANKS:
+ for dims in xrange(rank, rank + 3):
+ self._compare(gen((4,) * dims).astype(np_type), rank,
+ rtol=tol, atol=tol)
- for rank in VALID_FFT_RANKS:
- for dims in xrange(rank, rank + 3):
- self._compare(gen((4,) * dims), rank)
+ def testRandom1D(self):
+ with spectral_ops_test_util.fft_kernel_label_map():
+ for np_type in (np.complex64, np.complex128):
+ has_gpu = test.is_gpu_available(cuda_only=True)
+ tol = {(np.complex64, True): 1e-4,
+ (np.complex64, False): 1e-2,
+ (np.complex128, True): 1e-4,
+ (np.complex128, False): 1e-2}[(np_type, has_gpu)]
+ def gen(shape):
+ n = np.prod(shape)
+ re = np.random.uniform(size=n)
+ im = np.random.uniform(size=n)
+ return (re + im * 1j).reshape(shape)
+
+ # Check a variety of power-of-2 FFT sizes.
+ for dim in (128, 256, 512, 1024):
+ self._compare(gen((dim,)).astype(np_type), 1, rtol=tol, atol=tol)
+
+ # Check a variety of non-power-of-2 FFT sizes.
+ for dim in (127, 255, 511, 1023):
+ self._compare(gen((dim,)).astype(np_type), 1, rtol=tol, atol=tol)
def testError(self):
for rank in VALID_FFT_RANKS:
@@ -224,22 +253,27 @@ class FFTOpsTest(BaseFFTOpsTest):
def testGrad_Simple(self):
with spectral_ops_test_util.fft_kernel_label_map():
- for rank in VALID_FFT_RANKS:
- for dims in xrange(rank, rank + 2):
- re = np.ones(shape=(4,) * dims, dtype=np.float32) / 10.0
- im = np.zeros(shape=(4,) * dims, dtype=np.float32)
- self._checkGradComplex(self._tfFFTForRank(rank), re, im)
- self._checkGradComplex(self._tfIFFTForRank(rank), re, im)
+ for np_type, tol in ((np.float32, 1e-4), (np.float64, 1e-10)):
+ for rank in VALID_FFT_RANKS:
+ for dims in xrange(rank, rank + 2):
+ re = np.ones(shape=(4,) * dims, dtype=np_type) / 10.0
+ im = np.zeros(shape=(4,) * dims, dtype=np_type)
+ self._checkGradComplex(self._tfFFTForRank(rank), re, im,
+ rtol=tol, atol=tol)
+ self._checkGradComplex(self._tfIFFTForRank(rank), re, im,
+ rtol=tol, atol=tol)
def testGrad_Random(self):
with spectral_ops_test_util.fft_kernel_label_map():
- np.random.seed(54321)
- for rank in VALID_FFT_RANKS:
- for dims in xrange(rank, rank + 2):
- re = np.random.rand(*((3,) * dims)).astype(np.float32) * 2 - 1
- im = np.random.rand(*((3,) * dims)).astype(np.float32) * 2 - 1
- self._checkGradComplex(self._tfFFTForRank(rank), re, im)
- self._checkGradComplex(self._tfIFFTForRank(rank), re, im)
+ for np_type, tol in ((np.float32, 1e-2), (np.float64, 1e-10)):
+ for rank in VALID_FFT_RANKS:
+ for dims in xrange(rank, rank + 2):
+ re = np.random.rand(*((3,) * dims)).astype(np_type) * 2 - 1
+ im = np.random.rand(*((3,) * dims)).astype(np_type) * 2 - 1
+ self._checkGradComplex(self._tfFFTForRank(rank), re, im,
+ rtol=tol, atol=tol)
+ self._checkGradComplex(self._tfIFFTForRank(rank), re, im,
+ rtol=tol, atol=tol)
class RFFTOpsTest(BaseFFTOpsTest):
@@ -395,8 +429,6 @@ class RFFTOpsTest(BaseFFTOpsTest):
def testRandom(self):
with spectral_ops_test_util.fft_kernel_label_map():
- np.random.seed(12345)
-
def gen_real(shape):
n = np.prod(shape)
re = np.random.uniform(size=n)
@@ -491,7 +523,6 @@ class RFFTOpsTest(BaseFFTOpsTest):
def testGrad_Random(self):
with spectral_ops_test_util.fft_kernel_label_map():
- np.random.seed(54321)
for rank in VALID_FFT_RANKS:
# rfft3d/irfft3d do not have gradients yet.
if rank == 3:
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py
index e7f2f1c12b..5713d16969 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py
@@ -73,7 +73,7 @@ class LinearOperatorCirculantBaseTest(object):
x = np.zeros([domain_dimension])
# x is a basis vector.
x[m] = 1.0
- fft_x = math_ops.fft(x)
+ fft_x = math_ops.fft(x.astype(np.complex64))
h_convolve_x = math_ops.ifft(spectrum * fft_x)
matrix_rows.append(h_convolve_x)
matrix = array_ops.stack(matrix_rows, axis=-1)
@@ -91,7 +91,7 @@ class LinearOperatorCirculantTestSelfAdjointOperator(
@property
def _dtypes_to_test(self):
- # This operator will always be complex because, although the specturm is
+ # This operator will always be complex because, although the spectrum is
# real, the matrix will not be real.
return [dtypes.complex64]
@@ -408,7 +408,7 @@ class LinearOperatorCirculant2DBaseTest(object):
x = np.zeros(block_shape)
# x is a basis vector.
x[n0, n1] = 1.0
- fft_x = math_ops.fft2d(x)
+ fft_x = math_ops.fft2d(x.astype(np.complex64))
h_convolve_x = math_ops.ifft2d(spectrum * fft_x)
# We want the flat version of the action of the operator on a basis
# vector, not the block version.
diff --git a/tensorflow/python/ops/spectral_grad.py b/tensorflow/python/ops/spectral_grad.py
index deb0a57178..0af24114ac 100644
--- a/tensorflow/python/ops/spectral_grad.py
+++ b/tensorflow/python/ops/spectral_grad.py
@@ -32,38 +32,44 @@ def _FFTSizeForGrad(grad, rank):
@ops.RegisterGradient("FFT")
def _FFTGrad(_, grad):
- size = math_ops.cast(_FFTSizeForGrad(grad, 1), dtypes.float32)
- return spectral_ops.ifft(grad) * math_ops.complex(size, 0.)
+ size = math_ops.cast(_FFTSizeForGrad(grad, 1), grad.dtype)
+ return spectral_ops.ifft(grad) * size
@ops.RegisterGradient("IFFT")
def _IFFTGrad(_, grad):
- rsize = 1. / math_ops.cast(_FFTSizeForGrad(grad, 1), dtypes.float32)
- return spectral_ops.fft(grad) * math_ops.complex(rsize, 0.)
+ rsize = math_ops.cast(
+ 1. / math_ops.cast(_FFTSizeForGrad(grad, 1), grad.dtype.real_dtype),
+ grad.dtype)
+ return spectral_ops.fft(grad) * rsize
@ops.RegisterGradient("FFT2D")
def _FFT2DGrad(_, grad):
- size = math_ops.cast(_FFTSizeForGrad(grad, 2), dtypes.float32)
- return spectral_ops.ifft2d(grad) * math_ops.complex(size, 0.)
+ size = math_ops.cast(_FFTSizeForGrad(grad, 2), grad.dtype)
+ return spectral_ops.ifft2d(grad) * size
@ops.RegisterGradient("IFFT2D")
def _IFFT2DGrad(_, grad):
- rsize = 1. / math_ops.cast(_FFTSizeForGrad(grad, 2), dtypes.float32)
- return spectral_ops.fft2d(grad) * math_ops.complex(rsize, 0.)
+ rsize = math_ops.cast(
+ 1. / math_ops.cast(_FFTSizeForGrad(grad, 2), grad.dtype.real_dtype),
+ grad.dtype)
+ return spectral_ops.fft2d(grad) * rsize
@ops.RegisterGradient("FFT3D")
def _FFT3DGrad(_, grad):
- size = math_ops.cast(_FFTSizeForGrad(grad, 3), dtypes.float32)
- return spectral_ops.ifft3d(grad) * math_ops.complex(size, 0.)
+ size = math_ops.cast(_FFTSizeForGrad(grad, 3), grad.dtype)
+ return spectral_ops.ifft3d(grad) * size
@ops.RegisterGradient("IFFT3D")
def _IFFT3DGrad(_, grad):
- rsize = 1. / math_ops.cast(_FFTSizeForGrad(grad, 3), dtypes.float32)
- return spectral_ops.fft3d(grad) * math_ops.complex(rsize, 0.)
+ rsize = math_ops.cast(
+ 1. / math_ops.cast(_FFTSizeForGrad(grad, 3), grad.dtype.real_dtype),
+ grad.dtype)
+ return spectral_ops.fft3d(grad) * rsize
def _RFFTGradHelper(rank, irfft_fn):