aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/internal
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-15 13:59:05 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-15 14:03:06 -0700
commitb4e6098db6e707bccdb3ea9027365ddd9b38fb72 (patch)
tree9a6a1a013788f6a5a8f7ac69d6af5aae03bb26cc /tensorflow/contrib/lite/kernels/internal
parent6b13a15e0b9906abbb66f87d83db291d0099cb43 (diff)
Move params to kernels/internal/types.h.
PiperOrigin-RevId: 208877784
Diffstat (limited to 'tensorflow/contrib/lite/kernels/internal')
-rw-r--r--tensorflow/contrib/lite/kernels/internal/types.h212
1 files changed, 197 insertions, 15 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h
index 7b6838db53..204df9ab19 100644
--- a/tensorflow/contrib/lite/kernels/internal/types.h
+++ b/tensorflow/contrib/lite/kernels/internal/types.h
@@ -660,6 +660,19 @@ enum class BroadcastableOpCategory : uint8 {
kGenericBroadcast, // Fall-back.
};
+struct MinMax {
+ float min;
+ float max;
+};
+static_assert(sizeof(MinMax) == 8, "");
+
+struct ActivationParams {
+ FusedActivationFunctionType activation_type;
+ // Quantized inference params.
+ int32 activation_min;
+ int32 activation_max;
+};
+
// For Add, Sub, Mul ops.
struct ArithmeticParams {
// Shape dependent / common to data / op types.
@@ -695,29 +708,122 @@ struct ArithmeticParams {
int broadcast_shape[5];
};
-template <typename T>
-inline void SetActivationParams(T min, T max, ArithmeticParams* params);
+struct ConcatenationParams {
+ int8 axis;
+};
-template <>
-inline void SetActivationParams(float min, float max,
- ArithmeticParams* params) {
- params->float_activation_min = min;
- params->float_activation_max = max;
-}
+struct ComparisonParams {
+ // uint8 inference params.
+ int left_shift;
+ int32 input0_offset;
+ int32 input0_multiplier;
+ int input0_shift;
+ int32 input1_offset;
+ int32 input1_multiplier;
+ int input1_shift;
+ // Shape dependent / common to inference types.
+ bool is_broadcast;
+};
-template <>
-inline void SetActivationParams(int32 min, int32 max,
- ArithmeticParams* params) {
- params->quantized_activation_min = min;
- params->quantized_activation_max = max;
-}
+struct ConvParams {
+ PaddingType padding_type;
+ PaddingValues padding_values;
+ // TODO(starka): This was just "stride", so check that width+height is OK.
+ int8 stride_width;
+ int8 stride_height;
+ int8 dilation_width_factor;
+ int8 dilation_height_factor;
+ // uint8 inference params.
+ // TODO(b/65838351): Use smaller types if appropriate.
+ int32 input_offset;
+ int32 weights_offset;
+ int32 output_offset;
+ int32 output_multiplier;
+ int output_shift;
+ int32 output_activation_min;
+ int32 output_activation_max;
+};
+
+struct DepthToSpaceParams {
+ int16 block_size;
+};
+
+struct DepthwiseParams {
+ PaddingType padding_type;
+ PaddingValues padding_values;
+ int8 stride;
+ int8 depth_multiplier;
+ // uint8 inference params.
+ // TODO(b/65838351): Use smaller types if appropriate.
+ int32 input_offset;
+ int32 weights_offset;
+ int32 output_offset;
+ int32 output_multiplier;
+ int output_shift;
+ int32 output_activation_min;
+ int32 output_activation_max;
+};
+
+struct FakeQuantParams {
+ MinMax minmax;
+ int32 num_bits;
+};
+
+struct FullyConnectedParams {
+ // uint8 inference params.
+ // TODO(b/65838351): Use smaller types if appropriate.
+ int32 input_offset;
+ int32 weights_offset;
+ int32 output_offset;
+ int32 output_multiplier;
+ int output_shift;
+ int32 output_activation_min;
+ int32 output_activation_max;
+ FullyConnectedWeightsFormat weights_format;
+};
+
+struct GatherParams {
+ int8 input_rank;
+ int16 axis;
+};
+
+struct L2NormalizationParams {
+ // uint8 inference params.
+ int32 input_zero_point;
+};
+
+struct LocalResponseNormalizationParams {
+ int32 range;
+ double bias;
+ double alpha;
+ double beta;
+};
+
+struct LogisticParams {
+ // uint8 inference params.
+ int32 input_zero_point;
+ int32 input_range_radius;
+ int32 input_multiplier;
+ int input_left_shift;
+};
+
+struct LstmCellParams {
+ int32 weights_zero_point;
+ int32 accum_multiplier;
+ int accum_shift;
+ int state_integer_bits;
+};
+
+struct MeanParams {
+ int8 axis_count;
+ int16 axis[4];
+};
struct PadParams {
int8 left_padding_count;
int32 left_padding[4];
int8 right_padding_count;
int32 right_padding[4];
- // FloatOrInt pad_value;
};
struct PoolParams {
@@ -736,6 +842,15 @@ struct PoolParams {
float float_activation_max;
};
+struct ReshapeParams {
+ int8 shape_count;
+ int32 shape[4];
+};
+
+struct ResizeBilinearParams {
+ bool align_corners;
+};
+
struct SliceParams {
int8 begin_count;
int32 begin[4];
@@ -743,6 +858,73 @@ struct SliceParams {
int32 size[4];
};
+struct SoftmaxParams {
+ // beta is not really used (not a Tensorflow parameter) and not implemented
+ // for LogSoftmax.
+ double beta;
+ // uint8 inference params. Used even when beta defaults to 1.0.
+ int32 input_beta_multiplier;
+ int32 input_beta_left_shift;
+ // Reverse scaling is only used by LogSoftmax.
+ int32 reverse_scaling_divisor;
+ int32 reverse_scaling_right_shift;
+ int diff_min;
+};
+
+struct SpaceToDepthParams {
+ int16 block_size;
+};
+
+struct SplitParams {
+ // Graphs that split into, say, 2000 nodes are encountered. The indices in
+ // OperatorEdges are of type uint16.
+ uint16 num_split;
+};
+
+struct SqueezeParams {
+ int8 squeeze_dims_count;
+ int32 squeeze_dims[4];
+};
+
+struct StridedSliceParams {
+ int8 start_indices_count;
+ int16 start_indices[4];
+ int8 stop_indices_count;
+ int16 stop_indices[4];
+ int8 strides_count;
+ int16 strides[4];
+
+ int16 begin_mask;
+ int16 ellipsis_mask;
+ int16 end_mask;
+ int16 new_axis_mask;
+ int16 shrink_axis_mask;
+};
+
+struct TanhParams {
+ int32 input_zero_point;
+ int32 input_range_radius;
+ int32 input_multiplier;
+ int input_left_shift;
+};
+
+template <typename T>
+inline void SetActivationParams(T min, T max, ArithmeticParams* params);
+
+template <>
+inline void SetActivationParams(float min, float max,
+ ArithmeticParams* params) {
+ params->float_activation_min = min;
+ params->float_activation_max = max;
+}
+
+template <>
+inline void SetActivationParams(int32 min, int32 max,
+ ArithmeticParams* params) {
+ params->quantized_activation_min = min;
+ params->quantized_activation_max = max;
+}
+
} // namespace tflite
#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_