aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Andrew Selle <aselle@google.com>2016-11-22 10:04:37 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-22 10:22:31 -0800
commitfcc3923ab50a98dcbe0f972231b3c4656ebb9228 (patch)
tree7f520876e3705c7584b1b3fd7a33e289aa74ecd4
parent5591ca5f02377b27c3827b34c14b3f2f86451e06 (diff)
Change division in TensorFlow to flooring semantics.
- tf.div changes to new behavior, but it will be deprecated - tf.divide is currently a synonym for tf.div but will remain - tf.mod changes to new behavior, but it will be deprecated, you can use % or tf.floormod in the future. - the op FloorDiv now is extended to work on reals Change: 139922734
-rw-r--r--RELEASE.md4
-rw-r--r--tensorflow/core/kernels/cwise_op_floor_div.cc4
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_floor_div.cu.cc1
-rw-r--r--tensorflow/core/kernels/cwise_ops.h21
-rw-r--r--tensorflow/python/ops/math_grad.py40
-rw-r--r--tensorflow/python/ops/math_ops.py37
-rw-r--r--tensorflow/python/ops/math_ops_test.py21
7 files changed, 120 insertions, 8 deletions
diff --git a/RELEASE.md b/RELEASE.md
index d618c865f5..939eee0f2d 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -2,6 +2,10 @@
## Breaking Changes to the API
+* Division and modulus operators (/, //, %) now match Python (flooring)
+ semantics. tf.div is renamed to tf.division. New operators tf.truncatediv and
+ tf.truncatemod are available for achieving the previous C++ (truncation)
+ division/modulus semantics.
* `BusAdjacency` enum replaced with a protocol buffer `DeviceLocality`. PCI bus
indexing now starts from 1 instead of 0, and bus_id==0 is used where previously
BUS_ANY was used.
diff --git a/tensorflow/core/kernels/cwise_op_floor_div.cc b/tensorflow/core/kernels/cwise_op_floor_div.cc
index 7930d83413..a5767476c3 100644
--- a/tensorflow/core/kernels/cwise_op_floor_div.cc
+++ b/tensorflow/core/kernels/cwise_op_floor_div.cc
@@ -28,9 +28,13 @@ REGISTER5(BinaryOp, CPU, "FloorDiv", functor::safe_floor_div, uint8, uint16,
TF_CALL_INTEGRAL_TYPES(REGISTER_SYCL_KERNEL);
#undef REGISTER_SYCL_KERNEL
#endif // TENSORFLOW_USE_SYCL
+REGISTER3(BinaryOp, CPU, "FloorDiv", functor::floor_div_real, float,
+ Eigen::half, double);
#if GOOGLE_CUDA
REGISTER4(BinaryOp, GPU, "FloorDiv", functor::floor_div, uint8, uint16, int16,
int64);
+REGISTER3(BinaryOp, GPU, "FloorDiv", functor::floor_div_real, float,
+ Eigen::half, double);
#endif
#if GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_floor_div.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_floor_div.cu.cc
index 1300bf2232..0e4887eafd 100644
--- a/tensorflow/core/kernels/cwise_op_gpu_floor_div.cu.cc
+++ b/tensorflow/core/kernels/cwise_op_gpu_floor_div.cu.cc
@@ -20,6 +20,7 @@ limitations under the License.
namespace tensorflow {
namespace functor {
DEFINE_BINARY5(floor_div, uint8, uint16, int16, int32, int64);
+DEFINE_BINARY3(floor_div_real, Eigen::half, float, double);
} // namespace functor
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h
index 7f35e03feb..34103347fb 100644
--- a/tensorflow/core/kernels/cwise_ops.h
+++ b/tensorflow/core/kernels/cwise_ops.h
@@ -269,6 +269,24 @@ struct functor_traits<google_floor_div<Scalar>> {
};
};
+// TODO(b/32239616): This kernel should be moved into Eigen and vectorized.
+template <typename T, typename Enable = void>
+struct google_floor_div_real {
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x,
+ const T& y) const {
+ return Eigen::numext::floor(x / y);
+ }
+};
+
+template <typename Scalar>
+struct functor_traits<google_floor_div_real<Scalar>> {
+ enum {
+ Cost = 2 * Eigen::internal::scalar_div_cost<Scalar, false>::value +
+ 2 * NumTraits<Scalar>::AddCost,
+ PacketAccess = false
+ };
+};
+
// TODO(b//32239616): This kernel should be moved into Eigen and vectorized.
template <typename T>
struct google_floor_fmod {
@@ -612,6 +630,9 @@ struct safe_floor_div : base<T, Eigen::internal::safe_div_or_mod_op<
};
template <typename T>
+struct floor_div_real : base<T, Eigen::internal::google_floor_div_real<T>> {};
+
+template <typename T>
struct pow : base<T, Eigen::internal::scalar_binary_pow_op_google<T, T>> {};
template <typename T>
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index 1fd69ae717..3502f11892 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -613,16 +613,48 @@ def _MulGrad(op, grad):
@ops.RegisterGradient("Div")
def _DivGrad(op, grad):
+ """The gradient for the Div operator."""
x = op.inputs[0]
y = op.inputs[1]
sx = array_ops.shape(x)
sy = array_ops.shape(y)
- rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy) # pylint: disable=protected-access
+ # pylint: disable=protected-access
+ rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
+ # pylint: enable=protected-access
+ x = math_ops.conj(x)
+ y = math_ops.conj(y)
+ return (array_ops.reshape(math_ops.reduce_sum(math_ops.div(grad, y), rx), sx),
+ array_ops.reshape(math_ops.reduce_sum(
+ grad * math_ops.div(-x, math_ops.square(y)), ry), sy))
+
+
+@ops.RegisterGradient("FloorDiv")
+def _FloorDivGrad(_, unused_grad):
+ """The gradient for the FloorDiv operator."""
+ return None, None
+
+
+@ops.RegisterGradient("TruncateDiv")
+def _TruncateDivGrad(_, unused_grad):
+ return None, None
+
+
+@ops.RegisterGradient("RealDiv")
+def _RealDivGrad(op, grad):
+ """RealDiv op gradient."""
+ x = op.inputs[0]
+ y = op.inputs[1]
+ sx = array_ops.shape(x)
+ sy = array_ops.shape(y)
+ # pylint: disable=protected-access
+ rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
+ # pylint: enable=protected-access
x = math_ops.conj(x)
y = math_ops.conj(y)
- return (array_ops.reshape(math_ops.reduce_sum(grad / y, rx), sx),
- array_ops.reshape(math_ops.reduce_sum(grad *
- (-x / math_ops.square(y)), ry), sy))
+ return (array_ops.reshape(math_ops.reduce_sum(
+ math_ops.realdiv(grad, y), rx), sx),
+ array_ops.reshape(math_ops.reduce_sum(
+ grad * math_ops.realdiv(-x, math_ops.square(y)), ry), sy))
@ops.RegisterGradient("Pow")
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index c2aab4c945..6fce264bd9 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -928,7 +928,32 @@ def truediv(x, y, name=None):
if dtype is not None:
x = cast(x, dtype)
y = cast(y, dtype)
- return gen_math_ops.div(x, y, name=name)
+ return gen_math_ops.real_div(x, y, name=name)
+
+
+def div(x, y, name=None):
+ with ops.name_scope(name, "truediv", [x, y]) as name:
+ x = ops.convert_to_tensor(x, name="x")
+ y = ops.convert_to_tensor(y, name="y", dtype=x.dtype.base_dtype)
+ x_dtype = x.dtype.base_dtype
+ y_dtype = y.dtype.base_dtype
+ if x_dtype != y_dtype:
+ raise TypeError("x and y must have the same dtype, got %r != %r" %
+ (x_dtype, y_dtype))
+ if x_dtype.is_floating or x_dtype.is_complex:
+ return gen_math_ops.real_div(x, y, name=name)
+ else:
+ return gen_math_ops.floor_div(x, y, name=name)
+
+
+def div_deprecated(x, y, name=None):
+ return gen_math_ops.div(x, y, name)
+
+mod = gen_math_ops.floor_mod
+
+
+def mod_deprecated(x, y, name=None):
+ return gen_math_ops.mod(x, y, name)
# TODO(aselle): Deprecate this once all internal functionality uses
@@ -960,6 +985,11 @@ def floordiv(x, y, name=None):
TypeError: If the inputs are complex.
"""
with ops.name_scope(name, "floordiv", [x, y]) as name:
+ return gen_math_ops.floor_div(x, y, name=name)
+
+
+def floordiv_deprecated(x, y, name=None):
+ with ops.name_scope(name, "floordiv", [x, y]) as name:
x = ops.convert_to_tensor(x, name="x")
dtype = x.dtype
if dtype.is_floating:
@@ -971,7 +1001,6 @@ def floordiv(x, y, name=None):
# return gen_math_ops.floor_div(x, y, name=name)
return gen_math_ops.div(x, y, name=name)
-
realdiv = gen_math_ops.real_div
truncatediv = gen_math_ops.truncate_div
# TODO(aselle): Rename this to floordiv when we can.
@@ -1002,12 +1031,12 @@ _OverrideBinaryOperatorHelper(gen_sparse_ops.sparse_dense_cwise_mul, "mul",
_OverrideBinaryOperatorHelper(gen_math_ops.add, "add")
_OverrideBinaryOperatorHelper(gen_math_ops.sub, "sub")
_OverrideBinaryOperatorHelper(_mul_dispatch, "mul")
-_OverrideBinaryOperatorHelper(gen_math_ops.div, "div")
+_OverrideBinaryOperatorHelper(div, "div")
_OverrideBinaryOperatorHelper(truediv, "truediv")
_OverrideBinaryOperatorHelper(floordiv, "floordiv")
# TODO(aselle): Switch mod to floor_mod when ready
# _OverrideBinaryOperatorHelper(gen_math_ops.floor_mod, "mod")
-_OverrideBinaryOperatorHelper(gen_math_ops.mod, "mod")
+_OverrideBinaryOperatorHelper(gen_math_ops.floor_mod, "mod")
_OverrideBinaryOperatorHelper(pow, "pow")
diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py
index 4bbbc7b4f7..197ddb6a75 100644
--- a/tensorflow/python/ops/math_ops_test.py
+++ b/tensorflow/python/ops/math_ops_test.py
@@ -24,6 +24,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradients
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
@@ -305,6 +306,26 @@ class DivAndModTest(test_util.TensorFlowTestCase):
np_result = np.divide(nums, divs)
self.assertAllEqual(tf_result, np_result)
+ def testComplexDiv(self):
+ foo = array_ops.constant([1.+3.j])
+ with self.test_session():
+ _ = math_ops.div_deprecated(foo, 1.).eval()
+ _ = math_ops.div(foo, 2.).eval()
+
+ def testFloorDivGrad(self):
+ with self.test_session():
+ a = variables.Variable(2.)
+ b = variables.Variable(4.)
+ with self.test_session() as sess:
+ sess.run(variables.initialize_all_variables())
+ c_grad = gradients.gradients(math_ops.div_deprecated(a, b), [a, b])
+ self.assertAllEqual([x.eval() for x in c_grad], [.25, -.125])
+ c_grad = gradients.gradients(math_ops.div(a, b), [a, b])
+ self.assertAllEqual([x.eval() for x in c_grad], [.25, -.125])
+ c_grad = gradients.gradients(math_ops.floordiv(a, b), [a, b])
+ self.assertAllEqual([None if x is None else x.eval() for x in c_grad],
+ [None, None])
+
def testConsistent(self):
nums, divs = self.intTestData()
with self.test_session():