aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-19 10:32:48 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-19 10:36:53 -0700
commit5d5bc6d2b592374d7862cdebbc53e07b47e29c95 (patch)
tree97fd843d458c7f26d5e51872c0b4fc61dd3886ba /tensorflow/contrib/lite/kernels
parent414ca1cda5aec72b48d5da127f61b0d05fbdc22c (diff)
Convert more kernel signatures to use runtime shapes.
PiperOrigin-RevId: 213651158
Diffstat (limited to 'tensorflow/contrib/lite/kernels')
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h292
1 files changed, 210 insertions, 82 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index d315debdda..76fa1944bc 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -2412,56 +2412,90 @@ void DepthConcatenation(const Scalar* const* input_data,
output_data, output_dims);
}
-inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
- const float* prev_activ_data,
- const Dims<4>& prev_activ_dims, const float* weights_data,
- const Dims<4>& weights_dims, const float* bias_data,
- const Dims<4>& bias_dims, const float* prev_state_data,
- const Dims<4>& prev_state_dims, float* output_state_data,
- const Dims<4>& output_state_dims, float* output_activ_data,
- const Dims<4>& output_activ_dims, float* concat_temp_data,
- const Dims<4>& concat_temp_dims, float* activ_temp_data,
- const Dims<4>& activ_temp_dims) {
+inline void LstmCell(
+ const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
+ const float* input_data, const RuntimeShape& unextended_prev_activ_shape,
+ const float* prev_activ_data, const RuntimeShape& weights_shape,
+ const float* weights_data, const RuntimeShape& unextended_bias_shape,
+ const float* bias_data, const RuntimeShape& unextended_prev_state_shape,
+ const float* prev_state_data,
+ const RuntimeShape& unextended_output_state_shape, float* output_state_data,
+ const RuntimeShape& unextended_output_activ_shape, float* output_activ_data,
+ const RuntimeShape& unextended_concat_temp_shape, float* concat_temp_data,
+ const RuntimeShape& unextended_activ_temp_shape, float* activ_temp_data) {
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
+ const RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ const RuntimeShape prev_activ_shape =
+ RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
+ const RuntimeShape bias_shape =
+ RuntimeShape::ExtendedShape(4, unextended_bias_shape);
+ const RuntimeShape prev_state_shape =
+ RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
+ const RuntimeShape output_state_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
+ const RuntimeShape output_activ_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
+ const RuntimeShape concat_temp_shape =
+ RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
+ const RuntimeShape activ_temp_shape =
+ RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
+ TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
+
+ const int weights_dim_count = weights_shape.DimensionsCount();
const int batches =
- MatchingArraySize(input_dims, 3, prev_activ_dims, 3, prev_state_dims, 3,
- output_state_dims, 3, output_activ_dims, 3);
+ MatchingDim(input_shape, 0, prev_activ_shape, 0, prev_state_shape, 0,
+ output_state_shape, 0, output_activ_shape, 0);
const int height =
- MatchingArraySize(input_dims, 2, prev_activ_dims, 2, prev_state_dims, 2,
- output_state_dims, 2, output_activ_dims, 2);
+ MatchingDim(input_shape, 1, prev_activ_shape, 1, prev_state_shape, 1,
+ output_state_shape, 1, output_activ_shape, 1);
const int width =
- MatchingArraySize(input_dims, 1, prev_activ_dims, 1, prev_state_dims, 1,
- output_state_dims, 1, output_activ_dims, 1);
- TFLITE_CHECK_EQ(ArraySize(weights_dims, 2), 1);
- TFLITE_CHECK_EQ(ArraySize(weights_dims, 3), 1);
- const int input_depth = ArraySize(input_dims, 0);
- const int prev_activ_depth = ArraySize(prev_activ_dims, 0);
+ MatchingDim(input_shape, 2, prev_activ_shape, 2, prev_state_shape, 2,
+ output_state_shape, 2, output_activ_shape, 2);
+ const int input_depth = input_shape.Dims(3);
+ const int prev_activ_depth = prev_activ_shape.Dims(3);
const int total_input_depth = prev_activ_depth + input_depth;
- TFLITE_CHECK_EQ(ArraySize(weights_dims, 0), total_input_depth);
- TFLITE_CHECK_EQ(MatchingArraySize(bias_dims, 1, bias_dims, 2, bias_dims, 3),
- 1);
+ TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
+ total_input_depth);
+ TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
const int intern_activ_depth =
- MatchingArraySize(weights_dims, 1, bias_dims, 0);
- TFLITE_CHECK_EQ(intern_activ_depth % 4, 0);
+ MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
+ TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
+ intern_activ_depth * total_input_depth);
+ TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
const int output_depth =
- MatchingArraySize(prev_state_dims, 0, prev_activ_dims, 0,
- output_state_dims, 0, output_activ_dims, 0);
- TFLITE_CHECK_EQ(output_depth, intern_activ_depth / 4);
+ MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
+ 3, output_activ_shape, 3);
+ TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
// Concatenate prev_activ and input data together
std::vector<float const*> concat_input_arrays_data;
- std::vector<Dims<4> const*> concat_input_arrays_dims;
+ std::vector<RuntimeShape const*> concat_input_arrays_shapes;
concat_input_arrays_data.push_back(input_data);
concat_input_arrays_data.push_back(prev_activ_data);
- concat_input_arrays_dims.push_back(&input_dims);
- concat_input_arrays_dims.push_back(&prev_activ_dims);
- Concatenation<FusedActivationFunctionType::kNone, float>(
- 0, &(concat_input_arrays_data[0]), &(concat_input_arrays_dims[0]),
- concat_input_arrays_data.size(), concat_temp_data, concat_temp_dims);
+ concat_input_arrays_shapes.push_back(&input_shape);
+ concat_input_arrays_shapes.push_back(&prev_activ_shape);
+ tflite::ConcatenationParams concat_params;
+ concat_params.axis = 3;
+ concat_params.inputs_count = concat_input_arrays_data.size();
+ Concatenation(concat_params, &(concat_input_arrays_shapes[0]),
+ &(concat_input_arrays_data[0]), concat_temp_shape,
+ concat_temp_data);
// Fully connected
- FullyConnected<FusedActivationFunctionType::kNone>(
- concat_temp_data, concat_temp_dims, weights_data, weights_dims, bias_data,
- bias_dims, activ_temp_data, activ_temp_dims);
+ tflite::FullyConnectedParams fc_params;
+ fc_params.float_activation_min = std::numeric_limits<float>::lowest();
+ fc_params.float_activation_max = std::numeric_limits<float>::max();
+ FullyConnected(fc_params, concat_temp_shape, concat_temp_data, weights_shape,
+ weights_data, bias_shape, bias_data, activ_temp_shape,
+ activ_temp_data);
// Memory state update (the LSTM "guts")
for (int b = 0; b < batches; ++b) {
@@ -2470,24 +2504,24 @@ inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
for (int c = 0; c < output_depth; ++c) {
const float input_gate =
1.f /
- (1.f + std::exp(-activ_temp_data[Offset(
- activ_temp_dims, 0 * output_depth + c, w, h, b)]));
+ (1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w,
+ 0 * output_depth + c)]));
const float new_input = std::tanh(activ_temp_data[Offset(
- activ_temp_dims, 1 * output_depth + c, w, h, b)]);
+ activ_temp_shape, b, h, w, 1 * output_depth + c)]);
const float forget_gate =
1.f /
- (1.f + std::exp(-activ_temp_data[Offset(
- activ_temp_dims, 2 * output_depth + c, w, h, b)]));
+ (1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w,
+ 2 * output_depth + c)]));
const float output_gate =
1.f /
- (1.f + std::exp(-activ_temp_data[Offset(
- activ_temp_dims, 3 * output_depth + c, w, h, b)]));
+ (1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w,
+ 3 * output_depth + c)]));
const float new_state =
input_gate * new_input +
forget_gate *
- prev_state_data[Offset(prev_state_dims, c, w, h, b)];
- output_state_data[Offset(output_state_dims, c, w, h, b)] = new_state;
- output_activ_data[Offset(output_activ_dims, c, w, h, b)] =
+ prev_state_data[Offset(prev_state_shape, b, h, w, c)];
+ output_state_data[Offset(output_state_shape, b, h, w, c)] = new_state;
+ output_activ_data[Offset(output_activ_shape, b, h, w, c)] =
output_gate * std::tanh(new_state);
}
}
@@ -2495,6 +2529,31 @@ inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
+ const float* prev_activ_data,
+ const Dims<4>& prev_activ_dims, const float* weights_data,
+ const Dims<4>& weights_dims, const float* bias_data,
+ const Dims<4>& bias_dims, const float* prev_state_data,
+ const Dims<4>& prev_state_dims, float* output_state_data,
+ const Dims<4>& output_state_dims, float* output_activ_data,
+ const Dims<4>& output_activ_dims, float* concat_temp_data,
+ const Dims<4>& concat_temp_dims, float* activ_temp_data,
+ const Dims<4>& activ_temp_dims) {
+ tflite::LstmCellParams op_params;
+ // Float LSTM cell does not need parameters to be set: leave untouched.
+
+ LstmCell(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(prev_activ_dims), prev_activ_data,
+ DimsToShape(weights_dims), weights_data, DimsToShape(bias_dims),
+ bias_data, DimsToShape(prev_state_dims), prev_state_data,
+ DimsToShape(output_state_dims), output_state_data,
+ DimsToShape(output_activ_dims), output_activ_data,
+ DimsToShape(concat_temp_dims), concat_temp_data,
+ DimsToShape(activ_temp_dims), activ_temp_data);
+}
+
// Quantized LSTM cell implementation.
// The quantization of the input, output arrays is as follows:
// - The input activations are quantized as uint8 on the interval
@@ -2580,52 +2639,90 @@ inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
// aiming for 16-bit fixed-point quantization of these internal nodes here.
//
template <int StateIntegerBits>
-void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
- const uint8* prev_activ_data_uint8,
- const Dims<4>& prev_activ_dims, const uint8* weights_data_uint8,
- const Dims<4>& weights_dims, const int32* bias_data_int32,
- const Dims<4>& bias_dims, const int16* prev_state_data_int16,
- const Dims<4>& prev_state_dims, int16* output_state_data_int16,
- const Dims<4>& output_state_dims, uint8* output_activ_data_uint8,
- const Dims<4>& output_activ_dims, uint8* concat_temp_data_uint8,
- const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16,
- const Dims<4>& activ_temp_dims, int32 weights_zero_point,
- int32 accum_multiplier, int accum_shift,
- gemmlowp::GemmContext* gemm_context) {
+inline void LstmCell(
+ const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
+ const uint8* input_data_uint8,
+ const RuntimeShape& unextended_prev_activ_shape,
+ const uint8* prev_activ_data_uint8, const RuntimeShape& weights_shape,
+ const uint8* weights_data_uint8, const RuntimeShape& unextended_bias_shape,
+ const int32* bias_data_int32,
+ const RuntimeShape& unextended_prev_state_shape,
+ const int16* prev_state_data_int16,
+ const RuntimeShape& unextended_output_state_shape,
+ int16* output_state_data_int16,
+ const RuntimeShape& unextended_output_activ_shape,
+ uint8* output_activ_data_uint8,
+ const RuntimeShape& unextended_concat_temp_shape,
+ uint8* concat_temp_data_uint8,
+ const RuntimeShape& unextended_activ_temp_shape,
+ int16* activ_temp_data_int16, gemmlowp::GemmContext* gemm_context) {
(void)gemm_context; // only used in optimized code.
+ int32 weights_zero_point = params.weights_zero_point;
+ int32 accum_multiplier = params.accum_multiplier;
+ int accum_shift = params.accum_shift;
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
+ const RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ const RuntimeShape prev_activ_shape =
+ RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
+ const RuntimeShape bias_shape =
+ RuntimeShape::ExtendedShape(4, unextended_bias_shape);
+ const RuntimeShape prev_state_shape =
+ RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
+ const RuntimeShape output_state_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
+ const RuntimeShape output_activ_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
+ const RuntimeShape concat_temp_shape =
+ RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
+ const RuntimeShape activ_temp_shape =
+ RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
+ TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
// Gather dimensions information, and perform consistency checks.
- const int outer_size =
- MatchingFlatSizeSkipDim(input_dims, 0, prev_activ_dims, prev_state_dims,
- output_state_dims, output_activ_dims);
- TFLITE_CHECK_EQ(ArraySize(weights_dims, 2), 1);
- TFLITE_CHECK_EQ(ArraySize(weights_dims, 3), 1);
- const int input_depth = ArraySize(input_dims, 0);
- const int prev_activ_depth = ArraySize(prev_activ_dims, 0);
+ const int weights_dim_count = weights_shape.DimensionsCount();
+ const int outer_size = MatchingFlatSizeSkipDim(
+ input_shape, 3, prev_activ_shape, prev_state_shape, output_state_shape,
+ output_activ_shape);
+ const int input_depth = input_shape.Dims(3);
+ const int prev_activ_depth = prev_activ_shape.Dims(3);
const int total_input_depth = prev_activ_depth + input_depth;
- TFLITE_CHECK_EQ(ArraySize(weights_dims, 0), total_input_depth);
- TFLITE_CHECK_EQ(MatchingArraySize(bias_dims, 1, bias_dims, 2, bias_dims, 3),
- 1);
+ TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
+ total_input_depth);
const int intern_activ_depth =
- MatchingArraySize(weights_dims, 1, bias_dims, 0);
- TFLITE_CHECK_EQ(intern_activ_depth % 4, 0);
+ MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
+ TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
+ intern_activ_depth * total_input_depth);
+ TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
+ TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
const int output_depth =
- MatchingArraySize(prev_state_dims, 0, prev_activ_dims, 0,
- output_state_dims, 0, output_activ_dims, 0);
- TFLITE_CHECK_EQ(output_depth, intern_activ_depth / 4);
- const int fc_batches = FlatSizeSkipDim(activ_temp_dims, 0);
+ MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
+ 3, output_activ_shape, 3);
+ TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
+ const int fc_batches = FlatSizeSkipDim(activ_temp_shape, 3);
const int fc_output_depth =
- MatchingArraySize(weights_dims, 1, activ_temp_dims, 0);
- const int fc_accum_depth = ArraySize(weights_dims, 0);
- TFLITE_CHECK_EQ(fc_output_depth, 4 * output_depth);
+ MatchingDim(weights_shape, weights_dim_count - 2, activ_temp_shape, 3);
+ const int fc_accum_depth = total_input_depth;
+ TFLITE_DCHECK_EQ(fc_output_depth, 4 * output_depth);
// Depth-concatenate prev_activ and input data together.
uint8 const* concat_input_arrays_data[2] = {input_data_uint8,
prev_activ_data_uint8};
- Dims<4> const* concat_input_arrays_dims[2] = {&input_dims, &prev_activ_dims};
- Concatenation<FusedActivationFunctionType::kNone, uint8>(
- 0, concat_input_arrays_data, concat_input_arrays_dims, 2,
- concat_temp_data_uint8, concat_temp_dims);
+ const RuntimeShape* concat_input_arrays_shapes[2] = {&input_shape,
+ &prev_activ_shape};
+ tflite::ConcatenationParams concat_params;
+ concat_params.axis = 3;
+ concat_params.inputs_count = 2;
+ Concatenation(concat_params, concat_input_arrays_shapes,
+ concat_input_arrays_data, concat_temp_shape,
+ concat_temp_data_uint8);
// Implementation of the fully connected node inside the LSTM cell.
// The operands are 8-bit integers, the accumulators are internally 32bit
@@ -2731,6 +2828,37 @@ void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+template <int StateIntegerBits>
+void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
+ const uint8* prev_activ_data_uint8,
+ const Dims<4>& prev_activ_dims, const uint8* weights_data_uint8,
+ const Dims<4>& weights_dims, const int32* bias_data_int32,
+ const Dims<4>& bias_dims, const int16* prev_state_data_int16,
+ const Dims<4>& prev_state_dims, int16* output_state_data_int16,
+ const Dims<4>& output_state_dims, uint8* output_activ_data_uint8,
+ const Dims<4>& output_activ_dims, uint8* concat_temp_data_uint8,
+ const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16,
+ const Dims<4>& activ_temp_dims, int32 weights_zero_point,
+ int32 accum_multiplier, int accum_shift,
+ gemmlowp::GemmContext* gemm_context) {
+ tflite::LstmCellParams op_params;
+ op_params.weights_zero_point = weights_zero_point;
+ op_params.accum_multiplier = accum_multiplier;
+ op_params.accum_shift = accum_shift;
+
+ LstmCell<StateIntegerBits>(
+ op_params, DimsToShape(input_dims), input_data_uint8,
+ DimsToShape(prev_activ_dims), prev_activ_data_uint8,
+ DimsToShape(weights_dims), weights_data_uint8, DimsToShape(bias_dims),
+ bias_data_int32, DimsToShape(prev_state_dims), prev_state_data_int16,
+ DimsToShape(output_state_dims), output_state_data_int16,
+ DimsToShape(output_activ_dims), output_activ_data_uint8,
+ DimsToShape(concat_temp_dims), concat_temp_data_uint8,
+ DimsToShape(activ_temp_dims), activ_temp_data_int16, gemm_context);
+}
+
template <typename Scalar>
void Split(const SplitParams& params, const RuntimeShape& input_shape,
const Scalar* input_data, const RuntimeShape* const* output_shapes,