aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/internal
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-20 07:54:26 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-20 07:59:00 -0700
commit9aab0861f45ae31850cc3d61cb48a628a9a809cc (patch)
tree961e63fd7bd39c047e6061030799aa9cd6b6dd36 /tensorflow/contrib/lite/kernels/internal
parentf32afcd244beb255060c3d489f117289a3496efc (diff)
Automated rollback of commit f32afcd244beb255060c3d489f117289a3496efc
PiperOrigin-RevId: 209416871
Diffstat (limited to 'tensorflow/contrib/lite/kernels/internal')
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h12
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h37
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h24
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h261
4 files changed, 105 insertions, 229 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h
index df4d871466..7f0676be27 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h
@@ -46,8 +46,8 @@ inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims,
inline void Relu(const float* input_data, const Dims<4>& input_dims,
float* output_data, const Dims<4>& output_dims) {
- Relu(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
- output_data);
+ Relu(input_data, DimsToShape(input_dims), output_data,
+ DimsToShape(output_dims));
}
// legacy, for compatibility with old checked-in code
@@ -580,8 +580,8 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
inline void Logistic(const float* input_data, const Dims<4>& input_dims,
float* output_data, const Dims<4>& output_dims) {
- Logistic(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
- output_data);
+ Logistic(input_data, DimsToShape(input_dims), output_data,
+ DimsToShape(output_dims));
}
inline void Logistic(const uint8* input_data, const Dims<4>& input_dims,
@@ -601,8 +601,8 @@ inline void Logistic(const int16* input_data, const Dims<4>& input_dims,
inline void Tanh(const float* input_data, const Dims<4>& input_dims,
float* output_data, const Dims<4>& output_dims) {
- Tanh(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
- output_data);
+ Tanh(input_data, DimsToShape(input_dims), output_data,
+ DimsToShape(output_dims));
}
inline void Tanh(const uint8* input_data, const Dims<4>& input_dims,
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
index f19df5e17e..ca020215e6 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
@@ -2327,8 +2327,8 @@ void GlobalBatchNormalization(const float* input_data,
}
}
-inline void Relu(const RuntimeShape& input_shape, const float* input_data,
- const RuntimeShape& output_shape, float* output_data) {
+inline void Relu(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
gemmlowp::ScopedProfilingLabel label("Relu (not fused)");
const auto input = MapAsVector(input_data, input_shape);
@@ -4544,8 +4544,8 @@ inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
}
}
-inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
- const RuntimeShape& output_shape, float* output_data) {
+inline void Logistic(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
gemmlowp::ScopedProfilingLabel label("Logistic");
auto input_map = MapAsVector(input_data, input_shape);
auto output_map = MapAsVector(output_data, output_shape);
@@ -4690,8 +4690,8 @@ inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape,
}
}
-inline void Logistic(const RuntimeShape& input_shape, const int16* input_data,
- const RuntimeShape& output_shape, int16* output_data) {
+inline void Logistic(const int16* input_data, const RuntimeShape& input_shape,
+ int16* output_data, const RuntimeShape& output_shape) {
gemmlowp::ScopedProfilingLabel label("Logistic/Int16");
const int flat_size = MatchingFlatSize(input_shape, output_shape);
@@ -4750,14 +4750,8 @@ inline void Logistic(const RuntimeShape& input_shape, const int16* input_data,
}
}
-// Legacy version.
-inline void Logistic(const int16* input_data, const RuntimeShape& input_shape,
- int16* output_data, const RuntimeShape& output_shape) {
- Logistic(input_shape, input_data, output_shape, output_data);
-}
-
-inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
- const RuntimeShape& output_shape, float* output_data) {
+inline void Tanh(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
gemmlowp::ScopedProfilingLabel label("Tanh");
auto input_map = MapAsVector(input_data, input_shape);
auto output_map = MapAsVector(output_data, output_shape);
@@ -5020,19 +5014,12 @@ inline void Cast(const SrcT* input_data, const Dims<4>& input_dims,
output_map.array() = input_map.array().template cast<DstT>();
}
-inline void Floor(const RuntimeShape& input_shape, const float* input_data,
- const RuntimeShape& output_shape, float* output_data) {
- gemmlowp::ScopedProfilingLabel label("Floor");
- auto input_map = MapAsVector(input_data, input_shape);
- auto output_map = MapAsVector(output_data, output_shape);
- output_map.array() = Eigen::floor(input_map.array());
-}
-
-// Legacy Dims<4> version.
inline void Floor(const float* input_data, const Dims<4>& input_dims,
float* output_data, const Dims<4>& output_dims) {
- Floor(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
- output_data);
+ gemmlowp::ScopedProfilingLabel label("Floor");
+ auto input_map = MapAsVector(input_data, input_dims);
+ auto output_map = MapAsVector(output_data, output_dims);
+ output_map.array() = Eigen::floor(input_map.array());
}
#ifdef USE_NEON
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
index 71ae74f34c..b862ae38c7 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
@@ -42,20 +42,20 @@ inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims,
inline void Relu(const float* input_data, const Dims<4>& input_dims,
float* output_data, const Dims<4>& output_dims) {
- Relu(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
- output_data);
+ Relu(input_data, DimsToShape(input_dims), output_data,
+ DimsToShape(output_dims));
}
inline void Relu1(const float* input_data, const Dims<4>& input_dims,
float* output_data, const Dims<4>& output_dims) {
- Relu1(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
- output_data);
+ Relu1(input_data, DimsToShape(input_dims), output_data,
+ DimsToShape(output_dims));
}
inline void Relu6(const float* input_data, const Dims<4>& input_dims,
float* output_data, const Dims<4>& output_dims) {
- Relu6(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
- output_data);
+ Relu6(input_data, DimsToShape(input_dims), output_data,
+ DimsToShape(output_dims));
}
template <FusedActivationFunctionType Ac>
@@ -583,8 +583,8 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
inline void Logistic(const float* input_data, const Dims<4>& input_dims,
float* output_data, const Dims<4>& output_dims) {
- Logistic(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
- output_data);
+ Logistic(input_data, DimsToShape(input_dims), output_data,
+ DimsToShape(output_dims));
}
inline void Logistic(const uint8* input_data, const Dims<4>& input_dims,
@@ -598,14 +598,14 @@ inline void Logistic(const uint8* input_data, const Dims<4>& input_dims,
inline void Logistic(const int16* input_data, const Dims<4>& input_dims,
int16* output_data, const Dims<4>& output_dims) {
- Logistic(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
- output_data);
+ Logistic(input_data, DimsToShape(input_dims), output_data,
+ DimsToShape(output_dims));
}
inline void Tanh(const float* input_data, const Dims<4>& input_dims,
float* output_data, const Dims<4>& output_dims) {
- Tanh(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
- output_data);
+ Tanh(input_data, DimsToShape(input_dims), output_data,
+ DimsToShape(output_dims));
}
inline void Tanh(const uint8* input_data, const Dims<4>& input_dims,
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index ef7953dded..5634b8384a 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -846,8 +846,8 @@ void GlobalBatchNormalization(const float* input_data,
}
}
-inline void Relu(const RuntimeShape& input_shape, const float* input_data,
- const RuntimeShape& output_shape, float* output_data) {
+inline void Relu(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
const float val = input_data[i];
@@ -857,8 +857,8 @@ inline void Relu(const RuntimeShape& input_shape, const float* input_data,
}
}
-inline void Relu1(const RuntimeShape& input_shape, const float* input_data,
- const RuntimeShape& output_shape, float* output_data) {
+inline void Relu1(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
gemmlowp::ScopedProfilingLabel label("Relu1 (not fused)");
const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
@@ -870,8 +870,8 @@ inline void Relu1(const RuntimeShape& input_shape, const float* input_data,
}
}
-inline void Relu6(const RuntimeShape& input_shape, const float* input_data,
- const RuntimeShape& output_shape, float* output_data) {
+inline void Relu6(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
gemmlowp::ScopedProfilingLabel label("Relu6 (not fused)");
const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
@@ -3118,8 +3118,8 @@ inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
}
}
-inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
- const RuntimeShape& output_shape, float* output_data) {
+inline void Logistic(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
@@ -3167,8 +3167,8 @@ inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape,
}
}
-inline void Logistic(const RuntimeShape& input_shape, const int16* input_data,
- const RuntimeShape& output_shape, int16* output_data) {
+inline void Logistic(const int16* input_data, const RuntimeShape& input_shape,
+ int16* output_data, const RuntimeShape& output_shape) {
const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
@@ -3185,8 +3185,8 @@ inline void Logistic(const RuntimeShape& input_shape, const int16* input_data,
}
}
-inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
- const RuntimeShape& output_shape, float* output_data) {
+inline void Tanh(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
@@ -4070,24 +4070,21 @@ void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims,
}
template <typename T, typename Op>
-void MaximumMinimumBroadcast4DSlow(const RuntimeShape& input1_shape,
- const T* input1_data,
- const RuntimeShape& input2_shape,
- const T* input2_data,
- const RuntimeShape& output_shape,
- T* output_data, Op op) {
+void TensorFlowMaximumMinimum(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T* output_data, const Dims<4>& output_dims,
+ Op op) {
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
- &desc2);
+ NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
- for (int b = 0; b < output_shape.Dims(0); ++b) {
- for (int y = 0; y < output_shape.Dims(1); ++y) {
- for (int x = 0; x < output_shape.Dims(2); ++x) {
- for (int c = 0; c < output_shape.Dims(3); ++c) {
- auto out_idx = Offset(output_shape, b, y, x, c);
- auto in1_idx = SubscriptToIndex(desc1, b, y, x, c);
- auto in2_idx = SubscriptToIndex(desc2, b, y, x, c);
+ for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
+ for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
+ for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
+ for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
+ auto out_idx = Offset(output_dims, c, x, y, b);
+ auto in1_idx = SubscriptToIndex(desc1, c, x, y, b);
+ auto in2_idx = SubscriptToIndex(desc2, c, x, y, b);
auto in1_val = input1_data[in1_idx];
auto in2_val = input2_data[in2_idx];
output_data[out_idx] = op(in1_val, in2_val);
@@ -4097,20 +4094,9 @@ void MaximumMinimumBroadcast4DSlow(const RuntimeShape& input1_shape,
}
}
-template <typename T, typename Op>
-void TensorFlowMaximumMinimum(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T* output_data, const Dims<4>& output_dims,
- Op op) {
- MaximumMinimumBroadcast4DSlow(DimsToShape(input1_dims), input1_data,
- DimsToShape(input2_dims), input2_data,
- DimsToShape(output_dims), output_data, op);
-}
-
template <typename T1, typename T2, typename T3, typename Cmp>
-void ArgMinMax(const T3* axis, const RuntimeShape& input_shape,
- const T1* input_data, const RuntimeShape& output_shape,
- T2* output_data, const Cmp& cmp) {
+void ArgMinMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims,
+ T2* output_data, const Dims<4>& output_dims, const Cmp& cmp) {
// The current ArgMax implemention can only determine the index of the maximum
// value in the last dimension. So the axis argument is ignored.
@@ -4118,11 +4104,9 @@ void ArgMinMax(const T3* axis, const RuntimeShape& input_shape,
// 1). For the sake of simplicity, the output dimensions are equal to the
// input dimensions here. We enforce the constraint that the last dimension
// must always be 1.
- TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
- TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
- TFLITE_DCHECK_EQ(output_shape.Dims(3), 1);
- const int outer_size = MatchingFlatSizeSkipDim(input_shape, 3, output_shape);
- const int depth = input_shape.Dims(3);
+ TFLITE_DCHECK_EQ(ArraySize(output_dims, 0), 1);
+ const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
+ const int depth = ArraySize(input_dims, 0);
for (int i = 0; i < outer_size; ++i) {
auto min_max_value = input_data[i * depth];
@@ -4138,15 +4122,6 @@ void ArgMinMax(const T3* axis, const RuntimeShape& input_shape,
}
}
-// Legacy Dims<4> version.
-template <typename T1, typename T2, typename T3, typename Cmp>
-void ArgMinMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims,
- T2* output_data, const Dims<4>& output_dims, const Cmp& cmp) {
- ArgMinMax(axis, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
- output_data, cmp);
-}
-
-// Legacy.
// TODO(renjieliu): Remove this one.
template <typename T1, typename T2, typename T3>
void ArgMax(const T3* axis, const T1* input_data,
@@ -4279,26 +4254,16 @@ template <typename T>
using ComparisonFn = bool (*)(T, T);
template <typename T, ComparisonFn<T> F>
-inline void Comparison(const RuntimeShape& input1_shape, const T* input1_data,
- const RuntimeShape& input2_shape, const T* input2_data,
- const RuntimeShape& output_shape, bool* output_data) {
+inline void Comparison(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ bool* output_data, const Dims<4>& output_dims) {
const int64_t flatsize =
- MatchingFlatSize(input1_shape, input2_shape, output_shape);
+ MatchingFlatSize(input1_dims, input2_dims, output_dims);
for (int64_t i = 0; i < flatsize; ++i) {
output_data[i] = F(input1_data[i], input2_data[i]);
}
}
-// Legacy Dims<4> version.
-template <typename T, ComparisonFn<T> F>
-inline void Comparison(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- bool* output_data, const Dims<4>& output_dims) {
- Comparison<T, F>(DimsToShape(input1_dims), input1_data,
- DimsToShape(input2_dims), input2_data,
- DimsToShape(output_dims), output_data);
-}
-
template <typename T, ComparisonFn<int32> F>
inline void Comparison(int left_shift, const T* input1_data,
const Dims<4>& input1_dims, int32 input1_offset,
@@ -4509,156 +4474,69 @@ inline void SparseToDense(const std::vector<std::vector<TI>>& indices,
}
template <typename T>
-inline void Pow(const RuntimeShape& input1_shape, const T* input1_data,
- const RuntimeShape& input2_shape, const T* input2_data,
- const RuntimeShape& output_shape, T* output_data) {
- const int flat_size =
- MatchingFlatSize(input1_shape, input2_shape, output_shape);
- for (int i = 0; i < flat_size; ++i) {
- output_data[i] = std::pow(input1_data[i], input2_data[i]);
- }
-}
-
-// Legacy Dims<4> version.
-template <typename T>
inline void Pow(const T* input1_data, const Dims<4>& input1_dims,
const T* input2_data, const Dims<4>& input2_dims,
T* output_data, const Dims<4>& output_dims) {
- Pow(DimsToShape(input1_dims), input1_data, DimsToShape(input2_dims),
- input2_data, DimsToShape(output_dims), output_data);
+ const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
+ for (int i = 0; i < flat_size; ++i) {
+ output_data[i] = std::pow(input1_data[i], input2_data[i]);
+ }
}
template <typename T>
-inline void BroadcastPow4DSlow(const RuntimeShape& input1_shape,
- const T* input1_data,
- const RuntimeShape& input2_shape,
- const T* input2_data,
- const RuntimeShape& output_shape,
- T* output_data) {
+inline void BroadcastPow(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T* output_data, const Dims<4>& output_dims) {
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
- &desc2);
-
- for (int b = 0; b < output_shape.Dims(0); ++b) {
- for (int y = 0; y < output_shape.Dims(1); ++y) {
- for (int x = 0; x < output_shape.Dims(2); ++x) {
- for (int c = 0; c < output_shape.Dims(3); ++c) {
- auto out_idx = Offset(output_shape, b, y, x, c);
- auto in1_idx = SubscriptToIndex(desc1, b, y, x, c);
- auto in2_idx = SubscriptToIndex(desc2, b, y, x, c);
- auto in1_val = input1_data[in1_idx];
- auto in2_val = input2_data[in2_idx];
- output_data[out_idx] = std::pow(in1_val, in2_val);
+ NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+ for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
+ for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
+ for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
+ for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
+ output_data[Offset(output_dims, c, x, y, b)] =
+ std::pow(input1_data[SubscriptToIndex(desc1, c, x, y, b)],
+ input2_data[SubscriptToIndex(desc2, c, x, y, b)]);
}
}
}
}
}
-// Legacy Dims<4> version.
-template <typename T>
-inline void BroadcastPow(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T* output_data, const Dims<4>& output_dims) {
- BroadcastPow4DSlow(DimsToShape(input1_dims), input1_data,
- DimsToShape(input2_dims), input2_data,
- DimsToShape(output_dims), output_data);
-}
-
-inline void Logical(const RuntimeShape& input1_shape, const bool* input1_data,
- const RuntimeShape& input2_shape, const bool* input2_data,
- const RuntimeShape& output_shape, bool* output_data,
- const std::function<bool(bool, bool)>& func) {
- const int flat_size =
- MatchingFlatSize(input1_shape, input2_shape, output_shape);
- for (int i = 0; i < flat_size; ++i) {
- output_data[i] = func(input1_data[i], input2_data[i]);
- }
-}
-
-// Legacy Dims<4> version.
inline void Logical(const bool* input1_data, const Dims<4>& input1_dims,
const bool* input2_data, const Dims<4>& input2_dims,
bool* output_data, const Dims<4>& output_dims,
const std::function<bool(bool, bool)>& func) {
- Logical(DimsToShape(input1_dims), input1_data, DimsToShape(input2_dims),
- input2_data, DimsToShape(output_dims), output_data, func);
-}
-
-inline void BroadcastLogical4DSlow(
- const RuntimeShape& input1_shape, const bool* input1_data,
- const RuntimeShape& input2_shape, const bool* input2_data,
- const RuntimeShape& output_shape, bool* output_data,
- const std::function<bool(bool, bool)>& func) {
- NdArrayDesc<4> desc1;
- NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
- &desc2);
-
- for (int b = 0; b < output_shape.Dims(0); ++b) {
- for (int y = 0; y < output_shape.Dims(1); ++y) {
- for (int x = 0; x < output_shape.Dims(2); ++x) {
- for (int c = 0; c < output_shape.Dims(3); ++c) {
- auto out_idx = Offset(output_shape, b, y, x, c);
- auto in1_idx = SubscriptToIndex(desc1, b, y, x, c);
- auto in2_idx = SubscriptToIndex(desc2, b, y, x, c);
- auto in1_val = input1_data[in1_idx];
- auto in2_val = input2_data[in2_idx];
- output_data[out_idx] = func(in1_val, in2_val);
- }
- }
- }
+ const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
+ for (int i = 0; i < flat_size; ++i) {
+ output_data[i] = func(input1_data[i], input2_data[i]);
}
}
-// Legacy Dims<4> version.
inline void BroadcastLogical(const bool* input1_data,
const Dims<4>& input1_dims,
const bool* input2_data,
const Dims<4>& input2_dims, bool* output_data,
const Dims<4>& output_dims,
const std::function<bool(bool, bool)>& func) {
- BroadcastLogical4DSlow(DimsToShape(input1_dims), input1_data,
- DimsToShape(input2_dims), input2_data,
- DimsToShape(output_dims), output_data, func);
-}
-
-// TODO(ycling): Refactoring. Remove BroadcastLogical and use the more
-// generalized and efficient BroadcastBinaryFunction.
-//
-// Also appears to duplicte MinimumMaximum.
-//
-// R: Result type. T1: Input 1 type. T2: Input 2 type.
-template <typename R, typename T1, typename T2>
-inline void BroadcastBinaryFunction4DSlow(const RuntimeShape& input1_shape,
- const T1* input1_data,
- const RuntimeShape& input2_shape,
- const T2* input2_data,
- const RuntimeShape& output_shape,
- R* output_data, R (*func)(T1, T2)) {
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
- &desc2);
-
- for (int b = 0; b < output_shape.Dims(0); ++b) {
- for (int y = 0; y < output_shape.Dims(1); ++y) {
- for (int x = 0; x < output_shape.Dims(2); ++x) {
- for (int c = 0; c < output_shape.Dims(3); ++c) {
- auto out_idx = Offset(output_shape, b, y, x, c);
- auto in1_idx = SubscriptToIndex(desc1, b, y, x, c);
- auto in2_idx = SubscriptToIndex(desc2, b, y, x, c);
- auto in1_val = input1_data[in1_idx];
- auto in2_val = input2_data[in2_idx];
- output_data[out_idx] = func(in1_val, in2_val);
+ NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+ for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
+ for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
+ for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
+ for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
+ output_data[Offset(output_dims, c, x, y, b)] =
+ func(input1_data[SubscriptToIndex(desc1, c, x, y, b)],
+ input2_data[SubscriptToIndex(desc2, c, x, y, b)]);
}
}
}
}
}
-// Legacy Dims<4> version.
+// TODO(ycling): Refactoring. Remove BroadcastLogical and use the more
+// generalized and efficient BroadcastBinaryFunction.
//
// R: Result type. T1: Input 1 type. T2: Input 2 type.
template <typename R, typename T1, typename T2>
@@ -4668,9 +4546,20 @@ inline void BroadcastBinaryFunction(const T1* input1_data,
const Dims<4>& input2_dims, R* output_data,
const Dims<4>& output_dims,
R (*func)(T1, T2)) {
- BroadcastBinaryFunction4DSlow(DimsToShape(input1_dims), input1_data,
- DimsToShape(input2_dims), input2_data,
- DimsToShape(output_dims), output_data, func);
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+ for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
+ for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
+ for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
+ for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
+ output_data[Offset(output_dims, c, x, y, b)] =
+ func(input1_data[SubscriptToIndex(desc1, c, x, y, b)],
+ input2_data[SubscriptToIndex(desc2, c, x, y, b)]);
+ }
+ }
+ }
+ }
}
} // namespace reference_ops