diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-15 13:59:05 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-15 14:03:06 -0700 |
commit | b4e6098db6e707bccdb3ea9027365ddd9b38fb72 (patch) | |
tree | 9a6a1a013788f6a5a8f7ac69d6af5aae03bb26cc /tensorflow/contrib/lite/kernels/internal | |
parent | 6b13a15e0b9906abbb66f87d83db291d0099cb43 (diff) |
Move params to kernels/internal/types.h.
PiperOrigin-RevId: 208877784
Diffstat (limited to 'tensorflow/contrib/lite/kernels/internal')
-rw-r--r-- | tensorflow/contrib/lite/kernels/internal/types.h | 212 |
1 files changed, 197 insertions, 15 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h index 7b6838db53..204df9ab19 100644 --- a/tensorflow/contrib/lite/kernels/internal/types.h +++ b/tensorflow/contrib/lite/kernels/internal/types.h @@ -660,6 +660,19 @@ enum class BroadcastableOpCategory : uint8 { kGenericBroadcast, // Fall-back. }; +struct MinMax { + float min; + float max; +}; +static_assert(sizeof(MinMax) == 8, ""); + +struct ActivationParams { + FusedActivationFunctionType activation_type; + // Quantized inference params. + int32 activation_min; + int32 activation_max; +}; + // For Add, Sub, Mul ops. struct ArithmeticParams { // Shape dependent / common to data / op types. @@ -695,29 +708,122 @@ struct ArithmeticParams { int broadcast_shape[5]; }; -template <typename T> -inline void SetActivationParams(T min, T max, ArithmeticParams* params); +struct ConcatenationParams { + int8 axis; +}; -template <> -inline void SetActivationParams(float min, float max, - ArithmeticParams* params) { - params->float_activation_min = min; - params->float_activation_max = max; -} +struct ComparisonParams { + // uint8 inference params. + int left_shift; + int32 input0_offset; + int32 input0_multiplier; + int input0_shift; + int32 input1_offset; + int32 input1_multiplier; + int input1_shift; + // Shape dependent / common to inference types. + bool is_broadcast; +}; -template <> -inline void SetActivationParams(int32 min, int32 max, - ArithmeticParams* params) { - params->quantized_activation_min = min; - params->quantized_activation_max = max; -} +struct ConvParams { + PaddingType padding_type; + PaddingValues padding_values; + // TODO(starka): This was just "stride", so check that width+height is OK. + int8 stride_width; + int8 stride_height; + int8 dilation_width_factor; + int8 dilation_height_factor; + // uint8 inference params. + // TODO(b/65838351): Use smaller types if appropriate. + int32 input_offset; + int32 weights_offset; + int32 output_offset; + int32 output_multiplier; + int output_shift; + int32 output_activation_min; + int32 output_activation_max; +}; + +struct DepthToSpaceParams { + int16 block_size; +}; + +struct DepthwiseParams { + PaddingType padding_type; + PaddingValues padding_values; + int8 stride; + int8 depth_multiplier; + // uint8 inference params. + // TODO(b/65838351): Use smaller types if appropriate. + int32 input_offset; + int32 weights_offset; + int32 output_offset; + int32 output_multiplier; + int output_shift; + int32 output_activation_min; + int32 output_activation_max; +}; + +struct FakeQuantParams { + MinMax minmax; + int32 num_bits; +}; + +struct FullyConnectedParams { + // uint8 inference params. + // TODO(b/65838351): Use smaller types if appropriate. + int32 input_offset; + int32 weights_offset; + int32 output_offset; + int32 output_multiplier; + int output_shift; + int32 output_activation_min; + int32 output_activation_max; + FullyConnectedWeightsFormat weights_format; +}; + +struct GatherParams { + int8 input_rank; + int16 axis; +}; + +struct L2NormalizationParams { + // uint8 inference params. + int32 input_zero_point; +}; + +struct LocalResponseNormalizationParams { + int32 range; + double bias; + double alpha; + double beta; +}; + +struct LogisticParams { + // uint8 inference params. + int32 input_zero_point; + int32 input_range_radius; + int32 input_multiplier; + int input_left_shift; +}; + +struct LstmCellParams { + int32 weights_zero_point; + int32 accum_multiplier; + int accum_shift; + int state_integer_bits; +}; + +struct MeanParams { + int8 axis_count; + int16 axis[4]; +}; struct PadParams { int8 left_padding_count; int32 left_padding[4]; int8 right_padding_count; int32 right_padding[4]; - // FloatOrInt pad_value; }; struct PoolParams { @@ -736,6 +842,15 @@ struct PoolParams { float float_activation_max; }; +struct ReshapeParams { + int8 shape_count; + int32 shape[4]; +}; + +struct ResizeBilinearParams { + bool align_corners; +}; + struct SliceParams { int8 begin_count; int32 begin[4]; @@ -743,6 +858,73 @@ struct SliceParams { int32 size[4]; }; +struct SoftmaxParams { + // beta is not really used (not a Tensorflow parameter) and not implemented + // for LogSoftmax. + double beta; + // uint8 inference params. Used even when beta defaults to 1.0. + int32 input_beta_multiplier; + int32 input_beta_left_shift; + // Reverse scaling is only used by LogSoftmax. + int32 reverse_scaling_divisor; + int32 reverse_scaling_right_shift; + int diff_min; +}; + +struct SpaceToDepthParams { + int16 block_size; +}; + +struct SplitParams { + // Graphs that split into, say, 2000 nodes are encountered. The indices in + // OperatorEdges are of type uint16. + uint16 num_split; +}; + +struct SqueezeParams { + int8 squeeze_dims_count; + int32 squeeze_dims[4]; +}; + +struct StridedSliceParams { + int8 start_indices_count; + int16 start_indices[4]; + int8 stop_indices_count; + int16 stop_indices[4]; + int8 strides_count; + int16 strides[4]; + + int16 begin_mask; + int16 ellipsis_mask; + int16 end_mask; + int16 new_axis_mask; + int16 shrink_axis_mask; +}; + +struct TanhParams { + int32 input_zero_point; + int32 input_range_radius; + int32 input_multiplier; + int input_left_shift; +}; + +template <typename T> +inline void SetActivationParams(T min, T max, ArithmeticParams* params); + +template <> +inline void SetActivationParams(float min, float max, + ArithmeticParams* params) { + params->float_activation_min = min; + params->float_activation_max = max; +} + +template <> +inline void SetActivationParams(int32 min, int32 max, + ArithmeticParams* params) { + params->quantized_activation_min = min; + params->quantized_activation_max = max; +} + } // namespace tflite #endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_ |