diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-23 13:22:37 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-23 13:27:24 -0700 |
commit | d7682bb16f575eb0c4cbb1622d8098c592fed2b7 (patch) | |
tree | 4be203fa763c68dc5024719f77d3258f2c0001b0 /tensorflow/contrib/lite/kernels/internal | |
parent | ba2ccdf164a77e49d867da27bf31e5b4b7d1a08d (diff) |
Convert more kernel signatures to use runtime shapes.
PiperOrigin-RevId: 209988056
Diffstat (limited to 'tensorflow/contrib/lite/kernels/internal')
3 files changed, 289 insertions, 115 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h index 40160289c8..7319636bf5 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h @@ -2143,38 +2143,6 @@ void Conv(const uint8* input_data, const Dims<4>& input_dims, im2col_data, im2col_dims, gemm_context); } -template <typename T> -inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims, - int block_size, T* output_data, - const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("DepthToSpace"); - - const int input_depth = ArraySize(input_dims, 0); - const int input_width = ArraySize(input_dims, 1); - const int input_height = ArraySize(input_dims, 2); - - const int output_depth = ArraySize(output_dims, 0); - const int batch_size = ArraySize(output_dims, 3); - - // Number of continuous values that we can copy in one interation. - const int stride = block_size * output_depth; - - for (int batch = 0; batch < batch_size; ++batch) { - for (int in_h = 0; in_h < input_height; ++in_h) { - const T* input_ptr = input_data + Offset(input_dims, 0, 0, in_h, batch); - for (int offset_h = 0; offset_h < block_size; ++offset_h) { - const T* src = input_ptr; - for (int in_w = 0; in_w < input_width; ++in_w) { - memcpy(output_data, src, stride * sizeof(T)); - output_data += stride; - src += input_depth; - } - input_ptr += stride; - } - } - } -} - // legacy, for compatibility with old checked-in code template <FusedActivationFunctionType Ac, typename T> void Im2col(const T* input_data, const Dims<4>& input_dims, int stride, @@ -2250,25 +2218,87 @@ void ConvAsGemm(const uint8* input_data, const Dims<4>& input_dims, } template <typename T> -inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims, +inline void DepthToSpace(const tflite::DepthToSpaceParams& op_params, + const RuntimeShape& unextended_input_shape, + const T* input_data, + const RuntimeShape& unextended_output_shape, + T* output_data) { + gemmlowp::ScopedProfilingLabel label("DepthToSpace"); + + TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); + RuntimeShape input_shape = + RuntimeShape::ExtendedShape(4, unextended_input_shape); + RuntimeShape output_shape = + RuntimeShape::ExtendedShape(4, unextended_output_shape); + + const int input_depth = input_shape.Dims(3); + const int input_width = input_shape.Dims(2); + const int input_height = input_shape.Dims(1); + + const int output_depth = output_shape.Dims(3); + const int batch_size = output_shape.Dims(0); + + // Number of continuous values that we can copy in one interation. + const int stride = op_params.block_size * output_depth; + + for (int batch = 0; batch < batch_size; ++batch) { + for (int in_h = 0; in_h < input_height; ++in_h) { + const T* input_ptr = input_data + Offset(input_shape, batch, in_h, 0, 0); + for (int offset_h = 0; offset_h < op_params.block_size; ++offset_h) { + const T* src = input_ptr; + for (int in_w = 0; in_w < input_width; ++in_w) { + memcpy(output_data, src, stride * sizeof(T)); + output_data += stride; + src += input_depth; + } + input_ptr += stride; + } + } + } +} + +// Legacy Dims<4>. +template <typename T> +inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims, int block_size, T* output_data, const Dims<4>& output_dims) { + tflite::DepthToSpaceParams op_params; + op_params.block_size = block_size; + + DepthToSpace(op_params, DimsToShape(input_dims), input_data, + DimsToShape(output_dims), output_data); +} + +template <typename T> +inline void SpaceToDepth(const tflite::SpaceToDepthParams& op_params, + const RuntimeShape& unextended_input_shape, + const T* input_data, + const RuntimeShape& unextended_output_shape, + T* output_data) { gemmlowp::ScopedProfilingLabel label("SpaceToDepth"); - const int output_depth = ArraySize(output_dims, 0); - const int output_width = ArraySize(output_dims, 1); - const int output_height = ArraySize(output_dims, 2); + TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); + RuntimeShape input_shape = + RuntimeShape::ExtendedShape(4, unextended_input_shape); + RuntimeShape output_shape = + RuntimeShape::ExtendedShape(4, unextended_output_shape); - const int input_depth = ArraySize(input_dims, 0); - const int batch_size = ArraySize(input_dims, 3); + const int output_depth = output_shape.Dims(3); + const int output_width = output_shape.Dims(2); + const int output_height = output_shape.Dims(1); + + const int input_depth = input_shape.Dims(3); + const int batch_size = input_shape.Dims(0); // Number of continuous values that we can copy in one interation. - const int stride = block_size * input_depth; + const int stride = op_params.block_size * input_depth; for (int batch = 0; batch < batch_size; ++batch) { for (int out_h = 0; out_h < output_height; ++out_h) { - T* output_ptr = output_data + Offset(output_dims, 0, 0, out_h, batch); - for (int offset_h = 0; offset_h < block_size; ++offset_h) { + T* output_ptr = output_data + Offset(output_shape, batch, out_h, 0, 0); + for (int offset_h = 0; offset_h < op_params.block_size; ++offset_h) { T* dst = output_ptr; for (int out_w = 0; out_w < output_width; ++out_w) { memcpy(dst, input_data, stride * sizeof(T)); @@ -2281,6 +2311,18 @@ inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims, } } +// Legacy Dims<4>. +template <typename T> +inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims, + int block_size, T* output_data, + const Dims<4>& output_dims) { + tflite::SpaceToDepthParams op_params; + op_params.block_size = block_size; + + SpaceToDepth(op_params, DimsToShape(input_dims), input_data, + DimsToShape(output_dims), output_data); +} + template <FusedActivationFunctionType Ac> void NonGlobalBatchNormalization( const float* input_data, const Dims<4>& input_dims, const float* mean_data, @@ -5565,20 +5607,29 @@ inline void GetIndexRange(int spatial_index_dim, int block_shape_dim, } template <typename T> -inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims, - const int32* block_shape_data, - const Dims<4>& block_shape_dims, - const int32* crops_data, const Dims<4>& crops_dims, - T* output_data, const Dims<4>& output_dims) { +inline void BatchToSpaceND( + const RuntimeShape& unextended_input1_shape, const T* input1_data, + const RuntimeShape& unextended_input2_shape, const int32* block_shape_data, + const RuntimeShape& unextended_input3_shape, const int32* crops_data, + const RuntimeShape& unextended_output_shape, T* output_data) { gemmlowp::ScopedProfilingLabel label("BatchToSpaceND"); - const int output_batch_size = ArraySize(output_dims, 3); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); - const int input_batch_size = ArraySize(input_dims, 3); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int depth = ArraySize(input_dims, 0); + TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); + RuntimeShape input1_shape = + RuntimeShape::ExtendedShape(4, unextended_input1_shape); + RuntimeShape output_shape = + RuntimeShape::ExtendedShape(4, unextended_output_shape); + + const int output_width = output_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_batch_size = output_shape.Dims(0); + + const int depth = input1_shape.Dims(3); + const int input_width = input1_shape.Dims(2); + const int input_height = input1_shape.Dims(1); + const int input_batch_size = input1_shape.Dims(0); + const int block_shape_width = block_shape_data[1]; const int block_shape_height = block_shape_data[0]; const int crops_top = crops_data[0]; @@ -5613,14 +5664,28 @@ inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims, spatial_offset % block_shape_width - crops_left; TFLITE_DCHECK_GE(out_w, 0); TFLITE_DCHECK_LT(out_w, output_width); - T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_batch); - const T* in = input_data + Offset(input_dims, 0, in_w, in_h, in_batch); + T* out = output_data + Offset(output_shape, out_batch, out_h, out_w, 0); + const T* in = + input1_data + Offset(input1_shape, in_batch, in_h, in_w, 0); memcpy(out, in, depth * sizeof(T)); } } } } +// Legacy Dims<4>. +template <typename T> +inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims, + const int32* block_shape_data, + const Dims<4>& block_shape_dims, + const int32* crops_data, const Dims<4>& crops_dims, + T* output_data, const Dims<4>& output_dims) { + BatchToSpaceND(DimsToShape(input_dims), input_data, + DimsToShape(block_shape_dims), block_shape_data, + DimsToShape(crops_dims), crops_data, DimsToShape(output_dims), + output_data); +} + template <typename T> void TypedMemset(void* ptr, T value, size_t num) { // Optimization for common cases where memset() will suffice. diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index a6aef4fa29..020d8fdcf0 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -407,18 +407,29 @@ void Conv(const uint8* input_data, const Dims<4>& input_dims, } template <typename T> -inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims, - int block_size, T* output_data, - const Dims<4>& output_dims) { - const int input_depth = ArraySize(input_dims, 0); - const int input_width = ArraySize(input_dims, 1); - const int input_height = ArraySize(input_dims, 2); - const int input_batch = ArraySize(input_dims, 3); +inline void DepthToSpace(const tflite::DepthToSpaceParams& op_params, + const RuntimeShape& unextended_input_shape, + const T* input_data, + const RuntimeShape& unextended_output_shape, + T* output_data) { + TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); + RuntimeShape input_shape = + RuntimeShape::ExtendedShape(4, unextended_input_shape); + RuntimeShape output_shape = + RuntimeShape::ExtendedShape(4, unextended_output_shape); - const int output_depth = ArraySize(output_dims, 0); - const int output_width = ArraySize(output_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_batch = ArraySize(output_dims, 3); + const int input_depth = input_shape.Dims(3); + const int input_width = input_shape.Dims(2); + const int input_height = input_shape.Dims(1); + const int input_batch = input_shape.Dims(0); + + const int output_depth = output_shape.Dims(3); + const int output_width = output_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_batch = output_shape.Dims(0); + + const int32 block_size = op_params.block_size; TFLITE_DCHECK_EQ(input_width * block_size, output_width); TFLITE_DCHECK_EQ(input_height * block_size, output_height); @@ -437,9 +448,9 @@ inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims, const int in_h = out_h / block_size; const int in_b = out_b; + const int input_index = Offset(input_shape, in_b, in_h, in_w, in_d); const int output_index = - Offset(output_dims, out_d, out_w, out_h, out_b); - const int input_index = Offset(input_dims, in_d, in_w, in_h, in_b); + Offset(output_shape, out_b, out_h, out_w, out_d); output_data[output_index] = input_data[input_index]; } @@ -448,19 +459,42 @@ inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims, } } +// Legacy Dims<4>. template <typename T> -inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims, +inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims, int block_size, T* output_data, const Dims<4>& output_dims) { - const int input_depth = ArraySize(input_dims, 0); - const int input_width = ArraySize(input_dims, 1); - const int input_height = ArraySize(input_dims, 2); - const int input_batch = ArraySize(input_dims, 3); + tflite::DepthToSpaceParams op_params; + op_params.block_size = block_size; - const int output_depth = ArraySize(output_dims, 0); - const int output_width = ArraySize(output_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_batch = ArraySize(output_dims, 3); + DepthToSpace(op_params, DimsToShape(input_dims), input_data, + DimsToShape(output_dims), output_data); +} + +template <typename T> +inline void SpaceToDepth(const tflite::SpaceToDepthParams& op_params, + const RuntimeShape& unextended_input_shape, + const T* input_data, + const RuntimeShape& unextended_output_shape, + T* output_data) { + TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); + RuntimeShape input_shape = + RuntimeShape::ExtendedShape(4, unextended_input_shape); + RuntimeShape output_shape = + RuntimeShape::ExtendedShape(4, unextended_output_shape); + + const int input_depth = input_shape.Dims(3); + const int input_width = input_shape.Dims(2); + const int input_height = input_shape.Dims(1); + const int input_batch = input_shape.Dims(0); + + const int output_depth = output_shape.Dims(3); + const int output_width = output_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_batch = output_shape.Dims(0); + + const int32 block_size = op_params.block_size; TFLITE_DCHECK_EQ(input_width, output_width * block_size); TFLITE_DCHECK_EQ(input_height, output_height * block_size); @@ -478,9 +512,9 @@ inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims, const int out_h = in_h / block_size; const int out_b = in_b; + const int input_index = Offset(input_shape, in_b, in_h, in_w, in_d); const int output_index = - Offset(output_dims, out_d, out_w, out_h, out_b); - const int input_index = Offset(input_dims, in_d, in_w, in_h, in_b); + Offset(output_shape, out_b, out_h, out_w, out_d); output_data[output_index] = input_data[input_index]; } @@ -489,6 +523,18 @@ inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims, } } +// Legacy Dims<4>. +template <typename T> +inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims, + int block_size, T* output_data, + const Dims<4>& output_dims) { + tflite::SpaceToDepthParams op_params; + op_params.block_size = block_size; + + SpaceToDepth(op_params, DimsToShape(input_dims), input_data, + DimsToShape(output_dims), output_data); +} + inline void FullyConnected(const float* input_data, const Dims<4>& input_dims, const float* weights_data, const Dims<4>& weights_dims, const float* bias_data, @@ -3467,45 +3513,56 @@ inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims, } template <typename T> -inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, - const int32* block_shape_data, - const Dims<4>& block_shape_dims, - const int32* paddings_data, - const Dims<4>& paddings_dims, T* output_data, - const Dims<4>& output_dims, - const int32_t pad_value) { - const int output_batch_size = ArraySize(output_dims, 3); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); - const int input_batch_size = ArraySize(input_dims, 3); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int depth = ArraySize(input_dims, 0); +inline void SpaceToBatchND( + const SpaceToBatchParams& params, + const RuntimeShape& unextended_input1_shape, const T* input1_data, + const RuntimeShape& unextended_input2_shape, const int32* block_shape_data, + const RuntimeShape& unextended_input3_shape, const int32* paddings_data, + const RuntimeShape& unextended_output_shape, T* output_data) { + TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); + RuntimeShape input1_shape = + RuntimeShape::ExtendedShape(4, unextended_input1_shape); + RuntimeShape output_shape = + RuntimeShape::ExtendedShape(4, unextended_output_shape); + + const int depth = input1_shape.Dims(3); + const int input_width = input1_shape.Dims(2); + const int input_height = input1_shape.Dims(1); + const int input_batch_size = input1_shape.Dims(0); + + const int output_width = output_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_batch_size = output_shape.Dims(0); + const int block_shape_height = block_shape_data[0]; const int block_shape_width = block_shape_data[1]; const int padding_top = paddings_data[0]; const int padding_left = paddings_data[2]; + // For uint8 quantized, the correct padding "zero value" is the output offset. + const int32_t pad_value = params.output_offset; + for (int out_b = 0; out_b < output_batch_size; ++out_b) { int input_batch = out_b % input_batch_size; int shift_w = (out_b / input_batch_size) % block_shape_width; int shift_h = (out_b / input_batch_size) / block_shape_width; for (int out_h = 0; out_h < output_height; ++out_h) { for (int out_w = 0; out_w < output_width; ++out_w) { - T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_b); + T* out = output_data + Offset(output_shape, out_b, out_h, out_w, 0); if (out_h * block_shape_height + shift_h < padding_top || out_h * block_shape_height + shift_h >= padding_top + input_height || out_w * block_shape_width + shift_w < padding_left || out_w * block_shape_width + shift_w >= padding_left + input_width) { + // This may not execute correctly when pad_value != 0 and T != uint8. memset(out, pad_value, depth * sizeof(T)); } else { const T* in = - input_data + - Offset(input_dims, 0, - (out_w * block_shape_width + shift_w) - padding_left, + input1_data + + Offset(input1_shape, input_batch, (out_h * block_shape_height + shift_h) - padding_top, - input_batch); + (out_w * block_shape_width + shift_w) - padding_left, 0); memcpy(out, in, depth * sizeof(T)); } } @@ -3513,30 +3570,63 @@ inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, } } +// Legacy Dims<4>. template <typename T> inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, const int32* block_shape_data, const Dims<4>& block_shape_dims, const int32* paddings_data, const Dims<4>& paddings_dims, T* output_data, - const Dims<4>& output_dims) { - SpaceToBatchND(input_data, input_dims, block_shape_data, block_shape_dims, - paddings_data, paddings_dims, output_data, output_dims, 0); + const Dims<4>& output_dims, + const int32_t pad_value) { + tflite::SpaceToBatchParams op_params; + op_params.output_offset = pad_value; + + SpaceToBatchND(op_params, DimsToShape(input_dims), input_data, + DimsToShape(block_shape_dims), block_shape_data, + DimsToShape(paddings_dims), paddings_data, + DimsToShape(output_dims), output_data); } +// Legacy if no good reason to have signature with pad_value=0. template <typename T> -inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims, +inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, const int32* block_shape_data, const Dims<4>& block_shape_dims, - const int32* crops_data, const Dims<4>& crops_dims, - T* output_data, const Dims<4>& output_dims) { - const int output_batch_size = ArraySize(output_dims, 3); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); - const int input_batch_size = ArraySize(input_dims, 3); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int depth = ArraySize(input_dims, 0); + const int32* paddings_data, + const Dims<4>& paddings_dims, T* output_data, + const Dims<4>& output_dims) { + tflite::SpaceToBatchParams op_params; + op_params.output_offset = 0; + + SpaceToBatchND(op_params, DimsToShape(input_dims), input_data, + DimsToShape(block_shape_dims), block_shape_data, + DimsToShape(paddings_dims), paddings_data, + DimsToShape(output_dims), output_data); +} + +template <typename T> +inline void BatchToSpaceND( + const RuntimeShape& unextended_input1_shape, const T* input1_data, + const RuntimeShape& unextended_input2_shape, const int32* block_shape_data, + const RuntimeShape& unextended_input3_shape, const int32* crops_data, + const RuntimeShape& unextended_output_shape, T* output_data) { + TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); + RuntimeShape input1_shape = + RuntimeShape::ExtendedShape(4, unextended_input1_shape); + RuntimeShape output_shape = + RuntimeShape::ExtendedShape(4, unextended_output_shape); + + const int output_width = output_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_batch_size = output_shape.Dims(0); + + const int depth = input1_shape.Dims(3); + const int input_width = input1_shape.Dims(2); + const int input_height = input1_shape.Dims(1); + const int input_batch_size = input1_shape.Dims(0); + const int block_shape_width = block_shape_data[1]; const int block_shape_height = block_shape_data[0]; const int crops_top = crops_data[0]; @@ -3558,14 +3648,28 @@ inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims, if (out_w < 0 || out_w >= output_width) { continue; } - T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_batch); - const T* in = input_data + Offset(input_dims, 0, in_w, in_h, in_batch); + T* out = output_data + Offset(output_shape, out_batch, out_h, out_w, 0); + const T* in = + input1_data + Offset(input1_shape, in_batch, in_h, in_w, 0); memcpy(out, in, depth * sizeof(T)); } } } } +// Legacy Dims<4>. +template <typename T> +inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims, + const int32* block_shape_data, + const Dims<4>& block_shape_dims, + const int32* crops_data, const Dims<4>& crops_dims, + T* output_data, const Dims<4>& output_dims) { + BatchToSpaceND(DimsToShape(input_dims), input_data, + DimsToShape(block_shape_dims), block_shape_data, + DimsToShape(crops_dims), crops_data, DimsToShape(output_dims), + output_data); +} + // There are two versions of pad: Pad and PadV2. In PadV2 there is a second // scalar input that provides the padding value. Therefore pad_value_ptr can be // equivalent to a simple input1_data. For Pad, it should point to a zero diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h index 27b78aa225..2603ed2eb7 100644 --- a/tensorflow/contrib/lite/kernels/internal/types.h +++ b/tensorflow/contrib/lite/kernels/internal/types.h @@ -745,7 +745,7 @@ struct ConvParams { }; struct DepthToSpaceParams { - int16 block_size; + int32 block_size; }; struct DepthwiseParams { @@ -871,8 +871,13 @@ struct SoftmaxParams { int diff_min; }; +struct SpaceToBatchParams { + // "Zero" padding for uint8 means padding with the output offset. + int32 output_offset; +}; + struct SpaceToDepthParams { - int16 block_size; + int32 block_size; }; struct SplitParams { |