diff options
author | 2018-05-02 17:57:27 -0700 | |
---|---|---|
committer | 2018-05-02 18:00:06 -0700 | |
commit | 8f0a90b711480c12716d1a3b1094cc8b34939f2d (patch) | |
tree | fc13cfb0c8bbd942a5cc46be8cd426f9f3a32b02 | |
parent | 7833890a0da5226e4c409b1020155f1718c0edb2 (diff) |
Add complex128 support to FFT, FFT2D, FFT3D, IFFT, IFFT2D, and IFFT3D.
NumPy automatically upcasts to complex128 when computing FFTs, leading to issues like:
#10749
This change allows users to choose between 32-bit and 64-bit precision FFTs on CPU and GPU.
PiperOrigin-RevId: 195183206
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/fft_ops.cc | 17 | ||||
-rw-r--r-- | tensorflow/core/kernels/fft_ops.cc | 78 | ||||
-rw-r--r-- | tensorflow/core/ops/spectral_ops.cc | 30 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/fft_ops_test.py | 145 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py | 6 | ||||
-rw-r--r-- | tensorflow/python/ops/spectral_grad.py | 30 |
6 files changed, 196 insertions, 110 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc index fcb927dab0..933924cad1 100644 --- a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc @@ -81,9 +81,11 @@ class FFTOp : public GenericFftOp { explicit FFTOp(OpKernelConstruction* ctx) : GenericFftOp(ctx, /*fft_type=*/FftType::FFT, /*fft_rank=*/FFTRank) {} }; -REGISTER_XLA_OP(Name("FFT"), FFTOp<1>); -REGISTER_XLA_OP(Name("FFT2D"), FFTOp<2>); -REGISTER_XLA_OP(Name("FFT3D"), FFTOp<3>); +REGISTER_XLA_OP(Name("FFT").TypeConstraint("Tcomplex", DT_COMPLEX64), FFTOp<1>); +REGISTER_XLA_OP(Name("FFT2D").TypeConstraint("Tcomplex", DT_COMPLEX64), + FFTOp<2>); +REGISTER_XLA_OP(Name("FFT3D").TypeConstraint("Tcomplex", DT_COMPLEX64), + FFTOp<3>); template <int FFTRank> class IFFTOp : public GenericFftOp { @@ -91,9 +93,12 @@ class IFFTOp : public GenericFftOp { explicit IFFTOp(OpKernelConstruction* ctx) : GenericFftOp(ctx, /*fft_type=*/FftType::IFFT, /*fft_rank=*/FFTRank) {} }; -REGISTER_XLA_OP(Name("IFFT"), IFFTOp<1>); -REGISTER_XLA_OP(Name("IFFT2D"), IFFTOp<2>); -REGISTER_XLA_OP(Name("IFFT3D"), IFFTOp<3>); +REGISTER_XLA_OP(Name("IFFT").TypeConstraint("Tcomplex", DT_COMPLEX64), + IFFTOp<1>); +REGISTER_XLA_OP(Name("IFFT2D").TypeConstraint("Tcomplex", DT_COMPLEX64), + IFFTOp<2>); +REGISTER_XLA_OP(Name("IFFT3D").TypeConstraint("Tcomplex", DT_COMPLEX64), + IFFTOp<3>); template <int FFTRank> class RFFTOp : public GenericFftOp { diff --git a/tensorflow/core/kernels/fft_ops.cc b/tensorflow/core/kernels/fft_ops.cc index 661bf5fc5f..d7105a71bb 100644 --- a/tensorflow/core/kernels/fft_ops.cc +++ b/tensorflow/core/kernels/fft_ops.cc @@ -129,13 +129,23 @@ class FFTCPU : public FFTBase { auto device = ctx->eigen_device<CPUDevice>(); if (!IsReal()) { - auto input = Tensor(in).flat_inner_dims<complex64, FFTRank + 1>(); - // Compute the FFT using eigen. - auto output = out->flat_inner_dims<complex64, FFTRank + 1>(); + // Compute the FFT using Eigen. constexpr auto direction = Forward ? Eigen::FFT_FORWARD : Eigen::FFT_REVERSE; - output.device(device) = - input.template fft<Eigen::BothParts, direction>(axes); + if (in.dtype() == DT_COMPLEX64) { + DCHECK_EQ(out->dtype(), DT_COMPLEX64); + auto input = Tensor(in).flat_inner_dims<complex64, FFTRank + 1>(); + auto output = out->flat_inner_dims<complex64, FFTRank + 1>(); + output.device(device) = + input.template fft<Eigen::BothParts, direction>(axes); + } else { + DCHECK_EQ(DT_COMPLEX128, in.dtype()); + DCHECK_EQ(DT_COMPLEX128, out->dtype()); + auto input = Tensor(in).flat_inner_dims<complex128, FFTRank + 1>(); + auto output = out->flat_inner_dims<complex128, FFTRank + 1>(); + output.device(device) = + input.template fft<Eigen::BothParts, direction>(axes); + } } else { if (IsForward()) { auto input = Tensor(in).flat_inner_dims<float, FFTRank + 1>(); @@ -392,10 +402,16 @@ class FFTGPUBase : public FFTBase { } constexpr bool kInPlaceFft = false; + const bool is_complex128 = in.dtype() == DT_COMPLEX128; + // complex128 real FFT is not supported yet. + DCHECK(!IsReal() || !is_complex128); + const auto kFftType = IsReal() ? (IsForward() ? se::fft::Type::kR2C : se::fft::Type::kC2R) - : (IsForward() ? se::fft::Type::kC2CForward - : se::fft::Type::kC2CInverse); + : (IsForward() ? (is_complex128 ? se::fft::Type::kZ2ZForward + : se::fft::Type::kC2CForward) + : (is_complex128 ? se::fft::Type::kZ2ZInverse + : se::fft::Type::kC2CInverse)); CufftScratchAllocator scratch_allocator(CufftScratchSize, ctx); auto plan = @@ -428,20 +444,42 @@ class FFTGPUBase : public FFTBase { input_shape.DebugString())); } } else { - auto src = AsDeviceMemory<complex64>(in.flat<complex64>().data()); - auto dst = AsDeviceMemory<complex64>(out->flat<complex64>().data()); - OP_REQUIRES( - ctx, stream->ThenFft(plan.get(), src, &dst).ok(), - errors::Internal("fft failed : type=", static_cast<int>(kFftType), - " in.shape=", input_shape.DebugString())); - if (!IsForward()) { - auto alpha = complex64(1.f / output_distance); + if (!is_complex128) { + DCHECK_EQ(in.dtype(), DT_COMPLEX64); + DCHECK_EQ(out->dtype(), DT_COMPLEX64); + auto src = AsDeviceMemory<complex64>(in.flat<complex64>().data()); + auto dst = AsDeviceMemory<complex64>(out->flat<complex64>().data()); OP_REQUIRES( - ctx, - stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1) - .ok(), - errors::Internal("BlasScal failed : in.shape=", - input_shape.DebugString())); + ctx, stream->ThenFft(plan.get(), src, &dst).ok(), + errors::Internal("fft failed : type=", static_cast<int>(kFftType), + " in.shape=", input_shape.DebugString())); + if (!IsForward()) { + float alpha = 1.f / output_distance; + OP_REQUIRES( + ctx, + stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1) + .ok(), + errors::Internal("BlasScal failed : in.shape=", + input_shape.DebugString())); + } + } else { + DCHECK_EQ(in.dtype(), DT_COMPLEX128); + DCHECK_EQ(out->dtype(), DT_COMPLEX128); + auto src = AsDeviceMemory<complex128>(in.flat<complex128>().data()); + auto dst = AsDeviceMemory<complex128>(out->flat<complex128>().data()); + OP_REQUIRES( + ctx, stream->ThenFft(plan.get(), src, &dst).ok(), + errors::Internal("fft failed : type=", static_cast<int>(kFftType), + " in.shape=", input_shape.DebugString())); + if (!IsForward()) { + double alpha = 1.0 / output_distance; + OP_REQUIRES( + ctx, + stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1) + .ok(), + errors::Internal("BlasScal failed : in.shape=", + input_shape.DebugString())); + } } } } diff --git a/tensorflow/core/ops/spectral_ops.cc b/tensorflow/core/ops/spectral_ops.cc index 2790aee37e..b1ae7040f0 100644 --- a/tensorflow/core/ops/spectral_ops.cc +++ b/tensorflow/core/ops/spectral_ops.cc @@ -25,43 +25,49 @@ using shape_inference::InferenceContext; using shape_inference::ShapeHandle; REGISTER_OP("FFT") - .Input("input: complex64") - .Output("output: complex64") + .Input("input: Tcomplex") + .Output("output: Tcomplex") + .Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64") .SetShapeFn([](InferenceContext* c) { return shape_inference::UnchangedShapeWithRankAtLeast(c, 1); }); REGISTER_OP("IFFT") - .Input("input: complex64") - .Output("output: complex64") + .Input("input: Tcomplex") + .Output("output: Tcomplex") + .Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64") .SetShapeFn([](InferenceContext* c) { return shape_inference::UnchangedShapeWithRankAtLeast(c, 1); }); REGISTER_OP("FFT2D") - .Input("input: complex64") - .Output("output: complex64") + .Input("input: Tcomplex") + .Output("output: Tcomplex") + .Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64") .SetShapeFn([](InferenceContext* c) { return shape_inference::UnchangedShapeWithRankAtLeast(c, 2); }); REGISTER_OP("IFFT2D") - .Input("input: complex64") - .Output("output: complex64") + .Input("input: Tcomplex") + .Output("output: Tcomplex") + .Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64") .SetShapeFn([](InferenceContext* c) { return shape_inference::UnchangedShapeWithRankAtLeast(c, 2); }); REGISTER_OP("FFT3D") - .Input("input: complex64") - .Output("output: complex64") + .Input("input: Tcomplex") + .Output("output: Tcomplex") + .Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64") .SetShapeFn([](InferenceContext* c) { return shape_inference::UnchangedShapeWithRankAtLeast(c, 3); }); REGISTER_OP("IFFT3D") - .Input("input: complex64") - .Output("output: complex64") + .Input("input: Tcomplex") + .Output("output: Tcomplex") + .Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64") .SetShapeFn([](InferenceContext* c) { return shape_inference::UnchangedShapeWithRankAtLeast(c, 3); }); diff --git a/tensorflow/python/kernel_tests/fft_ops_test.py b/tensorflow/python/kernel_tests/fft_ops_test.py index b9e2aa1f3a..629acedda5 100644 --- a/tensorflow/python/kernel_tests/fft_ops_test.py +++ b/tensorflow/python/kernel_tests/fft_ops_test.py @@ -38,11 +38,13 @@ VALID_FFT_RANKS = (1, 2, 3) class BaseFFTOpsTest(test.TestCase): - def _compare(self, x, rank, fft_length=None, use_placeholder=False): - self._compareForward(x, rank, fft_length, use_placeholder) - self._compareBackward(x, rank, fft_length, use_placeholder) + def _compare(self, x, rank, fft_length=None, use_placeholder=False, + rtol=1e-4, atol=1e-4): + self._compareForward(x, rank, fft_length, use_placeholder, rtol, atol) + self._compareBackward(x, rank, fft_length, use_placeholder, rtol, atol) - def _compareForward(self, x, rank, fft_length=None, use_placeholder=False): + def _compareForward(self, x, rank, fft_length=None, use_placeholder=False, + rtol=1e-4, atol=1e-4): x_np = self._npFFT(x, rank, fft_length) if use_placeholder: x_ph = array_ops.placeholder(dtype=dtypes.as_dtype(x.dtype)) @@ -50,9 +52,10 @@ class BaseFFTOpsTest(test.TestCase): else: x_tf = self._tfFFT(x, rank, fft_length) - self.assertAllClose(x_np, x_tf, rtol=1e-4, atol=1e-4) + self.assertAllClose(x_np, x_tf, rtol=rtol, atol=atol) - def _compareBackward(self, x, rank, fft_length=None, use_placeholder=False): + def _compareBackward(self, x, rank, fft_length=None, use_placeholder=False, + rtol=1e-4, atol=1e-4): x_np = self._npIFFT(x, rank, fft_length) if use_placeholder: x_ph = array_ops.placeholder(dtype=dtypes.as_dtype(x.dtype)) @@ -60,7 +63,7 @@ class BaseFFTOpsTest(test.TestCase): else: x_tf = self._tfIFFT(x, rank, fft_length) - self.assertAllClose(x_np, x_tf, rtol=1e-4, atol=1e-4) + self.assertAllClose(x_np, x_tf, rtol=rtol, atol=atol) def _checkMemoryFail(self, x, rank): config = config_pb2.ConfigProto() @@ -68,7 +71,8 @@ class BaseFFTOpsTest(test.TestCase): with self.test_session(config=config, force_gpu=True): self._tfFFT(x, rank, fft_length=None) - def _checkGradComplex(self, func, x, y, result_is_complex=True): + def _checkGradComplex(self, func, x, y, result_is_complex=True, + rtol=1e-2, atol=1e-2): with self.test_session(use_gpu=True): inx = ops.convert_to_tensor(x) iny = ops.convert_to_tensor(y) @@ -85,10 +89,10 @@ class BaseFFTOpsTest(test.TestCase): x_init_value=[x, y], delta=1e-2) - self.assertAllClose(x_jacob_t, x_jacob_n, rtol=1e-2, atol=1e-2) - self.assertAllClose(y_jacob_t, y_jacob_n, rtol=1e-2, atol=1e-2) + self.assertAllClose(x_jacob_t, x_jacob_n, rtol=rtol, atol=atol) + self.assertAllClose(y_jacob_t, y_jacob_n, rtol=rtol, atol=atol) - def _checkGradReal(self, func, x): + def _checkGradReal(self, func, x, rtol=1e-2, atol=1e-2): with self.test_session(use_gpu=True): inx = ops.convert_to_tensor(x) # func is a forward RFFT function (batched or unbatched). @@ -98,7 +102,7 @@ class BaseFFTOpsTest(test.TestCase): x_jacob_t, x_jacob_n = test.compute_gradient( inx, list(x.shape), loss, [1], x_init_value=x, delta=1e-2) - self.assertAllClose(x_jacob_t, x_jacob_n, rtol=1e-2, atol=1e-2) + self.assertAllClose(x_jacob_t, x_jacob_n, rtol=rtol, atol=atol) class FFTOpsTest(BaseFFTOpsTest): @@ -155,27 +159,30 @@ class FFTOpsTest(BaseFFTOpsTest): def testEmpty(self): with spectral_ops_test_util.fft_kernel_label_map(): - for rank in VALID_FFT_RANKS: - for dims in xrange(rank, rank + 3): - x = np.zeros((0,) * dims).astype(np.complex64) - self.assertEqual(x.shape, self._tfFFT(x, rank).shape) - self.assertEqual(x.shape, self._tfIFFT(x, rank).shape) + for np_type in (np.complex64, np.complex128): + for rank in VALID_FFT_RANKS: + for dims in xrange(rank, rank + 3): + x = np.zeros((0,) * dims).astype(np_type) + self.assertEqual(x.shape, self._tfFFT(x, rank).shape) + self.assertEqual(x.shape, self._tfIFFT(x, rank).shape) def testBasic(self): with spectral_ops_test_util.fft_kernel_label_map(): - for rank in VALID_FFT_RANKS: - for dims in xrange(rank, rank + 3): - self._compare( - np.mod(np.arange(np.power(4, dims)), 10).reshape( - (4,) * dims).astype(np.complex64), rank) + for np_type, tol in ((np.complex64, 1e-4), (np.complex128, 1e-8)): + for rank in VALID_FFT_RANKS: + for dims in xrange(rank, rank + 3): + self._compare( + np.mod(np.arange(np.power(4, dims)), 10).reshape( + (4,) * dims).astype(np_type), rank, rtol=tol, atol=tol) def testLargeBatch(self): if test.is_gpu_available(cuda_only=True): rank = 1 for dims in xrange(rank, rank + 3): - self._compare( - np.mod(np.arange(np.power(128, dims)), 10).reshape( - (128,) * dims).astype(np.complex64), rank) + for np_type, tol in ((np.complex64, 1e-4), (np.complex128, 1e-5)): + self._compare( + np.mod(np.arange(np.power(128, dims)), 10).reshape( + (128,) * dims).astype(np_type), rank, rtol=tol, atol=tol) # TODO(yangzihao): Disable before we can figure out a way to # properly test memory fail for large batch fft. @@ -189,27 +196,49 @@ class FFTOpsTest(BaseFFTOpsTest): def testBasicPlaceholder(self): with spectral_ops_test_util.fft_kernel_label_map(): - for rank in VALID_FFT_RANKS: - for dims in xrange(rank, rank + 3): - self._compare( - np.mod(np.arange(np.power(4, dims)), 10).reshape( - (4,) * dims).astype(np.complex64), - rank, - use_placeholder=True) + for np_type, tol in ((np.complex64, 1e-4), (np.complex128, 1e-8)): + for rank in VALID_FFT_RANKS: + for dims in xrange(rank, rank + 3): + self._compare( + np.mod(np.arange(np.power(4, dims)), 10).reshape( + (4,) * dims).astype(np_type), + rank, use_placeholder=True, rtol=tol, atol=tol) def testRandom(self): with spectral_ops_test_util.fft_kernel_label_map(): - np.random.seed(12345) + for np_type, tol in ((np.complex64, 1e-4), (np.complex128, 5e-6)): + def gen(shape): + n = np.prod(shape) + re = np.random.uniform(size=n) + im = np.random.uniform(size=n) + return (re + im * 1j).reshape(shape) - def gen(shape): - n = np.prod(shape) - re = np.random.uniform(size=n) - im = np.random.uniform(size=n) - return (re + im * 1j).reshape(shape) + for rank in VALID_FFT_RANKS: + for dims in xrange(rank, rank + 3): + self._compare(gen((4,) * dims).astype(np_type), rank, + rtol=tol, atol=tol) - for rank in VALID_FFT_RANKS: - for dims in xrange(rank, rank + 3): - self._compare(gen((4,) * dims), rank) + def testRandom1D(self): + with spectral_ops_test_util.fft_kernel_label_map(): + for np_type in (np.complex64, np.complex128): + has_gpu = test.is_gpu_available(cuda_only=True) + tol = {(np.complex64, True): 1e-4, + (np.complex64, False): 1e-2, + (np.complex128, True): 1e-4, + (np.complex128, False): 1e-2}[(np_type, has_gpu)] + def gen(shape): + n = np.prod(shape) + re = np.random.uniform(size=n) + im = np.random.uniform(size=n) + return (re + im * 1j).reshape(shape) + + # Check a variety of power-of-2 FFT sizes. + for dim in (128, 256, 512, 1024): + self._compare(gen((dim,)).astype(np_type), 1, rtol=tol, atol=tol) + + # Check a variety of non-power-of-2 FFT sizes. + for dim in (127, 255, 511, 1023): + self._compare(gen((dim,)).astype(np_type), 1, rtol=tol, atol=tol) def testError(self): for rank in VALID_FFT_RANKS: @@ -224,22 +253,27 @@ class FFTOpsTest(BaseFFTOpsTest): def testGrad_Simple(self): with spectral_ops_test_util.fft_kernel_label_map(): - for rank in VALID_FFT_RANKS: - for dims in xrange(rank, rank + 2): - re = np.ones(shape=(4,) * dims, dtype=np.float32) / 10.0 - im = np.zeros(shape=(4,) * dims, dtype=np.float32) - self._checkGradComplex(self._tfFFTForRank(rank), re, im) - self._checkGradComplex(self._tfIFFTForRank(rank), re, im) + for np_type, tol in ((np.float32, 1e-4), (np.float64, 1e-10)): + for rank in VALID_FFT_RANKS: + for dims in xrange(rank, rank + 2): + re = np.ones(shape=(4,) * dims, dtype=np_type) / 10.0 + im = np.zeros(shape=(4,) * dims, dtype=np_type) + self._checkGradComplex(self._tfFFTForRank(rank), re, im, + rtol=tol, atol=tol) + self._checkGradComplex(self._tfIFFTForRank(rank), re, im, + rtol=tol, atol=tol) def testGrad_Random(self): with spectral_ops_test_util.fft_kernel_label_map(): - np.random.seed(54321) - for rank in VALID_FFT_RANKS: - for dims in xrange(rank, rank + 2): - re = np.random.rand(*((3,) * dims)).astype(np.float32) * 2 - 1 - im = np.random.rand(*((3,) * dims)).astype(np.float32) * 2 - 1 - self._checkGradComplex(self._tfFFTForRank(rank), re, im) - self._checkGradComplex(self._tfIFFTForRank(rank), re, im) + for np_type, tol in ((np.float32, 1e-2), (np.float64, 1e-10)): + for rank in VALID_FFT_RANKS: + for dims in xrange(rank, rank + 2): + re = np.random.rand(*((3,) * dims)).astype(np_type) * 2 - 1 + im = np.random.rand(*((3,) * dims)).astype(np_type) * 2 - 1 + self._checkGradComplex(self._tfFFTForRank(rank), re, im, + rtol=tol, atol=tol) + self._checkGradComplex(self._tfIFFTForRank(rank), re, im, + rtol=tol, atol=tol) class RFFTOpsTest(BaseFFTOpsTest): @@ -395,8 +429,6 @@ class RFFTOpsTest(BaseFFTOpsTest): def testRandom(self): with spectral_ops_test_util.fft_kernel_label_map(): - np.random.seed(12345) - def gen_real(shape): n = np.prod(shape) re = np.random.uniform(size=n) @@ -491,7 +523,6 @@ class RFFTOpsTest(BaseFFTOpsTest): def testGrad_Random(self): with spectral_ops_test_util.fft_kernel_label_map(): - np.random.seed(54321) for rank in VALID_FFT_RANKS: # rfft3d/irfft3d do not have gradients yet. if rank == 3: diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py index e7f2f1c12b..5713d16969 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py @@ -73,7 +73,7 @@ class LinearOperatorCirculantBaseTest(object): x = np.zeros([domain_dimension]) # x is a basis vector. x[m] = 1.0 - fft_x = math_ops.fft(x) + fft_x = math_ops.fft(x.astype(np.complex64)) h_convolve_x = math_ops.ifft(spectrum * fft_x) matrix_rows.append(h_convolve_x) matrix = array_ops.stack(matrix_rows, axis=-1) @@ -91,7 +91,7 @@ class LinearOperatorCirculantTestSelfAdjointOperator( @property def _dtypes_to_test(self): - # This operator will always be complex because, although the specturm is + # This operator will always be complex because, although the spectrum is # real, the matrix will not be real. return [dtypes.complex64] @@ -408,7 +408,7 @@ class LinearOperatorCirculant2DBaseTest(object): x = np.zeros(block_shape) # x is a basis vector. x[n0, n1] = 1.0 - fft_x = math_ops.fft2d(x) + fft_x = math_ops.fft2d(x.astype(np.complex64)) h_convolve_x = math_ops.ifft2d(spectrum * fft_x) # We want the flat version of the action of the operator on a basis # vector, not the block version. diff --git a/tensorflow/python/ops/spectral_grad.py b/tensorflow/python/ops/spectral_grad.py index deb0a57178..0af24114ac 100644 --- a/tensorflow/python/ops/spectral_grad.py +++ b/tensorflow/python/ops/spectral_grad.py @@ -32,38 +32,44 @@ def _FFTSizeForGrad(grad, rank): @ops.RegisterGradient("FFT") def _FFTGrad(_, grad): - size = math_ops.cast(_FFTSizeForGrad(grad, 1), dtypes.float32) - return spectral_ops.ifft(grad) * math_ops.complex(size, 0.) + size = math_ops.cast(_FFTSizeForGrad(grad, 1), grad.dtype) + return spectral_ops.ifft(grad) * size @ops.RegisterGradient("IFFT") def _IFFTGrad(_, grad): - rsize = 1. / math_ops.cast(_FFTSizeForGrad(grad, 1), dtypes.float32) - return spectral_ops.fft(grad) * math_ops.complex(rsize, 0.) + rsize = math_ops.cast( + 1. / math_ops.cast(_FFTSizeForGrad(grad, 1), grad.dtype.real_dtype), + grad.dtype) + return spectral_ops.fft(grad) * rsize @ops.RegisterGradient("FFT2D") def _FFT2DGrad(_, grad): - size = math_ops.cast(_FFTSizeForGrad(grad, 2), dtypes.float32) - return spectral_ops.ifft2d(grad) * math_ops.complex(size, 0.) + size = math_ops.cast(_FFTSizeForGrad(grad, 2), grad.dtype) + return spectral_ops.ifft2d(grad) * size @ops.RegisterGradient("IFFT2D") def _IFFT2DGrad(_, grad): - rsize = 1. / math_ops.cast(_FFTSizeForGrad(grad, 2), dtypes.float32) - return spectral_ops.fft2d(grad) * math_ops.complex(rsize, 0.) + rsize = math_ops.cast( + 1. / math_ops.cast(_FFTSizeForGrad(grad, 2), grad.dtype.real_dtype), + grad.dtype) + return spectral_ops.fft2d(grad) * rsize @ops.RegisterGradient("FFT3D") def _FFT3DGrad(_, grad): - size = math_ops.cast(_FFTSizeForGrad(grad, 3), dtypes.float32) - return spectral_ops.ifft3d(grad) * math_ops.complex(size, 0.) + size = math_ops.cast(_FFTSizeForGrad(grad, 3), grad.dtype) + return spectral_ops.ifft3d(grad) * size @ops.RegisterGradient("IFFT3D") def _IFFT3DGrad(_, grad): - rsize = 1. / math_ops.cast(_FFTSizeForGrad(grad, 3), dtypes.float32) - return spectral_ops.fft3d(grad) * math_ops.complex(rsize, 0.) + rsize = math_ops.cast( + 1. / math_ops.cast(_FFTSizeForGrad(grad, 3), grad.dtype.real_dtype), + grad.dtype) + return spectral_ops.fft3d(grad) * rsize def _RFFTGradHelper(rank, irfft_fn): |