aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/split.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-27 06:12:59 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-27 06:16:42 -0700
commitabf26356209cba1ba895a06d9ce55ad01dad7fc6 (patch)
tree5ef1c907a30bf89d08ba241ef985b19938427420 /tensorflow/contrib/lite/kernels/split.cc
parent19d8963bc0ea64e10ff08ad4e7cc76813a182196 (diff)
Update kernel evals to use new kernel signatures.
PiperOrigin-RevId: 214763814
Diffstat (limited to 'tensorflow/contrib/lite/kernels/split.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/split.cc27
1 files changed, 13 insertions, 14 deletions
diff --git a/tensorflow/contrib/lite/kernels/split.cc b/tensorflow/contrib/lite/kernels/split.cc
index 719e2dc606..dab887bf9c 100644
--- a/tensorflow/contrib/lite/kernels/split.cc
+++ b/tensorflow/contrib/lite/kernels/split.cc
@@ -109,25 +109,24 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
if (axis_value < 0) {
axis_value += NumDimensions(op_context.input);
}
- axis_value = RemapDim(NumDimensions(op_context.input), axis_value);
// TODO(ahentz): Our usage of VectorOfTensors could be optimized by
// calculating it in Prepare, unless we defer shape calculation.
// TODO(ahentz): We can improve the optimized_ops version to handle other
// cases too.
-#define TF_LITE_SPLIT(scalar) \
- VectorOfTensors<scalar> all_outputs(*context, *node->outputs); \
- if (axis_value == NumDimensions(op_context.input)) { \
- optimized_ops::TensorFlowSplit<FusedActivationFunctionType::kNone, \
- scalar>( \
- GetTensorData<scalar>(op_context.input), \
- GetTensorDims(op_context.input), NumOutputs(node), all_outputs.data(), \
- all_outputs.dims()); \
- } else { \
- reference_ops::TensorFlowSplit<scalar>( \
- GetTensorData<scalar>(op_context.input), \
- GetTensorDims(op_context.input), axis_value, NumOutputs(node), \
- all_outputs.data(), all_outputs.dims()); \
+#define TF_LITE_SPLIT(scalar) \
+ VectorOfTensors<scalar> all_outputs(*context, *node->outputs); \
+ tflite::SplitParams op_params; \
+ op_params.num_split = NumOutputs(node); \
+ op_params.axis = axis_value; \
+ if (axis_value == 0) { \
+ optimized_ops::Split(op_params, GetTensorShape(op_context.input), \
+ GetTensorData<scalar>(op_context.input), \
+ all_outputs.shapes(), all_outputs.data()); \
+ } else { \
+ reference_ops::Split(op_params, GetTensorShape(op_context.input), \
+ GetTensorData<scalar>(op_context.input), \
+ all_outputs.shapes(), all_outputs.data()); \
}
switch (op_context.input->type) {
case kTfLiteFloat32: {