aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/kernels/cwise_ops_common.h75
-rw-r--r--tensorflow/core/kernels/extract_image_patches_op.h5
2 files changed, 63 insertions, 17 deletions
diff --git a/tensorflow/core/kernels/cwise_ops_common.h b/tensorflow/core/kernels/cwise_ops_common.h
index 2454620776..8295fa939e 100644
--- a/tensorflow/core/kernels/cwise_ops_common.h
+++ b/tensorflow/core/kernels/cwise_ops_common.h
@@ -305,6 +305,62 @@ struct BinaryFunctor<CPUDevice, Functor, NDIMS, false> {
Assign(d, out, in.unaryExpr(Unary(scalar.data())));
}
+ void BCast(const CPUDevice& dev,
+ typename TTypes<typename Functor::out_type, NDIMS>::Tensor out,
+ typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0,
+ typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0,
+ typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1,
+ typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1,
+ bool* error) {
+ typename Functor::func func;
+ if (AllOne<NDIMS>(bcast0) && AllOne<NDIMS>(bcast1)) {
+ Assign(dev, out, in0.binaryExpr(in1, func));
+ } else if (AllOne<NDIMS>(bcast0)) {
+ auto rhs = in1.broadcast(bcast1);
+ Assign(dev, out, in0.binaryExpr(rhs, func));
+ } else if (AllOne<NDIMS>(bcast1)) {
+ auto lhs = in0.broadcast(bcast0);
+ Assign(dev, out, lhs.binaryExpr(in1, func));
+ } else {
+ auto lhs = in0.broadcast(bcast0);
+ auto rhs = in1.broadcast(bcast1);
+ Assign(dev, out, lhs.binaryExpr(rhs, func));
+ }
+ }
+};
+
+// Partial specialization of BinaryFunctor<Device=CPUDevice, Functor, 2>
+// for functors with with no error checking.
+template <typename Functor>
+struct BinaryFunctor<CPUDevice, Functor, 2, false> {
+ enum { NDIMS = 2 };
+
+ void operator()(const CPUDevice& d, typename Functor::tout_type out,
+ typename Functor::tin_type in0,
+ typename Functor::tin_type in1, bool* error) {
+ Assign(d, out, in0.binaryExpr(in1, typename Functor::func()));
+ }
+
+ void Left(const CPUDevice& d, typename Functor::tout_type out,
+ typename Functor::tscalar_type scalar,
+ typename Functor::tin_type in, bool* error) {
+ typedef typename Functor::out_type Tout;
+ typedef typename Functor::in_type Tin;
+ typedef typename Functor::func Binary;
+ typedef typename Eigen::internal::scalar_left<Tout, Tin, Binary> Unary;
+ Assign(d, out, in.unaryExpr(Unary(scalar.data())));
+ }
+
+ void Right(const CPUDevice& d, typename Functor::tout_type out,
+ typename Functor::tin_type in,
+ typename Functor::tscalar_type scalar, bool* error) {
+ typedef typename Functor::out_type Tout;
+ typedef typename Functor::in_type Tin;
+ typedef typename Functor::func Binary;
+ typedef typename Eigen::internal::scalar_right<Tout, Tin, Binary> Unary;
+ Assign(d, out, in.unaryExpr(Unary(scalar.data())));
+ }
+
#if !defined(EIGEN_HAS_INDEX_LIST)
inline Eigen::DSizes<int, 2> NByOne(int n) {
return Eigen::DSizes<int, 2>(n, 1);
@@ -334,8 +390,7 @@ struct BinaryFunctor<CPUDevice, Functor, NDIMS, false> {
bool* error) {
typedef typename Functor::in_type T;
typename Functor::func func;
- if ((NDIMS == 2) && Functor::use_bcast_optimization &&
- use_bcast_optimization<T>::value) {
+ if (Functor::use_bcast_optimization && use_bcast_optimization<T>::value) {
// Optimize for speed by using Eigen::type2index and avoid
// .broadcast() when we know its a no-op.
//
@@ -411,19 +466,9 @@ struct BinaryFunctor<CPUDevice, Functor, NDIMS, false> {
}
// Fallback path. Always works and probably slower.
- if (AllOne<NDIMS>(bcast0) && AllOne<NDIMS>(bcast1)) {
- Assign(dev, out, in0.binaryExpr(in1, func));
- } else if (AllOne<NDIMS>(bcast0)) {
- auto rhs = in1.broadcast(bcast1);
- Assign(dev, out, in0.binaryExpr(rhs, func));
- } else if (AllOne<NDIMS>(bcast1)) {
- auto lhs = in0.broadcast(bcast0);
- Assign(dev, out, lhs.binaryExpr(in1, func));
- } else {
- auto lhs = in0.broadcast(bcast0);
- auto rhs = in1.broadcast(bcast1);
- Assign(dev, out, lhs.binaryExpr(rhs, func));
- }
+ auto lhs = in0.broadcast(bcast0);
+ auto rhs = in1.broadcast(bcast1);
+ Assign(dev, out, lhs.binaryExpr(rhs, func));
}
};
diff --git a/tensorflow/core/kernels/extract_image_patches_op.h b/tensorflow/core/kernels/extract_image_patches_op.h
index 9d34daca64..e430a23d20 100644
--- a/tensorflow/core/kernels/extract_image_patches_op.h
+++ b/tensorflow/core/kernels/extract_image_patches_op.h
@@ -34,11 +34,12 @@ struct ExtractImagePatchesForward {
// NHWC format while Eigen assumes NWHC format.
const int64 N = std::max(input.size(), output.size());
if (N <= std::numeric_limits<Index32>::max()) {
- To32Bit(output).device(d) =
+ auto output_32bit = To32Bit(output);
+ output_32bit.device(d) =
To32Bit(input)
.extract_image_patches(patch_cols, patch_rows, stride_cols,
stride_rows, rate_cols, rate_rows, padding)
- .reshape(output.dimensions());
+ .reshape(output_32bit.dimensions());
} else {
output.device(d) =
input