aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/cwise_ops_common.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-25 09:47:05 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-25 09:52:40 -0700
commitf1f60ac3e59b7cbbd2badef11cc3da42064fc695 (patch)
treeaf27e499aa54ae15d8e5ccba9f85af4f34b01d0f /tensorflow/core/kernels/cwise_ops_common.h
parent4f86cf60254126d28f5f653d810de0a2cf1473c8 (diff)
1. Separate the special case BinaryFunctor when NDIMS == 2 into a template specialization. This prevents the NDIMS==2 optimization code from being compiled in the general case, which can lead to compile time errors when the underlying Eigen implementation becomes more strict about NDIMS.
2. Fix the 64-bit dimension() call on output in extract_image_patches_op.h when other operands have been cast to use 32-bit index. PiperOrigin-RevId: 173409602
Diffstat (limited to 'tensorflow/core/kernels/cwise_ops_common.h')
-rw-r--r--tensorflow/core/kernels/cwise_ops_common.h75
1 files changed, 60 insertions, 15 deletions
diff --git a/tensorflow/core/kernels/cwise_ops_common.h b/tensorflow/core/kernels/cwise_ops_common.h
index 2454620776..8295fa939e 100644
--- a/tensorflow/core/kernels/cwise_ops_common.h
+++ b/tensorflow/core/kernels/cwise_ops_common.h
@@ -305,6 +305,62 @@ struct BinaryFunctor<CPUDevice, Functor, NDIMS, false> {
Assign(d, out, in.unaryExpr(Unary(scalar.data())));
}
+ void BCast(const CPUDevice& dev,
+ typename TTypes<typename Functor::out_type, NDIMS>::Tensor out,
+ typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0,
+ typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0,
+ typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1,
+ typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1,
+ bool* error) {
+ typename Functor::func func;
+ if (AllOne<NDIMS>(bcast0) && AllOne<NDIMS>(bcast1)) {
+ Assign(dev, out, in0.binaryExpr(in1, func));
+ } else if (AllOne<NDIMS>(bcast0)) {
+ auto rhs = in1.broadcast(bcast1);
+ Assign(dev, out, in0.binaryExpr(rhs, func));
+ } else if (AllOne<NDIMS>(bcast1)) {
+ auto lhs = in0.broadcast(bcast0);
+ Assign(dev, out, lhs.binaryExpr(in1, func));
+ } else {
+ auto lhs = in0.broadcast(bcast0);
+ auto rhs = in1.broadcast(bcast1);
+ Assign(dev, out, lhs.binaryExpr(rhs, func));
+ }
+ }
+};
+
+// Partial specialization of BinaryFunctor<Device=CPUDevice, Functor, 2>
+// for functors with with no error checking.
+template <typename Functor>
+struct BinaryFunctor<CPUDevice, Functor, 2, false> {
+ enum { NDIMS = 2 };
+
+ void operator()(const CPUDevice& d, typename Functor::tout_type out,
+ typename Functor::tin_type in0,
+ typename Functor::tin_type in1, bool* error) {
+ Assign(d, out, in0.binaryExpr(in1, typename Functor::func()));
+ }
+
+ void Left(const CPUDevice& d, typename Functor::tout_type out,
+ typename Functor::tscalar_type scalar,
+ typename Functor::tin_type in, bool* error) {
+ typedef typename Functor::out_type Tout;
+ typedef typename Functor::in_type Tin;
+ typedef typename Functor::func Binary;
+ typedef typename Eigen::internal::scalar_left<Tout, Tin, Binary> Unary;
+ Assign(d, out, in.unaryExpr(Unary(scalar.data())));
+ }
+
+ void Right(const CPUDevice& d, typename Functor::tout_type out,
+ typename Functor::tin_type in,
+ typename Functor::tscalar_type scalar, bool* error) {
+ typedef typename Functor::out_type Tout;
+ typedef typename Functor::in_type Tin;
+ typedef typename Functor::func Binary;
+ typedef typename Eigen::internal::scalar_right<Tout, Tin, Binary> Unary;
+ Assign(d, out, in.unaryExpr(Unary(scalar.data())));
+ }
+
#if !defined(EIGEN_HAS_INDEX_LIST)
inline Eigen::DSizes<int, 2> NByOne(int n) {
return Eigen::DSizes<int, 2>(n, 1);
@@ -334,8 +390,7 @@ struct BinaryFunctor<CPUDevice, Functor, NDIMS, false> {
bool* error) {
typedef typename Functor::in_type T;
typename Functor::func func;
- if ((NDIMS == 2) && Functor::use_bcast_optimization &&
- use_bcast_optimization<T>::value) {
+ if (Functor::use_bcast_optimization && use_bcast_optimization<T>::value) {
// Optimize for speed by using Eigen::type2index and avoid
// .broadcast() when we know its a no-op.
//
@@ -411,19 +466,9 @@ struct BinaryFunctor<CPUDevice, Functor, NDIMS, false> {
}
// Fallback path. Always works and probably slower.
- if (AllOne<NDIMS>(bcast0) && AllOne<NDIMS>(bcast1)) {
- Assign(dev, out, in0.binaryExpr(in1, func));
- } else if (AllOne<NDIMS>(bcast0)) {
- auto rhs = in1.broadcast(bcast1);
- Assign(dev, out, in0.binaryExpr(rhs, func));
- } else if (AllOne<NDIMS>(bcast1)) {
- auto lhs = in0.broadcast(bcast0);
- Assign(dev, out, lhs.binaryExpr(in1, func));
- } else {
- auto lhs = in0.broadcast(bcast0);
- auto rhs = in1.broadcast(bcast1);
- Assign(dev, out, lhs.binaryExpr(rhs, func));
- }
+ auto lhs = in0.broadcast(bcast0);
+ auto rhs = in1.broadcast(bcast1);
+ Assign(dev, out, lhs.binaryExpr(rhs, func));
}
};