aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-30 11:17:57 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-30 11:23:53 -0700
commit9e12f1df3270b5e0b310645e6c3cae9fbd3f5dfc (patch)
tree6fb67b08ce4747aaf27f40d71a42edab04ea176c /tensorflow/contrib/lite
parent35bae087dce1e88c66007907f9e1b6b5b2958f10 (diff)
Consolidate refactoring of runtime shapes.
PiperOrigin-RevId: 210945714
Diffstat (limited to 'tensorflow/contrib/lite')
-rw-r--r--tensorflow/contrib/lite/kernels/activations.cc8
-rw-r--r--tensorflow/contrib/lite/kernels/arg_min_max.cc11
-rw-r--r--tensorflow/contrib/lite/kernels/batch_to_space_nd.cc12
-rw-r--r--tensorflow/contrib/lite/kernels/floor.cc5
-rw-r--r--tensorflow/contrib/lite/kernels/floor_div.cc14
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h18
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h36
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h107
-rw-r--r--tensorflow/contrib/lite/kernels/internal/resize_bilinear_test.cc30
-rw-r--r--tensorflow/contrib/lite/kernels/l2norm.cc20
-rw-r--r--tensorflow/contrib/lite/kernels/local_response_norm.cc14
-rw-r--r--tensorflow/contrib/lite/kernels/logical.cc14
-rw-r--r--tensorflow/contrib/lite/kernels/maximum_minimum.cc9
-rw-r--r--tensorflow/contrib/lite/kernels/mul.cc75
-rw-r--r--tensorflow/contrib/lite/kernels/pad.cc22
-rw-r--r--tensorflow/contrib/lite/kernels/pow.cc14
-rw-r--r--tensorflow/contrib/lite/kernels/resize_bilinear.cc12
-rw-r--r--tensorflow/contrib/lite/kernels/slice.cc26
-rw-r--r--tensorflow/contrib/lite/kernels/space_to_batch_nd.cc14
-rw-r--r--tensorflow/contrib/lite/kernels/space_to_depth.cc10
20 files changed, 286 insertions, 185 deletions
diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/contrib/lite/kernels/activations.cc
index d6d62580e2..9c891fe904 100644
--- a/tensorflow/contrib/lite/kernels/activations.cc
+++ b/tensorflow/contrib/lite/kernels/activations.cc
@@ -590,10 +590,10 @@ TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) {
input->type);
return kTfLiteError;
}
- reference_ops::BroadcastBinaryFunction<float, float, float>(
- GetTensorData<float>(input), GetTensorDims(input),
- GetTensorData<float>(alpha), GetTensorDims(alpha),
- GetTensorData<float>(output), GetTensorDims(output), ApplyPrelu<float>);
+ reference_ops::BroadcastBinaryFunction4DSlow<float, float, float>(
+ GetTensorShape(input), GetTensorData<float>(input), GetTensorShape(alpha),
+ GetTensorData<float>(alpha), GetTensorShape(output),
+ GetTensorData<float>(output), ApplyPrelu<float>);
return kTfLiteOk;
}
diff --git a/tensorflow/contrib/lite/kernels/arg_min_max.cc b/tensorflow/contrib/lite/kernels/arg_min_max.cc
index 4f30d09030..6e05f5a9b2 100644
--- a/tensorflow/contrib/lite/kernels/arg_min_max.cc
+++ b/tensorflow/contrib/lite/kernels/arg_min_max.cc
@@ -96,11 +96,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) {
const TfLiteTensor* axis = GetInput(context, node, kAxis);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
-#define TF_LITE_ARG_MIN_MAX(data_type, axis_type, output_type) \
- optimized_ops::ArgMinMax( \
- GetTensorData<axis_type>(axis), GetTensorData<data_type>(input), \
- GetTensorDims(input), GetTensorData<output_type>(output), \
- GetTensorDims(output), GetComparefunction<data_type>(is_arg_max))
+#define TF_LITE_ARG_MIN_MAX(data_type, axis_type, output_type) \
+ optimized_ops::ArgMinMax( \
+ GetTensorShape(input), GetTensorData<data_type>(input), \
+ GetTensorData<axis_type>(axis), GetTensorShape(output), \
+ GetTensorData<output_type>(output), \
+ GetComparefunction<data_type>(is_arg_max))
if (axis->type == kTfLiteInt32) {
switch (output->type) {
case kTfLiteInt32: {
diff --git a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc
index c8cee88edf..4efa9d596d 100644
--- a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc
+++ b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc
@@ -125,14 +125,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
#define TF_LITE_BATCH_TO_SPACE_ND(type, scalar) \
- type::BatchToSpaceND(GetTensorData<scalar>(op_context.input), \
- GetTensorDims(op_context.input), \
+ type::BatchToSpaceND(GetTensorShape(op_context.input), \
+ GetTensorData<scalar>(op_context.input), \
+ GetTensorShape(op_context.block_shape), \
GetTensorData<int32_t>(op_context.block_shape), \
- GetTensorDims(op_context.block_shape), \
+ GetTensorShape(op_context.crops), \
GetTensorData<int32_t>(op_context.crops), \
- GetTensorDims(op_context.crops), \
- GetTensorData<scalar>(op_context.output), \
- GetTensorDims(op_context.output))
+ GetTensorShape(op_context.output), \
+ GetTensorData<scalar>(op_context.output))
switch (op_context.input->type) { // Already know in/out types are same.
case kTfLiteFloat32:
if (kernel_type == kReference) {
diff --git a/tensorflow/contrib/lite/kernels/floor.cc b/tensorflow/contrib/lite/kernels/floor.cc
index 697b777693..f7d5f5146d 100644
--- a/tensorflow/contrib/lite/kernels/floor.cc
+++ b/tensorflow/contrib/lite/kernels/floor.cc
@@ -41,8 +41,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- optimized_ops::Floor(GetTensorData<float>(input), GetTensorDims(input),
- GetTensorData<float>(output), GetTensorDims(output));
+ optimized_ops::Floor(GetTensorShape(input), GetTensorData<float>(input),
+ GetTensorShape(output), GetTensorData<float>(output));
+
return kTfLiteOk;
}
} // namespace floor
diff --git a/tensorflow/contrib/lite/kernels/floor_div.cc b/tensorflow/contrib/lite/kernels/floor_div.cc
index 3c177ea330..75cf19a5a7 100644
--- a/tensorflow/contrib/lite/kernels/floor_div.cc
+++ b/tensorflow/contrib/lite/kernels/floor_div.cc
@@ -97,15 +97,15 @@ TfLiteStatus EvalImpl(TfLiteContext* context, bool requires_broadcast,
}
}
if (requires_broadcast) {
- reference_ops::BroadcastBinaryFunction<T, T, T>(
- GetTensorData<T>(input1), GetTensorDims(input1), denominator_data,
- GetTensorDims(input2), GetTensorData<T>(output), GetTensorDims(output),
- FloorDiv<T>);
+ reference_ops::BroadcastBinaryFunction4DSlow<T, T, T>(
+ GetTensorShape(input1), GetTensorData<T>(input1),
+ GetTensorShape(input2), denominator_data, GetTensorShape(output),
+ GetTensorData<T>(output), FloorDiv<T>);
} else {
reference_ops::BinaryFunction<T, T, T>(
- GetTensorData<T>(input1), GetTensorDims(input1),
- GetTensorData<T>(input2), GetTensorDims(input2),
- GetTensorData<T>(output), GetTensorDims(output), FloorDiv<T>);
+ GetTensorShape(input1), GetTensorData<T>(input1),
+ GetTensorShape(input2), GetTensorData<T>(input2),
+ GetTensorShape(output), GetTensorData<T>(output), FloorDiv<T>);
}
return kTfLiteOk;
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h
index df4d871466..332e7f803b 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h
@@ -296,13 +296,17 @@ inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
int output_shift, int32 output_activation_min,
int32 output_activation_max, uint8* output_data,
const Dims<4>& output_dims) {
- BroadcastMul4DSlow(
- input1_data, input1_dims, input1_offset, input2_data, input2_dims,
- input2_offset, output_offset, output_multiplier,
- // This legacy version switches the sign of the output shift.
- kReverseShift * output_shift,
- // (Break to highlight preceding line.)
- output_activation_min, output_activation_max, output_data, output_dims);
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(output_activation_min, output_activation_max, &op_params);
+ op_params.input1_offset = input1_offset;
+ op_params.input2_offset = input2_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ op_params.output_shift = kReverseShift * output_shift;
+
+ BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
}
// legacy, for compatibility with old checked-in code
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
index e4bb4e0534..c7ee65d63a 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
@@ -5586,18 +5586,15 @@ inline void ResizeBilinearGenericSmallChannel(
inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
const RuntimeShape& unextended_input_shape,
const float* input_data,
- const RuntimeShape& unextended_output_size_shape,
+ const RuntimeShape& output_size_shape,
const int32* output_size_data,
const RuntimeShape& unextended_output_shape,
float* output_data) {
gemmlowp::ScopedProfilingLabel label("ResizeBilinear");
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
- TFLITE_DCHECK_LE(unextended_output_size_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
RuntimeShape input_shape =
RuntimeShape::ExtendedShape(4, unextended_input_shape);
- RuntimeShape output_size_shape =
- RuntimeShape::ExtendedShape(4, unextended_output_size_shape);
RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
@@ -5606,12 +5603,9 @@ inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
int32 input_width = input_shape.Dims(2);
int32 depth = MatchingDim(input_shape, 3, output_shape, 3);
- TFLITE_DCHECK_EQ(output_size_shape.Dims(0), 1);
- TFLITE_DCHECK_EQ(output_size_shape.Dims(1), 1);
- TFLITE_DCHECK_EQ(output_size_shape.Dims(2), 1);
- TFLITE_DCHECK_EQ(output_size_shape.Dims(3), 2);
- int32 output_height = output_size_data[Offset(output_size_shape, 0, 0, 0, 0)];
- int32 output_width = output_size_data[Offset(output_size_shape, 0, 0, 0, 1)];
+ TFLITE_DCHECK_EQ(output_size_shape.FlatSize(), 2);
+ int32 output_height = output_size_data[0];
+ int32 output_width = output_size_data[1];
// Specialize for 2x2 upsample.
if (!op_params.align_corners && output_height == 2 * input_height &&
@@ -5651,28 +5645,28 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
// TODO(prabhumk): This is not a real quantized bilinear. It does not use int8
// or int16 arithmetic.
inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
- const RuntimeShape& input_shape,
+ const RuntimeShape& unextended_input_shape,
const uint8* input_data,
const RuntimeShape& output_size_shape,
const int32* output_size_data,
- const RuntimeShape& output_shape,
+ const RuntimeShape& unextended_output_shape,
uint8* output_data) {
gemmlowp::ScopedProfilingLabel label("ResizeBilinear");
- TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
- TFLITE_DCHECK_EQ(output_size_shape.DimensionsCount(), 4);
- TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
int32 input_height = input_shape.Dims(1);
int32 input_width = input_shape.Dims(2);
int32 depth = MatchingDim(input_shape, 3, output_shape, 3);
- TFLITE_DCHECK_EQ(output_size_shape.Dims(0), 1);
- TFLITE_DCHECK_EQ(output_size_shape.Dims(1), 1);
- TFLITE_DCHECK_EQ(output_size_shape.Dims(2), 1);
- TFLITE_DCHECK_EQ(output_size_shape.Dims(3), 2);
- int32 output_height = output_size_data[Offset(output_size_shape, 0, 0, 0, 0)];
- int32 output_width = output_size_data[Offset(output_size_shape, 0, 0, 0, 1)];
+ TFLITE_DCHECK_EQ(output_size_shape.FlatSize(), 2);
+ int32 output_height = output_size_data[0];
+ int32 output_width = output_size_data[1];
float height_scale =
(op_params.align_corners && output_height > 1)
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index 3875b73e05..5f84c737eb 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -4421,16 +4421,22 @@ void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims,
}
template <typename T, typename Op>
-void MaximumMinimumBroadcast4DSlow(const RuntimeShape& input1_shape,
+void MaximumMinimumBroadcast4DSlow(const RuntimeShape& unextended_input1_shape,
const T* input1_data,
- const RuntimeShape& input2_shape,
+ const RuntimeShape& unextended_input2_shape,
const T* input2_data,
- const RuntimeShape& output_shape,
+ const RuntimeShape& unextended_output_shape,
T* output_data, Op op) {
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
- &desc2);
+ NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+ unextended_input2_shape, &desc1, &desc2);
for (int b = 0; b < output_shape.Dims(0); ++b) {
for (int y = 0; y < output_shape.Dims(1); ++y) {
@@ -4459,8 +4465,8 @@ void TensorFlowMaximumMinimum(const T* input1_data, const Dims<4>& input1_dims,
}
template <typename T1, typename T2, typename T3, typename Cmp>
-void ArgMinMax(const T3* axis, const RuntimeShape& input_shape,
- const T1* input_data, const RuntimeShape& output_shape,
+void ArgMinMax(const RuntimeShape& input1_shape, const T1* input1_data,
+ const T3* input2_data, const RuntimeShape& output_shape,
T2* output_data, const Cmp& cmp) {
// The current ArgMax implemention can only determine the index of the maximum
// value in the last dimension. So the axis argument is ignored.
@@ -4469,17 +4475,19 @@ void ArgMinMax(const T3* axis, const RuntimeShape& input_shape,
// 1). For the sake of simplicity, the output dimensions are equal to the
// input dimensions here. We enforce the constraint that the last dimension
// must always be 1.
- TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
- TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
- TFLITE_DCHECK_EQ(output_shape.Dims(3), 1);
- const int outer_size = MatchingFlatSizeSkipDim(input_shape, 3, output_shape);
- const int depth = input_shape.Dims(3);
+ const int trailing_dim = output_shape.DimensionsCount() - 1;
+ TFLITE_DCHECK_EQ(input1_shape.DimensionsCount(),
+ output_shape.DimensionsCount());
+ TFLITE_DCHECK_EQ(output_shape.Dims(trailing_dim), 1);
+ const int outer_size =
+ MatchingFlatSizeSkipDim(input1_shape, trailing_dim, output_shape);
+ const int depth = input1_shape.Dims(trailing_dim);
for (int i = 0; i < outer_size; ++i) {
- auto min_max_value = input_data[i * depth];
+ auto min_max_value = input1_data[i * depth];
int min_max_index = 0;
for (int d = 1; d < depth; ++d) {
- const auto& curr_value = input_data[i * depth + d];
+ const auto& curr_value = input1_data[i * depth + d];
if (cmp(curr_value, min_max_value)) {
min_max_value = curr_value;
min_max_index = d;
@@ -4493,12 +4501,19 @@ void ArgMinMax(const T3* axis, const RuntimeShape& input_shape,
template <typename T1, typename T2, typename T3, typename Cmp>
void ArgMinMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims,
T2* output_data, const Dims<4>& output_dims, const Cmp& cmp) {
- ArgMinMax(axis, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ ArgMinMax(DimsToShape(input_dims), input_data, axis, DimsToShape(output_dims),
output_data, cmp);
}
+template <typename T1, typename T2, typename T3>
+void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
+ const T3* input2_data, const RuntimeShape& output_shape,
+ T2* output_data) {
+ ArgMinMax(input1_shape, input1_data, input2_data, output_shape, output_data,
+ std::greater<T1>());
+}
+
// Legacy.
-// TODO(renjieliu): Remove this one.
template <typename T1, typename T2, typename T3>
void ArgMax(const T3* axis, const T1* input_data,
const tflite::Dims<4>& input_dims, T2* output_data,
@@ -4938,14 +4953,20 @@ inline void Logical(const bool* input1_data, const Dims<4>& input1_dims,
}
inline void BroadcastLogical4DSlow(
- const RuntimeShape& input1_shape, const bool* input1_data,
- const RuntimeShape& input2_shape, const bool* input2_data,
- const RuntimeShape& output_shape, bool* output_data,
+ const RuntimeShape& unextended_input1_shape, const bool* input1_data,
+ const RuntimeShape& unextended_input2_shape, const bool* input2_data,
+ const RuntimeShape& unextended_output_shape, bool* output_data,
const std::function<bool(bool, bool)>& func) {
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
- &desc2);
+ NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+ unextended_input2_shape, &desc1, &desc2);
for (int b = 0; b < output_shape.Dims(0); ++b) {
for (int y = 0; y < output_shape.Dims(1); ++y) {
@@ -4982,16 +5003,21 @@ inline void BroadcastLogical(const bool* input1_data,
//
// R: Result type. T1: Input 1 type. T2: Input 2 type.
template <typename R, typename T1, typename T2>
-inline void BroadcastBinaryFunction4DSlow(const RuntimeShape& input1_shape,
- const T1* input1_data,
- const RuntimeShape& input2_shape,
- const T2* input2_data,
- const RuntimeShape& output_shape,
- R* output_data, R (*func)(T1, T2)) {
+inline void BroadcastBinaryFunction4DSlow(
+ const RuntimeShape& unextended_input1_shape, const T1* input1_data,
+ const RuntimeShape& unextended_input2_shape, const T2* input2_data,
+ const RuntimeShape& unextended_output_shape, R* output_data,
+ R (*func)(T1, T2)) {
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
- &desc2);
+ NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+ unextended_input2_shape, &desc1, &desc2);
for (int b = 0; b < output_shape.Dims(0); ++b) {
for (int y = 0; y < output_shape.Dims(1); ++y) {
@@ -5024,6 +5050,22 @@ inline void BroadcastBinaryFunction(const T1* input1_data,
DimsToShape(output_dims), output_data, func);
}
+// R: Result type. T1: Input 1 type. T2: Input 2 type.
+// TODO(renjieliu): Refactor other binary functions to use this one.
+template <typename R, typename T1, typename T2>
+inline void BinaryFunction(const RuntimeShape& input1_shape,
+ const T1* input1_data,
+ const RuntimeShape& input2_shape,
+ const T2* input2_data,
+ const RuntimeShape& output_shape, R* output_data,
+ R (*func)(T1, T2)) {
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
+ for (int i = 0; i < flat_size; ++i) {
+ output_data[i] = func(input1_data[i], input2_data[i]);
+ }
+}
+
// Legacy Dims<4> version.
//
// R: Result type. T1: Input 1 type. T2: Input 2 type.
@@ -5033,10 +5075,9 @@ inline void BinaryFunction(const T1* input1_data, const Dims<4>& input1_dims,
const T2* input2_data, const Dims<4>& input2_dims,
R* output_data, const Dims<4>& output_dims,
R (*func)(T1, T2)) {
- const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
- for (int i = 0; i < flat_size; ++i) {
- output_data[i] = func(input1_data[i], input2_data[i]);
- }
+ BinaryFunction(DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data, func);
}
} // namespace reference_ops
diff --git a/tensorflow/contrib/lite/kernels/internal/resize_bilinear_test.cc b/tensorflow/contrib/lite/kernels/internal/resize_bilinear_test.cc
index 3d8765f11b..15df31f75a 100644
--- a/tensorflow/contrib/lite/kernels/internal/resize_bilinear_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/resize_bilinear_test.cc
@@ -28,14 +28,12 @@ template <typename T>
void TestOneResizeBilinear(int batch, int depth, int input_width,
int input_height, int output_width,
int output_height, float error_threshold) {
- Dims<4> input_dims_inference =
- MakeDimsForInference(depth, input_width, input_height, batch);
- Dims<4> output_dims_inference =
- MakeDimsForInference(depth, output_width, output_height, batch);
+ RuntimeShape input_dims_inference({batch, input_height, input_width, depth});
+ RuntimeShape output_dims_inference(
+ {batch, output_height, output_width, depth});
- const int input_buffer_size = RequiredBufferSizeForDims(input_dims_inference);
- const int output_buffer_size =
- RequiredBufferSizeForDims(output_dims_inference);
+ const int input_buffer_size = input_dims_inference.FlatSize();
+ const int output_buffer_size = output_dims_inference.FlatSize();
std::vector<T> input_data(input_buffer_size, 0);
std::vector<T> reference_output_data(output_buffer_size, 0);
@@ -47,15 +45,19 @@ void TestOneResizeBilinear(int batch, int depth, int input_width,
const T max_amplitude = static_cast<T>(255);
FillRandom(&input_data, min_amplitude, max_amplitude);
- Dims<4> output_size_dims = MakeDimsForInference(2, 1, 1, 1);
+ RuntimeShape output_size_dims({1, 1, 1, 2});
std::vector<int32> output_size_data = {output_height, output_width};
- reference_ops::ResizeBilinear(
- input_data.data(), input_dims_inference, output_size_data.data(),
- output_size_dims, reference_output_data.data(), output_dims_inference);
- optimized_ops::ResizeBilinear(input_data.data(), input_dims_inference,
- output_size_data.data(), output_size_dims,
- output_data.data(), output_dims_inference);
+ tflite::ResizeBilinearParams op_params;
+ op_params.align_corners = false;
+
+ reference_ops::ResizeBilinear(op_params, input_dims_inference,
+ input_data.data(), output_size_dims,
+ output_size_data.data(), output_dims_inference,
+ reference_output_data.data());
+ optimized_ops::ResizeBilinear(
+ op_params, input_dims_inference, input_data.data(), output_size_dims,
+ output_size_data.data(), output_dims_inference, output_data.data());
double sum_diff = 0;
float max_abs_val = 0;
diff --git a/tensorflow/contrib/lite/kernels/l2norm.cc b/tensorflow/contrib/lite/kernels/l2norm.cc
index a7b54c6b84..5b3536de0c 100644
--- a/tensorflow/contrib/lite/kernels/l2norm.cc
+++ b/tensorflow/contrib/lite/kernels/l2norm.cc
@@ -68,10 +68,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
if (output->type == kTfLiteFloat32) {
-#define TF_LITE_L2NORM(type) \
- type::L2Normalization<FusedActivationFunctionType::kNone>( \
- GetTensorData<float>(input), GetTensorShape(input), \
- GetTensorData<float>(output), GetTensorShape(output))
+#define TF_LITE_L2NORM(type) \
+ tflite::L2NormalizationParams op_params; \
+ op_params.input_zero_point = 0; \
+ type::L2Normalization(op_params, GetTensorShape(input), \
+ GetTensorData<float>(input), GetTensorShape(output), \
+ GetTensorData<float>(output))
if (kernel_type == kReference) {
TF_LITE_L2NORM(reference_ops);
@@ -81,10 +83,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
#undef TF_LITE_L2NORM
} else if (output->type == kTfLiteUInt8) {
-#define TF_LITE_L2NORM(type) \
- type::L2Normalization(GetTensorData<uint8>(input), GetTensorShape(input), \
- input->params.zero_point, \
- GetTensorData<uint8>(output), GetTensorShape(output))
+#define TF_LITE_L2NORM(type) \
+ tflite::L2NormalizationParams op_params; \
+ op_params.input_zero_point = input->params.zero_point; \
+ type::L2Normalization(op_params, GetTensorShape(input), \
+ GetTensorData<uint8>(input), GetTensorShape(output), \
+ GetTensorData<uint8>(output))
if (kernel_type == kReference) {
TF_LITE_L2NORM(reference_ops);
diff --git a/tensorflow/contrib/lite/kernels/local_response_norm.cc b/tensorflow/contrib/lite/kernels/local_response_norm.cc
index 36dca299d0..799c1528bd 100644
--- a/tensorflow/contrib/lite/kernels/local_response_norm.cc
+++ b/tensorflow/contrib/lite/kernels/local_response_norm.cc
@@ -64,11 +64,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
if (output->type == kTfLiteFloat32) {
-#define TF_LITE_LOCAL_RESPONSE_NORM(type) \
- type::LocalResponseNormalization( \
- GetTensorData<float>(input), GetTensorDims(input), params->radius, \
- params->bias, params->alpha, params->beta, GetTensorData<float>(output), \
- GetTensorDims(output))
+#define TF_LITE_LOCAL_RESPONSE_NORM(type) \
+ tflite::LocalResponseNormalizationParams op_params; \
+ op_params.range = params->radius; \
+ op_params.bias = params->bias; \
+ op_params.alpha = params->alpha; \
+ op_params.beta = params->beta; \
+ type::LocalResponseNormalization( \
+ op_params, GetTensorShape(input), GetTensorData<float>(input), \
+ GetTensorShape(output), GetTensorData<float>(output))
if (kernel_type == kReference) {
TF_LITE_LOCAL_RESPONSE_NORM(reference_ops);
}
diff --git a/tensorflow/contrib/lite/kernels/logical.cc b/tensorflow/contrib/lite/kernels/logical.cc
index 87c2fee667..c71f3b4701 100644
--- a/tensorflow/contrib/lite/kernels/logical.cc
+++ b/tensorflow/contrib/lite/kernels/logical.cc
@@ -86,14 +86,14 @@ TfLiteStatus LogicalImpl(TfLiteContext* context, TfLiteNode* node,
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
if (data->requires_broadcast) {
- reference_ops::BroadcastLogical(
- GetTensorData<bool>(input1), GetTensorDims(input1),
- GetTensorData<bool>(input2), GetTensorDims(input2),
- GetTensorData<bool>(output), GetTensorDims(output), func);
+ reference_ops::BroadcastLogical4DSlow(
+ GetTensorShape(input1), GetTensorData<bool>(input1),
+ GetTensorShape(input2), GetTensorData<bool>(input2),
+ GetTensorShape(output), GetTensorData<bool>(output), func);
} else {
- reference_ops::Logical(GetTensorData<bool>(input1), GetTensorDims(input1),
- GetTensorData<bool>(input2), GetTensorDims(input2),
- GetTensorData<bool>(output), GetTensorDims(output),
+ reference_ops::Logical(GetTensorShape(input1), GetTensorData<bool>(input1),
+ GetTensorShape(input2), GetTensorData<bool>(input2),
+ GetTensorShape(output), GetTensorData<bool>(output),
func);
}
diff --git a/tensorflow/contrib/lite/kernels/maximum_minimum.cc b/tensorflow/contrib/lite/kernels/maximum_minimum.cc
index 8d676218bd..0308a3976a 100644
--- a/tensorflow/contrib/lite/kernels/maximum_minimum.cc
+++ b/tensorflow/contrib/lite/kernels/maximum_minimum.cc
@@ -86,13 +86,14 @@ struct MinimumOp {
template <typename data_type, typename op_type>
void TFLiteOperation(TfLiteContext* context, TfLiteNode* node,
const OpContext& op_context) {
- reference_ops::TensorFlowMaximumMinimum<data_type>(
+ reference_ops::MaximumMinimumBroadcast4DSlow(
+ GetTensorShape(op_context.input1),
GetTensorData<data_type>(op_context.input1),
- GetTensorDims(op_context.input1),
+ GetTensorShape(op_context.input2),
GetTensorData<data_type>(op_context.input2),
- GetTensorDims(op_context.input2),
+ GetTensorShape(op_context.output),
GetTensorData<data_type>(op_context.output),
- GetTensorDims(op_context.output), op_type::template op<data_type>);
+ op_type::template op<data_type>);
}
template <KernelType kernel_type, typename OpType>
diff --git a/tensorflow/contrib/lite/kernels/mul.cc b/tensorflow/contrib/lite/kernels/mul.cc
index 561e39cfc6..92d8bc8b67 100644
--- a/tensorflow/contrib/lite/kernels/mul.cc
+++ b/tensorflow/contrib/lite/kernels/mul.cc
@@ -102,24 +102,28 @@ template <KernelType kernel_type>
void EvalMul(TfLiteContext* context, TfLiteNode* node, TfLiteMulParams* params,
const OpData* data, const TfLiteTensor* input1,
const TfLiteTensor* input2, TfLiteTensor* output) {
-#define TF_LITE_MUL(type, opname, data_type) \
- data_type output_activation_min, output_activation_max; \
- CalculateActivationRange(params->activation, &output_activation_min, \
- &output_activation_max); \
- type::opname(GetTensorData<data_type>(input1), GetTensorDims(input1), \
- GetTensorData<data_type>(input2), GetTensorDims(input2), \
- output_activation_min, output_activation_max, \
- GetTensorData<data_type>(output), GetTensorDims(output))
+#define TF_LITE_MUL(type, opname, data_type) \
+ data_type output_activation_min, output_activation_max; \
+ CalculateActivationRange(params->activation, &output_activation_min, \
+ &output_activation_max); \
+ tflite::ArithmeticParams op_params; \
+ SetActivationParams(output_activation_min, output_activation_max, \
+ &op_params); \
+ type::opname(op_params, GetTensorShape(input1), \
+ GetTensorData<data_type>(input1), GetTensorShape(input2), \
+ GetTensorData<data_type>(input2), GetTensorShape(output), \
+ GetTensorData<data_type>(output))
+
if (output->type == kTfLiteInt32) {
if (kernel_type == kReference) {
if (data->requires_broadcast) {
- TF_LITE_MUL(reference_ops, BroadcastMul, int32_t);
+ TF_LITE_MUL(reference_ops, BroadcastMul4DSlow, int32_t);
} else {
TF_LITE_MUL(reference_ops, Mul, int32_t);
}
} else {
if (data->requires_broadcast) {
- TF_LITE_MUL(optimized_ops, BroadcastMul, int32_t);
+ TF_LITE_MUL(optimized_ops, BroadcastMul4DSlow, int32_t);
} else {
TF_LITE_MUL(optimized_ops, Mul, int32_t);
}
@@ -127,13 +131,13 @@ void EvalMul(TfLiteContext* context, TfLiteNode* node, TfLiteMulParams* params,
} else if (output->type == kTfLiteFloat32) {
if (kernel_type == kReference) {
if (data->requires_broadcast) {
- TF_LITE_MUL(reference_ops, BroadcastMul, float);
+ TF_LITE_MUL(reference_ops, BroadcastMul4DSlow, float);
} else {
TF_LITE_MUL(reference_ops, Mul, float);
}
} else {
if (data->requires_broadcast) {
- TF_LITE_MUL(optimized_ops, BroadcastMul, float);
+ TF_LITE_MUL(optimized_ops, BroadcastMul4DSlow, float);
} else {
TF_LITE_MUL(optimized_ops, Mul, float);
}
@@ -149,14 +153,20 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
const TfLiteTensor* input2, TfLiteTensor* output) {
if (input1->type == kTfLiteUInt8 && input2->type == kTfLiteUInt8 &&
output->type == kTfLiteUInt8) {
-#define TF_LITE_MUL(type, opname) \
- type::opname(GetTensorData<uint8_t>(input1), GetTensorDims(input1), \
- -input1->params.zero_point, GetTensorData<uint8_t>(input2), \
- GetTensorDims(input2), -input2->params.zero_point, \
- output->params.zero_point, data->output_multiplier, \
- data->output_shift, data->output_activation_min, \
- data->output_activation_max, GetTensorData<uint8_t>(output), \
- GetTensorDims(output));
+#define TF_LITE_MUL(type, opname) \
+ tflite::ArithmeticParams op_params; \
+ SetActivationParams(data->output_activation_min, \
+ data->output_activation_max, &op_params); \
+ op_params.input1_offset = -input1->params.zero_point; \
+ op_params.input2_offset = -input2->params.zero_point; \
+ op_params.output_offset = output->params.zero_point; \
+ op_params.output_multiplier = data->output_multiplier; \
+ op_params.output_shift = data->output_shift; \
+ type::opname(op_params, GetTensorShape(input1), \
+ GetTensorData<uint8_t>(input1), GetTensorShape(input2), \
+ GetTensorData<uint8_t>(input2), GetTensorShape(output), \
+ GetTensorData<uint8_t>(output))
+
// The quantized version of Mul doesn't support activations, so we
// always use BroadcastMul.
if (kernel_type == kReference) {
@@ -167,10 +177,12 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
#undef TF_LITE_MUL
} else if (input1->type == kTfLiteInt16 && input2->type == kTfLiteInt16 &&
output->type == kTfLiteInt16) {
-#define TF_LITE_MUL(type, opname) \
- type::opname(GetTensorData<int16_t>(input1), GetTensorDims(input1), \
- GetTensorData<int16_t>(input2), GetTensorDims(input2), \
- GetTensorData<int16_t>(output), GetTensorDims(output));
+#define TF_LITE_MUL(type, opname) \
+ tflite::ArithmeticParams op_params; \
+ type::opname(op_params, GetTensorShape(input1), \
+ GetTensorData<int16_t>(input1), GetTensorShape(input2), \
+ GetTensorData<int16_t>(input2), GetTensorShape(output), \
+ GetTensorData<int16_t>(output))
if (kernel_type == kReference) {
TF_LITE_MUL(reference_ops, Mul);
} else {
@@ -179,12 +191,15 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
#undef TF_LITE_MUL
} else if (input1->type == kTfLiteInt16 && input2->type == kTfLiteInt16 &&
output->type == kTfLiteUInt8) {
-#define TF_LITE_MUL(type, opname) \
- type::opname(GetTensorData<int16_t>(input1), GetTensorDims(input1), \
- GetTensorData<int16_t>(input2), GetTensorDims(input2), \
- output->params.zero_point, data->output_activation_min, \
- data->output_activation_max, GetTensorData<uint8_t>(output), \
- GetTensorDims(output));
+#define TF_LITE_MUL(type, opname) \
+ tflite::ArithmeticParams op_params; \
+ SetActivationParams(data->output_activation_min, \
+ data->output_activation_max, &op_params); \
+ op_params.output_offset = output->params.zero_point; \
+ type::opname(op_params, GetTensorShape(input1), \
+ GetTensorData<int16_t>(input1), GetTensorShape(input2), \
+ GetTensorData<int16_t>(input2), GetTensorShape(output), \
+ GetTensorData<uint8_t>(output))
if (kernel_type == kReference) {
TF_LITE_MUL(reference_ops, Mul);
} else {
diff --git a/tensorflow/contrib/lite/kernels/pad.cc b/tensorflow/contrib/lite/kernels/pad.cc
index 4be8c243c1..55bcf3b533 100644
--- a/tensorflow/contrib/lite/kernels/pad.cc
+++ b/tensorflow/contrib/lite/kernels/pad.cc
@@ -134,12 +134,22 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
after_padding.push_back(paddings_data[idx * 2 + 1]);
}
-#define TF_LITE_PAD(type, scalar, pad_value) \
- type::PadV2(GetTensorData<scalar>(op_context.input), \
- GetTensorDims(op_context.input), before_padding, after_padding, \
- GetTensorData<scalar>(op_context.output), \
- GetTensorDims(op_context.output), pad_value)
-
+#define TF_LITE_PAD(type, scalar, pad_value) \
+ TF_LITE_ENSURE_EQ(context, before_padding.size(), 4); \
+ TF_LITE_ENSURE_EQ(context, after_padding.size(), 4); \
+ tflite::PadParams op_params; \
+ op_params.left_padding_count = 4; \
+ op_params.right_padding_count = 4; \
+ for (int i = 0; i < 4; ++i) { \
+ op_params.left_padding[i] = before_padding[3 - i]; \
+ op_params.right_padding[i] = after_padding[3 - i]; \
+ } \
+ const scalar pad_value_copy = pad_value; \
+ \
+ type::Pad(op_params, GetTensorShape(op_context.input), \
+ GetTensorData<scalar>(op_context.input), &pad_value_copy, \
+ GetTensorShape(op_context.output), \
+ GetTensorData<scalar>(op_context.output))
switch (op_context.input->type) {
case kTfLiteFloat32: {
float pad_value = op_context.constant_values == nullptr
diff --git a/tensorflow/contrib/lite/kernels/pow.cc b/tensorflow/contrib/lite/kernels/pow.cc
index 4a539c47a8..d676de5b1d 100644
--- a/tensorflow/contrib/lite/kernels/pow.cc
+++ b/tensorflow/contrib/lite/kernels/pow.cc
@@ -80,14 +80,14 @@ template <typename T>
void PowImpl(const TfLiteTensor* input1, const TfLiteTensor* input2,
TfLiteTensor* output, bool requires_broadcast) {
if (requires_broadcast) {
- reference_ops::BroadcastPow(GetTensorData<T>(input1), GetTensorDims(input1),
- GetTensorData<T>(input2), GetTensorDims(input2),
- GetTensorData<T>(output),
- GetTensorDims(output));
+ reference_ops::BroadcastPow4DSlow(
+ GetTensorShape(input1), GetTensorData<T>(input1),
+ GetTensorShape(input2), GetTensorData<T>(input2),
+ GetTensorShape(output), GetTensorData<T>(output));
} else {
- reference_ops::Pow(GetTensorData<T>(input1), GetTensorDims(input1),
- GetTensorData<T>(input2), GetTensorDims(input2),
- GetTensorData<T>(output), GetTensorDims(output));
+ reference_ops::Pow(GetTensorShape(input1), GetTensorData<T>(input1),
+ GetTensorShape(input2), GetTensorData<T>(input2),
+ GetTensorShape(output), GetTensorData<T>(output));
}
}
diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear.cc b/tensorflow/contrib/lite/kernels/resize_bilinear.cc
index 86c4cd3ee8..dafa3aebab 100644
--- a/tensorflow/contrib/lite/kernels/resize_bilinear.cc
+++ b/tensorflow/contrib/lite/kernels/resize_bilinear.cc
@@ -88,11 +88,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
if (output->type == kTfLiteFloat32) {
-#define TF_LITE_RESIZE_BILINEAR(type, datatype) \
- type::ResizeBilinear(GetTensorData<datatype>(input), GetTensorDims(input), \
- GetTensorData<int32>(size), GetTensorDims(size), \
- GetTensorData<datatype>(output), GetTensorDims(output), \
- params->align_corners)
+#define TF_LITE_RESIZE_BILINEAR(type, datatype) \
+ tflite::ResizeBilinearParams op_params; \
+ op_params.align_corners = params->align_corners; \
+ type::ResizeBilinear(op_params, GetTensorShape(input), \
+ GetTensorData<datatype>(input), GetTensorShape(size), \
+ GetTensorData<int32>(size), GetTensorShape(output), \
+ GetTensorData<datatype>(output))
if (kernel_type == kReference) {
TF_LITE_RESIZE_BILINEAR(reference_ops, float);
diff --git a/tensorflow/contrib/lite/kernels/slice.cc b/tensorflow/contrib/lite/kernels/slice.cc
index 6a20e802a9..55e16506df 100644
--- a/tensorflow/contrib/lite/kernels/slice.cc
+++ b/tensorflow/contrib/lite/kernels/slice.cc
@@ -159,10 +159,28 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
sizes.push_back(1);
}
-#define TF_LITE_SLICE(data_type) \
- optimized_ops::Slice<data_type>( \
- GetTensorData<data_type>(input), GetTensorDims(input), begins, sizes, \
- GetTensorData<data_type>(output), GetTensorDims(output))
+ // The original Slice op implementation only accepted 4-D sizes. That
+ // constraint is, for the present, maintained here.
+ //
+ // The dimensions in the kernel used to be in reverse-order, and TFLite
+ // arranged the begins and sizes vectors accordingly. This macro incorporates
+ // the needed reversing.
+#define TF_LITE_SLICE(data_type) \
+ { \
+ TF_LITE_ENSURE_EQ(context, begins.size(), 4); \
+ TF_LITE_ENSURE_EQ(context, sizes.size(), 4); \
+ tflite::SliceParams op_params; \
+ op_params.begin_count = 4; \
+ op_params.size_count = 4; \
+ for (int i = 0; i < 4; ++i) { \
+ op_params.begin[i] = begins[3 - i]; \
+ op_params.size[i] = sizes[3 - i]; \
+ } \
+ \
+ optimized_ops::Slice<data_type>( \
+ op_params, GetTensorShape(input), GetTensorData<data_type>(input), \
+ GetTensorShape(output), GetTensorData<data_type>(output)); \
+ }
switch (input->type) {
case kTfLiteFloat32:
diff --git a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc
index 03079f1c3b..8332ae32cf 100644
--- a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc
+++ b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc
@@ -114,14 +114,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
#define TF_LITE_SPACE_TO_BATCH_ND(type, scalar, pad_value) \
- type::SpaceToBatchND(GetTensorData<scalar>(op_context.input), \
- GetTensorDims(op_context.input), \
+ tflite::SpaceToBatchParams op_params; \
+ op_params.output_offset = pad_value; \
+ type::SpaceToBatchND(op_params, GetTensorShape(op_context.input), \
+ GetTensorData<scalar>(op_context.input), \
+ GetTensorShape(op_context.block_shape), \
GetTensorData<int32_t>(op_context.block_shape), \
- GetTensorDims(op_context.block_shape), \
+ GetTensorShape(op_context.paddings), \
GetTensorData<int32_t>(op_context.paddings), \
- GetTensorDims(op_context.paddings), \
- GetTensorData<scalar>(op_context.output), \
- GetTensorDims(op_context.output), pad_value)
+ GetTensorShape(op_context.output), \
+ GetTensorData<scalar>(op_context.output))
switch (op_context.input->type) { // Already know in/out types are same.
case kTfLiteFloat32:
if (kernel_type == kReference) {
diff --git a/tensorflow/contrib/lite/kernels/space_to_depth.cc b/tensorflow/contrib/lite/kernels/space_to_depth.cc
index 9dbe9b9eda..9238e879f8 100644
--- a/tensorflow/contrib/lite/kernels/space_to_depth.cc
+++ b/tensorflow/contrib/lite/kernels/space_to_depth.cc
@@ -79,10 +79,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
-#define TF_LITE_SPACE_TO_DEPTH(type, scalar) \
- type::SpaceToDepth<scalar>( \
- GetTensorData<scalar>(input), GetTensorDims(input), params->block_size, \
- GetTensorData<scalar>(output), GetTensorDims(output))
+#define TF_LITE_SPACE_TO_DEPTH(type, scalar) \
+ tflite::SpaceToDepthParams op_params; \
+ op_params.block_size = params->block_size; \
+ type::SpaceToDepth(op_params, GetTensorShape(input), \
+ GetTensorData<scalar>(input), GetTensorShape(output), \
+ GetTensorData<scalar>(output))
switch (input->type) { // Already know in/out types are same.
case kTfLiteFloat32:
if (kernel_type == kReference) {