aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/autograph
diff options
context:
space:
mode:
authorGravatar Dan Moldovan <mdan@google.com>2018-09-24 18:52:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-24 18:56:01 -0700
commitec2cc9122cca5fdec52d6c1ec42b771b8082d298 (patch)
treea89df7e35858632e5ea70fd66658b89694b97f1b /tensorflow/python/autograph
parentd5c5f8ecc124ee9a866318f2bd7082df9e38ebf2 (diff)
Ensure tf.range has semantics consistent with range, which allows start and end indices that would result in an empty range. tf.range errors out at graph construction time in that case.
PiperOrigin-RevId: 214369488
Diffstat (limited to 'tensorflow/python/autograph')
-rw-r--r--tensorflow/python/autograph/operators/py_builtins.py7
-rw-r--r--tensorflow/python/autograph/operators/py_builtins_test.py7
2 files changed, 14 insertions, 0 deletions
diff --git a/tensorflow/python/autograph/operators/py_builtins.py b/tensorflow/python/autograph/operators/py_builtins.py
index 1d37ae72d3..91a2a22cc2 100644
--- a/tensorflow/python/autograph/operators/py_builtins.py
+++ b/tensorflow/python/autograph/operators/py_builtins.py
@@ -193,11 +193,18 @@ def range_(start_or_stop, stop=UNDEFINED, step=UNDEFINED):
def _tf_range(start_or_stop, stop, step):
+ # Note: for static inputs (e.g. constants), tf.range errors out at graph
+ # construction time, instead of returning an empty tensor. Preventing the
+ # graph construction error aligns the semantics with Python.
+
# TODO(mdan): We should optimize this when a full tensor is not required.
if step is not UNDEFINED:
+ # TODO(mdan): Add argument coercion similar to other cases.
return math_ops.range(start_or_stop, stop, step)
if stop is not UNDEFINED:
+ stop = math_ops.maximum(start_or_stop, stop)
return math_ops.range(start_or_stop, stop)
+ start_or_stop = math_ops.maximum(start_or_stop, 0)
return math_ops.range(start_or_stop)
diff --git a/tensorflow/python/autograph/operators/py_builtins_test.py b/tensorflow/python/autograph/operators/py_builtins_test.py
index d64d31cc79..c94a918d5a 100644
--- a/tensorflow/python/autograph/operators/py_builtins_test.py
+++ b/tensorflow/python/autograph/operators/py_builtins_test.py
@@ -126,6 +126,13 @@ class PyBuiltinsTest(test.TestCase):
r = py_builtins.range_(2, 0, constant_op.constant(-1))
self.assertAllEqual(sess.run(r), [2, 1])
+ def test_range_tensor_empty_range(self):
+ with self.test_session() as sess:
+ r = py_builtins.range_(constant_op.constant(-3))
+ self.assertAllEqual(sess.run(r), [])
+ r = py_builtins.range_(5, constant_op.constant(2))
+ self.assertAllEqual(sess.run(r), [])
+
if __name__ == '__main__':
test.main()