aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/signal
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-09-28 09:03:07 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-28 09:08:21 -0700
commit457bc31afdbc4f11181a93fed3ac8a404610be2a (patch)
treedd073a40c8b18a01f0dd708f5c84fbfd4f1c45d3 /tensorflow/contrib/signal
parent728e238d26669a358ff296364b83325ce0e14c34 (diff)
Compute static GCD where possible.
PiperOrigin-RevId: 170350852
Diffstat (limited to 'tensorflow/contrib/signal')
-rw-r--r--tensorflow/contrib/signal/python/ops/util_ops.py10
1 files changed, 10 insertions, 0 deletions
diff --git a/tensorflow/contrib/signal/python/ops/util_ops.py b/tensorflow/contrib/signal/python/ops/util_ops.py
index eee829d799..817c9b97d6 100644
--- a/tensorflow/contrib/signal/python/ops/util_ops.py
+++ b/tensorflow/contrib/signal/python/ops/util_ops.py
@@ -18,7 +18,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import fractions
+
from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
@@ -51,6 +54,13 @@ def gcd(a, b, name=None):
if not b.dtype.is_integer:
raise ValueError('b must be an integer type. Got: %s' % b.dtype)
+ # TPU requires static shape inference. GCD is used for subframe size
+ # computation, so we should prefer static computation where possible.
+ const_a = tensor_util.constant_value(a)
+ const_b = tensor_util.constant_value(b)
+ if const_a is not None and const_b is not None:
+ return ops.convert_to_tensor(fractions.gcd(const_a, const_b))
+
cond = lambda _, b: math_ops.greater(b, array_ops.zeros_like(b))
body = lambda a, b: [b, math_ops.mod(a, b)]
a, b = control_flow_ops.while_loop(cond, body, [a, b], back_prop=False)