aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
authorGravatar Andrew Selle <aselle@google.com>2016-10-18 16:08:51 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-18 17:23:52 -0700
commitbdae9c62caa19c7ecef1eca8e3b9149629404182 (patch)
treeca64c32090d9c21bb7942371bbd5f19b444e81bf /tensorflow/core
parent1855c80558fecebc1e8b90efbe2cfe8573bff38a (diff)
Add ops to implement NumPy parity for 1.0 in preparation for switching
We are holding off on switchover to the new ops for operator overloading and the Python API because of forward compatibility. - Implemented new FloorMod and FloorDiv which represent new implementations that match Pythonic semantics. These are not yet vectorized, but they are also not yet used. - Add TruncateMod in preparation to rename tf.mod to tf.truncateMod - Add RealDiv which only works on floating point inputs, but it uses Div's kernels - Add TruncateDiv which only works on integer inputs, but it also uses Div's kernels. Change: 136539473
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/kernels/cwise_op_div.cc8
-rw-r--r--tensorflow/core/kernels/cwise_op_floor_div.cc38
-rw-r--r--tensorflow/core/kernels/cwise_op_floor_mod.cc34
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_floor_div.cu.cc26
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_floor_mod.cu.cc26
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_mod.cu.cc2
-rw-r--r--tensorflow/core/kernels/cwise_op_mod.cc9
-rw-r--r--tensorflow/core/kernels/cwise_ops.h88
-rw-r--r--tensorflow/core/kernels/cwise_ops_common.cc2
-rw-r--r--tensorflow/core/ops/math_ops.cc68
10 files changed, 295 insertions, 6 deletions
diff --git a/tensorflow/core/kernels/cwise_op_div.cc b/tensorflow/core/kernels/cwise_op_div.cc
index 305d149501..925c9e9916 100644
--- a/tensorflow/core/kernels/cwise_op_div.cc
+++ b/tensorflow/core/kernels/cwise_op_div.cc
@@ -20,9 +20,17 @@ REGISTER5(BinaryOp, CPU, "Div", functor::div, float, Eigen::half, double,
complex64, complex128);
REGISTER5(BinaryOp, CPU, "Div", functor::safe_div, uint8, uint16, int16, int32,
int64);
+REGISTER5(BinaryOp, CPU, "TruncateDiv", functor::safe_div, uint8, uint16, int16,
+ int32, int64);
+REGISTER5(BinaryOp, CPU, "RealDiv", functor::div, float, Eigen::half, double,
+ complex64, complex128);
#if GOOGLE_CUDA
REGISTER9(BinaryOp, GPU, "Div", functor::div, float, Eigen::half, double, uint8,
uint16, int16, int64, complex64, complex128);
+REGISTER4(BinaryOp, GPU, "TruncateDiv", functor::div, uint8, uint16, int16,
+ int64);
+REGISTER5(BinaryOp, GPU, "RealDiv", functor::div, float, Eigen::half, double,
+ complex64, complex128);
// A special GPU kernel for int32.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
diff --git a/tensorflow/core/kernels/cwise_op_floor_div.cc b/tensorflow/core/kernels/cwise_op_floor_div.cc
new file mode 100644
index 0000000000..83b2771ed2
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_floor_div.cc
@@ -0,0 +1,38 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER5(BinaryOp, CPU, "FloorDiv", functor::safe_floor_div, uint8, uint16,
+ int16, int32, int64);
+#if GOOGLE_CUDA
+REGISTER4(BinaryOp, GPU, "FloorDiv", functor::floor_div, uint8, uint16, int16,
+ int64);
+#endif
+
+#if GOOGLE_CUDA
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("FloorDiv")
+ .Device(DEVICE_GPU)
+ .HostMemory("x")
+ .HostMemory("y")
+ .HostMemory("z")
+ .TypeConstraint<int32>("T"),
+ BinaryOp<CPUDevice, functor::safe_floor_div<int32>>);
+#endif
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_floor_mod.cc b/tensorflow/core/kernels/cwise_op_floor_mod.cc
new file mode 100644
index 0000000000..4e641a8bb3
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_floor_mod.cc
@@ -0,0 +1,34 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER2(BinaryOp, CPU, "FloorMod", functor::safe_floor_mod, int32, int64);
+REGISTER2(BinaryOp, CPU, "FloorMod", functor::floor_fmod, float, double);
+
+#if GOOGLE_CUDA
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("FloorMod")
+ .Device(DEVICE_GPU)
+ .HostMemory("x")
+ .HostMemory("y")
+ .HostMemory("z")
+ .TypeConstraint<int32>("T"),
+ BinaryOp<CPUDevice, functor::safe_floor_mod<int32>>);
+#endif
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_gpu_floor_div.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_floor_div.cu.cc
new file mode 100644
index 0000000000..1300bf2232
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_floor_div.cu.cc
@@ -0,0 +1,26 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_BINARY5(floor_div, uint8, uint16, int16, int32, int64);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_floor_mod.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_floor_mod.cu.cc
new file mode 100644
index 0000000000..bbe97e4b4d
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_floor_mod.cu.cc
@@ -0,0 +1,26 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+// TODO(b/32239807) No GPU ops for mod yet.
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_mod.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_mod.cu.cc
index 6ad66353fa..bbe97e4b4d 100644
--- a/tensorflow/core/kernels/cwise_op_gpu_mod.cu.cc
+++ b/tensorflow/core/kernels/cwise_op_gpu_mod.cu.cc
@@ -19,7 +19,7 @@ limitations under the License.
namespace tensorflow {
namespace functor {
-// No GPU ops for mod yet.
+// TODO(b/32239807) No GPU ops for mod yet.
} // namespace functor
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_mod.cc b/tensorflow/core/kernels/cwise_op_mod.cc
index 5c64794765..2d19156609 100644
--- a/tensorflow/core/kernels/cwise_op_mod.cc
+++ b/tensorflow/core/kernels/cwise_op_mod.cc
@@ -18,6 +18,8 @@ limitations under the License.
namespace tensorflow {
REGISTER2(BinaryOp, CPU, "Mod", functor::safe_mod, int32, int64);
REGISTER2(BinaryOp, CPU, "Mod", functor::fmod, float, double);
+REGISTER2(BinaryOp, CPU, "TruncateMod", functor::safe_mod, int32, int64);
+REGISTER2(BinaryOp, CPU, "TruncateMod", functor::fmod, float, double);
#if GOOGLE_CUDA
// A special GPU kernel for int32.
@@ -30,5 +32,12 @@ REGISTER_KERNEL_BUILDER(Name("Mod")
.HostMemory("z")
.TypeConstraint<int32>("T"),
BinaryOp<CPUDevice, functor::safe_mod<int32>>);
+REGISTER_KERNEL_BUILDER(Name("TruncateMod")
+ .Device(DEVICE_GPU)
+ .HostMemory("x")
+ .HostMemory("y")
+ .HostMemory("z")
+ .TypeConstraint<int32>("T"),
+ BinaryOp<CPUDevice, functor::safe_mod<int32>>);
#endif
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h
index 766c7152b0..5d15cf0048 100644
--- a/tensorflow/core/kernels/cwise_ops.h
+++ b/tensorflow/core/kernels/cwise_ops.h
@@ -18,7 +18,6 @@ limitations under the License.
#include <cmath>
#include <functional>
-#include <typeinfo>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/tensor_types.h"
@@ -86,7 +85,7 @@ struct safe_div_or_mod_op {
template <typename T, typename DivOrMod>
struct functor_traits<safe_div_or_mod_op<T, DivOrMod>> {
enum {
- Cost = scalar_div_cost<T, false>::value,
+ Cost = functor_traits<DivOrMod>::Cost + NumTraits<T>::AddCost,
PacketAccess = false,
};
};
@@ -237,6 +236,70 @@ struct functor_traits<scalar_compose_op<Scalar, UnaryFunctor, BinaryFunctor>> {
};
};
+// TODO(b/32239616): This kernel should be moved into Eigen and vectorized.
+template <typename T>
+struct google_floor_div {
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x,
+ const T& y) const {
+ if ((x < T(0)) != (y < T(0))) {
+ T abs_x = std::abs(x);
+ T abs_y = std::abs(y);
+ return -(abs_x + abs_y - 1) / abs_y;
+ } else {
+ return x / y;
+ }
+ }
+};
+
+template <typename Scalar>
+struct functor_traits<google_floor_div<Scalar>> {
+ enum {
+ Cost = 2 * Eigen::internal::scalar_div_cost<Scalar, false>::value +
+ 2 * NumTraits<Scalar>::AddCost,
+ PacketAccess = false
+ };
+};
+
+// TODO(b//32239616): This kernel should be moved into Eigen and vectorized.
+template <typename T>
+struct google_floor_fmod {
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x,
+ const T& y) const {
+ // EIGEN_STATIC_ASSERT(NUMERIC_TYPE_MUST_BE_REAL);
+ T trunc_mod = std::fmod(x, y);
+ return (x < T(0)) == (y < T(0)) ? trunc_mod : std::fmod(trunc_mod + y, y);
+ }
+};
+
+template <typename Scalar>
+struct functor_traits<google_floor_fmod<Scalar>> {
+ enum {
+ Cost = 2 * Eigen::internal::scalar_div_cost<Scalar, false>::value +
+ 2 * NumTraits<Scalar>::AddCost,
+ PacketAccess = false
+ };
+};
+
+// TODO(b/32239616): This kernel should be moved into Eigen and vectorized.
+template <typename T>
+struct google_floor_mod {
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x,
+ const T& y) const {
+ // EIGEN_STATIC_ASSERT(!NUMERIC_TYPE_MUST_BE_REAL);
+ T trunc_mod = x % y;
+ return (x < T(0)) == (y < T(0)) ? trunc_mod : (trunc_mod + y) % y;
+ }
+};
+
+template <typename Scalar>
+struct functor_traits<google_floor_mod<Scalar>> {
+ enum {
+ Cost = 2 * Eigen::internal::scalar_div_cost<Scalar, false>::value +
+ 2 * NumTraits<Scalar>::AddCost,
+ PacketAccess = false
+ };
+};
+
#if EIGEN_COMP_GNUC && __cplusplus > 199711L
#define DISABLE_FLOAT_EQUALITY_WARNING \
_Pragma("GCC diagnostic push") \
@@ -254,8 +317,7 @@ struct scalar_round_op_google {
EIGEN_STATIC_ASSERT((!NumTraits<Scalar>::IsComplex),
NUMERIC_TYPE_MUST_BE_REAL)
- Scalar round_val;
- round_val = Eigen::numext::floor(x);
+ Scalar round_val = Eigen::numext::floor(x);
const Scalar fraction = x - round_val;
if (fraction > Scalar(.5)) {
round_val += Scalar(1.0);
@@ -499,6 +561,24 @@ struct safe_mod : base<T, Eigen::internal::safe_div_or_mod_op<
};
template <typename T>
+struct floor_fmod : base<T, Eigen::internal::google_floor_fmod<T>> {};
+
+template <typename T>
+struct safe_floor_mod : base<T, Eigen::internal::safe_div_or_mod_op<
+ T, Eigen::internal::google_floor_mod<T>>> {
+ static const bool has_errors = true;
+};
+
+template <typename T>
+struct floor_div : base<T, Eigen::internal::google_floor_div<T>> {};
+
+template <typename T>
+struct safe_floor_div : base<T, Eigen::internal::safe_div_or_mod_op<
+ T, Eigen::internal::google_floor_div<T>>> {
+ static const bool has_errors = true;
+};
+
+template <typename T>
struct pow : base<T, Eigen::internal::scalar_binary_pow_op_google<T, T>> {};
template <typename T>
diff --git a/tensorflow/core/kernels/cwise_ops_common.cc b/tensorflow/core/kernels/cwise_ops_common.cc
index 79bb223ccb..c675faeea1 100644
--- a/tensorflow/core/kernels/cwise_ops_common.cc
+++ b/tensorflow/core/kernels/cwise_ops_common.cc
@@ -35,7 +35,7 @@ void BinaryOpShared::SetComputeError(OpKernelContext* ctx) {
// ops that have compute errors are integer division and mod, and the only
// error they produce is zero division.
const string& op = ctx->op_kernel().type_string();
- if ((op == "Div" || op == "Mod") &&
+ if ((op == "Div" || op == "Mod" || op == "FloorMod" || op == "FloorDiv") &&
DataTypeIsInteger(ctx->op_kernel().input_type(0))) {
ctx->CtxFailure(errors::InvalidArgument("Integer division by zero"));
} else {
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index d684ef242b..115fdae393 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -504,6 +504,43 @@ Returns x / y element-wise.
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
)doc");
+REGISTER_OP("FloorDiv")
+ .BINARY_MORE()
+ .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
+ .Doc(R"doc(
+Returns x // y element-wise.
+
+*NOTE*: `FloorDiv` supports broadcasting. More about broadcasting
+[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+)doc");
+
+REGISTER_OP("TruncateDiv")
+ .BINARY_MORE()
+ .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
+ .Doc(R"doc(
+Returns x / y element-wise for integer types.
+
+Truncation designates that negative numbers will round fractional quantities
+toward zero. I.e. -7 / 5 = 1. This matches C semantics but it is different
+than Python semantics. See `FloorDiv` for a division function that matches
+Python Semantics.
+
+*NOTE*: `TruncateDiv` supports broadcasting. More about broadcasting
+[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+)doc");
+
+REGISTER_OP("RealDiv")
+ .BINARY_MORE()
+ .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
+ .Doc(R"doc(
+Returns x / y element-wise for real types.
+
+If `x` and `y` are reals, this will return the floating-point division.
+
+*NOTE*: `Div` supports broadcasting. More about broadcasting
+[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+)doc");
+
REGISTER_OP("SquaredDifference")
.BINARY_FEWER()
.SetIsCommutative()
@@ -559,6 +596,37 @@ Returns element-wise remainder of division.
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
)doc");
+REGISTER_OP("FloorMod")
+ .Input("x: T")
+ .Input("y: T")
+ .Output("z: T")
+ .Attr("T: {int32, int64, float, double}")
+ .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
+ .Doc(R"doc(
+Returns element-wise remainder of division. When `x < 0` xor `y < 0` is
+true, this follows Python semantics in that the result here is consistent
+with a flooring divide. E.g. `floor(x / y) * y + mod(x, y) = x`.
+
+*NOTE*: `FloorMod` supports broadcasting. More about broadcasting
+[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+)doc");
+
+REGISTER_OP("TruncateMod")
+ .Input("x: T")
+ .Input("y: T")
+ .Output("z: T")
+ .Attr("T: {int32, int64, float, double}")
+ .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
+ .Doc(R"doc(
+Returns element-wise remainder of division. This emulates C semantics where
+
+true, this follows C semantics in that the result here is consistent
+with a flooring divide. E.g. `floor(x / y) * y + mod(x, y) = x`.
+
+*NOTE*: `Mod` supports broadcasting. More about broadcasting
+[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+)doc");
+
REGISTER_OP("Pow")
.Input("x: T")
.Input("y: T")