diff options
-rw-r--r-- | tensorflow/core/kernels/cwise_ops_common.h | 75 | ||||
-rw-r--r-- | tensorflow/core/kernels/extract_image_patches_op.h | 5 |
2 files changed, 63 insertions, 17 deletions
diff --git a/tensorflow/core/kernels/cwise_ops_common.h b/tensorflow/core/kernels/cwise_ops_common.h index 2454620776..8295fa939e 100644 --- a/tensorflow/core/kernels/cwise_ops_common.h +++ b/tensorflow/core/kernels/cwise_ops_common.h @@ -305,6 +305,62 @@ struct BinaryFunctor<CPUDevice, Functor, NDIMS, false> { Assign(d, out, in.unaryExpr(Unary(scalar.data()))); } + void BCast(const CPUDevice& dev, + typename TTypes<typename Functor::out_type, NDIMS>::Tensor out, + typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0, + typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0, + typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1, + typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1, + bool* error) { + typename Functor::func func; + if (AllOne<NDIMS>(bcast0) && AllOne<NDIMS>(bcast1)) { + Assign(dev, out, in0.binaryExpr(in1, func)); + } else if (AllOne<NDIMS>(bcast0)) { + auto rhs = in1.broadcast(bcast1); + Assign(dev, out, in0.binaryExpr(rhs, func)); + } else if (AllOne<NDIMS>(bcast1)) { + auto lhs = in0.broadcast(bcast0); + Assign(dev, out, lhs.binaryExpr(in1, func)); + } else { + auto lhs = in0.broadcast(bcast0); + auto rhs = in1.broadcast(bcast1); + Assign(dev, out, lhs.binaryExpr(rhs, func)); + } + } +}; + +// Partial specialization of BinaryFunctor<Device=CPUDevice, Functor, 2> +// for functors with with no error checking. +template <typename Functor> +struct BinaryFunctor<CPUDevice, Functor, 2, false> { + enum { NDIMS = 2 }; + + void operator()(const CPUDevice& d, typename Functor::tout_type out, + typename Functor::tin_type in0, + typename Functor::tin_type in1, bool* error) { + Assign(d, out, in0.binaryExpr(in1, typename Functor::func())); + } + + void Left(const CPUDevice& d, typename Functor::tout_type out, + typename Functor::tscalar_type scalar, + typename Functor::tin_type in, bool* error) { + typedef typename Functor::out_type Tout; + typedef typename Functor::in_type Tin; + typedef typename Functor::func Binary; + typedef typename Eigen::internal::scalar_left<Tout, Tin, Binary> Unary; + Assign(d, out, in.unaryExpr(Unary(scalar.data()))); + } + + void Right(const CPUDevice& d, typename Functor::tout_type out, + typename Functor::tin_type in, + typename Functor::tscalar_type scalar, bool* error) { + typedef typename Functor::out_type Tout; + typedef typename Functor::in_type Tin; + typedef typename Functor::func Binary; + typedef typename Eigen::internal::scalar_right<Tout, Tin, Binary> Unary; + Assign(d, out, in.unaryExpr(Unary(scalar.data()))); + } + #if !defined(EIGEN_HAS_INDEX_LIST) inline Eigen::DSizes<int, 2> NByOne(int n) { return Eigen::DSizes<int, 2>(n, 1); @@ -334,8 +390,7 @@ struct BinaryFunctor<CPUDevice, Functor, NDIMS, false> { bool* error) { typedef typename Functor::in_type T; typename Functor::func func; - if ((NDIMS == 2) && Functor::use_bcast_optimization && - use_bcast_optimization<T>::value) { + if (Functor::use_bcast_optimization && use_bcast_optimization<T>::value) { // Optimize for speed by using Eigen::type2index and avoid // .broadcast() when we know its a no-op. // @@ -411,19 +466,9 @@ struct BinaryFunctor<CPUDevice, Functor, NDIMS, false> { } // Fallback path. Always works and probably slower. - if (AllOne<NDIMS>(bcast0) && AllOne<NDIMS>(bcast1)) { - Assign(dev, out, in0.binaryExpr(in1, func)); - } else if (AllOne<NDIMS>(bcast0)) { - auto rhs = in1.broadcast(bcast1); - Assign(dev, out, in0.binaryExpr(rhs, func)); - } else if (AllOne<NDIMS>(bcast1)) { - auto lhs = in0.broadcast(bcast0); - Assign(dev, out, lhs.binaryExpr(in1, func)); - } else { - auto lhs = in0.broadcast(bcast0); - auto rhs = in1.broadcast(bcast1); - Assign(dev, out, lhs.binaryExpr(rhs, func)); - } + auto lhs = in0.broadcast(bcast0); + auto rhs = in1.broadcast(bcast1); + Assign(dev, out, lhs.binaryExpr(rhs, func)); } }; diff --git a/tensorflow/core/kernels/extract_image_patches_op.h b/tensorflow/core/kernels/extract_image_patches_op.h index 9d34daca64..e430a23d20 100644 --- a/tensorflow/core/kernels/extract_image_patches_op.h +++ b/tensorflow/core/kernels/extract_image_patches_op.h @@ -34,11 +34,12 @@ struct ExtractImagePatchesForward { // NHWC format while Eigen assumes NWHC format. const int64 N = std::max(input.size(), output.size()); if (N <= std::numeric_limits<Index32>::max()) { - To32Bit(output).device(d) = + auto output_32bit = To32Bit(output); + output_32bit.device(d) = To32Bit(input) .extract_image_patches(patch_cols, patch_rows, stride_cols, stride_rows, rate_cols, rate_rows, padding) - .reshape(output.dimensions()); + .reshape(output_32bit.dimensions()); } else { output.device(d) = input |