diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-19 10:32:48 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-19 10:36:53 -0700 |
commit | 5d5bc6d2b592374d7862cdebbc53e07b47e29c95 (patch) | |
tree | 97fd843d458c7f26d5e51872c0b4fc61dd3886ba /tensorflow/contrib/lite/kernels | |
parent | 414ca1cda5aec72b48d5da127f61b0d05fbdc22c (diff) |
Convert more kernel signatures to use runtime shapes.
PiperOrigin-RevId: 213651158
Diffstat (limited to 'tensorflow/contrib/lite/kernels')
-rw-r--r-- | tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h | 292 |
1 files changed, 210 insertions, 82 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index d315debdda..76fa1944bc 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -2412,56 +2412,90 @@ void DepthConcatenation(const Scalar* const* input_data, output_data, output_dims); } -inline void LstmCell(const float* input_data, const Dims<4>& input_dims, - const float* prev_activ_data, - const Dims<4>& prev_activ_dims, const float* weights_data, - const Dims<4>& weights_dims, const float* bias_data, - const Dims<4>& bias_dims, const float* prev_state_data, - const Dims<4>& prev_state_dims, float* output_state_data, - const Dims<4>& output_state_dims, float* output_activ_data, - const Dims<4>& output_activ_dims, float* concat_temp_data, - const Dims<4>& concat_temp_dims, float* activ_temp_data, - const Dims<4>& activ_temp_dims) { +inline void LstmCell( + const LstmCellParams& params, const RuntimeShape& unextended_input_shape, + const float* input_data, const RuntimeShape& unextended_prev_activ_shape, + const float* prev_activ_data, const RuntimeShape& weights_shape, + const float* weights_data, const RuntimeShape& unextended_bias_shape, + const float* bias_data, const RuntimeShape& unextended_prev_state_shape, + const float* prev_state_data, + const RuntimeShape& unextended_output_state_shape, float* output_state_data, + const RuntimeShape& unextended_output_activ_shape, float* output_activ_data, + const RuntimeShape& unextended_concat_temp_shape, float* concat_temp_data, + const RuntimeShape& unextended_activ_temp_shape, float* activ_temp_data) { + TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4); + const RuntimeShape input_shape = + RuntimeShape::ExtendedShape(4, unextended_input_shape); + const RuntimeShape prev_activ_shape = + RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape); + const RuntimeShape bias_shape = + RuntimeShape::ExtendedShape(4, unextended_bias_shape); + const RuntimeShape prev_state_shape = + RuntimeShape::ExtendedShape(4, unextended_prev_state_shape); + const RuntimeShape output_state_shape = + RuntimeShape::ExtendedShape(4, unextended_output_state_shape); + const RuntimeShape output_activ_shape = + RuntimeShape::ExtendedShape(4, unextended_output_activ_shape); + const RuntimeShape concat_temp_shape = + RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape); + const RuntimeShape activ_temp_shape = + RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape); + TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2); + + const int weights_dim_count = weights_shape.DimensionsCount(); const int batches = - MatchingArraySize(input_dims, 3, prev_activ_dims, 3, prev_state_dims, 3, - output_state_dims, 3, output_activ_dims, 3); + MatchingDim(input_shape, 0, prev_activ_shape, 0, prev_state_shape, 0, + output_state_shape, 0, output_activ_shape, 0); const int height = - MatchingArraySize(input_dims, 2, prev_activ_dims, 2, prev_state_dims, 2, - output_state_dims, 2, output_activ_dims, 2); + MatchingDim(input_shape, 1, prev_activ_shape, 1, prev_state_shape, 1, + output_state_shape, 1, output_activ_shape, 1); const int width = - MatchingArraySize(input_dims, 1, prev_activ_dims, 1, prev_state_dims, 1, - output_state_dims, 1, output_activ_dims, 1); - TFLITE_CHECK_EQ(ArraySize(weights_dims, 2), 1); - TFLITE_CHECK_EQ(ArraySize(weights_dims, 3), 1); - const int input_depth = ArraySize(input_dims, 0); - const int prev_activ_depth = ArraySize(prev_activ_dims, 0); + MatchingDim(input_shape, 2, prev_activ_shape, 2, prev_state_shape, 2, + output_state_shape, 2, output_activ_shape, 2); + const int input_depth = input_shape.Dims(3); + const int prev_activ_depth = prev_activ_shape.Dims(3); const int total_input_depth = prev_activ_depth + input_depth; - TFLITE_CHECK_EQ(ArraySize(weights_dims, 0), total_input_depth); - TFLITE_CHECK_EQ(MatchingArraySize(bias_dims, 1, bias_dims, 2, bias_dims, 3), - 1); + TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1), + total_input_depth); + TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1); const int intern_activ_depth = - MatchingArraySize(weights_dims, 1, bias_dims, 0); - TFLITE_CHECK_EQ(intern_activ_depth % 4, 0); + MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3); + TFLITE_DCHECK_EQ(weights_shape.FlatSize(), + intern_activ_depth * total_input_depth); + TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0); const int output_depth = - MatchingArraySize(prev_state_dims, 0, prev_activ_dims, 0, - output_state_dims, 0, output_activ_dims, 0); - TFLITE_CHECK_EQ(output_depth, intern_activ_depth / 4); + MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape, + 3, output_activ_shape, 3); + TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4); // Concatenate prev_activ and input data together std::vector<float const*> concat_input_arrays_data; - std::vector<Dims<4> const*> concat_input_arrays_dims; + std::vector<RuntimeShape const*> concat_input_arrays_shapes; concat_input_arrays_data.push_back(input_data); concat_input_arrays_data.push_back(prev_activ_data); - concat_input_arrays_dims.push_back(&input_dims); - concat_input_arrays_dims.push_back(&prev_activ_dims); - Concatenation<FusedActivationFunctionType::kNone, float>( - 0, &(concat_input_arrays_data[0]), &(concat_input_arrays_dims[0]), - concat_input_arrays_data.size(), concat_temp_data, concat_temp_dims); + concat_input_arrays_shapes.push_back(&input_shape); + concat_input_arrays_shapes.push_back(&prev_activ_shape); + tflite::ConcatenationParams concat_params; + concat_params.axis = 3; + concat_params.inputs_count = concat_input_arrays_data.size(); + Concatenation(concat_params, &(concat_input_arrays_shapes[0]), + &(concat_input_arrays_data[0]), concat_temp_shape, + concat_temp_data); // Fully connected - FullyConnected<FusedActivationFunctionType::kNone>( - concat_temp_data, concat_temp_dims, weights_data, weights_dims, bias_data, - bias_dims, activ_temp_data, activ_temp_dims); + tflite::FullyConnectedParams fc_params; + fc_params.float_activation_min = std::numeric_limits<float>::lowest(); + fc_params.float_activation_max = std::numeric_limits<float>::max(); + FullyConnected(fc_params, concat_temp_shape, concat_temp_data, weights_shape, + weights_data, bias_shape, bias_data, activ_temp_shape, + activ_temp_data); // Memory state update (the LSTM "guts") for (int b = 0; b < batches; ++b) { @@ -2470,24 +2504,24 @@ inline void LstmCell(const float* input_data, const Dims<4>& input_dims, for (int c = 0; c < output_depth; ++c) { const float input_gate = 1.f / - (1.f + std::exp(-activ_temp_data[Offset( - activ_temp_dims, 0 * output_depth + c, w, h, b)])); + (1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w, + 0 * output_depth + c)])); const float new_input = std::tanh(activ_temp_data[Offset( - activ_temp_dims, 1 * output_depth + c, w, h, b)]); + activ_temp_shape, b, h, w, 1 * output_depth + c)]); const float forget_gate = 1.f / - (1.f + std::exp(-activ_temp_data[Offset( - activ_temp_dims, 2 * output_depth + c, w, h, b)])); + (1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w, + 2 * output_depth + c)])); const float output_gate = 1.f / - (1.f + std::exp(-activ_temp_data[Offset( - activ_temp_dims, 3 * output_depth + c, w, h, b)])); + (1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w, + 3 * output_depth + c)])); const float new_state = input_gate * new_input + forget_gate * - prev_state_data[Offset(prev_state_dims, c, w, h, b)]; - output_state_data[Offset(output_state_dims, c, w, h, b)] = new_state; - output_activ_data[Offset(output_activ_dims, c, w, h, b)] = + prev_state_data[Offset(prev_state_shape, b, h, w, c)]; + output_state_data[Offset(output_state_shape, b, h, w, c)] = new_state; + output_activ_data[Offset(output_activ_shape, b, h, w, c)] = output_gate * std::tanh(new_state); } } @@ -2495,6 +2529,31 @@ inline void LstmCell(const float* input_data, const Dims<4>& input_dims, } } +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy. +inline void LstmCell(const float* input_data, const Dims<4>& input_dims, + const float* prev_activ_data, + const Dims<4>& prev_activ_dims, const float* weights_data, + const Dims<4>& weights_dims, const float* bias_data, + const Dims<4>& bias_dims, const float* prev_state_data, + const Dims<4>& prev_state_dims, float* output_state_data, + const Dims<4>& output_state_dims, float* output_activ_data, + const Dims<4>& output_activ_dims, float* concat_temp_data, + const Dims<4>& concat_temp_dims, float* activ_temp_data, + const Dims<4>& activ_temp_dims) { + tflite::LstmCellParams op_params; + // Float LSTM cell does not need parameters to be set: leave untouched. + + LstmCell(op_params, DimsToShape(input_dims), input_data, + DimsToShape(prev_activ_dims), prev_activ_data, + DimsToShape(weights_dims), weights_data, DimsToShape(bias_dims), + bias_data, DimsToShape(prev_state_dims), prev_state_data, + DimsToShape(output_state_dims), output_state_data, + DimsToShape(output_activ_dims), output_activ_data, + DimsToShape(concat_temp_dims), concat_temp_data, + DimsToShape(activ_temp_dims), activ_temp_data); +} + // Quantized LSTM cell implementation. // The quantization of the input, output arrays is as follows: // - The input activations are quantized as uint8 on the interval @@ -2580,52 +2639,90 @@ inline void LstmCell(const float* input_data, const Dims<4>& input_dims, // aiming for 16-bit fixed-point quantization of these internal nodes here. // template <int StateIntegerBits> -void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims, - const uint8* prev_activ_data_uint8, - const Dims<4>& prev_activ_dims, const uint8* weights_data_uint8, - const Dims<4>& weights_dims, const int32* bias_data_int32, - const Dims<4>& bias_dims, const int16* prev_state_data_int16, - const Dims<4>& prev_state_dims, int16* output_state_data_int16, - const Dims<4>& output_state_dims, uint8* output_activ_data_uint8, - const Dims<4>& output_activ_dims, uint8* concat_temp_data_uint8, - const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16, - const Dims<4>& activ_temp_dims, int32 weights_zero_point, - int32 accum_multiplier, int accum_shift, - gemmlowp::GemmContext* gemm_context) { +inline void LstmCell( + const LstmCellParams& params, const RuntimeShape& unextended_input_shape, + const uint8* input_data_uint8, + const RuntimeShape& unextended_prev_activ_shape, + const uint8* prev_activ_data_uint8, const RuntimeShape& weights_shape, + const uint8* weights_data_uint8, const RuntimeShape& unextended_bias_shape, + const int32* bias_data_int32, + const RuntimeShape& unextended_prev_state_shape, + const int16* prev_state_data_int16, + const RuntimeShape& unextended_output_state_shape, + int16* output_state_data_int16, + const RuntimeShape& unextended_output_activ_shape, + uint8* output_activ_data_uint8, + const RuntimeShape& unextended_concat_temp_shape, + uint8* concat_temp_data_uint8, + const RuntimeShape& unextended_activ_temp_shape, + int16* activ_temp_data_int16, gemmlowp::GemmContext* gemm_context) { (void)gemm_context; // only used in optimized code. + int32 weights_zero_point = params.weights_zero_point; + int32 accum_multiplier = params.accum_multiplier; + int accum_shift = params.accum_shift; + TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4); + const RuntimeShape input_shape = + RuntimeShape::ExtendedShape(4, unextended_input_shape); + const RuntimeShape prev_activ_shape = + RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape); + const RuntimeShape bias_shape = + RuntimeShape::ExtendedShape(4, unextended_bias_shape); + const RuntimeShape prev_state_shape = + RuntimeShape::ExtendedShape(4, unextended_prev_state_shape); + const RuntimeShape output_state_shape = + RuntimeShape::ExtendedShape(4, unextended_output_state_shape); + const RuntimeShape output_activ_shape = + RuntimeShape::ExtendedShape(4, unextended_output_activ_shape); + const RuntimeShape concat_temp_shape = + RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape); + const RuntimeShape activ_temp_shape = + RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape); + TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2); // Gather dimensions information, and perform consistency checks. - const int outer_size = - MatchingFlatSizeSkipDim(input_dims, 0, prev_activ_dims, prev_state_dims, - output_state_dims, output_activ_dims); - TFLITE_CHECK_EQ(ArraySize(weights_dims, 2), 1); - TFLITE_CHECK_EQ(ArraySize(weights_dims, 3), 1); - const int input_depth = ArraySize(input_dims, 0); - const int prev_activ_depth = ArraySize(prev_activ_dims, 0); + const int weights_dim_count = weights_shape.DimensionsCount(); + const int outer_size = MatchingFlatSizeSkipDim( + input_shape, 3, prev_activ_shape, prev_state_shape, output_state_shape, + output_activ_shape); + const int input_depth = input_shape.Dims(3); + const int prev_activ_depth = prev_activ_shape.Dims(3); const int total_input_depth = prev_activ_depth + input_depth; - TFLITE_CHECK_EQ(ArraySize(weights_dims, 0), total_input_depth); - TFLITE_CHECK_EQ(MatchingArraySize(bias_dims, 1, bias_dims, 2, bias_dims, 3), - 1); + TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1), + total_input_depth); const int intern_activ_depth = - MatchingArraySize(weights_dims, 1, bias_dims, 0); - TFLITE_CHECK_EQ(intern_activ_depth % 4, 0); + MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3); + TFLITE_DCHECK_EQ(weights_shape.FlatSize(), + intern_activ_depth * total_input_depth); + TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1); + TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0); const int output_depth = - MatchingArraySize(prev_state_dims, 0, prev_activ_dims, 0, - output_state_dims, 0, output_activ_dims, 0); - TFLITE_CHECK_EQ(output_depth, intern_activ_depth / 4); - const int fc_batches = FlatSizeSkipDim(activ_temp_dims, 0); + MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape, + 3, output_activ_shape, 3); + TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4); + const int fc_batches = FlatSizeSkipDim(activ_temp_shape, 3); const int fc_output_depth = - MatchingArraySize(weights_dims, 1, activ_temp_dims, 0); - const int fc_accum_depth = ArraySize(weights_dims, 0); - TFLITE_CHECK_EQ(fc_output_depth, 4 * output_depth); + MatchingDim(weights_shape, weights_dim_count - 2, activ_temp_shape, 3); + const int fc_accum_depth = total_input_depth; + TFLITE_DCHECK_EQ(fc_output_depth, 4 * output_depth); // Depth-concatenate prev_activ and input data together. uint8 const* concat_input_arrays_data[2] = {input_data_uint8, prev_activ_data_uint8}; - Dims<4> const* concat_input_arrays_dims[2] = {&input_dims, &prev_activ_dims}; - Concatenation<FusedActivationFunctionType::kNone, uint8>( - 0, concat_input_arrays_data, concat_input_arrays_dims, 2, - concat_temp_data_uint8, concat_temp_dims); + const RuntimeShape* concat_input_arrays_shapes[2] = {&input_shape, + &prev_activ_shape}; + tflite::ConcatenationParams concat_params; + concat_params.axis = 3; + concat_params.inputs_count = 2; + Concatenation(concat_params, concat_input_arrays_shapes, + concat_input_arrays_data, concat_temp_shape, + concat_temp_data_uint8); // Implementation of the fully connected node inside the LSTM cell. // The operands are 8-bit integers, the accumulators are internally 32bit @@ -2731,6 +2828,37 @@ void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims, } } +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy. +template <int StateIntegerBits> +void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims, + const uint8* prev_activ_data_uint8, + const Dims<4>& prev_activ_dims, const uint8* weights_data_uint8, + const Dims<4>& weights_dims, const int32* bias_data_int32, + const Dims<4>& bias_dims, const int16* prev_state_data_int16, + const Dims<4>& prev_state_dims, int16* output_state_data_int16, + const Dims<4>& output_state_dims, uint8* output_activ_data_uint8, + const Dims<4>& output_activ_dims, uint8* concat_temp_data_uint8, + const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16, + const Dims<4>& activ_temp_dims, int32 weights_zero_point, + int32 accum_multiplier, int accum_shift, + gemmlowp::GemmContext* gemm_context) { + tflite::LstmCellParams op_params; + op_params.weights_zero_point = weights_zero_point; + op_params.accum_multiplier = accum_multiplier; + op_params.accum_shift = accum_shift; + + LstmCell<StateIntegerBits>( + op_params, DimsToShape(input_dims), input_data_uint8, + DimsToShape(prev_activ_dims), prev_activ_data_uint8, + DimsToShape(weights_dims), weights_data_uint8, DimsToShape(bias_dims), + bias_data_int32, DimsToShape(prev_state_dims), prev_state_data_int16, + DimsToShape(output_state_dims), output_state_data_int16, + DimsToShape(output_activ_dims), output_activ_data_uint8, + DimsToShape(concat_temp_dims), concat_temp_data_uint8, + DimsToShape(activ_temp_dims), activ_temp_data_int16, gemm_context); +} + template <typename Scalar> void Split(const SplitParams& params, const RuntimeShape& input_shape, const Scalar* input_data, const RuntimeShape* const* output_shapes, |