diff options
author | 2018-09-04 08:18:41 -0700 | |
---|---|---|
committer | 2018-09-04 08:23:05 -0700 | |
commit | 70550f8e2914d3ff677b72db53437e4433ea7250 (patch) | |
tree | fd78aaa6c09c5c974e5382139cd2d8f415eaa6c7 | |
parent | 956159119b6c81ec500dc541f6f5ea3f776f2d0a (diff) |
Fix Split, convert kernel signature to use runtime shapes.
PiperOrigin-RevId: 211459453
3 files changed, 60 insertions, 22 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h index 70adffda3b..9b35648b4e 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h @@ -67,6 +67,7 @@ using reference_ops::Relu6; using reference_ops::ReluX; using reference_ops::Select; using reference_ops::SpaceToBatchND; +using reference_ops::Split; using reference_ops::StridedSlice; using reference_ops::Transpose; diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index 62f7ade7d5..e5b71f81fa 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -2524,32 +2524,69 @@ void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims, } template <typename Scalar> +void Split(const SplitParams& params, const RuntimeShape& input_shape, + const Scalar* input_data, const RuntimeShape* const* output_shapes, + Scalar* const* output_data) { + const int concat_dimensions = input_shape.DimensionsCount(); + int axis = params.axis < 0 ? params.axis + concat_dimensions : params.axis; + int outputs_count = params.num_split; + TFLITE_DCHECK_LT(axis, concat_dimensions); + + int64_t concat_size = 0; + for (int i = 0; i < outputs_count; i++) { + TFLITE_DCHECK_EQ(output_shapes[i]->DimensionsCount(), concat_dimensions); + for (int j = 0; j < concat_dimensions; j++) { + if (j != axis) { + MatchingDim(*output_shapes[i], j, input_shape, j); + } + } + concat_size += output_shapes[i]->Dims(axis); + } + TFLITE_DCHECK_EQ(concat_size, input_shape.Dims(axis)); + int64_t outer_size = 1; + for (int i = 0; i < axis; ++i) { + outer_size *= input_shape.Dims(i); + } + // For all output arrays, + // FlatSize() = outer_size * Dims(axis) * base_inner_size; + int64_t base_inner_size = 1; + for (int i = axis + 1; i < concat_dimensions; ++i) { + base_inner_size *= input_shape.Dims(i); + } + + const Scalar* input_ptr = input_data; + for (int k = 0; k < outer_size; k++) { + for (int i = 0; i < outputs_count; ++i) { + const int copy_size = output_shapes[i]->Dims(axis) * base_inner_size; + memcpy(output_data[i] + k * copy_size, input_ptr, + copy_size * sizeof(Scalar)); + input_ptr += copy_size; + } + } +} + +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy Dims<4>. +template <typename Scalar> void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims, int axis, int outputs_count, Scalar* const* output_data, const Dims<4>* const* output_dims) { - const int batches = ArraySize(*output_dims[0], 3); - const int height = ArraySize(*output_dims[0], 2); - const int width = ArraySize(*output_dims[0], 1); - const int depth = ArraySize(*output_dims[0], 0); - - const int slice_size = ArraySize(*output_dims[0], axis); - + std::vector<RuntimeShape> output_shapes(outputs_count); + std::vector<const RuntimeShape*> output_shapes_indirect(outputs_count); for (int i = 0; i < outputs_count; ++i) { - int offset = i * slice_size * input_dims.strides[axis]; - for (int b = 0; b < batches; ++b) { - for (int y = 0; y < height; ++y) { - for (int x = 0; x < width; ++x) { - for (int c = 0; c < depth; ++c) { - auto out = Offset(*output_dims[i], c, x, y, b); - auto in = Offset(input_dims, c, x, y, b); - output_data[i][out] = input_data[offset + in]; - } - } - } - } + ShapeFromDims(*output_dims[i], &output_shapes[i]); + output_shapes_indirect[i] = &output_shapes[i]; } + tflite::SplitParams op_params; + op_params.axis = 3 - axis; + op_params.num_split = outputs_count; + + Split(op_params, DimsToShape(input_dims), input_data, + output_shapes_indirect.data(), output_data); } +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy Dims<4>. template <FusedActivationFunctionType Ac, typename Scalar> void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims, int outputs_count, Scalar* const* output_data, @@ -2560,9 +2597,8 @@ void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims, /* height = */ MatchingArraySize(*output_dims[i], 2, input_dims, 2); /* width = */ MatchingArraySize(*output_dims[i], 1, input_dims, 1); } - // for now we dont have a model with a TensorFlowSplit - // with fused activation function. - TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); + // For now we don't have a model with a Split with fused activation. + TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone); TensorFlowSplit(input_data, input_dims, /*axis=*/0, outputs_count, output_data, output_dims); diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h index 3b296f024f..6ae4ebc79e 100644 --- a/tensorflow/contrib/lite/kernels/internal/types.h +++ b/tensorflow/contrib/lite/kernels/internal/types.h @@ -889,6 +889,7 @@ struct SplitParams { // Graphs that split into, say, 2000 nodes are encountered. The indices in // OperatorEdges are of type uint16. uint16 num_split; + int16 axis; }; struct SqueezeParams { |