diff options
4 files changed, 105 insertions, 229 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h index df4d871466..7f0676be27 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h @@ -46,8 +46,8 @@ inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims, inline void Relu(const float* input_data, const Dims<4>& input_dims, float* output_data, const Dims<4>& output_dims) { - Relu(DimsToShape(input_dims), input_data, DimsToShape(output_dims), - output_data); + Relu(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); } // legacy, for compatibility with old checked-in code @@ -580,8 +580,8 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, inline void Logistic(const float* input_data, const Dims<4>& input_dims, float* output_data, const Dims<4>& output_dims) { - Logistic(DimsToShape(input_dims), input_data, DimsToShape(output_dims), - output_data); + Logistic(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); } inline void Logistic(const uint8* input_data, const Dims<4>& input_dims, @@ -601,8 +601,8 @@ inline void Logistic(const int16* input_data, const Dims<4>& input_dims, inline void Tanh(const float* input_data, const Dims<4>& input_dims, float* output_data, const Dims<4>& output_dims) { - Tanh(DimsToShape(input_dims), input_data, DimsToShape(output_dims), - output_data); + Tanh(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); } inline void Tanh(const uint8* input_data, const Dims<4>& input_dims, diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h index f19df5e17e..ca020215e6 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h @@ -2327,8 +2327,8 @@ void GlobalBatchNormalization(const float* input_data, } } -inline void Relu(const RuntimeShape& input_shape, const float* input_data, - const RuntimeShape& output_shape, float* output_data) { +inline void Relu(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("Relu (not fused)"); const auto input = MapAsVector(input_data, input_shape); @@ -4544,8 +4544,8 @@ inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape, } } -inline void Logistic(const RuntimeShape& input_shape, const float* input_data, - const RuntimeShape& output_shape, float* output_data) { +inline void Logistic(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("Logistic"); auto input_map = MapAsVector(input_data, input_shape); auto output_map = MapAsVector(output_data, output_shape); @@ -4690,8 +4690,8 @@ inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape, } } -inline void Logistic(const RuntimeShape& input_shape, const int16* input_data, - const RuntimeShape& output_shape, int16* output_data) { +inline void Logistic(const int16* input_data, const RuntimeShape& input_shape, + int16* output_data, const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("Logistic/Int16"); const int flat_size = MatchingFlatSize(input_shape, output_shape); @@ -4750,14 +4750,8 @@ inline void Logistic(const RuntimeShape& input_shape, const int16* input_data, } } -// Legacy version. -inline void Logistic(const int16* input_data, const RuntimeShape& input_shape, - int16* output_data, const RuntimeShape& output_shape) { - Logistic(input_shape, input_data, output_shape, output_data); -} - -inline void Tanh(const RuntimeShape& input_shape, const float* input_data, - const RuntimeShape& output_shape, float* output_data) { +inline void Tanh(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("Tanh"); auto input_map = MapAsVector(input_data, input_shape); auto output_map = MapAsVector(output_data, output_shape); @@ -5020,19 +5014,12 @@ inline void Cast(const SrcT* input_data, const Dims<4>& input_dims, output_map.array() = input_map.array().template cast<DstT>(); } -inline void Floor(const RuntimeShape& input_shape, const float* input_data, - const RuntimeShape& output_shape, float* output_data) { - gemmlowp::ScopedProfilingLabel label("Floor"); - auto input_map = MapAsVector(input_data, input_shape); - auto output_map = MapAsVector(output_data, output_shape); - output_map.array() = Eigen::floor(input_map.array()); -} - -// Legacy Dims<4> version. inline void Floor(const float* input_data, const Dims<4>& input_dims, float* output_data, const Dims<4>& output_dims) { - Floor(DimsToShape(input_dims), input_data, DimsToShape(output_dims), - output_data); + gemmlowp::ScopedProfilingLabel label("Floor"); + auto input_map = MapAsVector(input_data, input_dims); + auto output_map = MapAsVector(output_data, output_dims); + output_map.array() = Eigen::floor(input_map.array()); } #ifdef USE_NEON diff --git a/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h index 71ae74f34c..b862ae38c7 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h @@ -42,20 +42,20 @@ inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims, inline void Relu(const float* input_data, const Dims<4>& input_dims, float* output_data, const Dims<4>& output_dims) { - Relu(DimsToShape(input_dims), input_data, DimsToShape(output_dims), - output_data); + Relu(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); } inline void Relu1(const float* input_data, const Dims<4>& input_dims, float* output_data, const Dims<4>& output_dims) { - Relu1(DimsToShape(input_dims), input_data, DimsToShape(output_dims), - output_data); + Relu1(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); } inline void Relu6(const float* input_data, const Dims<4>& input_dims, float* output_data, const Dims<4>& output_dims) { - Relu6(DimsToShape(input_dims), input_data, DimsToShape(output_dims), - output_data); + Relu6(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); } template <FusedActivationFunctionType Ac> @@ -583,8 +583,8 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, inline void Logistic(const float* input_data, const Dims<4>& input_dims, float* output_data, const Dims<4>& output_dims) { - Logistic(DimsToShape(input_dims), input_data, DimsToShape(output_dims), - output_data); + Logistic(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); } inline void Logistic(const uint8* input_data, const Dims<4>& input_dims, @@ -598,14 +598,14 @@ inline void Logistic(const uint8* input_data, const Dims<4>& input_dims, inline void Logistic(const int16* input_data, const Dims<4>& input_dims, int16* output_data, const Dims<4>& output_dims) { - Logistic(DimsToShape(input_dims), input_data, DimsToShape(output_dims), - output_data); + Logistic(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); } inline void Tanh(const float* input_data, const Dims<4>& input_dims, float* output_data, const Dims<4>& output_dims) { - Tanh(DimsToShape(input_dims), input_data, DimsToShape(output_dims), - output_data); + Tanh(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); } inline void Tanh(const uint8* input_data, const Dims<4>& input_dims, diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index ef7953dded..5634b8384a 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -846,8 +846,8 @@ void GlobalBatchNormalization(const float* input_data, } } -inline void Relu(const RuntimeShape& input_shape, const float* input_data, - const RuntimeShape& output_shape, float* output_data) { +inline void Relu(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; ++i) { const float val = input_data[i]; @@ -857,8 +857,8 @@ inline void Relu(const RuntimeShape& input_shape, const float* input_data, } } -inline void Relu1(const RuntimeShape& input_shape, const float* input_data, - const RuntimeShape& output_shape, float* output_data) { +inline void Relu1(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("Relu1 (not fused)"); const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; ++i) { @@ -870,8 +870,8 @@ inline void Relu1(const RuntimeShape& input_shape, const float* input_data, } } -inline void Relu6(const RuntimeShape& input_shape, const float* input_data, - const RuntimeShape& output_shape, float* output_data) { +inline void Relu6(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("Relu6 (not fused)"); const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; ++i) { @@ -3118,8 +3118,8 @@ inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape, } } -inline void Logistic(const RuntimeShape& input_shape, const float* input_data, - const RuntimeShape& output_shape, float* output_data) { +inline void Logistic(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; i++) { @@ -3167,8 +3167,8 @@ inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape, } } -inline void Logistic(const RuntimeShape& input_shape, const int16* input_data, - const RuntimeShape& output_shape, int16* output_data) { +inline void Logistic(const int16* input_data, const RuntimeShape& input_shape, + int16* output_data, const RuntimeShape& output_shape) { const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; i++) { @@ -3185,8 +3185,8 @@ inline void Logistic(const RuntimeShape& input_shape, const int16* input_data, } } -inline void Tanh(const RuntimeShape& input_shape, const float* input_data, - const RuntimeShape& output_shape, float* output_data) { +inline void Tanh(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; i++) { @@ -4070,24 +4070,21 @@ void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims, } template <typename T, typename Op> -void MaximumMinimumBroadcast4DSlow(const RuntimeShape& input1_shape, - const T* input1_data, - const RuntimeShape& input2_shape, - const T* input2_data, - const RuntimeShape& output_shape, - T* output_data, Op op) { +void TensorFlowMaximumMinimum(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T* output_data, const Dims<4>& output_dims, + Op op) { NdArrayDesc<4> desc1; NdArrayDesc<4> desc2; - NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, - &desc2); + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); - for (int b = 0; b < output_shape.Dims(0); ++b) { - for (int y = 0; y < output_shape.Dims(1); ++y) { - for (int x = 0; x < output_shape.Dims(2); ++x) { - for (int c = 0; c < output_shape.Dims(3); ++c) { - auto out_idx = Offset(output_shape, b, y, x, c); - auto in1_idx = SubscriptToIndex(desc1, b, y, x, c); - auto in2_idx = SubscriptToIndex(desc2, b, y, x, c); + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + auto out_idx = Offset(output_dims, c, x, y, b); + auto in1_idx = SubscriptToIndex(desc1, c, x, y, b); + auto in2_idx = SubscriptToIndex(desc2, c, x, y, b); auto in1_val = input1_data[in1_idx]; auto in2_val = input2_data[in2_idx]; output_data[out_idx] = op(in1_val, in2_val); @@ -4097,20 +4094,9 @@ void MaximumMinimumBroadcast4DSlow(const RuntimeShape& input1_shape, } } -template <typename T, typename Op> -void TensorFlowMaximumMinimum(const T* input1_data, const Dims<4>& input1_dims, - const T* input2_data, const Dims<4>& input2_dims, - T* output_data, const Dims<4>& output_dims, - Op op) { - MaximumMinimumBroadcast4DSlow(DimsToShape(input1_dims), input1_data, - DimsToShape(input2_dims), input2_data, - DimsToShape(output_dims), output_data, op); -} - template <typename T1, typename T2, typename T3, typename Cmp> -void ArgMinMax(const T3* axis, const RuntimeShape& input_shape, - const T1* input_data, const RuntimeShape& output_shape, - T2* output_data, const Cmp& cmp) { +void ArgMinMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims, + T2* output_data, const Dims<4>& output_dims, const Cmp& cmp) { // The current ArgMax implemention can only determine the index of the maximum // value in the last dimension. So the axis argument is ignored. @@ -4118,11 +4104,9 @@ void ArgMinMax(const T3* axis, const RuntimeShape& input_shape, // 1). For the sake of simplicity, the output dimensions are equal to the // input dimensions here. We enforce the constraint that the last dimension // must always be 1. - TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); - TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); - TFLITE_DCHECK_EQ(output_shape.Dims(3), 1); - const int outer_size = MatchingFlatSizeSkipDim(input_shape, 3, output_shape); - const int depth = input_shape.Dims(3); + TFLITE_DCHECK_EQ(ArraySize(output_dims, 0), 1); + const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); + const int depth = ArraySize(input_dims, 0); for (int i = 0; i < outer_size; ++i) { auto min_max_value = input_data[i * depth]; @@ -4138,15 +4122,6 @@ void ArgMinMax(const T3* axis, const RuntimeShape& input_shape, } } -// Legacy Dims<4> version. -template <typename T1, typename T2, typename T3, typename Cmp> -void ArgMinMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims, - T2* output_data, const Dims<4>& output_dims, const Cmp& cmp) { - ArgMinMax(axis, DimsToShape(input_dims), input_data, DimsToShape(output_dims), - output_data, cmp); -} - -// Legacy. // TODO(renjieliu): Remove this one. template <typename T1, typename T2, typename T3> void ArgMax(const T3* axis, const T1* input_data, @@ -4279,26 +4254,16 @@ template <typename T> using ComparisonFn = bool (*)(T, T); template <typename T, ComparisonFn<T> F> -inline void Comparison(const RuntimeShape& input1_shape, const T* input1_data, - const RuntimeShape& input2_shape, const T* input2_data, - const RuntimeShape& output_shape, bool* output_data) { +inline void Comparison(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + bool* output_data, const Dims<4>& output_dims) { const int64_t flatsize = - MatchingFlatSize(input1_shape, input2_shape, output_shape); + MatchingFlatSize(input1_dims, input2_dims, output_dims); for (int64_t i = 0; i < flatsize; ++i) { output_data[i] = F(input1_data[i], input2_data[i]); } } -// Legacy Dims<4> version. -template <typename T, ComparisonFn<T> F> -inline void Comparison(const T* input1_data, const Dims<4>& input1_dims, - const T* input2_data, const Dims<4>& input2_dims, - bool* output_data, const Dims<4>& output_dims) { - Comparison<T, F>(DimsToShape(input1_dims), input1_data, - DimsToShape(input2_dims), input2_data, - DimsToShape(output_dims), output_data); -} - template <typename T, ComparisonFn<int32> F> inline void Comparison(int left_shift, const T* input1_data, const Dims<4>& input1_dims, int32 input1_offset, @@ -4509,156 +4474,69 @@ inline void SparseToDense(const std::vector<std::vector<TI>>& indices, } template <typename T> -inline void Pow(const RuntimeShape& input1_shape, const T* input1_data, - const RuntimeShape& input2_shape, const T* input2_data, - const RuntimeShape& output_shape, T* output_data) { - const int flat_size = - MatchingFlatSize(input1_shape, input2_shape, output_shape); - for (int i = 0; i < flat_size; ++i) { - output_data[i] = std::pow(input1_data[i], input2_data[i]); - } -} - -// Legacy Dims<4> version. -template <typename T> inline void Pow(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, const Dims<4>& input2_dims, T* output_data, const Dims<4>& output_dims) { - Pow(DimsToShape(input1_dims), input1_data, DimsToShape(input2_dims), - input2_data, DimsToShape(output_dims), output_data); + const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims); + for (int i = 0; i < flat_size; ++i) { + output_data[i] = std::pow(input1_data[i], input2_data[i]); + } } template <typename T> -inline void BroadcastPow4DSlow(const RuntimeShape& input1_shape, - const T* input1_data, - const RuntimeShape& input2_shape, - const T* input2_data, - const RuntimeShape& output_shape, - T* output_data) { +inline void BroadcastPow(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T* output_data, const Dims<4>& output_dims) { NdArrayDesc<4> desc1; NdArrayDesc<4> desc2; - NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, - &desc2); - - for (int b = 0; b < output_shape.Dims(0); ++b) { - for (int y = 0; y < output_shape.Dims(1); ++y) { - for (int x = 0; x < output_shape.Dims(2); ++x) { - for (int c = 0; c < output_shape.Dims(3); ++c) { - auto out_idx = Offset(output_shape, b, y, x, c); - auto in1_idx = SubscriptToIndex(desc1, b, y, x, c); - auto in2_idx = SubscriptToIndex(desc2, b, y, x, c); - auto in1_val = input1_data[in1_idx]; - auto in2_val = input2_data[in2_idx]; - output_data[out_idx] = std::pow(in1_val, in2_val); + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + std::pow(input1_data[SubscriptToIndex(desc1, c, x, y, b)], + input2_data[SubscriptToIndex(desc2, c, x, y, b)]); } } } } } -// Legacy Dims<4> version. -template <typename T> -inline void BroadcastPow(const T* input1_data, const Dims<4>& input1_dims, - const T* input2_data, const Dims<4>& input2_dims, - T* output_data, const Dims<4>& output_dims) { - BroadcastPow4DSlow(DimsToShape(input1_dims), input1_data, - DimsToShape(input2_dims), input2_data, - DimsToShape(output_dims), output_data); -} - -inline void Logical(const RuntimeShape& input1_shape, const bool* input1_data, - const RuntimeShape& input2_shape, const bool* input2_data, - const RuntimeShape& output_shape, bool* output_data, - const std::function<bool(bool, bool)>& func) { - const int flat_size = - MatchingFlatSize(input1_shape, input2_shape, output_shape); - for (int i = 0; i < flat_size; ++i) { - output_data[i] = func(input1_data[i], input2_data[i]); - } -} - -// Legacy Dims<4> version. inline void Logical(const bool* input1_data, const Dims<4>& input1_dims, const bool* input2_data, const Dims<4>& input2_dims, bool* output_data, const Dims<4>& output_dims, const std::function<bool(bool, bool)>& func) { - Logical(DimsToShape(input1_dims), input1_data, DimsToShape(input2_dims), - input2_data, DimsToShape(output_dims), output_data, func); -} - -inline void BroadcastLogical4DSlow( - const RuntimeShape& input1_shape, const bool* input1_data, - const RuntimeShape& input2_shape, const bool* input2_data, - const RuntimeShape& output_shape, bool* output_data, - const std::function<bool(bool, bool)>& func) { - NdArrayDesc<4> desc1; - NdArrayDesc<4> desc2; - NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, - &desc2); - - for (int b = 0; b < output_shape.Dims(0); ++b) { - for (int y = 0; y < output_shape.Dims(1); ++y) { - for (int x = 0; x < output_shape.Dims(2); ++x) { - for (int c = 0; c < output_shape.Dims(3); ++c) { - auto out_idx = Offset(output_shape, b, y, x, c); - auto in1_idx = SubscriptToIndex(desc1, b, y, x, c); - auto in2_idx = SubscriptToIndex(desc2, b, y, x, c); - auto in1_val = input1_data[in1_idx]; - auto in2_val = input2_data[in2_idx]; - output_data[out_idx] = func(in1_val, in2_val); - } - } - } + const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims); + for (int i = 0; i < flat_size; ++i) { + output_data[i] = func(input1_data[i], input2_data[i]); } } -// Legacy Dims<4> version. inline void BroadcastLogical(const bool* input1_data, const Dims<4>& input1_dims, const bool* input2_data, const Dims<4>& input2_dims, bool* output_data, const Dims<4>& output_dims, const std::function<bool(bool, bool)>& func) { - BroadcastLogical4DSlow(DimsToShape(input1_dims), input1_data, - DimsToShape(input2_dims), input2_data, - DimsToShape(output_dims), output_data, func); -} - -// TODO(ycling): Refactoring. Remove BroadcastLogical and use the more -// generalized and efficient BroadcastBinaryFunction. -// -// Also appears to duplicte MinimumMaximum. -// -// R: Result type. T1: Input 1 type. T2: Input 2 type. -template <typename R, typename T1, typename T2> -inline void BroadcastBinaryFunction4DSlow(const RuntimeShape& input1_shape, - const T1* input1_data, - const RuntimeShape& input2_shape, - const T2* input2_data, - const RuntimeShape& output_shape, - R* output_data, R (*func)(T1, T2)) { NdArrayDesc<4> desc1; NdArrayDesc<4> desc2; - NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, - &desc2); - - for (int b = 0; b < output_shape.Dims(0); ++b) { - for (int y = 0; y < output_shape.Dims(1); ++y) { - for (int x = 0; x < output_shape.Dims(2); ++x) { - for (int c = 0; c < output_shape.Dims(3); ++c) { - auto out_idx = Offset(output_shape, b, y, x, c); - auto in1_idx = SubscriptToIndex(desc1, b, y, x, c); - auto in2_idx = SubscriptToIndex(desc2, b, y, x, c); - auto in1_val = input1_data[in1_idx]; - auto in2_val = input2_data[in2_idx]; - output_data[out_idx] = func(in1_val, in2_val); + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + func(input1_data[SubscriptToIndex(desc1, c, x, y, b)], + input2_data[SubscriptToIndex(desc2, c, x, y, b)]); } } } } } -// Legacy Dims<4> version. +// TODO(ycling): Refactoring. Remove BroadcastLogical and use the more +// generalized and efficient BroadcastBinaryFunction. // // R: Result type. T1: Input 1 type. T2: Input 2 type. template <typename R, typename T1, typename T2> @@ -4668,9 +4546,20 @@ inline void BroadcastBinaryFunction(const T1* input1_data, const Dims<4>& input2_dims, R* output_data, const Dims<4>& output_dims, R (*func)(T1, T2)) { - BroadcastBinaryFunction4DSlow(DimsToShape(input1_dims), input1_data, - DimsToShape(input2_dims), input2_data, - DimsToShape(output_dims), output_data, func); + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + func(input1_data[SubscriptToIndex(desc1, c, x, y, b)], + input2_data[SubscriptToIndex(desc2, c, x, y, b)]); + } + } + } + } } } // namespace reference_ops |