diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-10-25 09:47:05 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-25 09:52:40 -0700 |
commit | f1f60ac3e59b7cbbd2badef11cc3da42064fc695 (patch) | |
tree | af27e499aa54ae15d8e5ccba9f85af4f34b01d0f /tensorflow/core/kernels/cwise_ops_common.h | |
parent | 4f86cf60254126d28f5f653d810de0a2cf1473c8 (diff) |
1. Separate the special case BinaryFunctor when NDIMS == 2 into a template specialization. This prevents the NDIMS==2 optimization code from being compiled in the general case, which can lead to compile time errors when the underlying Eigen implementation becomes more strict about NDIMS.
2. Fix the 64-bit dimension() call on output in extract_image_patches_op.h when other operands have been cast to use 32-bit index.
PiperOrigin-RevId: 173409602
Diffstat (limited to 'tensorflow/core/kernels/cwise_ops_common.h')
-rw-r--r-- | tensorflow/core/kernels/cwise_ops_common.h | 75 |
1 files changed, 60 insertions, 15 deletions
diff --git a/tensorflow/core/kernels/cwise_ops_common.h b/tensorflow/core/kernels/cwise_ops_common.h index 2454620776..8295fa939e 100644 --- a/tensorflow/core/kernels/cwise_ops_common.h +++ b/tensorflow/core/kernels/cwise_ops_common.h @@ -305,6 +305,62 @@ struct BinaryFunctor<CPUDevice, Functor, NDIMS, false> { Assign(d, out, in.unaryExpr(Unary(scalar.data()))); } + void BCast(const CPUDevice& dev, + typename TTypes<typename Functor::out_type, NDIMS>::Tensor out, + typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0, + typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0, + typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1, + typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1, + bool* error) { + typename Functor::func func; + if (AllOne<NDIMS>(bcast0) && AllOne<NDIMS>(bcast1)) { + Assign(dev, out, in0.binaryExpr(in1, func)); + } else if (AllOne<NDIMS>(bcast0)) { + auto rhs = in1.broadcast(bcast1); + Assign(dev, out, in0.binaryExpr(rhs, func)); + } else if (AllOne<NDIMS>(bcast1)) { + auto lhs = in0.broadcast(bcast0); + Assign(dev, out, lhs.binaryExpr(in1, func)); + } else { + auto lhs = in0.broadcast(bcast0); + auto rhs = in1.broadcast(bcast1); + Assign(dev, out, lhs.binaryExpr(rhs, func)); + } + } +}; + +// Partial specialization of BinaryFunctor<Device=CPUDevice, Functor, 2> +// for functors with with no error checking. +template <typename Functor> +struct BinaryFunctor<CPUDevice, Functor, 2, false> { + enum { NDIMS = 2 }; + + void operator()(const CPUDevice& d, typename Functor::tout_type out, + typename Functor::tin_type in0, + typename Functor::tin_type in1, bool* error) { + Assign(d, out, in0.binaryExpr(in1, typename Functor::func())); + } + + void Left(const CPUDevice& d, typename Functor::tout_type out, + typename Functor::tscalar_type scalar, + typename Functor::tin_type in, bool* error) { + typedef typename Functor::out_type Tout; + typedef typename Functor::in_type Tin; + typedef typename Functor::func Binary; + typedef typename Eigen::internal::scalar_left<Tout, Tin, Binary> Unary; + Assign(d, out, in.unaryExpr(Unary(scalar.data()))); + } + + void Right(const CPUDevice& d, typename Functor::tout_type out, + typename Functor::tin_type in, + typename Functor::tscalar_type scalar, bool* error) { + typedef typename Functor::out_type Tout; + typedef typename Functor::in_type Tin; + typedef typename Functor::func Binary; + typedef typename Eigen::internal::scalar_right<Tout, Tin, Binary> Unary; + Assign(d, out, in.unaryExpr(Unary(scalar.data()))); + } + #if !defined(EIGEN_HAS_INDEX_LIST) inline Eigen::DSizes<int, 2> NByOne(int n) { return Eigen::DSizes<int, 2>(n, 1); @@ -334,8 +390,7 @@ struct BinaryFunctor<CPUDevice, Functor, NDIMS, false> { bool* error) { typedef typename Functor::in_type T; typename Functor::func func; - if ((NDIMS == 2) && Functor::use_bcast_optimization && - use_bcast_optimization<T>::value) { + if (Functor::use_bcast_optimization && use_bcast_optimization<T>::value) { // Optimize for speed by using Eigen::type2index and avoid // .broadcast() when we know its a no-op. // @@ -411,19 +466,9 @@ struct BinaryFunctor<CPUDevice, Functor, NDIMS, false> { } // Fallback path. Always works and probably slower. - if (AllOne<NDIMS>(bcast0) && AllOne<NDIMS>(bcast1)) { - Assign(dev, out, in0.binaryExpr(in1, func)); - } else if (AllOne<NDIMS>(bcast0)) { - auto rhs = in1.broadcast(bcast1); - Assign(dev, out, in0.binaryExpr(rhs, func)); - } else if (AllOne<NDIMS>(bcast1)) { - auto lhs = in0.broadcast(bcast0); - Assign(dev, out, lhs.binaryExpr(in1, func)); - } else { - auto lhs = in0.broadcast(bcast0); - auto rhs = in1.broadcast(bcast1); - Assign(dev, out, lhs.binaryExpr(rhs, func)); - } + auto lhs = in0.broadcast(bcast0); + auto rhs = in1.broadcast(bcast1); + Assign(dev, out, lhs.binaryExpr(rhs, func)); } }; |