aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/signal/python/ops/reconstruction_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/signal/python/ops/reconstruction_ops.py')
-rw-r--r--tensorflow/contrib/signal/python/ops/reconstruction_ops.py26
1 files changed, 16 insertions, 10 deletions
diff --git a/tensorflow/contrib/signal/python/ops/reconstruction_ops.py b/tensorflow/contrib/signal/python/ops/reconstruction_ops.py
index 653c030a04..4db8dc2ca0 100644
--- a/tensorflow/contrib/signal/python/ops/reconstruction_ops.py
+++ b/tensorflow/contrib/signal/python/ops/reconstruction_ops.py
@@ -90,22 +90,28 @@ def overlap_and_add(signal, frame_step, name=None):
raise ValueError("frame_step must be an integer. Got %s" %
frame_step.dtype)
- # If frame_length and frame_step are known at graph construction time, check
- # frame_step is less than or equal to frame_length.
- frame_step_static = tensor_util.constant_value(frame_step)
- if (frame_step_static is not None and signal.shape.ndims is not None and
- signal.shape[-1].value is not None and
- frame_step_static > signal.shape[-1].value):
- raise ValueError(
- "frame_step (%d) must be less than or equal to frame_length (%d)" % (
- frame_step_static, signal.shape[-1].value))
-
signal_shape = array_ops.shape(signal)
# All dimensions that are not part of the overlap-and-add. Can be empty for
# rank 2 inputs.
outer_dimensions = signal_shape[:-2]
+ # If frame_length and frame_step are known at graph construction time, check
+ # frame_step is less than or equal to frame_length.
+ frame_step_static = tensor_util.constant_value(frame_step)
+ if (frame_step_static is not None and signal.shape.ndims is not None and
+ signal.shape[-1].value is not None):
+ if frame_step_static > signal.shape[-1].value:
+ raise ValueError(
+ "frame_step (%d) must be less than or equal to "
+ "frame_length (%d)" % (
+ frame_step_static, signal.shape[-1].value))
+ # If frame_length is equal to frame_step, there's no overlap so just
+ # reshape the tensor.
+ if frame_step_static == signal.shape[-1].value:
+ return array_ops.reshape(signal, array_ops.concat(
+ [outer_dimensions, [-1]], 0))
+
signal_rank = array_ops.rank(signal)
frames = signal_shape[-2]
frame_length = signal_shape[-1]