aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h')
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h425
1 files changed, 340 insertions, 85 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index f4176e474e..cb254f36cc 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -105,6 +105,11 @@ namespace reference_ops {
// Used mainly to convert from old-style shifts (right) to new-style (left).
static constexpr int kReverseShift = -1;
+inline RuntimeShape DimsToShape(const tflite::Dims<4>& dims) {
+ return RuntimeShape(
+ {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]});
+}
+
template <typename T>
int CountLeadingZeros(T integer_input) {
static_assert(std::is_unsigned<T>::value,
@@ -271,12 +276,12 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
int32 input_offset, const uint8* filter_data,
const Dims<4>& filter_dims, int32 filter_offset,
const int32* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims, uint8* im2col_data,
- const Dims<4>& im2col_dims,
+ int stride_width, int stride_height, int dilation_width_factor,
+ int dilation_height_factor, int pad_width, int pad_height,
+ int32 output_offset, int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims,
+ uint8* im2col_data, const Dims<4>& im2col_dims,
gemmlowp::GemmContext* gemm_context) {
(void)im2col_data; // only used in optimized code.
(void)im2col_dims; // only used in optimized code.
@@ -302,8 +307,9 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
for (int in_channel = 0; in_channel < input_depth; ++in_channel) {
- const int in_x = in_x_origin + filter_x;
- const int in_y = in_y_origin + filter_y;
+ const int in_x = in_x_origin + dilation_width_factor * filter_x;
+ const int in_y =
+ in_y_origin + dilation_height_factor * filter_y;
// If the location is outside the bounds of the input image,
// use zero as a default value.
if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
@@ -335,6 +341,24 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
}
}
+inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims, uint8* im2col_data,
+ const Dims<4>& im2col_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride_width, stride_height, 1, 1,
+ pad_width, pad_height, output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data, output_dims,
+ im2col_data, im2col_dims, gemm_context);
+}
+
// legacy, for compatibility with old checked-in code
template <FusedActivationFunctionType Ac>
inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
@@ -1374,13 +1398,143 @@ void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
output_dims);
}
-inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
- int32 input1_offset, const uint8* input2_data,
- const Dims<4>& input2_dims, int32 input2_offset,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
+// Element-wise mul that can often be used for inner loop of broadcast Mul as
+// well as the non-broadcast Mul.
+inline void MulElementwise(int size, const ArithmeticParams& params,
+ const uint8* input1_data, const uint8* input2_data,
+ uint8* output_data) {
+ for (int i = 0; i < size; ++i) {
+ const int32 input1_val = params.input1_offset + input1_data[i];
+ const int32 input2_val = params.input2_offset + input2_data[i];
+ const int32 unclamped_result =
+ params.output_offset +
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(input1_val * input2_val,
+ params.output_multiplier,
+ params.output_shift);
+ const int32 clamped_output =
+ std::min(params.quantized_activation_max,
+ std::max(params.quantized_activation_min, unclamped_result));
+ output_data[i] = static_cast<uint8>(clamped_output);
+ }
+}
+
+inline void Mul(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const uint8* input1_data,
+ const RuntimeShape& input2_shape, const uint8* input2_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
+ TFLITE_DCHECK_LE(params.quantized_activation_min,
+ params.quantized_activation_max);
+ gemmlowp::ScopedProfilingLabel label("Mul/8bit");
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
+
+ MulElementwise(flat_size, params, input1_data, input2_data, output_data);
+}
+
+inline void BroadcastMulFivefold(const ArithmeticParams& unswitched_params,
+ const RuntimeShape& unswitched_input1_shape,
+ const uint8* unswitched_input1_data,
+ const RuntimeShape& unswitched_input2_shape,
+ const uint8* unswitched_input2_data,
+ const RuntimeShape& output_shape,
+ uint8* output_data) {
+ ArithmeticParams switched_params = unswitched_params;
+ switched_params.input1_offset = unswitched_params.input2_offset;
+ switched_params.input2_offset = unswitched_params.input1_offset;
+
+ const bool use_unswitched =
+ unswitched_params.broadcast_category ==
+ tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
+
+ const ArithmeticParams& params =
+ use_unswitched ? unswitched_params : switched_params;
+ const uint8* input1_data =
+ use_unswitched ? unswitched_input1_data : unswitched_input2_data;
+ const uint8* input2_data =
+ use_unswitched ? unswitched_input2_data : unswitched_input1_data;
+
+ // Fivefold nested loops. The second input resets its position for each
+ // iteration of the second loop. The first input resets its position at the
+ // beginning of the fourth loop. The innermost loop is an elementwise Mul of
+ // sections of the arrays.
+ uint8* output_data_ptr = output_data;
+ const uint8* input1_data_ptr = input1_data;
+ const uint8* input2_data_reset = input2_data;
+ int y0 = params.broadcast_shape[0];
+ int y1 = params.broadcast_shape[1];
+ int y2 = params.broadcast_shape[2];
+ int y3 = params.broadcast_shape[3];
+ int y4 = params.broadcast_shape[4];
+ for (int i0 = 0; i0 < y0; ++i0) {
+ const uint8* input2_data_ptr;
+ for (int i1 = 0; i1 < y1; ++i1) {
+ input2_data_ptr = input2_data_reset;
+ for (int i2 = 0; i2 < y2; ++i2) {
+ for (int i3 = 0; i3 < y3; ++i3) {
+ MulElementwise(y4, params, input1_data_ptr, input2_data_ptr,
+ output_data_ptr);
+ input2_data_ptr += y4;
+ output_data_ptr += y4;
+ }
+ input1_data_ptr += y4;
+ }
+ }
+ input2_data_reset = input2_data_ptr;
+ }
+}
+
+inline void BroadcastMul4DSlow(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape,
+ const uint8* input1_data,
+ const RuntimeShape& input2_shape,
+ const uint8* input2_data,
+ const RuntimeShape& output_shape,
+ uint8* output_data) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastMul4DSlow/8bit");
+
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
+ &desc2);
+ RuntimeShape extended_output_shape =
+ RuntimeShape::ExtendedShape(4, output_shape);
+
+ for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
+ for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
+ for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
+ for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
+ const int32 input1_val =
+ params.input1_offset +
+ input1_data[SubscriptToIndex(desc1, b, y, x, c)];
+ const int32 input2_val =
+ params.input2_offset +
+ input2_data[SubscriptToIndex(desc2, b, y, x, c)];
+ const int32 unclamped_result =
+ params.output_offset +
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ input1_val * input2_val, params.output_multiplier,
+ params.output_shift);
+ const int32 clamped_output = std::min(
+ params.quantized_activation_max,
+ std::max(params.quantized_activation_min, unclamped_result));
+ output_data[Offset(extended_output_shape, b, y, x, c)] =
+ static_cast<uint8>(clamped_output);
+ }
+ }
+ }
+ }
+}
+
+// Transitional version that will be moved shortly to legacy_reference_ops, as
+// part of RuntimeShape revisions.
+inline void BroadcastMul4DSlow(const uint8* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ const uint8* input2_data,
+ const Dims<4>& input2_dims, int32 input2_offset,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
gemmlowp::ScopedProfilingLabel label("BroadcastMul/8bit");
NdArrayDesc<4> desc1;
@@ -1407,9 +1561,9 @@ inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
const int32 input2_val =
input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)];
const int32 unclamped_result =
- output_offset + MultiplyByQuantizedMultiplierSmallerThanOneExp(
- input1_val * input2_val, output_multiplier,
- kReverseShift * output_shift);
+ output_offset +
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ input1_val * input2_val, output_multiplier, output_shift);
const int32 clamped_output =
std::min(output_activation_max,
std::max(output_activation_min, unclamped_result));
@@ -1464,21 +1618,6 @@ inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
- int32 input1_offset, const uint8* input2_data,
- const Dims<4>& input2_dims, int32 input2_offset,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- BroadcastMul(input1_data, input1_dims, input1_offset, input2_data,
- input2_dims, input2_offset, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max,
- output_data, output_dims);
-}
-
// TODO(jiawen): We can implement BroadcastDiv on buffers of arbitrary
// dimensionality if the runtime code does a single loop over one dimension
// that handles broadcasting as the base case. The code generator would then
@@ -3370,28 +3509,50 @@ inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
}
}
-template <typename T>
-inline void PadV2(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& left_paddings,
- const std::vector<int>& right_paddings, T* output_data,
- const Dims<4>& output_dims, const T pad_value) {
- TFLITE_DCHECK_EQ(left_paddings.size(), 4);
- TFLITE_DCHECK_EQ(right_paddings.size(), 4);
-
- const int output_batch = ArraySize(output_dims, 3);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- const int output_depth = ArraySize(output_dims, 0);
-
- const int left_b_padding = left_paddings[3];
- const int left_h_padding = left_paddings[2];
- const int left_w_padding = left_paddings[1];
- const int left_d_padding = left_paddings[0];
-
- const int right_b_padding = right_paddings[3];
- const int right_h_padding = right_paddings[2];
- const int right_w_padding = right_paddings[1];
- const int right_d_padding = right_paddings[0];
+// There are two versions of pad: Pad and PadV2. In PadV2 there is a second
+// scalar input that provides the padding value. Therefore pad_value_ptr can be
+// equivalent to a simple input1_data. For Pad, it should point to a zero
+// value.
+//
+// Note that two typenames are required, so that T=P=int32 is considered a
+// specialization distinct from P=int32.
+template <typename T, typename P>
+inline void PadImpl(const tflite::PadParams& op_params,
+ const RuntimeShape& input_shape, const T* input_data,
+ const P* pad_value_ptr, const RuntimeShape& output_shape,
+ T* output_data) {
+ RuntimeShape ext_input_shape = RuntimeShape::ExtendedShape(4, input_shape);
+ RuntimeShape ext_output_shape = RuntimeShape::ExtendedShape(4, output_shape);
+ TFLITE_DCHECK_LE(op_params.left_padding_count, 4);
+ TFLITE_DCHECK_LE(op_params.right_padding_count, 4);
+
+ // Runtime calls are currently fixed at 4 dimensions. Copy inputs so
+ // we can pad them to 4 dims (yes, we are "padding the padding").
+ std::vector<int> left_padding_copy(4, 0);
+ for (int i = 0; i < op_params.left_padding_count; ++i) {
+ left_padding_copy[i] = op_params.left_padding[i];
+ }
+ std::vector<int> right_padding_copy(4, 0);
+ for (int i = 0; i < op_params.right_padding_count; ++i) {
+ right_padding_copy[i] = op_params.right_padding[i];
+ }
+
+ const int output_batch = ext_output_shape.Dims(0);
+ const int output_height = ext_output_shape.Dims(1);
+ const int output_width = ext_output_shape.Dims(2);
+ const int output_depth = ext_output_shape.Dims(3);
+
+ const int left_b_padding = left_padding_copy[0];
+ const int left_h_padding = left_padding_copy[1];
+ const int left_w_padding = left_padding_copy[2];
+ const int left_d_padding = left_padding_copy[3];
+
+ const int right_b_padding = right_padding_copy[0];
+ const int right_h_padding = right_padding_copy[1];
+ const int right_w_padding = right_padding_copy[2];
+ const int right_d_padding = right_padding_copy[3];
+
+ const T pad_value = *pad_value_ptr;
const T* in_ptr = input_data;
T* out_ptr = output_data;
@@ -3417,7 +3578,59 @@ inline void PadV2(const T* input_data, const Dims<4>& input_dims,
}
}
-// Legacy Pad() method that casts an int32_t to T before padding.
+template <typename T, typename P>
+inline void Pad(const tflite::PadParams& op_params,
+ const RuntimeShape& input_shape, const T* input_data,
+ const P* pad_value_ptr, const RuntimeShape& output_shape,
+ T* output_data) {
+ PadImpl(op_params, input_shape, input_data, pad_value_ptr, output_shape,
+ output_data);
+}
+
+// The second (pad-value) input can be int32 when, say, the first is uint8.
+template <typename T>
+inline void Pad(const tflite::PadParams& op_params,
+ const RuntimeShape& input_shape, const T* input_data,
+ const int32* pad_value_ptr, const RuntimeShape& output_shape,
+ T* output_data) {
+ const T converted_pad_value = static_cast<T>(*pad_value_ptr);
+ PadImpl(op_params, input_shape, input_data, &converted_pad_value,
+ output_shape, output_data);
+}
+
+// This version avoids conflicting template matching.
+template <>
+inline void Pad(const tflite::PadParams& op_params,
+ const RuntimeShape& input_shape, const int32* input_data,
+ const int32* pad_value_ptr, const RuntimeShape& output_shape,
+ int32* output_data) {
+ PadImpl(op_params, input_shape, input_data, pad_value_ptr, output_shape,
+ output_data);
+}
+
+// Legacy signature, function covered both Pad and PadV2.
+template <typename T>
+inline void PadV2(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& left_paddings,
+ const std::vector<int>& right_paddings, T* output_data,
+ const Dims<4>& output_dims, const T pad_value) {
+ TFLITE_DCHECK_EQ(left_paddings.size(), 4);
+ TFLITE_DCHECK_EQ(right_paddings.size(), 4);
+ tflite::PadParams op_params;
+ op_params.left_padding_count = 4;
+ op_params.right_padding_count = 4;
+ for (int i = 0; i < 4; ++i) {
+ op_params.left_padding[i] = left_paddings[3 - i];
+ op_params.right_padding[i] = right_paddings[3 - i];
+ }
+ // SetFloatOrInt(pad_value, &op_params.pad_value);
+ const T pad_value_copy = pad_value;
+
+ Pad(op_params, DimsToShape(input_dims), input_data, &pad_value_copy,
+ DimsToShape(output_dims), output_data);
+}
+
+// Old Pad that calls legacy PadV2.
template <typename T>
inline void Pad(const T* input_data, const Dims<4>& input_dims,
const std::vector<int>& left_paddings,
@@ -3428,13 +3641,15 @@ inline void Pad(const T* input_data, const Dims<4>& input_dims,
output_dims, converted_pad_value);
}
+// Old Pad that only padded with 0.
template <typename T>
inline void Pad(const T* input_data, const Dims<4>& input_dims,
const std::vector<int>& left_paddings,
const std::vector<int>& right_paddings, T* output_data,
const Dims<4>& output_dims) {
- Pad(input_data, input_dims, left_paddings, right_paddings, output_data,
- output_dims, 0);
+ const T pad_value = static_cast<T>(0);
+ PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
+ output_dims, pad_value);
}
template <typename T>
@@ -3491,31 +3706,39 @@ inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
}
template <typename T>
-inline void Slice(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& begin, const std::vector<int>& size,
- T* output_data, const Dims<4>& output_dims) {
- // TODO(dkalenichenko): This op only supports 4D tensors.
- TFLITE_DCHECK_EQ(begin.size(), 4);
- TFLITE_DCHECK_EQ(size.size(), 4);
- const int start_b = begin[3];
- const int stop_b =
- size[3] == -1 ? input_dims.sizes[3] - start_b : start_b + size[3];
- const int start_h = begin[2];
- const int stop_h =
- size[2] == -1 ? input_dims.sizes[2] - start_h : start_h + size[2];
- const int start_w = begin[1];
- const int stop_w =
- size[1] == -1 ? input_dims.sizes[1] - start_w : start_w + size[1];
- const int start_d = begin[0];
- const int stop_d =
- size[0] == -1 ? input_dims.sizes[0] - start_d : start_d + size[0];
+inline void Slice(const tflite::SliceParams& op_params,
+ const RuntimeShape& input_shape, const T* input_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ RuntimeShape ext_shape = RuntimeShape::ExtendedShape(4, input_shape);
+ // TODO(dkalenichenko): This op only supports 4D tensors or smaller.
+ TFLITE_DCHECK_LE(op_params.begin_count, 4);
+ TFLITE_DCHECK_LE(op_params.size_count, 4);
+ const int begin_count = op_params.begin_count;
+ const int size_count = op_params.size_count;
+ // We front-pad the begin and size vectors.
+ const int start_b = 4 - begin_count > 0 ? 0 : op_params.begin[0];
+ const int stop_b = (4 - size_count > 0 || op_params.size[0] == -1)
+ ? ext_shape.Dims(0) - start_b
+ : start_b + op_params.size[0];
+ const int start_h = begin_count < 3 ? 0 : op_params.begin[begin_count - 3];
+ const int stop_h = (size_count < 3 || op_params.size[size_count - 3] == -1)
+ ? ext_shape.Dims(1) - start_h
+ : start_h + op_params.size[size_count - 3];
+ const int start_w = begin_count < 2 ? 0 : op_params.begin[begin_count - 2];
+ const int stop_w = (size_count < 2 || op_params.size[size_count - 2] == -1)
+ ? ext_shape.Dims(2) - start_w
+ : start_w + op_params.size[size_count - 2];
+ const int start_d = begin_count < 1 ? 0 : op_params.begin[begin_count - 1];
+ const int stop_d = (size_count < 1 || op_params.size[size_count - 1] == -1)
+ ? ext_shape.Dims(3) - start_d
+ : start_d + op_params.size[size_count - 1];
T* out_ptr = output_data;
for (int in_b = start_b; in_b < stop_b; ++in_b) {
for (int in_h = start_h; in_h < stop_h; ++in_h) {
for (int in_w = start_w; in_w < stop_w; ++in_w) {
for (int in_d = start_d; in_d < stop_d; ++in_d) {
- *out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)];
+ *out_ptr++ = input_data[Offset(ext_shape, in_b, in_h, in_w, in_d)];
}
}
}
@@ -3523,6 +3746,22 @@ inline void Slice(const T* input_data, const Dims<4>& input_dims,
}
template <typename T>
+inline void Slice(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& begin, const std::vector<int>& size,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::SliceParams op_params;
+ op_params.begin_count = 4;
+ op_params.size_count = 4;
+ for (int i = 0; i < 4; ++i) {
+ op_params.begin[i] = begin[3 - i];
+ op_params.size[i] = size[3 - i];
+ }
+
+ Slice(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
inline void Exp(const T* input_data, const size_t num_elements,
T* output_data) {
for (size_t idx = 0; idx < num_elements; ++idx) {
@@ -3790,10 +4029,10 @@ inline void Mean(const T* input_data, const Dims<4>& input_dims,
}
template <typename T>
-void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, T* output_data,
- const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(output_dims, input1_dims);
+void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
+ const T* input2_data, const RuntimeShape& output_shape,
+ T* output_data) {
+ const int flat_size = MatchingFlatSize(input1_shape, output_shape);
auto min_value = input2_data[0];
for (int i = 0; i < flat_size; i++) {
@@ -3802,10 +4041,10 @@ void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims,
}
template <typename T>
-void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, T* output_data,
- const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(output_dims, input1_dims);
+void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
+ const T* input2_data, const RuntimeShape& output_shape,
+ T* output_data) {
+ const int flat_size = MatchingFlatSize(input1_shape, output_shape);
auto max_value = input2_data[0];
for (int i = 0; i < flat_size; i++) {
@@ -3813,6 +4052,22 @@ void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims,
}
}
+template <typename T>
+void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, T* output_data,
+ const Dims<4>& output_dims) {
+ Minimum(DimsToShape(input1_dims), input1_data, input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, T* output_data,
+ const Dims<4>& output_dims) {
+ Maximum(DimsToShape(input1_dims), input1_data, input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
template <typename T, typename Op>
void TensorFlowMaximumMinimum(const T* input1_data, const Dims<4>& input1_dims,
const T* input2_data, const Dims<4>& input2_dims,