aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc
diff options
context:
space:
mode:
authorGravatar Suharsh Sivakumar <suharshs@google.com>2018-07-26 20:41:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-26 20:45:04 -0700
commit82293a9cb0606b4dfbe45fc4e2f4aa8778ff7e9a (patch)
tree862f9a045738f8d233b1e5c3bcd748fd65d16210 /tensorflow/contrib/lite/kernels/space_to_batch_nd.cc
parent9beaf7038c4f8ca5b6a5168c47efbb3fc669b64b (diff)
SpaceToBatchND should pad with zero_point when inference_type is uint8
PiperOrigin-RevId: 206265356
Diffstat (limited to 'tensorflow/contrib/lite/kernels/space_to_batch_nd.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/space_to_batch_nd.cc22
1 files changed, 12 insertions, 10 deletions
diff --git a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc
index c9269599e5..03079f1c3b 100644
--- a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc
+++ b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc
@@ -113,7 +113,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
}
-#define TF_LITE_SPACE_TO_BATCH_ND(type, scalar) \
+#define TF_LITE_SPACE_TO_BATCH_ND(type, scalar, pad_value) \
type::SpaceToBatchND(GetTensorData<scalar>(op_context.input), \
GetTensorDims(op_context.input), \
GetTensorData<int32_t>(op_context.block_shape), \
@@ -121,34 +121,36 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetTensorData<int32_t>(op_context.paddings), \
GetTensorDims(op_context.paddings), \
GetTensorData<scalar>(op_context.output), \
- GetTensorDims(op_context.output))
+ GetTensorDims(op_context.output), pad_value)
switch (op_context.input->type) { // Already know in/out types are same.
case kTfLiteFloat32:
if (kernel_type == kReference) {
- TF_LITE_SPACE_TO_BATCH_ND(reference_ops, float);
+ TF_LITE_SPACE_TO_BATCH_ND(reference_ops, float, 0);
} else {
- TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, float);
+ TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, float, 0);
}
break;
case kTfLiteUInt8:
if (kernel_type == kReference) {
- TF_LITE_SPACE_TO_BATCH_ND(reference_ops, uint8_t);
+ TF_LITE_SPACE_TO_BATCH_ND(reference_ops, uint8_t,
+ op_context.output->params.zero_point);
} else {
- TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, uint8_t);
+ TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, uint8_t,
+ op_context.output->params.zero_point);
}
break;
case kTfLiteInt32:
if (kernel_type == kReference) {
- TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int32_t);
+ TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int32_t, 0);
} else {
- TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int32_t);
+ TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int32_t, 0);
}
break;
case kTfLiteInt64:
if (kernel_type == kReference) {
- TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int64_t);
+ TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int64_t, 0);
} else {
- TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int64_t);
+ TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int64_t, 0);
}
break;
default: