diff options
author | 2017-09-28 09:03:07 -0700 | |
---|---|---|
committer | 2017-09-28 09:08:21 -0700 | |
commit | 457bc31afdbc4f11181a93fed3ac8a404610be2a (patch) | |
tree | dd073a40c8b18a01f0dd708f5c84fbfd4f1c45d3 /tensorflow/contrib/signal | |
parent | 728e238d26669a358ff296364b83325ce0e14c34 (diff) |
Compute static GCD where possible.
PiperOrigin-RevId: 170350852
Diffstat (limited to 'tensorflow/contrib/signal')
-rw-r--r-- | tensorflow/contrib/signal/python/ops/util_ops.py | 10 |
1 files changed, 10 insertions, 0 deletions
diff --git a/tensorflow/contrib/signal/python/ops/util_ops.py b/tensorflow/contrib/signal/python/ops/util_ops.py index eee829d799..817c9b97d6 100644 --- a/tensorflow/contrib/signal/python/ops/util_ops.py +++ b/tensorflow/contrib/signal/python/ops/util_ops.py @@ -18,7 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import fractions + from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops @@ -51,6 +54,13 @@ def gcd(a, b, name=None): if not b.dtype.is_integer: raise ValueError('b must be an integer type. Got: %s' % b.dtype) + # TPU requires static shape inference. GCD is used for subframe size + # computation, so we should prefer static computation where possible. + const_a = tensor_util.constant_value(a) + const_b = tensor_util.constant_value(b) + if const_a is not None and const_b is not None: + return ops.convert_to_tensor(fractions.gcd(const_a, const_b)) + cond = lambda _, b: math_ops.greater(b, array_ops.zeros_like(b)) body = lambda a, b: [b, math_ops.mod(a, b)] a, b = control_flow_ops.while_loop(cond, body, [a, b], back_prop=False) |