aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/import_tensorflow.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-21 19:49:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-21 19:52:21 -0700
commitb912ce1b83570eabd3a14db678bb752a71846756 (patch)
tree1125217a8b67407f0a81cc0eb2f14ebfda0afa12 /tensorflow/contrib/lite/toco/import_tensorflow.cc
parenteea807c75d6d18f3efc6b988edb8d3c93a48a16c (diff)
Fix import int32/uint8/int64/bool array for toco.
PiperOrigin-RevId: 201627909
Diffstat (limited to 'tensorflow/contrib/lite/toco/import_tensorflow.cc')
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc26
1 files changed, 21 insertions, 5 deletions
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 8da33e8a22..da7e5add7e 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -263,7 +263,11 @@ tensorflow::Status ImportQuint8Array(const TensorProto& input_tensor,
output_array->GetMutableBuffer<ArrayDataType::kUint8>().data;
output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0);
CHECK_GE(output_int_data.size(), input_flat_size);
- if (input_tensor.int_val_size()) {
+ if (input_tensor.int_val_size() == 1) {
+ for (int i = 0; i < input_flat_size; i++) {
+ output_int_data[i] = input_tensor.int_val(0);
+ }
+ } else if (input_tensor.int_val_size() == input_flat_size) {
for (int i = 0; i < input_tensor.int_val_size(); i++) {
output_int_data[i] = input_tensor.int_val(i);
}
@@ -296,7 +300,11 @@ tensorflow::Status ImportInt32Array(const TensorProto& input_tensor,
output_array->GetMutableBuffer<ArrayDataType::kInt32>().data;
output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0);
CHECK_GE(output_int_data.size(), input_flat_size);
- if (input_tensor.int_val_size()) {
+ if (input_tensor.int_val_size() == 1) {
+ for (int i = 0; i < input_flat_size; i++) {
+ output_int_data[i] = input_tensor.int_val(0);
+ }
+ } else if (input_tensor.int_val_size() == input_flat_size) {
for (int i = 0; i < input_tensor.int_val_size(); i++) {
output_int_data[i] = input_tensor.int_val(i);
}
@@ -328,8 +336,12 @@ tensorflow::Status ImportInt64Array(const TensorProto& input_tensor,
output_array->GetMutableBuffer<ArrayDataType::kInt64>().data;
output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0);
CHECK_GE(output_int_data.size(), input_flat_size);
- if (input_tensor.int64_val_size()) {
- for (int i = 0; i < input_tensor.int64_val_size(); i++) {
+ if (input_tensor.int64_val_size() == 1) {
+ for (int i = 0; i < input_flat_size; i++) {
+ output_int_data[i] = input_tensor.int64_val(0);
+ }
+ } else if (input_tensor.int64_val_size() == input_flat_size) {
+ for (int i = 0; i < input_tensor.float_val_size(); i++) {
output_int_data[i] = input_tensor.int64_val(i);
}
} else if (input_tensor.tensor_content().size() ==
@@ -362,7 +374,11 @@ tensorflow::Status ImportBoolArray(const TensorProto& input_tensor,
output_bool_data.resize(RequiredBufferSizeForShape(output_array->shape()),
false);
CHECK_GE(output_bool_data.size(), input_flat_size);
- if (input_tensor.bool_val_size()) {
+ if (input_tensor.bool_val_size() == 1) {
+ for (int i = 0; i < input_flat_size; i++) {
+ output_bool_data[i] = input_tensor.bool_val(0);
+ }
+ } else if (input_tensor.int_val_size() == input_flat_size) {
for (int i = 0; i < input_tensor.bool_val_size(); i++) {
output_bool_data[i] = input_tensor.bool_val(i);
}