aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-04 08:18:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-04 08:23:05 -0700
commit70550f8e2914d3ff677b72db53437e4433ea7250 (patch)
treefd78aaa6c09c5c974e5382139cd2d8f415eaa6c7
parent956159119b6c81ec500dc541f6f5ea3f776f2d0a (diff)
Fix Split, convert kernel signature to use runtime shapes.
PiperOrigin-RevId: 211459453
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h1
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h80
-rw-r--r--tensorflow/contrib/lite/kernels/internal/types.h1
3 files changed, 60 insertions, 22 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
index 70adffda3b..9b35648b4e 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
@@ -67,6 +67,7 @@ using reference_ops::Relu6;
using reference_ops::ReluX;
using reference_ops::Select;
using reference_ops::SpaceToBatchND;
+using reference_ops::Split;
using reference_ops::StridedSlice;
using reference_ops::Transpose;
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index 62f7ade7d5..e5b71f81fa 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -2524,32 +2524,69 @@ void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
}
template <typename Scalar>
+void Split(const SplitParams& params, const RuntimeShape& input_shape,
+ const Scalar* input_data, const RuntimeShape* const* output_shapes,
+ Scalar* const* output_data) {
+ const int concat_dimensions = input_shape.DimensionsCount();
+ int axis = params.axis < 0 ? params.axis + concat_dimensions : params.axis;
+ int outputs_count = params.num_split;
+ TFLITE_DCHECK_LT(axis, concat_dimensions);
+
+ int64_t concat_size = 0;
+ for (int i = 0; i < outputs_count; i++) {
+ TFLITE_DCHECK_EQ(output_shapes[i]->DimensionsCount(), concat_dimensions);
+ for (int j = 0; j < concat_dimensions; j++) {
+ if (j != axis) {
+ MatchingDim(*output_shapes[i], j, input_shape, j);
+ }
+ }
+ concat_size += output_shapes[i]->Dims(axis);
+ }
+ TFLITE_DCHECK_EQ(concat_size, input_shape.Dims(axis));
+ int64_t outer_size = 1;
+ for (int i = 0; i < axis; ++i) {
+ outer_size *= input_shape.Dims(i);
+ }
+ // For all output arrays,
+ // FlatSize() = outer_size * Dims(axis) * base_inner_size;
+ int64_t base_inner_size = 1;
+ for (int i = axis + 1; i < concat_dimensions; ++i) {
+ base_inner_size *= input_shape.Dims(i);
+ }
+
+ const Scalar* input_ptr = input_data;
+ for (int k = 0; k < outer_size; k++) {
+ for (int i = 0; i < outputs_count; ++i) {
+ const int copy_size = output_shapes[i]->Dims(axis) * base_inner_size;
+ memcpy(output_data[i] + k * copy_size, input_ptr,
+ copy_size * sizeof(Scalar));
+ input_ptr += copy_size;
+ }
+ }
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy Dims<4>.
+template <typename Scalar>
void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims,
int axis, int outputs_count, Scalar* const* output_data,
const Dims<4>* const* output_dims) {
- const int batches = ArraySize(*output_dims[0], 3);
- const int height = ArraySize(*output_dims[0], 2);
- const int width = ArraySize(*output_dims[0], 1);
- const int depth = ArraySize(*output_dims[0], 0);
-
- const int slice_size = ArraySize(*output_dims[0], axis);
-
+ std::vector<RuntimeShape> output_shapes(outputs_count);
+ std::vector<const RuntimeShape*> output_shapes_indirect(outputs_count);
for (int i = 0; i < outputs_count; ++i) {
- int offset = i * slice_size * input_dims.strides[axis];
- for (int b = 0; b < batches; ++b) {
- for (int y = 0; y < height; ++y) {
- for (int x = 0; x < width; ++x) {
- for (int c = 0; c < depth; ++c) {
- auto out = Offset(*output_dims[i], c, x, y, b);
- auto in = Offset(input_dims, c, x, y, b);
- output_data[i][out] = input_data[offset + in];
- }
- }
- }
- }
+ ShapeFromDims(*output_dims[i], &output_shapes[i]);
+ output_shapes_indirect[i] = &output_shapes[i];
}
+ tflite::SplitParams op_params;
+ op_params.axis = 3 - axis;
+ op_params.num_split = outputs_count;
+
+ Split(op_params, DimsToShape(input_dims), input_data,
+ output_shapes_indirect.data(), output_data);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy Dims<4>.
template <FusedActivationFunctionType Ac, typename Scalar>
void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims,
int outputs_count, Scalar* const* output_data,
@@ -2560,9 +2597,8 @@ void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims,
/* height = */ MatchingArraySize(*output_dims[i], 2, input_dims, 2);
/* width = */ MatchingArraySize(*output_dims[i], 1, input_dims, 1);
}
- // for now we dont have a model with a TensorFlowSplit
- // with fused activation function.
- TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
+ // For now we don't have a model with a Split with fused activation.
+ TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone);
TensorFlowSplit(input_data, input_dims, /*axis=*/0, outputs_count,
output_data, output_dims);
diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h
index 3b296f024f..6ae4ebc79e 100644
--- a/tensorflow/contrib/lite/kernels/internal/types.h
+++ b/tensorflow/contrib/lite/kernels/internal/types.h
@@ -889,6 +889,7 @@ struct SplitParams {
// Graphs that split into, say, 2000 nodes are encountered. The indices in
// OperatorEdges are of type uint16.
uint16 num_split;
+ int16 axis;
};
struct SqueezeParams {