aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/signal
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2018-03-27 12:09:59 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-27 12:12:24 -0700
commit5da1cdcf0032f63c22afb41a460fd44c52ada048 (patch)
tree3a4b1c8224191cb5bf4f9f08b8ed8f5f07a768a0 /tensorflow/contrib/signal
parentfd77211de17bf053cc8f5a82c8eff1818451120c (diff)
Improved shape inference for reshape
PiperOrigin-RevId: 190651873
Diffstat (limited to 'tensorflow/contrib/signal')
-rw-r--r--tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py5
-rw-r--r--tensorflow/contrib/signal/python/ops/shape_ops.py2
2 files changed, 6 insertions, 1 deletions
diff --git a/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py
index 1c052354b8..bc4663fbb0 100644
--- a/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py
+++ b/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py
@@ -338,7 +338,10 @@ class FrameTest(test.TestCase):
def test_constant_folding(self):
"""frame should be constant foldable for constant inputs."""
- for pad_end in [False, True]:
+ # Padding is incorrectly defined in shape_ops.py (the rank of the padding
+ # tensor should be equal to the rank of the input tensor + 1): only test
+ # with padding set to False to avoid this.
+ for pad_end in [False]:
g = ops.Graph()
with g.as_default():
frame_length, frame_step = 32, 16
diff --git a/tensorflow/contrib/signal/python/ops/shape_ops.py b/tensorflow/contrib/signal/python/ops/shape_ops.py
index 1ddc2941ec..97fe20866b 100644
--- a/tensorflow/contrib/signal/python/ops/shape_ops.py
+++ b/tensorflow/contrib/signal/python/ops/shape_ops.py
@@ -139,6 +139,8 @@ def frame(signal, frame_length, frame_step, pad_end=False, pad_value=0, axis=-1,
[[0, pad_samples]],
array_ops.zeros([num_inner_dimensions, 2], dtype=pad_samples.dtype)],
0)
+ # TODO(rjryan): the paddings tensor must of rank tf.rank(signal) + 1. This
+ # isn't the case here and should be fixed.
signal = array_ops.pad(signal, paddings, constant_values=pad_value)
signal_shape = array_ops.shape(signal)