aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-09-30 09:54:02 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-30 11:03:14 -0700
commita790eadfeddcc99c96b23bb0ed51613737477691 (patch)
tree4512a930e65f81dedfdb817999f727e1654b7ea1
parent4b8be072eaf0c66298d0f9b1657bc166399d1108 (diff)
tf.range: change capabilities and arguments to better match np.arange
* allow user to specify dtype * if dtype is not specified, it is inferred * delta < 0 is allowed Change: 134801946
-rw-r--r--tensorflow/core/kernels/sequence_ops.cc76
-rw-r--r--tensorflow/core/ops/math_ops.cc63
-rw-r--r--tensorflow/core/ops/math_ops_test.cc18
-rw-r--r--tensorflow/python/kernel_tests/init_ops_test.py43
-rw-r--r--tensorflow/python/ops/math_ops.py43
5 files changed, 186 insertions, 57 deletions
diff --git a/tensorflow/core/kernels/sequence_ops.cc b/tensorflow/core/kernels/sequence_ops.cc
index 3cbd9691d1..88a163600c 100644
--- a/tensorflow/core/kernels/sequence_ops.cc
+++ b/tensorflow/core/kernels/sequence_ops.cc
@@ -15,6 +15,8 @@ limitations under the License.
// See docs in ../ops/math_ops.cc.
+#include <cmath>
+
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
@@ -43,45 +45,69 @@ class RangeOp : public OpKernel {
OP_REQUIRES(context, IsLegacyScalar(delta_in.shape()),
errors::InvalidArgument("delta must be a scalar, not shape ",
delta_in.shape().DebugString()));
- const int32 start = GetValue(start_in.scalar<T>()());
- const int32 limit = GetValue(limit_in.scalar<T>()());
- OP_REQUIRES(context, start <= limit,
- errors::InvalidArgument("Requires start <= limit: ", start, "/",
- limit));
- const int32 delta = GetValue(delta_in.scalar<T>()());
- OP_REQUIRES(context, delta > 0,
- errors::InvalidArgument("Requires delta > 0: ", delta));
- int32 size = (limit - start + delta - 1) / delta;
+ const T start = start_in.scalar<T>()();
+ const T limit = limit_in.scalar<T>()();
+ const T delta = delta_in.scalar<T>()();
+ OP_REQUIRES(context, delta != 0,
+ errors::InvalidArgument("Requires delta != 0: ", delta));
+ if (delta > 0) {
+ OP_REQUIRES(
+ context, start <= limit,
+ errors::InvalidArgument("Requires start <= limit when delta > 0: ",
+ start, "/", limit));
+ } else {
+ OP_REQUIRES(
+ context, start >= limit,
+ errors::InvalidArgument("Requires start >= limit when delta < 0: ",
+ start, "/", limit));
+ }
+ int64 size = (std::is_integral<T>::value
+ ? ((std::abs(limit - start) + std::abs(delta) - 1) /
+ std::abs(delta))
+ : std::ceil(std::abs((limit - start) / delta)));
Tensor* out = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(0, TensorShape({size}), &out));
auto flat = out->flat<T>();
- int32 val = start;
- for (int32 i = 0; i < size; ++i) {
+ T val = start;
+ for (int64 i = 0; i < size; ++i) {
flat(i) = T(val);
val += delta;
}
}
};
-REGISTER_KERNEL_BUILDER(Name("Range")
- .Device(DEVICE_CPU)
- .HostMemory("start")
- .HostMemory("limit")
- .HostMemory("delta")
- .HostMemory("output"),
- RangeOp<int32>);
+#define REGISTER_KERNEL(DEV, TYPE) \
+ REGISTER_KERNEL_BUILDER(Name("Range") \
+ .Device(DEV) \
+ .HostMemory("start") \
+ .HostMemory("limit") \
+ .HostMemory("delta") \
+ .HostMemory("output") \
+ .TypeConstraint<TYPE>("Tidx"), \
+ RangeOp<TYPE>);
+
+#define REGISTER_CPU_KERNEL(T) REGISTER_KERNEL(DEVICE_CPU, T)
+#define REGISTER_GPU_KERNEL(T) REGISTER_KERNEL(DEVICE_GPU, T)
+
+TF_CALL_float(REGISTER_CPU_KERNEL);
+TF_CALL_double(REGISTER_CPU_KERNEL);
+TF_CALL_int32(REGISTER_CPU_KERNEL);
+TF_CALL_int64(REGISTER_CPU_KERNEL);
#if GOOGLE_CUDA
-REGISTER_KERNEL_BUILDER(Name("Range")
- .Device(DEVICE_GPU)
- .HostMemory("start")
- .HostMemory("limit")
- .HostMemory("delta")
- .HostMemory("output"),
- RangeOp<int32>);
+
+TF_CALL_float(REGISTER_GPU_KERNEL);
+TF_CALL_double(REGISTER_GPU_KERNEL);
+TF_CALL_int32(REGISTER_GPU_KERNEL);
+TF_CALL_int64(REGISTER_GPU_KERNEL);
+
#endif // GOOGLE_CUDA
+#undef REGISTER_KERNEL
+#undef REGISTER_CPU_KERNEL
+#undef REGISTER_GPU_KERNEL
+
template <typename T>
class LinSpaceOp : public OpKernel {
public:
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 0034301690..e498d9757a 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -1699,12 +1699,42 @@ output: The reduced tensor.
// --------------------------------------------------------------------------
+namespace {
+
+template <typename T>
+Status RangeSize(const Tensor* start_t, const Tensor* limit_t,
+ const Tensor* delta_t, InferenceContext* const c) {
+ T start = start_t->scalar<T>()();
+ T limit = limit_t->scalar<T>()();
+ T delta = delta_t->scalar<T>()();
+ if (start > limit && delta > 0) {
+ return errors::InvalidArgument("Requires start <= limit when delta > 0: ",
+ start, "/", limit);
+ }
+ if (start < limit && delta < 0) {
+ return errors::InvalidArgument("Requires start >= limit when delta < 0: ",
+ start, "/", limit);
+ }
+ if (delta == 0) {
+ return errors::InvalidArgument("Requires delta != 0");
+ }
+
+ int64 size =
+ (std::is_integral<T>::value
+ ? ((std::abs(limit - start) + std::abs(delta) - 1) / std::abs(delta))
+ : std::ceil(std::abs((limit - start) / delta)));
+ c->set_output(0, c->Vector(size));
+ return Status::OK();
+}
+
+} // namespace
+
REGISTER_OP("Range")
.Input("start: Tidx")
.Input("limit: Tidx")
.Input("delta: Tidx")
.Output("output: Tidx")
- .Attr("Tidx: {int32, int64} = DT_INT32")
+ .Attr("Tidx: {float, double, int32, int64} = DT_INT32")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle unused;
TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(0), 0, &unused),
@@ -1716,36 +1746,27 @@ REGISTER_OP("Range")
const Tensor* start_t = c->input_tensor(0);
const Tensor* limit_t = c->input_tensor(1);
const Tensor* delta_t = c->input_tensor(2);
+ DataType dtype;
+ TF_RETURN_IF_ERROR(c->GetAttr("Tidx", &dtype));
if (start_t == nullptr || limit_t == nullptr || delta_t == nullptr) {
c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
return Status::OK();
}
- // TODO
- int64 start, limit, delta;
- if (start_t->dtype() == DT_INT32) {
- start = start_t->scalar<int32>()();
- limit = limit_t->scalar<int32>()();
- delta = delta_t->scalar<int32>()();
+ if (dtype == DT_INT32) {
+ return RangeSize<int32>(start_t, limit_t, delta_t, c);
+ } else if (dtype == DT_INT64) {
+ return RangeSize<int64>(start_t, limit_t, delta_t, c);
+ } else if (dtype == DT_FLOAT) {
+ return RangeSize<float>(start_t, limit_t, delta_t, c);
} else {
- start = start_t->scalar<int64>()();
- limit = limit_t->scalar<int64>()();
- delta = delta_t->scalar<int64>()();
- }
- if (start > limit) {
- return errors::InvalidArgument("Requires start <= limit: ", start, "/",
- limit);
- }
- if (delta <= 0) {
- return errors::InvalidArgument("Requires delta > 0: ", delta);
+ return RangeSize<double>(start_t, limit_t, delta_t, c);
}
- const int64 size = (limit - start + delta - 1) / delta;
- c->set_output(0, c->Vector(size));
return Status::OK();
})
.Doc(R"doc(
-Creates a sequence of integers.
+Creates a sequence of numbers.
-This operation creates a sequence of integers that begins at `start` and
+This operation creates a sequence of numbers that begins at `start` and
extends by increments of `delta` up to but not including `limit`.
For example:
diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc
index b091c8dac0..5d0b75c579 100644
--- a/tensorflow/core/ops/math_ops_test.cc
+++ b/tensorflow/core/ops/math_ops_test.cc
@@ -217,6 +217,14 @@ TEST(MathOpsTest, Select_ShapeFn) {
TEST(MathOpsTest, Range_ShapeFn) {
ShapeInferenceTestOp op("Range");
+
+ TF_ASSERT_OK(NodeDefBuilder("test", "Range")
+ .Input({"start", {}, DT_INT32})
+ .Input({"limit", {}, DT_INT32})
+ .Input({"delta", {}, DT_INT32})
+ .Attr("Tidx", DT_INT32)
+ .Finalize(&op.node_def));
+
op.input_tensors.resize(3);
INFER_OK(op, "?;?;?", "[?]");
INFER_ERROR("Shape must be rank 0 but is rank 2", op, "[1,2];?;?");
@@ -240,11 +248,17 @@ TEST(MathOpsTest, Range_ShapeFn) {
INFER_OK(op, "?;?;?", "[0]");
delta_t = test::AsScalar(0);
- INFER_ERROR("Requires delta > 0: 0", op, "?;?;?");
+ INFER_ERROR("Requires delta != 0", op, "?;?;?");
delta_t = test::AsScalar(3);
limit_t = test::AsScalar(-1);
- INFER_ERROR("Requires start <= limit: 1/-1", op, "?;?;?");
+ INFER_ERROR("Requires start <= limit when delta > 0: 1/-1", op, "?;?;?");
+
+ delta_t = test::AsScalar(-1);
+ INFER_OK(op, "?;?;?", "[2]");
+
+ limit_t = test::AsScalar(4);
+ INFER_ERROR("Requires start >= limit when delta < 0: 1/4", op, "?;?;?");
limit_t = test::AsScalar(100);
start_t = test::AsScalar(2);
diff --git a/tensorflow/python/kernel_tests/init_ops_test.py b/tensorflow/python/kernel_tests/init_ops_test.py
index 24df91c34f..4ef11ed914 100644
--- a/tensorflow/python/kernel_tests/init_ops_test.py
+++ b/tensorflow/python/kernel_tests/init_ops_test.py
@@ -308,7 +308,8 @@ class RangeTest(tf.test.TestCase):
def _Range(self, start, limit, delta):
with self.test_session(use_gpu=True):
tf_ans = tf.range(start, limit, delta, name="range")
- self.assertEqual([len(range(start, limit, delta))], tf_ans.get_shape())
+ self.assertEqual([len(np.arange(start, limit, delta))],
+ tf_ans.get_shape())
return tf_ans.eval()
def testBasic(self):
@@ -332,6 +333,46 @@ class RangeTest(tf.test.TestCase):
for start in 0, 5:
self.assertTrue(np.array_equal(self._Range(start, start, 1), []))
+ def testNonInteger(self):
+ self.assertTrue(
+ np.allclose(self._Range(0, 2, 0.5), np.array([0, 0.5, 1, 1.5])))
+ self.assertTrue(np.allclose(self._Range(0, 5, 2.5), np.array([0, 2.5])))
+ self.assertTrue(
+ np.allclose(self._Range(0, 3, 0.9), np.array([0, 0.9, 1.8, 2.7])))
+ self.assertTrue(
+ np.allclose(
+ self._Range(100., 500., 100.), np.array([100, 200, 300, 400])))
+ self.assertEqual(tf.range(0., 5., 1.).dtype, tf.float32)
+
+ def testNegativeDelta(self):
+ self.assertTrue(
+ np.array_equal(self._Range(5, -1, -1), np.array([5, 4, 3, 2, 1, 0])))
+ self.assertTrue(
+ np.allclose(self._Range(2.5, 0, -0.5), np.array([2.5, 2, 1.5, 1, 0.5])))
+ self.assertTrue(
+ np.array_equal(self._Range(-5, -10, -3), np.array([-5, -8])))
+
+ def testDType(self):
+ zero_int32 = tf.cast(0, tf.int32)
+ zero_int64 = tf.cast(0, tf.int64)
+ zero_float32 = tf.cast(0, tf.float32)
+ zero_float64 = tf.cast(0, tf.float64)
+
+ self.assertEqual(tf.range(zero_int32, 0, 1).dtype, tf.int32)
+ self.assertEqual(tf.range(zero_int64, 0, 1).dtype, tf.int64)
+ self.assertEqual(tf.range(zero_float32, 0, 1).dtype, tf.float32)
+ self.assertEqual(tf.range(zero_float64, 0, 1).dtype, tf.float64)
+
+ self.assertEqual(tf.range(zero_int32, zero_int64, 1).dtype, tf.int64)
+ self.assertEqual(tf.range(zero_int64, zero_float32, 1).dtype, tf.float32)
+ self.assertEqual(tf.range(zero_float32, zero_float64, 1).dtype, tf.float64)
+ self.assertEqual(tf.range(zero_float64, zero_int32, 1).dtype, tf.float64)
+
+ self.assertEqual(tf.range(0, 0, 1, dtype=tf.int32).dtype, tf.int32)
+ self.assertEqual(tf.range(0, 0, 1, dtype=tf.int64).dtype, tf.int64)
+ self.assertEqual(tf.range(0, 0, 1, dtype=tf.float32).dtype, tf.float32)
+ self.assertEqual(tf.range(0, 0, 1, dtype=tf.float64).dtype, tf.float64)
+
# TODO(vrv): move to sequence_ops_test?
class LinSpaceTest(tf.test.TestCase):
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 2c025d951c..9785d11943 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -951,12 +951,15 @@ ops.Tensor._override_operator("__gt__", gen_math_ops.greater)
ops.Tensor._override_operator("__ge__", gen_math_ops.greater_equal)
-def range(start, limit=None, delta=1, name="range"):
- """Creates a sequence of integers.
+def range(start, limit=None, delta=1, dtype=None, name="range"):
+ """Creates a sequence of numbers.
- Creates a sequence of integers that begins at `start` and extends by
+ Creates a sequence of numbers that begins at `start` and extends by
increments of `delta` up to but not including `limit`.
+ The dtype of the resulting tensor is inferred from the inputs unless
+ it is provided explicitly.
+
Like the Python builtin `range`, `start` defaults to 0, so that
`range(n) = range(0, n)`.
@@ -968,27 +971,51 @@ def range(start, limit=None, delta=1, name="range"):
# 'delta' is 3
tf.range(start, limit, delta) ==> [3, 6, 9, 12, 15]
+ # 'start' is 3
+ # 'limit' is 1
+ # 'delta' is -0.5
+ tf.range(start, limit, delta) ==> [3, 2.5, 2, 1.5]
+
# 'limit' is 5
tf.range(limit) ==> [0, 1, 2, 3, 4]
```
Args:
- start: A 0-D (scalar) of type `int32`. Acts as first entry in the range if
+ start: A 0-D `Tensor` (scalar). Acts as first entry in the range if
`limit` is not None; otherwise, acts as range limit and first entry
defaults to 0.
- limit: A 0-D (scalar) of type `int32`. Upper limit of sequence,
+ limit: A 0-D `Tensor` (scalar). Upper limit of sequence,
exclusive. If None, defaults to the value of `start` while the first
entry of the range defaults to 0.
- delta: A 0-D `Tensor` (scalar) of type `int32`. Number that increments
+ delta: A 0-D `Tensor` (scalar). Number that increments
`start`. Defaults to 1.
+ dtype: The type of the elements of the resulting tensor.
name: A name for the operation. Defaults to "range".
Returns:
- An 1-D `int32` `Tensor`.
+ An 1-D `Tensor` of type `dtype`.
"""
if limit is None:
start, limit = 0, start
- return gen_math_ops._range(start, limit, delta, name=name)
+
+ with ops.name_scope(name, "Range", [start, limit, delta]) as name:
+ start = ops.convert_to_tensor(start, dtype=dtype, name="start")
+ limit = ops.convert_to_tensor(limit, dtype=dtype, name="limit")
+ delta = ops.convert_to_tensor(delta, dtype=dtype, name="delta")
+
+ # infer dtype if not explicitly provided
+ if dtype is None:
+ dtype_hierarchy = [dtypes.int32, dtypes.int64, dtypes.float32,
+ dtypes.float64]
+ assert all(arg.dtype in dtype_hierarchy for arg in [start, limit, delta])
+ inferred_dtype = max([arg.dtype for arg in [start, limit, delta]],
+ key=dtype_hierarchy.index)
+
+ start = cast(start, inferred_dtype)
+ limit = cast(limit, inferred_dtype)
+ delta = cast(delta, inferred_dtype)
+
+ return gen_math_ops._range(start, limit, delta, name=name)
@ops.RegisterShape("Range")