aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/tools
diff options
context:
space:
mode:
authorGravatar Yifei Feng <yifeif@google.com>2018-01-29 10:42:32 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-29 10:46:04 -0800
commitfd63d4e30a01cf860baf60b990b223cd54bc895c (patch)
treefcea79b1e89bcf30ac80d087edf051c3711d06b1 /tensorflow/contrib/lite/tools
parent730071d0dca35a9e08f3bdc49661ae34d109da74 (diff)
Add C0326 bad-whitespace error to pylint sanity check.
PiperOrigin-RevId: 183689499
Diffstat (limited to 'tensorflow/contrib/lite/tools')
-rw-r--r--tensorflow/contrib/lite/tools/visualize.py104
1 files changed, 58 insertions, 46 deletions
diff --git a/tensorflow/contrib/lite/tools/visualize.py b/tensorflow/contrib/lite/tools/visualize.py
index d0d78e3afa..f571dd59da 100644
--- a/tensorflow/contrib/lite/tools/visualize.py
+++ b/tensorflow/contrib/lite/tools/visualize.py
@@ -198,10 +198,13 @@ class TensorMapper(object):
def GenerateGraph(subgraph_idx, g, opcode_mapper):
"""Produces the HTML required to have a d3 visualization of the dag."""
+
def TensorName(idx):
- return "t%d"%idx
+ return "t%d" % idx
+
def OpName(idx):
- return "o%d"%idx
+ return "o%d" % idx
+
edges = []
nodes = []
first = {}
@@ -210,27 +213,35 @@ def GenerateGraph(subgraph_idx, g, opcode_mapper):
for tensor_input_position, tensor_index in enumerate(op["inputs"]):
if tensor_index not in first:
first[tensor_index] = (
- op_index*pixel_mult,
- tensor_input_position*pixel_mult - pixel_mult/2)
- edges.append(
- {"source": TensorName(tensor_index), "target": OpName(op_index)})
+ op_index * pixel_mult,
+ tensor_input_position * pixel_mult - pixel_mult / 2)
+ edges.append({
+ "source": TensorName(tensor_index),
+ "target": OpName(op_index)
+ })
for tensor_index in op["outputs"]:
- edges.append(
- {"target": TensorName(tensor_index), "source": OpName(op_index)})
- nodes.append({"id": OpName(op_index),
- "name": opcode_mapper(op["opcode_index"]),
- "group": 2,
- "x": pixel_mult,
- "y": op_index * pixel_mult})
+ edges.append({
+ "target": TensorName(tensor_index),
+ "source": OpName(op_index)
+ })
+ nodes.append({
+ "id": OpName(op_index),
+ "name": opcode_mapper(op["opcode_index"]),
+ "group": 2,
+ "x": pixel_mult,
+ "y": op_index * pixel_mult
+ })
for tensor_index, tensor in enumerate(g["tensors"]):
- initial_y = (first[tensor_index] if tensor_index in first
- else len(g["operators"]))
-
- nodes.append({"id": TensorName(tensor_index),
- "name": "%s (%d)" % (tensor["name"], tensor_index),
- "group": 1,
- "x": 2,
- "y": initial_y})
+ initial_y = (
+ first[tensor_index] if tensor_index in first else len(g["operators"]))
+
+ nodes.append({
+ "id": TensorName(tensor_index),
+ "name": "%s (%d)" % (tensor["name"], tensor_index),
+ "group": 1,
+ "x": 2,
+ "y": initial_y
+ })
graph_str = json.dumps({"nodes": nodes, "edges": edges})
html = _D3_HTML_TEMPLATE % (graph_str, subgraph_idx)
@@ -267,7 +278,7 @@ def GenerateTableHtml(items, keys_to_print, display_index=True):
for h, mapper in keys_to_print:
val = tensor[h] if h in tensor else None
val = val if mapper is None else mapper(val)
- html += "<td>%s</td>\n"%val
+ html += "<td>%s</td>\n" % val
html += "</tr>\n"
html += "</table>\n"
@@ -279,18 +290,19 @@ def CreateHtmlFile(tflite_input, html_output):
# Convert the model into a JSON flatbuffer using flatc (build if doesn't
# exist.
- if not os.path.exists(tflite_input):
+ if not os.path.exists(tflite_input):
raise RuntimeError("Invalid filename %r" % tflite_input)
if tflite_input.endswith(".tflite") or tflite_input.endswith(".bin"):
# Run convert
- cmd = (_BINARY + " -t "
- "--strict-json --defaults-json -o /tmp {schema} -- {input}".format(
- input=tflite_input, schema=_SCHEMA))
+ cmd = (
+ _BINARY + " -t "
+ "--strict-json --defaults-json -o /tmp {schema} -- {input}".format(
+ input=tflite_input, schema=_SCHEMA))
print(cmd)
os.system(cmd)
- real_output = ("/tmp/"+ os.path.splitext(os.path.split(tflite_input)[-1])[0]
- + ".json")
+ real_output = ("/tmp/" + os.path.splitext(
+ os.path.split(tflite_input)[-1])[0] + ".json")
data = json.load(open(real_output))
elif tflite_input.endswith(".json"):
@@ -302,12 +314,13 @@ def CreateHtmlFile(tflite_input, html_output):
html += "<h1>TensorFlow Lite Model</h2>"
data["filename"] = tflite_input # Avoid special case
- toplevel_stuff = [("filename", None), ("version", None),
- ("description", None)]
+ toplevel_stuff = [("filename", None), ("version", None), ("description",
+ None)]
html += "<table>\n"
for key, mapping in toplevel_stuff:
- if not mapping: mapping = lambda x: x
+ if not mapping:
+ mapping = lambda x: x
html += "<tr><th>%s</th><td>%s</td></tr>\n" % (key, mapping(data[key]))
html += "</table>\n"
@@ -320,22 +333,22 @@ def CreateHtmlFile(tflite_input, html_output):
html += "<div class='subgraph'>"
tensor_mapper = TensorMapper(g)
opcode_mapper = OpCodeMapper(data)
- op_keys_to_display = [
- ("inputs", tensor_mapper), ("outputs", tensor_mapper),
- ("builtin_options", None), ("opcode_index", opcode_mapper)]
- tensor_keys_to_display = [
- ("name", None), ("type", None), ("shape", None), ("buffer", None),
- ("quantization", None)]
+ op_keys_to_display = [("inputs", tensor_mapper), ("outputs", tensor_mapper),
+ ("builtin_options", None), ("opcode_index",
+ opcode_mapper)]
+ tensor_keys_to_display = [("name", None), ("type", None), ("shape", None),
+ ("buffer", None), ("quantization", None)]
html += "<h2>Subgraph %d</h2>\n" % subgraph_idx
# Inputs and outputs.
html += "<h3>Inputs/Outputs</h3>\n"
- html += GenerateTableHtml([{"inputs": g["inputs"],
- "outputs": g["outputs"]}],
- [("inputs", tensor_mapper),
- ("outputs", tensor_mapper)],
- display_index=False)
+ html += GenerateTableHtml(
+ [{
+ "inputs": g["inputs"],
+ "outputs": g["outputs"]
+ }], [("inputs", tensor_mapper), ("outputs", tensor_mapper)],
+ display_index=False)
# Print the tensors.
html += "<h3>Tensors</h3>\n"
@@ -357,8 +370,7 @@ def CreateHtmlFile(tflite_input, html_output):
# Operator codes
html += "<h2>Operator Codes</h2>\n"
- html += GenerateTableHtml(data["operator_codes"],
- operator_keys_to_display)
+ html += GenerateTableHtml(data["operator_codes"], operator_keys_to_display)
html += "</body></html>\n"
@@ -370,10 +382,10 @@ def main(argv):
tflite_input = argv[1]
html_output = argv[2]
except IndexError:
- print ("Usage: %s <input tflite> <output html>" % (argv[0]))
+ print("Usage: %s <input tflite> <output html>" % (argv[0]))
else:
CreateHtmlFile(tflite_input, html_output)
+
if __name__ == "__main__":
main(sys.argv)
-