aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/lite/toco/dump_graphviz.cc10
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc74
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.h2
3 files changed, 36 insertions, 50 deletions
diff --git a/tensorflow/contrib/lite/toco/dump_graphviz.cc b/tensorflow/contrib/lite/toco/dump_graphviz.cc
index c726eb6d86..2184e8f607 100644
--- a/tensorflow/contrib/lite/toco/dump_graphviz.cc
+++ b/tensorflow/contrib/lite/toco/dump_graphviz.cc
@@ -142,14 +142,8 @@ NodeProperties GetPropertiesForArray(const Model& model,
// Append array shape to the label.
auto& array = model.GetArray(array_name);
-
- if (array.data_type == ArrayDataType::kFloat) {
- AppendF(&node_properties.label, "\\nType: float");
- } else if (array.data_type == ArrayDataType::kInt32) {
- AppendF(&node_properties.label, "\\nType: int32");
- } else if (array.data_type == ArrayDataType::kUint8) {
- AppendF(&node_properties.label, "\\nType: uint8");
- }
+ AppendF(&node_properties.label, "\\nType: %s",
+ ArrayDataTypeName(array.data_type));
if (array.has_shape()) {
auto& array_shape = array.shape();
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index dcb409c84d..eec35b7b59 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -62,6 +62,35 @@ string LogName(const Operator& op) {
}
}
+string ArrayDataTypeName(ArrayDataType data_type) {
+ switch (data_type) {
+ case ArrayDataType::kFloat:
+ return "Float";
+ case ArrayDataType::kInt8:
+ return "Int8";
+ case ArrayDataType::kUint8:
+ return "Uint8";
+ case ArrayDataType::kInt16:
+ return "Int16";
+ case ArrayDataType::kUint16:
+ return "Uint16";
+ case ArrayDataType::kInt32:
+ return "Int32";
+ case ArrayDataType::kUint32:
+ return "Uint32";
+ case ArrayDataType::kInt64:
+ return "Int64";
+ case ArrayDataType::kUint64:
+ return "Uint64";
+ case ArrayDataType::kString:
+ return "String";
+ case ArrayDataType::kNone:
+ return "None";
+ default:
+ LOG(FATAL) << "Unhandled array data type " << static_cast<int>(data_type);
+ }
+}
+
bool IsInputArray(const Model& model, const string& name) {
for (const auto& input_array : model.flags.input_arrays()) {
if (input_array.name() == name) {
@@ -363,48 +392,9 @@ void LogSummary(int log_level, const Model& model) {
void LogArray(int log_level, const Model& model, const string& name) {
const auto& array = model.GetArray(name);
VLOG(log_level) << "Array: " << name;
- switch (array.data_type) {
- case ArrayDataType::kNone:
- VLOG(log_level) << " Data type:";
- break;
- case ArrayDataType::kFloat:
- VLOG(log_level) << " Data type: kFloat";
- break;
- case ArrayDataType::kInt32:
- VLOG(log_level) << " Data type: kInt32";
- break;
- case ArrayDataType::kUint8:
- VLOG(log_level) << " Data type: kUint8";
- break;
- case ArrayDataType::kString:
- VLOG(log_level) << " Data type: kString";
- break;
- default:
- VLOG(log_level) << " Data type: other (numerical value: "
- << static_cast<int>(array.data_type) << ")";
- break;
- }
- switch (array.final_data_type) {
- case ArrayDataType::kNone:
- VLOG(log_level) << " Final type:";
- break;
- case ArrayDataType::kFloat:
- VLOG(log_level) << " Final type: kFloat";
- break;
- case ArrayDataType::kInt32:
- VLOG(log_level) << " Final type: kInt32";
- break;
- case ArrayDataType::kUint8:
- VLOG(log_level) << " Final type: kUint8";
- break;
- case ArrayDataType::kString:
- VLOG(log_level) << " Final type: kString";
- break;
- default:
- VLOG(log_level) << " Final type: other (numerical value: "
- << static_cast<int>(array.data_type) << ")";
- break;
- }
+ VLOG(log_level) << " Data type: " << ArrayDataTypeName(array.data_type);
+ VLOG(log_level) << " Final type: "
+ << ArrayDataTypeName(array.final_data_type);
if (array.buffer) {
VLOG(log_level) << " Constant Buffer";
}
diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h
index 0aaa0f6a21..11208ed667 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.h
+++ b/tensorflow/contrib/lite/toco/tooling_util.h
@@ -54,6 +54,8 @@ absl::string_view FindLongestCommonPrefix(absl::string_view a,
absl::string_view b);
string LogName(const Operator& op);
+string ArrayDataTypeName(ArrayDataType data_type);
+
bool IsInputArray(const Model& model, const string& name);
bool IsArrayConsumed(const Model& model, const string& name);
int CountTrueOutputs(const Model& model, const Operator& op);