diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/tooling_util.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/tooling_util.cc | 41 |
1 files changed, 27 insertions, 14 deletions
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index 7dc1af9f1d..98e416b76e 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -350,16 +350,16 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(Less) HANDLE_OPERATORTYPENAME_CASE(LessEqual) HANDLE_OPERATORTYPENAME_CASE(MatMul) - HANDLE_OPERATORTYPENAME_CASE(Max) // Reduction Max - HANDLE_OPERATORTYPENAME_CASE(Maximum) // Element-wise Maximum + HANDLE_OPERATORTYPENAME_CASE(ReduceMax) // Reduction Max + HANDLE_OPERATORTYPENAME_CASE(Maximum) // Element-wise Maximum HANDLE_OPERATORTYPENAME_CASE(Merge) - HANDLE_OPERATORTYPENAME_CASE(Min) // Reduction Min - HANDLE_OPERATORTYPENAME_CASE(Minimum) // Element-wise Minimum + HANDLE_OPERATORTYPENAME_CASE(ReduceMin) // Reduction Min + HANDLE_OPERATORTYPENAME_CASE(Minimum) // Element-wise Minimum HANDLE_OPERATORTYPENAME_CASE(Neg) + HANDLE_OPERATORTYPENAME_CASE(Pack) HANDLE_OPERATORTYPENAME_CASE(Pad) HANDLE_OPERATORTYPENAME_CASE(PadV2) HANDLE_OPERATORTYPENAME_CASE(StridedSlice) - HANDLE_OPERATORTYPENAME_CASE(Stack) HANDLE_OPERATORTYPENAME_CASE(Range) HANDLE_OPERATORTYPENAME_CASE(Rank) HANDLE_OPERATORTYPENAME_CASE(Reshape) @@ -385,8 +385,10 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(SpaceToBatchND) HANDLE_OPERATORTYPENAME_CASE(BatchToSpaceND) HANDLE_OPERATORTYPENAME_CASE(Mean) + HANDLE_OPERATORTYPENAME_CASE(ReduceProd) HANDLE_OPERATORTYPENAME_CASE(Svdf) HANDLE_OPERATORTYPENAME_CASE(ArgMax) + HANDLE_OPERATORTYPENAME_CASE(ArgMin) HANDLE_OPERATORTYPENAME_CASE(TopK_V2) HANDLE_OPERATORTYPENAME_CASE(Unsupported) HANDLE_OPERATORTYPENAME_CASE(Exp) @@ -397,6 +399,9 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(Equal) HANDLE_OPERATORTYPENAME_CASE(NotEqual) HANDLE_OPERATORTYPENAME_CASE(Pow) + HANDLE_OPERATORTYPENAME_CASE(Any) + HANDLE_OPERATORTYPENAME_CASE(LogicalAnd) + HANDLE_OPERATORTYPENAME_CASE(LogicalNot) default: LOG(FATAL) << "Unhandled op type"; #undef HANDLE_OPERATORTYPENAME_CASE @@ -447,8 +452,12 @@ void LogSummary(int log_level, const Model& model) { } void LogArray(int log_level, const Model& model, const string& name) { - const auto& array = model.GetArray(name); VLOG(log_level) << "Array: " << name; + if (!model.HasArray(name)) { + VLOG(log_level) << " DOES NOT EXIST"; + return; + } + const auto& array = model.GetArray(name); VLOG(log_level) << " Data type: " << ArrayDataTypeName(array.data_type); VLOG(log_level) << " Final type: " << ArrayDataTypeName(array.final_data_type); @@ -934,8 +943,12 @@ void CheckEachArray(const Model& model) { // shape. CHECK(array->has_shape()); // Constant buffer should has a valid shape. - for (int d : array->shape().dims()) { - CHECK_GE(d, 1); + bool is_scalar = + array->shape().dimensions_count() == 1 && array->shape().dims(0) == 0; + if (!is_scalar) { + for (int d : array->shape().dims()) { + CHECK_GE(d, 1); + } } // The shape flat-size should agree with the buffer length. CHECK_EQ(array->buffer->Length(), @@ -1261,8 +1274,13 @@ void InsertCopyOperator(Model* model, const string& source_array_name, auto* copy_op = new TensorFlowReshapeOperator; copy_op->inputs = { source_array_name, - CreateInt32Array(model, target_array_name + "_copy_shape", shape)}; + CreateInt32Array( + model, AvailableArrayName(*model, target_array_name + "_copy_shape"), + shape)}; copy_op->outputs = {target_array_name}; + if (target_array.has_shape()) { + copy_op->shape = target_array.shape().dims(); + } model->operators.emplace_back(copy_op); } @@ -1567,11 +1585,6 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) { model); } - for (const auto& input_array : model->flags.input_arrays()) { - if (input_array.has_shape()) { - CHECK(input_array.shape().dims_size()); - } - } model->flags.set_change_concat_input_ranges( model_flags.change_concat_input_ranges()); model->flags.set_allow_nonascii_arrays(model_flags.allow_nonascii_arrays()); |