diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-17 12:20:03 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-17 12:26:06 -0700 |
commit | 0d9868d8f9c01c1402ae99d672599c4bac6e787d (patch) | |
tree | a3d0570efa9706ae4e84b0cee31a7a04addfeb40 /tensorflow/contrib/lite/kernels | |
parent | 779d87cfc1421eb6be2f9cc4ae29bca77c8d2929 (diff) |
Convert more kernel signatures to use runtime shapes.
PiperOrigin-RevId: 213316034
Diffstat (limited to 'tensorflow/contrib/lite/kernels')
-rw-r--r-- | tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h | 215 |
1 files changed, 165 insertions, 50 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index 77927af227..09a4ba7701 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -511,24 +511,25 @@ inline void SpaceToDepth(const tflite::SpaceToDepthParams& op_params, } } -inline void FullyConnected(const float* input_data, const Dims<4>& input_dims, - const float* weights_data, - const Dims<4>& weights_dims, const float* bias_data, - const Dims<4>& bias_dims, - float output_activation_min, - float output_activation_max, float* output_data, - const Dims<4>& output_dims) { +inline void FullyConnected( + const FullyConnectedParams& params, const RuntimeShape& input_shape, + const float* input_data, const RuntimeShape& weights_shape, + const float* weights_data, const RuntimeShape& bias_shape, + const float* bias_data, const RuntimeShape& output_shape, + float* output_data) { + const float output_activation_min = params.float_activation_min; + const float output_activation_max = params.float_activation_max; // TODO(benoitjacob): This really should be: // const int batches = ArraySize(output_dims, 1); // but the current --variable_batch hack consists in overwriting the 3rd // dimension with the runtime batch size, as we don't keep track for each // array of which dimension is the batch dimension in it. - const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) * - ArraySize(output_dims, 3); - const int output_depth = MatchingArraySize(weights_dims, 1, output_dims, 0); - const int accum_depth = ArraySize(weights_dims, 0); - TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); - TFLITE_DCHECK(IsPackedWithoutStrides(weights_dims)); + const int output_dims_count = output_shape.DimensionsCount(); + const int weights_dims_count = weights_shape.DimensionsCount(); + const int batches = FlatSizeSkipDim(output_shape, output_dims_count - 1); + const int output_depth = MatchingDim(weights_shape, weights_dims_count - 2, + output_shape, output_dims_count - 1); + const int accum_depth = weights_shape.Dims(weights_dims_count - 1); for (int b = 0; b < batches; ++b) { for (int out_c = 0; out_c < output_depth; ++out_c) { float total = 0.f; @@ -538,7 +539,7 @@ inline void FullyConnected(const float* input_data, const Dims<4>& input_dims, } float bias_value = 0.0f; if (bias_data) { - bias_value = bias_data[Offset(bias_dims, out_c, 0, 0, 0)]; + bias_value = bias_data[out_c]; } output_data[out_c + output_depth * b] = ActivationFunctionWithMinMax( total + bias_value, output_activation_min, output_activation_max); @@ -546,6 +547,26 @@ inline void FullyConnected(const float* input_data, const Dims<4>& input_dims, } } +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy. +inline void FullyConnected(const float* input_data, const Dims<4>& input_dims, + const float* weights_data, + const Dims<4>& weights_dims, const float* bias_data, + const Dims<4>& bias_dims, + float output_activation_min, + float output_activation_max, float* output_data, + const Dims<4>& output_dims) { + tflite::FullyConnectedParams op_params; + op_params.float_activation_min = output_activation_min; + op_params.float_activation_max = output_activation_max; + + FullyConnected(op_params, DimsToShape(input_dims), input_data, + DimsToShape(weights_dims), weights_data, + DimsToShape(bias_dims), bias_data, DimsToShape(output_dims), + output_data); +} + +// TODO(b/80418076): Move to legacy ops file, update invocations. // legacy, for compatibility with old checked-in code template <FusedActivationFunctionType Ac> void FullyConnected(const float* input_data, const Dims<4>& input_dims, @@ -559,28 +580,35 @@ void FullyConnected(const float* input_data, const Dims<4>& input_dims, output_data, output_dims); } -inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, - int32 input_offset, const uint8* filter_data, - const Dims<4>& filter_dims, int32 filter_offset, - const int32* bias_data, const Dims<4>& bias_dims, - int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims, - gemmlowp::GemmContext* gemm_context) { +inline void FullyConnected( + const FullyConnectedParams& params, const RuntimeShape& input_shape, + const uint8* input_data, const RuntimeShape& filter_shape, + const uint8* filter_data, const RuntimeShape& bias_shape, + const int32* bias_data, const RuntimeShape& output_shape, + uint8* output_data, gemmlowp::GemmContext* gemm_context) { (void)gemm_context; // only used in optimized code. + const int32 input_offset = params.input_offset; + const int32 filter_offset = params.weights_offset; + const int32 output_offset = params.output_offset; + const int32 output_multiplier = params.output_multiplier; + const int output_shift = params.output_shift; + const int32 output_activation_min = params.quantized_activation_min; + const int32 output_activation_max = params.quantized_activation_max; + TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2); + TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); // TODO(benoitjacob): This really should be: // const int batches = ArraySize(output_dims, 1); // but the current --variable_batch hack consists in overwriting the 3rd // dimension with the runtime batch size, as we don't keep track for each // array of which dimension is the batch dimension in it. - const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) * - ArraySize(output_dims, 3); - const int output_depth = MatchingArraySize(filter_dims, 1, output_dims, 0); - const int accum_depth = ArraySize(filter_dims, 0); - TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); - TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims)); + const int output_dim_count = output_shape.DimensionsCount(); + const int filter_dim_count = filter_shape.DimensionsCount(); + const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1); + const int output_depth = MatchingDim(filter_shape, filter_dim_count - 2, + output_shape, output_dim_count - 1); + const int accum_depth = filter_shape.Dims(filter_dim_count - 1); for (int b = 0; b < batches; ++b) { for (int out_c = 0; out_c < output_depth; ++out_c) { int32 acc = 0; @@ -590,7 +618,7 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, acc += (filter_val + filter_offset) * (input_val + input_offset); } if (bias_data) { - acc += bias_data[Offset(bias_dims, out_c, 0, 0, 0)]; + acc += bias_data[out_c]; } acc = MultiplyByQuantizedMultiplier(acc, output_multiplier, kReverseShift * output_shift); @@ -602,16 +630,47 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, } } +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy. inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, int32 input_offset, const uint8* filter_data, const Dims<4>& filter_dims, int32 filter_offset, const int32* bias_data, const Dims<4>& bias_dims, int32 output_offset, int32 output_multiplier, int output_shift, int32 output_activation_min, - int32 output_activation_max, int16* output_data, + int32 output_activation_max, uint8* output_data, const Dims<4>& output_dims, gemmlowp::GemmContext* gemm_context) { + tflite::FullyConnectedParams op_params; + op_params.input_offset = input_offset; + op_params.weights_offset = filter_offset; + op_params.output_offset = output_offset; + op_params.output_multiplier = output_multiplier; + op_params.output_shift = output_shift; + op_params.quantized_activation_min = output_activation_min; + op_params.quantized_activation_max = output_activation_max; + + FullyConnected(op_params, DimsToShape(input_dims), input_data, + DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims), + bias_data, DimsToShape(output_dims), output_data, + gemm_context); +} + +inline void FullyConnected( + const FullyConnectedParams& params, const RuntimeShape& input_shape, + const uint8* input_data, const RuntimeShape& filter_shape, + const uint8* filter_data, const RuntimeShape& bias_shape, + const int32* bias_data, const RuntimeShape& output_shape, + int16* output_data, gemmlowp::GemmContext* gemm_context) { (void)gemm_context; // only used in optimized code. + const int32 input_offset = params.input_offset; + const int32 filter_offset = params.weights_offset; + const int32 output_offset = params.output_offset; + const int32 output_multiplier = params.output_multiplier; + const int output_shift = params.output_shift; + const int32 output_activation_min = params.quantized_activation_min; + const int32 output_activation_max = params.quantized_activation_max; + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); TFLITE_DCHECK_EQ(output_offset, 0); // TODO(benoitjacob): This really should be: @@ -619,12 +678,12 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, // but the current --variable_batch hack consists in overwriting the 3rd // dimension with the runtime batch size, as we don't keep track for each // array of which dimension is the batch dimension in it. - const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) * - ArraySize(output_dims, 3); - const int output_depth = MatchingArraySize(filter_dims, 1, output_dims, 0); - const int accum_depth = ArraySize(filter_dims, 0); - TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); - TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims)); + const int output_dim_count = output_shape.DimensionsCount(); + const int filter_dim_count = filter_shape.DimensionsCount(); + const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1); + const int output_depth = MatchingDim(filter_shape, filter_dim_count - 2, + output_shape, output_dim_count - 1); + const int accum_depth = filter_shape.Dims(filter_dim_count - 1); for (int b = 0; b < batches; ++b) { for (int out_c = 0; out_c < output_depth; ++out_c) { // Internal accumulation. @@ -651,27 +710,60 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, } } +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy. +inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, int16* output_data, + const Dims<4>& output_dims, + gemmlowp::GemmContext* gemm_context) { + tflite::FullyConnectedParams op_params; + op_params.input_offset = input_offset; + op_params.weights_offset = filter_offset; + op_params.output_offset = output_offset; + op_params.output_multiplier = output_multiplier; + op_params.output_shift = output_shift; + op_params.quantized_activation_min = output_activation_min; + op_params.quantized_activation_max = output_activation_max; + + FullyConnected(op_params, DimsToShape(input_dims), input_data, + DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims), + bias_data, DimsToShape(output_dims), output_data, + gemm_context); +} + inline void ShuffledFullyConnected( - const uint8* input_data, const Dims<4>& input_dims, - const uint8* shuffled_weights_data, const Dims<4>& weights_dims, - const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier, - int output_shift, int32 output_activation_min, int32 output_activation_max, - int16* output_data, const Dims<4>& output_dims, - uint8* shuffled_input_workspace_data, gemmlowp::GemmContext* gemm_context) { + const FullyConnectedParams& params, const RuntimeShape& input_shape, + const uint8* input_data, const RuntimeShape& weights_shape, + const uint8* shuffled_weights_data, const RuntimeShape& bias_shape, + const int32* bias_data, const RuntimeShape& output_shape, + int16* output_data, uint8* shuffled_input_workspace_data, + gemmlowp::GemmContext* gemm_context) { (void)gemm_context; // only used in optimized code. - + const int32 output_multiplier = params.output_multiplier; + const int output_shift = params.output_shift; + const int32 output_activation_min = params.quantized_activation_min; + const int32 output_activation_max = params.quantized_activation_max; TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + + TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1); + TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2); + TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1); // TODO(benoitjacob): This really should be: // const int batches = ArraySize(output_dims, 1); // but the current --variable_batch hack consists in overwriting the 3rd // dimension with the runtime batch size, as we don't keep track for each // array of which dimension is the batch dimension in it. - const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) * - ArraySize(output_dims, 3); - const int output_depth = MatchingArraySize(weights_dims, 1, output_dims, 0); - const int accum_depth = ArraySize(weights_dims, 0); - TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); - TFLITE_DCHECK(IsPackedWithoutStrides(weights_dims)); + const int output_dim_count = output_shape.DimensionsCount(); + const int weights_dim_count = weights_shape.DimensionsCount(); + const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1); + const int output_depth = MatchingDim(weights_shape, weights_dim_count - 2, + output_shape, output_dim_count - 1); + const int accum_depth = weights_shape.Dims(weights_dim_count - 1); TFLITE_DCHECK((accum_depth % 16) == 0); TFLITE_DCHECK((output_depth % 4) == 0); @@ -799,6 +891,29 @@ inline void ShuffledFullyConnected( } } +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy. +inline void ShuffledFullyConnected( + const uint8* input_data, const Dims<4>& input_dims, + const uint8* shuffled_weights_data, const Dims<4>& weights_dims, + const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier, + int output_shift, int32 output_activation_min, int32 output_activation_max, + int16* output_data, const Dims<4>& output_dims, + uint8* shuffled_input_workspace_data, gemmlowp::GemmContext* gemm_context) { + tflite::FullyConnectedParams op_params; + op_params.output_multiplier = output_multiplier; + op_params.output_shift = output_shift; + op_params.quantized_activation_min = output_activation_min; + op_params.quantized_activation_max = output_activation_max; + + ShuffledFullyConnected(op_params, DimsToShape(input_dims), input_data, + DimsToShape(weights_dims), shuffled_weights_data, + DimsToShape(bias_dims), bias_data, + DimsToShape(output_dims), output_data, + shuffled_input_workspace_data, gemm_context); +} + +// TODO(b/80418076): Move to legacy ops file, update invocations. // legacy, for compatibility with old checked-in code template <FusedActivationFunctionType Ac> void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, |