aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-16 15:14:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-16 15:18:15 -0700
commit3618796b3bee7bd0eb06425d6a069d28b95e6f42 (patch)
tree122530f5f932aba1ae183a2ecf84fa5e311be69c /tensorflow/compiler/xla/service/shape_inference.cc
parent409c0673b6a98a02be19adba5e64a489bd28b703 (diff)
Implement lgamma for XLA
Add support for Real and Imag for real floating point types. Compute the Lgamma function using Lanczos' approximation from "A Precision Approximation of the Gamma Function". SIAM Journal on Numerical Analysis series B. Vol. 1: lgamma(z + 1) = (log(2) + log(pi)) / 2 + (z + 1/2) * log(t(z)) - t(z) + A(z) t(z) = z + kLanczosGamma + 1/2 A(z) = kBaseLanczosCoeff + sigma(k = 1, n, kLanczosCoefficients[i] / (z + k)) PiperOrigin-RevId: 204815805
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.cc')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc11
1 files changed, 7 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 70edf7883f..214146cf68 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -222,13 +222,16 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
return shape;
case HloOpcode::kReal:
case HloOpcode::kImag:
- if (!ShapeUtil::ElementIsComplex(shape)) {
+ if (ShapeUtil::ElementIsComplex(shape)) {
+ return ShapeUtil::ComplexComponentShape(shape);
+ } else if (ShapeUtil::ElementIsFloating(shape)) {
+ return shape;
+ } else {
return InvalidArgument(
- "Expected element type in shape to be complex for real/imag "
- "operation; got %s.",
+ "Expected element type in shape to be floating or complex for "
+ "real/imag operation; got %s.",
PrimitiveType_Name(shape.element_type()).c_str());
}
- return ShapeUtil::ChangeElementType(shape, F32);
case HloOpcode::kAbs:
if (ShapeUtil::ElementIsComplex(shape)) {
return ShapeUtil::ChangeElementType(