diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-27 06:12:59 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-27 06:16:42 -0700 |
commit | abf26356209cba1ba895a06d9ce55ad01dad7fc6 (patch) | |
tree | 5ef1c907a30bf89d08ba241ef985b19938427420 /tensorflow/contrib/lite/kernels/split.cc | |
parent | 19d8963bc0ea64e10ff08ad4e7cc76813a182196 (diff) |
Update kernel evals to use new kernel signatures.
PiperOrigin-RevId: 214763814
Diffstat (limited to 'tensorflow/contrib/lite/kernels/split.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/split.cc | 27 |
1 files changed, 13 insertions, 14 deletions
diff --git a/tensorflow/contrib/lite/kernels/split.cc b/tensorflow/contrib/lite/kernels/split.cc index 719e2dc606..dab887bf9c 100644 --- a/tensorflow/contrib/lite/kernels/split.cc +++ b/tensorflow/contrib/lite/kernels/split.cc @@ -109,25 +109,24 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { if (axis_value < 0) { axis_value += NumDimensions(op_context.input); } - axis_value = RemapDim(NumDimensions(op_context.input), axis_value); // TODO(ahentz): Our usage of VectorOfTensors could be optimized by // calculating it in Prepare, unless we defer shape calculation. // TODO(ahentz): We can improve the optimized_ops version to handle other // cases too. -#define TF_LITE_SPLIT(scalar) \ - VectorOfTensors<scalar> all_outputs(*context, *node->outputs); \ - if (axis_value == NumDimensions(op_context.input)) { \ - optimized_ops::TensorFlowSplit<FusedActivationFunctionType::kNone, \ - scalar>( \ - GetTensorData<scalar>(op_context.input), \ - GetTensorDims(op_context.input), NumOutputs(node), all_outputs.data(), \ - all_outputs.dims()); \ - } else { \ - reference_ops::TensorFlowSplit<scalar>( \ - GetTensorData<scalar>(op_context.input), \ - GetTensorDims(op_context.input), axis_value, NumOutputs(node), \ - all_outputs.data(), all_outputs.dims()); \ +#define TF_LITE_SPLIT(scalar) \ + VectorOfTensors<scalar> all_outputs(*context, *node->outputs); \ + tflite::SplitParams op_params; \ + op_params.num_split = NumOutputs(node); \ + op_params.axis = axis_value; \ + if (axis_value == 0) { \ + optimized_ops::Split(op_params, GetTensorShape(op_context.input), \ + GetTensorData<scalar>(op_context.input), \ + all_outputs.shapes(), all_outputs.data()); \ + } else { \ + reference_ops::Split(op_params, GetTensorShape(op_context.input), \ + GetTensorData<scalar>(op_context.input), \ + all_outputs.shapes(), all_outputs.data()); \ } switch (op_context.input->type) { case kTfLiteFloat32: { |