diff options
author | RJ Ryan <rjryan@google.com> | 2018-05-01 12:02:59 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-01 12:05:51 -0700 |
commit | 8e918c3d202bb0eed6b423eb78a6ef45629f952e (patch) | |
tree | c12438ba89e55812f39e24dad5a06f160baca800 /tensorflow/contrib/signal | |
parent | 07c58859c2ec62757f110dc56da9946d415b72ee (diff) |
Improve shape inference for tf.contrib.signal.frame.
PiperOrigin-RevId: 194972934
Diffstat (limited to 'tensorflow/contrib/signal')
-rw-r--r-- | tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py | 2 | ||||
-rw-r--r-- | tensorflow/contrib/signal/python/ops/shape_ops.py | 14 |
2 files changed, 8 insertions, 8 deletions
diff --git a/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py index 64cc8c7ea5..f132050153 100644 --- a/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py +++ b/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py @@ -119,7 +119,7 @@ class FrameTest(test.TestCase): frame_step = 1 result = shape_ops.frame(signal, frame_length, frame_step, pad_end=True, pad_value=99, axis=1) - self.assertEqual([1, None, None, 3, 4], result.shape.as_list()) + self.assertEqual([1, 2, None, 3, 4], result.shape.as_list()) result = shape_ops.frame(signal, frame_length, frame_step, pad_end=False, axis=1) diff --git a/tensorflow/contrib/signal/python/ops/shape_ops.py b/tensorflow/contrib/signal/python/ops/shape_ops.py index 1ddc2941ec..91862f0cc0 100644 --- a/tensorflow/contrib/signal/python/ops/shape_ops.py +++ b/tensorflow/contrib/signal/python/ops/shape_ops.py @@ -43,13 +43,13 @@ def _infer_frame_shape(signal, frame_length, frame_step, pad_end, axis): outer_dimensions = signal_shape[:axis] inner_dimensions = signal_shape[axis:][1:] if signal_shape and frame_axis is not None: - if frame_step and frame_length is not None: - if pad_end: - # Double negative is so that we round up. - num_frames = -(-frame_axis // frame_step) - else: - num_frames = (frame_axis - frame_length + frame_step) // frame_step - num_frames = max(0, num_frames) + if frame_step is not None and pad_end: + # Double negative is so that we round up. + num_frames = max(0, -(-frame_axis // frame_step)) + elif frame_step is not None and frame_length is not None: + assert not pad_end + num_frames = max( + 0, (frame_axis - frame_length + frame_step) // frame_step) return outer_dimensions + [num_frames, frame_length] + inner_dimensions |