aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-17 12:20:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-17 12:26:06 -0700
commit0d9868d8f9c01c1402ae99d672599c4bac6e787d (patch)
treea3d0570efa9706ae4e84b0cee31a7a04addfeb40 /tensorflow/contrib/lite/kernels
parent779d87cfc1421eb6be2f9cc4ae29bca77c8d2929 (diff)
Convert more kernel signatures to use runtime shapes.
PiperOrigin-RevId: 213316034
Diffstat (limited to 'tensorflow/contrib/lite/kernels')
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h215
1 files changed, 165 insertions, 50 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index 77927af227..09a4ba7701 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -511,24 +511,25 @@ inline void SpaceToDepth(const tflite::SpaceToDepthParams& op_params,
}
}
-inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,
- const float* weights_data,
- const Dims<4>& weights_dims, const float* bias_data,
- const Dims<4>& bias_dims,
- float output_activation_min,
- float output_activation_max, float* output_data,
- const Dims<4>& output_dims) {
+inline void FullyConnected(
+ const FullyConnectedParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& weights_shape,
+ const float* weights_data, const RuntimeShape& bias_shape,
+ const float* bias_data, const RuntimeShape& output_shape,
+ float* output_data) {
+ const float output_activation_min = params.float_activation_min;
+ const float output_activation_max = params.float_activation_max;
// TODO(benoitjacob): This really should be:
// const int batches = ArraySize(output_dims, 1);
// but the current --variable_batch hack consists in overwriting the 3rd
// dimension with the runtime batch size, as we don't keep track for each
// array of which dimension is the batch dimension in it.
- const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) *
- ArraySize(output_dims, 3);
- const int output_depth = MatchingArraySize(weights_dims, 1, output_dims, 0);
- const int accum_depth = ArraySize(weights_dims, 0);
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(weights_dims));
+ const int output_dims_count = output_shape.DimensionsCount();
+ const int weights_dims_count = weights_shape.DimensionsCount();
+ const int batches = FlatSizeSkipDim(output_shape, output_dims_count - 1);
+ const int output_depth = MatchingDim(weights_shape, weights_dims_count - 2,
+ output_shape, output_dims_count - 1);
+ const int accum_depth = weights_shape.Dims(weights_dims_count - 1);
for (int b = 0; b < batches; ++b) {
for (int out_c = 0; out_c < output_depth; ++out_c) {
float total = 0.f;
@@ -538,7 +539,7 @@ inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,
}
float bias_value = 0.0f;
if (bias_data) {
- bias_value = bias_data[Offset(bias_dims, out_c, 0, 0, 0)];
+ bias_value = bias_data[out_c];
}
output_data[out_c + output_depth * b] = ActivationFunctionWithMinMax(
total + bias_value, output_activation_min, output_activation_max);
@@ -546,6 +547,26 @@ inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,
+ const float* weights_data,
+ const Dims<4>& weights_dims, const float* bias_data,
+ const Dims<4>& bias_dims,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ tflite::FullyConnectedParams op_params;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+
+ FullyConnected(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(weights_dims), weights_data,
+ DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
+ output_data);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
// legacy, for compatibility with old checked-in code
template <FusedActivationFunctionType Ac>
void FullyConnected(const float* input_data, const Dims<4>& input_dims,
@@ -559,28 +580,35 @@ void FullyConnected(const float* input_data, const Dims<4>& input_dims,
output_data, output_dims);
}
-inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims,
- gemmlowp::GemmContext* gemm_context) {
+inline void FullyConnected(
+ const FullyConnectedParams& params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& filter_shape,
+ const uint8* filter_data, const RuntimeShape& bias_shape,
+ const int32* bias_data, const RuntimeShape& output_shape,
+ uint8* output_data, gemmlowp::GemmContext* gemm_context) {
(void)gemm_context; // only used in optimized code.
+ const int32 input_offset = params.input_offset;
+ const int32 filter_offset = params.weights_offset;
+ const int32 output_offset = params.output_offset;
+ const int32 output_multiplier = params.output_multiplier;
+ const int output_shift = params.output_shift;
+ const int32 output_activation_min = params.quantized_activation_min;
+ const int32 output_activation_max = params.quantized_activation_max;
+ TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
+ TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
+
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
// TODO(benoitjacob): This really should be:
// const int batches = ArraySize(output_dims, 1);
// but the current --variable_batch hack consists in overwriting the 3rd
// dimension with the runtime batch size, as we don't keep track for each
// array of which dimension is the batch dimension in it.
- const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) *
- ArraySize(output_dims, 3);
- const int output_depth = MatchingArraySize(filter_dims, 1, output_dims, 0);
- const int accum_depth = ArraySize(filter_dims, 0);
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims));
+ const int output_dim_count = output_shape.DimensionsCount();
+ const int filter_dim_count = filter_shape.DimensionsCount();
+ const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
+ const int output_depth = MatchingDim(filter_shape, filter_dim_count - 2,
+ output_shape, output_dim_count - 1);
+ const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
for (int b = 0; b < batches; ++b) {
for (int out_c = 0; out_c < output_depth; ++out_c) {
int32 acc = 0;
@@ -590,7 +618,7 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
acc += (filter_val + filter_offset) * (input_val + input_offset);
}
if (bias_data) {
- acc += bias_data[Offset(bias_dims, out_c, 0, 0, 0)];
+ acc += bias_data[out_c];
}
acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
kReverseShift * output_shift);
@@ -602,16 +630,47 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
int32 input_offset, const uint8* filter_data,
const Dims<4>& filter_dims, int32 filter_offset,
const int32* bias_data, const Dims<4>& bias_dims,
int32 output_offset, int32 output_multiplier,
int output_shift, int32 output_activation_min,
- int32 output_activation_max, int16* output_data,
+ int32 output_activation_max, uint8* output_data,
const Dims<4>& output_dims,
gemmlowp::GemmContext* gemm_context) {
+ tflite::FullyConnectedParams op_params;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ op_params.output_shift = output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ FullyConnected(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
+ bias_data, DimsToShape(output_dims), output_data,
+ gemm_context);
+}
+
+inline void FullyConnected(
+ const FullyConnectedParams& params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& filter_shape,
+ const uint8* filter_data, const RuntimeShape& bias_shape,
+ const int32* bias_data, const RuntimeShape& output_shape,
+ int16* output_data, gemmlowp::GemmContext* gemm_context) {
(void)gemm_context; // only used in optimized code.
+ const int32 input_offset = params.input_offset;
+ const int32 filter_offset = params.weights_offset;
+ const int32 output_offset = params.output_offset;
+ const int32 output_multiplier = params.output_multiplier;
+ const int output_shift = params.output_shift;
+ const int32 output_activation_min = params.quantized_activation_min;
+ const int32 output_activation_max = params.quantized_activation_max;
+
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
TFLITE_DCHECK_EQ(output_offset, 0);
// TODO(benoitjacob): This really should be:
@@ -619,12 +678,12 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
// but the current --variable_batch hack consists in overwriting the 3rd
// dimension with the runtime batch size, as we don't keep track for each
// array of which dimension is the batch dimension in it.
- const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) *
- ArraySize(output_dims, 3);
- const int output_depth = MatchingArraySize(filter_dims, 1, output_dims, 0);
- const int accum_depth = ArraySize(filter_dims, 0);
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims));
+ const int output_dim_count = output_shape.DimensionsCount();
+ const int filter_dim_count = filter_shape.DimensionsCount();
+ const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
+ const int output_depth = MatchingDim(filter_shape, filter_dim_count - 2,
+ output_shape, output_dim_count - 1);
+ const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
for (int b = 0; b < batches; ++b) {
for (int out_c = 0; out_c < output_depth; ++out_c) {
// Internal accumulation.
@@ -651,27 +710,60 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, int16* output_data,
+ const Dims<4>& output_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ tflite::FullyConnectedParams op_params;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ op_params.output_shift = output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ FullyConnected(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
+ bias_data, DimsToShape(output_dims), output_data,
+ gemm_context);
+}
+
inline void ShuffledFullyConnected(
- const uint8* input_data, const Dims<4>& input_dims,
- const uint8* shuffled_weights_data, const Dims<4>& weights_dims,
- const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier,
- int output_shift, int32 output_activation_min, int32 output_activation_max,
- int16* output_data, const Dims<4>& output_dims,
- uint8* shuffled_input_workspace_data, gemmlowp::GemmContext* gemm_context) {
+ const FullyConnectedParams& params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& weights_shape,
+ const uint8* shuffled_weights_data, const RuntimeShape& bias_shape,
+ const int32* bias_data, const RuntimeShape& output_shape,
+ int16* output_data, uint8* shuffled_input_workspace_data,
+ gemmlowp::GemmContext* gemm_context) {
(void)gemm_context; // only used in optimized code.
-
+ const int32 output_multiplier = params.output_multiplier;
+ const int output_shift = params.output_shift;
+ const int32 output_activation_min = params.quantized_activation_min;
+ const int32 output_activation_max = params.quantized_activation_max;
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+
+ TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
+ TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
+ TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
// TODO(benoitjacob): This really should be:
// const int batches = ArraySize(output_dims, 1);
// but the current --variable_batch hack consists in overwriting the 3rd
// dimension with the runtime batch size, as we don't keep track for each
// array of which dimension is the batch dimension in it.
- const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) *
- ArraySize(output_dims, 3);
- const int output_depth = MatchingArraySize(weights_dims, 1, output_dims, 0);
- const int accum_depth = ArraySize(weights_dims, 0);
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(weights_dims));
+ const int output_dim_count = output_shape.DimensionsCount();
+ const int weights_dim_count = weights_shape.DimensionsCount();
+ const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
+ const int output_depth = MatchingDim(weights_shape, weights_dim_count - 2,
+ output_shape, output_dim_count - 1);
+ const int accum_depth = weights_shape.Dims(weights_dim_count - 1);
TFLITE_DCHECK((accum_depth % 16) == 0);
TFLITE_DCHECK((output_depth % 4) == 0);
@@ -799,6 +891,29 @@ inline void ShuffledFullyConnected(
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void ShuffledFullyConnected(
+ const uint8* input_data, const Dims<4>& input_dims,
+ const uint8* shuffled_weights_data, const Dims<4>& weights_dims,
+ const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier,
+ int output_shift, int32 output_activation_min, int32 output_activation_max,
+ int16* output_data, const Dims<4>& output_dims,
+ uint8* shuffled_input_workspace_data, gemmlowp::GemmContext* gemm_context) {
+ tflite::FullyConnectedParams op_params;
+ op_params.output_multiplier = output_multiplier;
+ op_params.output_shift = output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ ShuffledFullyConnected(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(weights_dims), shuffled_weights_data,
+ DimsToShape(bias_dims), bias_data,
+ DimsToShape(output_dims), output_data,
+ shuffled_input_workspace_data, gemm_context);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
// legacy, for compatibility with old checked-in code
template <FusedActivationFunctionType Ac>
void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,