diff options
author | Suharsh Sivakumar <suharshs@google.com> | 2018-07-26 20:41:37 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-26 20:45:04 -0700 |
commit | 82293a9cb0606b4dfbe45fc4e2f4aa8778ff7e9a (patch) | |
tree | 862f9a045738f8d233b1e5c3bcd748fd65d16210 /tensorflow/contrib/lite/kernels/space_to_batch_nd.cc | |
parent | 9beaf7038c4f8ca5b6a5168c47efbb3fc669b64b (diff) |
SpaceToBatchND should pad with zero_point when inference_type is uint8
PiperOrigin-RevId: 206265356
Diffstat (limited to 'tensorflow/contrib/lite/kernels/space_to_batch_nd.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/space_to_batch_nd.cc | 22 |
1 files changed, 12 insertions, 10 deletions
diff --git a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc index c9269599e5..03079f1c3b 100644 --- a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc +++ b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc @@ -113,7 +113,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); } -#define TF_LITE_SPACE_TO_BATCH_ND(type, scalar) \ +#define TF_LITE_SPACE_TO_BATCH_ND(type, scalar, pad_value) \ type::SpaceToBatchND(GetTensorData<scalar>(op_context.input), \ GetTensorDims(op_context.input), \ GetTensorData<int32_t>(op_context.block_shape), \ @@ -121,34 +121,36 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetTensorData<int32_t>(op_context.paddings), \ GetTensorDims(op_context.paddings), \ GetTensorData<scalar>(op_context.output), \ - GetTensorDims(op_context.output)) + GetTensorDims(op_context.output), pad_value) switch (op_context.input->type) { // Already know in/out types are same. case kTfLiteFloat32: if (kernel_type == kReference) { - TF_LITE_SPACE_TO_BATCH_ND(reference_ops, float); + TF_LITE_SPACE_TO_BATCH_ND(reference_ops, float, 0); } else { - TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, float); + TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, float, 0); } break; case kTfLiteUInt8: if (kernel_type == kReference) { - TF_LITE_SPACE_TO_BATCH_ND(reference_ops, uint8_t); + TF_LITE_SPACE_TO_BATCH_ND(reference_ops, uint8_t, + op_context.output->params.zero_point); } else { - TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, uint8_t); + TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, uint8_t, + op_context.output->params.zero_point); } break; case kTfLiteInt32: if (kernel_type == kReference) { - TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int32_t); + TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int32_t, 0); } else { - TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int32_t); + TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int32_t, 0); } break; case kTfLiteInt64: if (kernel_type == kReference) { - TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int64_t); + TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int64_t, 0); } else { - TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int64_t); + TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int64_t, 0); } break; default: |