aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Olivia Nordquist <nolivia@google.com>2016-10-14 10:06:30 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-14 11:19:22 -0700
commit221a9154e64070d7dbdc7f4cea8cf4a3620c9a9c (patch)
tree8640d1be76cec5c9d028ee0ce849c4a286171a74
parent5c93dbdb336e762910717ecdadd7d48a8e629115 (diff)
Create OP "Round" which uses banker's rounding. Not used in Python API tf.round yet.
Change: 136175631
-rw-r--r--tensorflow/core/ops/math_ops.cc7
-rw-r--r--tensorflow/python/ops/math_ops.py11
-rw-r--r--tensorflow/python/ops/math_ops_test.py4
3 files changed, 19 insertions, 3 deletions
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index ab062c7601..97515cbb72 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -259,6 +259,13 @@ Computes reciprocal of square root of x element-wise.
I.e., \\(y = 1 / \sqrt{x}\\).
)doc");
+REGISTER_OP("Round").UNARY().Doc(R"doc(
+Rounds the values of a tensor to the nearest integer, element-wise.
+
+Rounds half to even. Also known as bankers rounding. If you want to round
+according to the current system rounding mode use std::cint.
+)doc");
+
REGISTER_OP("RsqrtGrad").UNARY_GRADIENT_COMPLEX().Doc(R"doc(
Computes the gradient for the rsqrt of `x` wrt its input.
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 23f141039a..c035673745 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -556,11 +556,13 @@ def imag(input, name=None):
def round(x, name=None):
"""Rounds the values of a tensor to the nearest integer, element-wise.
+ Rounds half to even. Also known as bankers rounding. If you want to round
+ according to the current system rounding mode use tf::cint.
For example:
```python
- # 'a' is [0.9, 2.5, 2.3, -4.4]
- tf.round(a) ==> [ 1.0, 3.0, 2.0, -4.0 ]
+ # 'a' is [0.9, 2.5, 2.3, 1.5, -4.5]
+ tf.round(a) ==> [ 1.0, 2.0, 2.0, 2.0, -4.0 ]
```
Args:
@@ -574,9 +576,14 @@ def round(x, name=None):
if x.dtype.is_integer:
return x
else:
+ # TODO(nolivia): Switch to new Round op
+ # return gen_math_ops.round(x, name=name)
return gen_math_ops.floor(x + 0.5, name=name)
+ops.RegisterShape("Round")(common_shapes.call_cpp_shape_fn)
+
+
def cast(x, dtype, name=None):
"""Casts a tensor to a new type.
diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py
index 50a913ebd8..6bfa13ba0a 100644
--- a/tensorflow/python/ops/math_ops_test.py
+++ b/tensorflow/python/ops/math_ops_test.py
@@ -115,7 +115,9 @@ class RoundTest(test_util.TensorFlowTestCase):
def testRounding(self):
x = [0.49, 0.7, -0.3, -0.8]
- for dtype in [np.float32, np.double]:
+ # TODO(nolivia): Remove this when RoundOp is forwards compatible
+ # x = np.arange(-5.0, 5.0, .25)
+ for dtype in [np.float32, np.double, np.int32]:
x_np = np.array(x, dtype=dtype)
with self.test_session(use_gpu=True):
x_tf = constant_op.constant(x_np, shape=x_np.shape)