aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/tooling_util.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/tooling_util.cc')
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc41
1 files changed, 27 insertions, 14 deletions
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index 7dc1af9f1d..98e416b76e 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -350,16 +350,16 @@ const char* OperatorTypeName(OperatorType type) {
HANDLE_OPERATORTYPENAME_CASE(Less)
HANDLE_OPERATORTYPENAME_CASE(LessEqual)
HANDLE_OPERATORTYPENAME_CASE(MatMul)
- HANDLE_OPERATORTYPENAME_CASE(Max) // Reduction Max
- HANDLE_OPERATORTYPENAME_CASE(Maximum) // Element-wise Maximum
+ HANDLE_OPERATORTYPENAME_CASE(ReduceMax) // Reduction Max
+ HANDLE_OPERATORTYPENAME_CASE(Maximum) // Element-wise Maximum
HANDLE_OPERATORTYPENAME_CASE(Merge)
- HANDLE_OPERATORTYPENAME_CASE(Min) // Reduction Min
- HANDLE_OPERATORTYPENAME_CASE(Minimum) // Element-wise Minimum
+ HANDLE_OPERATORTYPENAME_CASE(ReduceMin) // Reduction Min
+ HANDLE_OPERATORTYPENAME_CASE(Minimum) // Element-wise Minimum
HANDLE_OPERATORTYPENAME_CASE(Neg)
+ HANDLE_OPERATORTYPENAME_CASE(Pack)
HANDLE_OPERATORTYPENAME_CASE(Pad)
HANDLE_OPERATORTYPENAME_CASE(PadV2)
HANDLE_OPERATORTYPENAME_CASE(StridedSlice)
- HANDLE_OPERATORTYPENAME_CASE(Stack)
HANDLE_OPERATORTYPENAME_CASE(Range)
HANDLE_OPERATORTYPENAME_CASE(Rank)
HANDLE_OPERATORTYPENAME_CASE(Reshape)
@@ -385,8 +385,10 @@ const char* OperatorTypeName(OperatorType type) {
HANDLE_OPERATORTYPENAME_CASE(SpaceToBatchND)
HANDLE_OPERATORTYPENAME_CASE(BatchToSpaceND)
HANDLE_OPERATORTYPENAME_CASE(Mean)
+ HANDLE_OPERATORTYPENAME_CASE(ReduceProd)
HANDLE_OPERATORTYPENAME_CASE(Svdf)
HANDLE_OPERATORTYPENAME_CASE(ArgMax)
+ HANDLE_OPERATORTYPENAME_CASE(ArgMin)
HANDLE_OPERATORTYPENAME_CASE(TopK_V2)
HANDLE_OPERATORTYPENAME_CASE(Unsupported)
HANDLE_OPERATORTYPENAME_CASE(Exp)
@@ -397,6 +399,9 @@ const char* OperatorTypeName(OperatorType type) {
HANDLE_OPERATORTYPENAME_CASE(Equal)
HANDLE_OPERATORTYPENAME_CASE(NotEqual)
HANDLE_OPERATORTYPENAME_CASE(Pow)
+ HANDLE_OPERATORTYPENAME_CASE(Any)
+ HANDLE_OPERATORTYPENAME_CASE(LogicalAnd)
+ HANDLE_OPERATORTYPENAME_CASE(LogicalNot)
default:
LOG(FATAL) << "Unhandled op type";
#undef HANDLE_OPERATORTYPENAME_CASE
@@ -447,8 +452,12 @@ void LogSummary(int log_level, const Model& model) {
}
void LogArray(int log_level, const Model& model, const string& name) {
- const auto& array = model.GetArray(name);
VLOG(log_level) << "Array: " << name;
+ if (!model.HasArray(name)) {
+ VLOG(log_level) << " DOES NOT EXIST";
+ return;
+ }
+ const auto& array = model.GetArray(name);
VLOG(log_level) << " Data type: " << ArrayDataTypeName(array.data_type);
VLOG(log_level) << " Final type: "
<< ArrayDataTypeName(array.final_data_type);
@@ -934,8 +943,12 @@ void CheckEachArray(const Model& model) {
// shape.
CHECK(array->has_shape());
// Constant buffer should has a valid shape.
- for (int d : array->shape().dims()) {
- CHECK_GE(d, 1);
+ bool is_scalar =
+ array->shape().dimensions_count() == 1 && array->shape().dims(0) == 0;
+ if (!is_scalar) {
+ for (int d : array->shape().dims()) {
+ CHECK_GE(d, 1);
+ }
}
// The shape flat-size should agree with the buffer length.
CHECK_EQ(array->buffer->Length(),
@@ -1261,8 +1274,13 @@ void InsertCopyOperator(Model* model, const string& source_array_name,
auto* copy_op = new TensorFlowReshapeOperator;
copy_op->inputs = {
source_array_name,
- CreateInt32Array(model, target_array_name + "_copy_shape", shape)};
+ CreateInt32Array(
+ model, AvailableArrayName(*model, target_array_name + "_copy_shape"),
+ shape)};
copy_op->outputs = {target_array_name};
+ if (target_array.has_shape()) {
+ copy_op->shape = target_array.shape().dims();
+ }
model->operators.emplace_back(copy_op);
}
@@ -1567,11 +1585,6 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) {
model);
}
- for (const auto& input_array : model->flags.input_arrays()) {
- if (input_array.has_shape()) {
- CHECK(input_array.shape().dims_size());
- }
- }
model->flags.set_change_concat_input_ranges(
model_flags.change_concat_input_ranges());
model->flags.set_allow_nonascii_arrays(model_flags.allow_nonascii_arrays());