diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-06-21 19:49:45 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-21 19:52:21 -0700 |
commit | b912ce1b83570eabd3a14db678bb752a71846756 (patch) | |
tree | 1125217a8b67407f0a81cc0eb2f14ebfda0afa12 /tensorflow/contrib/lite/toco/import_tensorflow.cc | |
parent | eea807c75d6d18f3efc6b988edb8d3c93a48a16c (diff) |
Fix import int32/uint8/int64/bool array for toco.
PiperOrigin-RevId: 201627909
Diffstat (limited to 'tensorflow/contrib/lite/toco/import_tensorflow.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/import_tensorflow.cc | 26 |
1 files changed, 21 insertions, 5 deletions
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index 8da33e8a22..da7e5add7e 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -263,7 +263,11 @@ tensorflow::Status ImportQuint8Array(const TensorProto& input_tensor, output_array->GetMutableBuffer<ArrayDataType::kUint8>().data; output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0); CHECK_GE(output_int_data.size(), input_flat_size); - if (input_tensor.int_val_size()) { + if (input_tensor.int_val_size() == 1) { + for (int i = 0; i < input_flat_size; i++) { + output_int_data[i] = input_tensor.int_val(0); + } + } else if (input_tensor.int_val_size() == input_flat_size) { for (int i = 0; i < input_tensor.int_val_size(); i++) { output_int_data[i] = input_tensor.int_val(i); } @@ -296,7 +300,11 @@ tensorflow::Status ImportInt32Array(const TensorProto& input_tensor, output_array->GetMutableBuffer<ArrayDataType::kInt32>().data; output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0); CHECK_GE(output_int_data.size(), input_flat_size); - if (input_tensor.int_val_size()) { + if (input_tensor.int_val_size() == 1) { + for (int i = 0; i < input_flat_size; i++) { + output_int_data[i] = input_tensor.int_val(0); + } + } else if (input_tensor.int_val_size() == input_flat_size) { for (int i = 0; i < input_tensor.int_val_size(); i++) { output_int_data[i] = input_tensor.int_val(i); } @@ -328,8 +336,12 @@ tensorflow::Status ImportInt64Array(const TensorProto& input_tensor, output_array->GetMutableBuffer<ArrayDataType::kInt64>().data; output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0); CHECK_GE(output_int_data.size(), input_flat_size); - if (input_tensor.int64_val_size()) { - for (int i = 0; i < input_tensor.int64_val_size(); i++) { + if (input_tensor.int64_val_size() == 1) { + for (int i = 0; i < input_flat_size; i++) { + output_int_data[i] = input_tensor.int64_val(0); + } + } else if (input_tensor.int64_val_size() == input_flat_size) { + for (int i = 0; i < input_tensor.float_val_size(); i++) { output_int_data[i] = input_tensor.int64_val(i); } } else if (input_tensor.tensor_content().size() == @@ -362,7 +374,11 @@ tensorflow::Status ImportBoolArray(const TensorProto& input_tensor, output_bool_data.resize(RequiredBufferSizeForShape(output_array->shape()), false); CHECK_GE(output_bool_data.size(), input_flat_size); - if (input_tensor.bool_val_size()) { + if (input_tensor.bool_val_size() == 1) { + for (int i = 0; i < input_flat_size; i++) { + output_bool_data[i] = input_tensor.bool_val(0); + } + } else if (input_tensor.int_val_size() == input_flat_size) { for (int i = 0; i < input_tensor.bool_val_size(); i++) { output_bool_data[i] = input_tensor.bool_val(i); } |