aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/signal
diff options
context:
space:
mode:
authorGravatar RJ Ryan <rjryan@google.com>2018-05-01 12:02:59 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-01 12:05:51 -0700
commit8e918c3d202bb0eed6b423eb78a6ef45629f952e (patch)
treec12438ba89e55812f39e24dad5a06f160baca800 /tensorflow/contrib/signal
parent07c58859c2ec62757f110dc56da9946d415b72ee (diff)
Improve shape inference for tf.contrib.signal.frame.
PiperOrigin-RevId: 194972934
Diffstat (limited to 'tensorflow/contrib/signal')
-rw-r--r--tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py2
-rw-r--r--tensorflow/contrib/signal/python/ops/shape_ops.py14
2 files changed, 8 insertions, 8 deletions
diff --git a/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py
index 64cc8c7ea5..f132050153 100644
--- a/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py
+++ b/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py
@@ -119,7 +119,7 @@ class FrameTest(test.TestCase):
frame_step = 1
result = shape_ops.frame(signal, frame_length, frame_step,
pad_end=True, pad_value=99, axis=1)
- self.assertEqual([1, None, None, 3, 4], result.shape.as_list())
+ self.assertEqual([1, 2, None, 3, 4], result.shape.as_list())
result = shape_ops.frame(signal, frame_length, frame_step,
pad_end=False, axis=1)
diff --git a/tensorflow/contrib/signal/python/ops/shape_ops.py b/tensorflow/contrib/signal/python/ops/shape_ops.py
index 1ddc2941ec..91862f0cc0 100644
--- a/tensorflow/contrib/signal/python/ops/shape_ops.py
+++ b/tensorflow/contrib/signal/python/ops/shape_ops.py
@@ -43,13 +43,13 @@ def _infer_frame_shape(signal, frame_length, frame_step, pad_end, axis):
outer_dimensions = signal_shape[:axis]
inner_dimensions = signal_shape[axis:][1:]
if signal_shape and frame_axis is not None:
- if frame_step and frame_length is not None:
- if pad_end:
- # Double negative is so that we round up.
- num_frames = -(-frame_axis // frame_step)
- else:
- num_frames = (frame_axis - frame_length + frame_step) // frame_step
- num_frames = max(0, num_frames)
+ if frame_step is not None and pad_end:
+ # Double negative is so that we round up.
+ num_frames = max(0, -(-frame_axis // frame_step))
+ elif frame_step is not None and frame_length is not None:
+ assert not pad_end
+ num_frames = max(
+ 0, (frame_axis - frame_length + frame_step) // frame_step)
return outer_dimensions + [num_frames, frame_length] + inner_dimensions