aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools
diff options
context:
space:
mode:
authorGravatar Suharsh Sivakumar <suharshs@google.com>2018-09-26 10:26:23 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-26 10:30:47 -0700
commita8203086b9bd0a4cd874e42aead0758a3365c387 (patch)
treeb75742a3e2fee7edbb05b31b267f8221a2ba4552 /tensorflow/tools
parent00ae12ad8bf5c348e4c31448e3922cbaab54cc03 (diff)
Remove quantize_graph script. TF Lite quantization is the supported mobile quantization tooling.
PiperOrigin-RevId: 214625933
Diffstat (limited to 'tensorflow/tools')
-rw-r--r--tensorflow/tools/quantization/BUILD78
-rw-r--r--tensorflow/tools/quantization/graph_to_dot.py68
-rw-r--r--tensorflow/tools/quantization/quantize_graph.py1302
-rw-r--r--tensorflow/tools/quantization/quantize_graph_test.py966
4 files changed, 0 insertions, 2414 deletions
diff --git a/tensorflow/tools/quantization/BUILD b/tensorflow/tools/quantization/BUILD
deleted file mode 100644
index 17443a8617..0000000000
--- a/tensorflow/tools/quantization/BUILD
+++ /dev/null
@@ -1,78 +0,0 @@
-# Description:
-# Utilities for quantizing TensorFlow graphs to lower bit depths.
-
-package(default_visibility = ["//visibility:public"])
-
-licenses(["notice"]) # Apache 2.0
-
-exports_files(["LICENSE"])
-
-load("//tensorflow:tensorflow.bzl", "py_test")
-
-py_library(
- name = "quantize_graph_lib",
- srcs = ["quantize_graph.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:graph_util",
- "//tensorflow/python:platform",
- "//tensorflow/python:session",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python:tensor_util",
- "//third_party/py/numpy",
- ],
-)
-
-py_binary(
- name = "quantize_graph",
- srcs = ["quantize_graph.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python", # TODO(b/34059704): remove when fixed
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client",
- "//tensorflow/python:framework",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:graph_util",
- "//tensorflow/python:platform",
- "//tensorflow/python:tensor_util",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "quantize_graph_test",
- size = "small",
- srcs = ["quantize_graph_test.py"],
- srcs_version = "PY2AND3",
- tags = ["nomsan"], # http://b/32242946
- deps = [
- ":quantize_graph",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:client",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:graph_util",
- "//tensorflow/python:platform",
- "//third_party/py/numpy",
- ],
-)
-
-py_binary(
- name = "graph_to_dot",
- srcs = ["graph_to_dot.py"],
- main = "graph_to_dot.py",
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:platform",
- ],
-)
diff --git a/tensorflow/tools/quantization/graph_to_dot.py b/tensorflow/tools/quantization/graph_to_dot.py
deleted file mode 100644
index 81d6aa62c8..0000000000
--- a/tensorflow/tools/quantization/graph_to_dot.py
+++ /dev/null
@@ -1,68 +0,0 @@
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Converts a GraphDef file into a DOT format suitable for visualization.
-
-This script takes a GraphDef representing a network, and produces a DOT file
-that can then be visualized by GraphViz tools like dot and xdot.
-
-"""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import re
-
-from google.protobuf import text_format
-
-from tensorflow.core.framework import graph_pb2
-from tensorflow.python.platform import app
-from tensorflow.python.platform import flags
-from tensorflow.python.platform import gfile
-
-FLAGS = flags.FLAGS
-
-flags.DEFINE_string("graph", "", """TensorFlow 'GraphDef' file to load.""")
-flags.DEFINE_bool("input_binary", True,
- """Whether the input files are in binary format.""")
-flags.DEFINE_string("dot_output", "", """Where to write the DOT output.""")
-
-
-def main(unused_args):
- if not gfile.Exists(FLAGS.graph):
- print("Input graph file '" + FLAGS.graph + "' does not exist!")
- return -1
-
- graph = graph_pb2.GraphDef()
- with open(FLAGS.graph, "r") as f:
- if FLAGS.input_binary:
- graph.ParseFromString(f.read())
- else:
- text_format.Merge(f.read(), graph)
-
- with open(FLAGS.dot_output, "wb") as f:
- print("digraph graphname {", file=f)
- for node in graph.node:
- output_name = node.name
- print(" \"" + output_name + "\" [label=\"" + node.op + "\"];", file=f)
- for input_full_name in node.input:
- parts = input_full_name.split(":")
- input_name = re.sub(r"^\^", "", parts[0])
- print(" \"" + input_name + "\" -> \"" + output_name + "\";", file=f)
- print("}", file=f)
- print("Created DOT file '" + FLAGS.dot_output + "'.")
-
-
-if __name__ == "__main__":
- app.run()
diff --git a/tensorflow/tools/quantization/quantize_graph.py b/tensorflow/tools/quantization/quantize_graph.py
deleted file mode 100644
index 3acb532263..0000000000
--- a/tensorflow/tools/quantization/quantize_graph.py
+++ /dev/null
@@ -1,1302 +0,0 @@
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-r"""Transforms a float-trained graph into an equivalent quantized version.
-
-An example of command-line usage is:
-bazel build tensorflow/tools/quantization:quantize_graph \
-&& bazel-bin/tensorflow/tools/quantization/quantize_graph \
---input=tensorflow_inception_graph.pb
---output_node_names="softmax2" --print_nodes --output=/tmp/quantized_graph.pb \
---mode=eightbit --logtostderr
-
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import collections
-import re
-import numpy as np
-
-from tensorflow.core.framework import attr_value_pb2
-from tensorflow.core.framework import graph_pb2
-from tensorflow.core.framework import node_def_pb2
-from tensorflow.python.client import session
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import graph_util
-from tensorflow.python.framework import importer
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.framework import tensor_util
-from tensorflow.python.ops import array_ops
-from tensorflow.python.platform import app
-from tensorflow.python.platform import flags as flags_lib
-from tensorflow.python.platform import gfile
-
-flags = flags_lib
-FLAGS = flags.FLAGS
-
-flags.DEFINE_boolean("print_nodes", False, """Lists all nodes in the model.""")
-flags.DEFINE_string("input", "", """TensorFlow 'GraphDef' file to load.""")
-flags.DEFINE_string("output_node_names", "",
- """Output node names, comma separated.""")
-flags.DEFINE_string("output", "", """File to save the output graph to.""")
-flags.DEFINE_integer("bitdepth", 8,
- """How many bits to quantize the graph to.""")
-flags.DEFINE_string("mode", "round",
- """What transformation to apply (round, quantize,"""
- """ eightbit, weights, or weights_rounded).""")
-flags.DEFINE_string("test_input_dims", "1,224,224,3",
- """The size of the input tensor to use when testing a"""
- """ graph loaded from a file.""")
-flags.DEFINE_boolean("strip_redundant_quantization", True,
- """Removes redundant dequantize/quantize pairs.""")
-flags.DEFINE_boolean("quantized_input", False,
- "If true, assume Placeholders are quantized with values "
- "covering [--quantized_input_min,--quantized_input_max]. "
- "Only supported when --mode=eightbit")
-flags.DEFINE_float("quantized_input_min", 0,
- "The minimum of the actual input range when "
- "--quantized_input")
-flags.DEFINE_float("quantized_input_max", 1,
- "The maximum of the actual input range when "
- "--quantized_input")
-flags.DEFINE_float(
- "quantized_fallback_min", None,
- "The fallback 'min' value to use for layers which lack min-max "
- "information. Note: this should be considered a coarse tool just good "
- "enough for experimentation purposes, since graphs quantized in this way "
- "would be very inaccurate.")
-flags.DEFINE_float(
- "quantized_fallback_max", None,
- "The fallback 'max' value to use for layers which lack min-max "
- "information. Note: this should be considered a coarse tool just good "
- "enough for experimentation purposes, since graphs quantized in this way "
- "would be very inaccurate.")
-
-
-def print_input_nodes(current_node, nodes_map, indent, already_visited):
- print(" " * indent + current_node.op + ":" + current_node.name)
- already_visited[current_node.name] = True
- for input_node_name in current_node.input:
- if input_node_name in already_visited:
- continue
- input_node = nodes_map[input_node_name]
- print_input_nodes(input_node, nodes_map, indent + 1, already_visited)
-
-
-def create_node(op, name, inputs):
- new_node = node_def_pb2.NodeDef()
- new_node.op = op
- new_node.name = name
- for input_name in inputs:
- new_node.input.extend([input_name])
- return new_node
-
-
-def create_constant_node(name, value, dtype, shape=None):
- node = create_node("Const", name, [])
- set_attr_dtype(node, "dtype", dtype)
- set_attr_tensor(node, "value", value, dtype, shape)
- return node
-
-
-def copy_attr(node, key, attr_value):
- try:
- node.attr[key].CopyFrom(attr_value)
- except KeyError:
- pass
-
-
-def set_attr_dtype(node, key, value):
- try:
- node.attr[key].CopyFrom(
- attr_value_pb2.AttrValue(type=value.as_datatype_enum))
- except KeyError:
- pass
-
-
-def set_attr_shape(node, key, value):
- try:
- node.attr[key].CopyFrom(
- attr_value_pb2.AttrValue(shape=tensor_shape.as_shape(value).as_proto()))
- except KeyError:
- pass
-
-
-def set_attr_tensor(node, key, value, dtype, shape=None):
- try:
- node.attr[key].CopyFrom(
- attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
- value, dtype=dtype, shape=shape)))
- except KeyError:
- pass
-
-
-def set_attr_string(node, key, value):
- try:
- node.attr[key].CopyFrom(attr_value_pb2.AttrValue(s=value))
- except KeyError:
- pass
-
-
-def set_attr_int_list(node, key, value):
- list_value = attr_value_pb2.AttrValue.ListValue(i=value)
- try:
- node.attr[key].CopyFrom(attr_value_pb2.AttrValue(list=list_value))
- except KeyError:
- pass
-
-
-def set_attr_bool(node, key, value):
- try:
- node.attr[key].CopyFrom(attr_value_pb2.AttrValue(b=value))
- except KeyError:
- pass
-
-
-def set_attr_int(node, key, value):
- try:
- node.attr[key].CopyFrom(attr_value_pb2.AttrValue(i=value))
- except KeyError:
- pass
-
-
-def set_attr_float(node, key, value):
- try:
- node.attr[key].CopyFrom(attr_value_pb2.AttrValue(f=value))
- except KeyError:
- pass
-
-
-def node_name_from_input(node_name):
- """Strips off ports and other decorations to get the underlying node name."""
- if node_name.startswith("^"):
- node_name = node_name[1:]
- m = re.search(r"(.*):\d+$", node_name)
- if m:
- node_name = m.group(1)
- return node_name
-
-
-def ensure_tensor_name_has_port(node_name):
- """Makes sure that a tensor name has :0 if no explicit port exists."""
- m = re.search(r"(.*):\d+$", node_name)
- if m:
- name_with_port = node_name
- else:
- name_with_port = node_name + ":0"
- return name_with_port
-
-
-def unique_node_name_from_input(node_name):
- """Replaces invalid characters in input names to get a unique node name."""
- return node_name.replace(":", "__port__").replace("^", "__hat__")
-
-
-def quantize_array(arr, num_buckets):
- """Quantizes a numpy array.
-
- This function maps each scalar in arr to the center of one of num_buckets
- buckets. For instance,
- quantize_array([0, 0.3, 0.6, 1], 2) => [0.25, 0.25, 0.75, 0.75]
-
- Args:
- arr: The numpy array to quantize.
- num_buckets: The number of buckets to map "var" to.
- Returns:
- The quantized numpy array.
- Raises:
- ValueError: when num_buckets < 1.
- """
- if num_buckets < 1:
- raise ValueError("num_buckets must be >= 1")
- arr_max = arr.max()
- arr_min = arr.min()
- if arr_max == arr_min:
- return arr
- bucket_width = (arr_max - arr_min) / num_buckets
- # Map scalars to bucket indices. Take special care of max(arr).
- bucket_indices = np.floor((arr - arr_min) / bucket_width)
- bucket_indices[bucket_indices == num_buckets] = num_buckets - 1
- # Map each scalar to the center of a bucket.
- arr = arr_min + bucket_width * (bucket_indices + 0.5)
- return arr
-
-
-def quantize_weight_rounded(input_node):
- """Returns a replacement node for input_node containing bucketed floats."""
- input_tensor = input_node.attr["value"].tensor
- tensor_value = tensor_util.MakeNdarray(input_tensor)
- shape = input_tensor.tensor_shape
- # Currently, the parameter FLAGS.bitdepth is used to compute the
- # number of buckets as 1 << FLAGS.bitdepth, meaning the number of
- # buckets can only be a power of 2.
- # This could be fixed by introducing a new parameter, num_buckets,
- # which would allow for more flexibility in chosing the right model
- # size/accuracy tradeoff. But I didn't want to add more parameters
- # to this script than absolutely necessary.
- num_buckets = 1 << FLAGS.bitdepth
- tensor_value_rounded = quantize_array(tensor_value, num_buckets)
- tensor_shape_list = tensor_util.TensorShapeProtoToList(shape)
- return [
- create_constant_node(
- input_node.name,
- tensor_value_rounded,
- dtypes.float32,
- shape=tensor_shape_list)
- ]
-
-
-def quantize_weight_eightbit(input_node, quantization_mode):
- """Returns replacement nodes for input_node using the Dequantize op."""
- base_name = input_node.name + "_"
- quint8_const_name = base_name + "quint8_const"
- min_name = base_name + "min"
- max_name = base_name + "max"
- float_tensor = tensor_util.MakeNdarray(input_node.attr["value"].tensor)
- min_value = np.min(float_tensor.flatten())
- max_value = np.max(float_tensor.flatten())
- # Make sure that the range includes zero.
- if min_value > 0.0:
- min_value = 0.0
- # min_value == max_value is a tricky case. It can occur for general
- # tensors, and of course for scalars. The quantized ops cannot deal
- # with this case, so we set max_value to something else.
- # It's a tricky question what is the numerically best solution to
- # deal with this degeneracy.
- # TODO(petewarden): Better use a tolerance than a hard comparison?
- if min_value == max_value:
- if abs(min_value) < 0.000001:
- max_value = min_value + 1.0
- elif min_value > 0:
- max_value = 2 * min_value
- else:
- max_value = min_value / 2.0
-
- sess = session.Session()
- with sess.as_default():
- quantize_op = array_ops.quantize_v2(
- float_tensor,
- min_value,
- max_value,
- dtypes.quint8,
- mode=quantization_mode)
- quint8_tensor = quantize_op[0].eval()
- shape = tensor_util.TensorShapeProtoToList(input_node.attr["value"]
- .tensor.tensor_shape)
- quint8_const_node = create_constant_node(
- quint8_const_name, quint8_tensor, dtypes.quint8, shape=shape)
- min_node = create_constant_node(min_name, min_value, dtypes.float32)
- max_node = create_constant_node(max_name, max_value, dtypes.float32)
- dequantize_node = create_node("Dequantize", input_node.name,
- [quint8_const_name, min_name, max_name])
- set_attr_dtype(dequantize_node, "T", dtypes.quint8)
- set_attr_string(dequantize_node, "mode", quantization_mode)
- return [quint8_const_node, min_node, max_node, dequantize_node]
-
-
-EightbitizeRecursionState = collections.namedtuple(
- "EightbitizeRecursionState",
- ["already_visited", "output_node_stack", "merged_with_fake_quant"])
-
-
-class GraphRewriter(object):
- """Takes a float graph, and rewrites it in quantized form."""
-
- def __init__(self,
- input_graph,
- mode,
- quantized_input_range,
- fallback_quantization_range=None):
- """Sets up the class to rewrite a float graph.
-
- Args:
- input_graph: A float graph to transform.
- mode: A string controlling how quantization is performed -
- round, quantize, eightbit, or weights.
- quantized_input_range: if set, assume the input is
- quantized and represents the range
- [quantized_input_range[0], quantized_input_range[1]]
- fallback_quantization_range: if set, then for nodes where the quantization
- range can't be inferred from the graph, use the range
- [fallback_quantization_range[0], fallback_quantization_range[1]) instead
- of using a RequantizationRange node in the graph.
-
- Raises:
- ValueError: Two nodes with the same name were found in the graph.
- """
- self.input_graph = input_graph
- self.nodes_map = self.create_nodes_map(input_graph)
- self.output_graph = None
- self.mode = mode
- self.final_node_renames = {}
- if quantized_input_range:
- self.input_range = (quantized_input_range[0], quantized_input_range[1])
- if self.input_range[0] >= self.input_range[1]:
- raise ValueError("Invalid quantized_input_range: [%s,%s]" %
- self.input_range)
- if self.mode != "eightbit":
- raise ValueError(
- "quantized_input_range can only be specified in eightbit mode")
- else:
- self.input_range = None
-
- if fallback_quantization_range:
- self.fallback_quantization_range = [
- fallback_quantization_range[0], fallback_quantization_range[1]
- ]
- if (self.fallback_quantization_range[0] >=
- self.fallback_quantization_range[1]):
- raise ValueError("Invalid fallback_quantization_range: [%s,%s]" %
- self.fallback_quantization_range)
- if self.mode != "eightbit":
- raise ValueError("fallback_quantization_range can only be "
- "specified in eightbit mode")
- else:
- self.fallback_quantization_range = None
-
- # Data that is valid only during the recursive call to rewrite the graph.
- self.state = None
-
- def create_nodes_map(self, graph):
- """Builds a mapping of node names to their defs from the graph."""
- nodes_map = {}
- for node in graph.node:
- if node.name not in nodes_map.keys():
- nodes_map[node.name] = node
- else:
- raise ValueError("Duplicate node names detected.")
- return nodes_map
-
- def rewrite(self, output_node_names):
- """Triggers rewriting of the float graph.
-
- Args:
- output_node_names: A list of names of the nodes that produce the final
- results.
-
- Returns:
- A quantized version of the float graph.
- """
- self.output_graph = graph_pb2.GraphDef()
- output_nodes = [
- self.nodes_map[output_node_name]
- for output_node_name in output_node_names
- ]
- if self.mode == "round":
- self.already_visited = {}
- for output_node in output_nodes:
- self.round_nodes_recursively(output_node)
- elif self.mode == "quantize":
- self.already_visited = {}
- self.already_quantized = {}
- for output_node in output_nodes:
- self.quantize_nodes_recursively(output_node)
- elif self.mode == "eightbit":
- self.set_input_graph(graph_util.remove_training_nodes(
- self.input_graph, protected_nodes=output_node_names))
- output_nodes = [
- self.nodes_map[output_node_name]
- for output_node_name in output_node_names
- ]
-
- self.state = EightbitizeRecursionState(
- already_visited={}, output_node_stack=[], merged_with_fake_quant={})
- for output_node in output_nodes:
- self.eightbitize_nodes_recursively(output_node)
- self.state = None
- if self.input_range:
- self.add_output_graph_node(
- create_constant_node("quantized_input_min_value", self.input_range[
- 0], dtypes.float32, []))
- self.add_output_graph_node(
- create_constant_node("quantized_input_max_value", self.input_range[
- 1], dtypes.float32, []))
- if self.fallback_quantization_range:
- self.add_output_graph_node(
- create_constant_node("fallback_quantization_min_value",
- self.fallback_quantization_range[0],
- dtypes.float32, []))
- self.add_output_graph_node(
- create_constant_node("fallback_quantization_max_value",
- self.fallback_quantization_range[1],
- dtypes.float32, []))
- if FLAGS.strip_redundant_quantization:
- self.output_graph = self.remove_redundant_quantization(
- self.output_graph)
- self.remove_dead_nodes(output_node_names)
- self.apply_final_node_renames()
- elif self.mode == "weights":
- self.output_graph = self.quantize_weights(self.input_graph,
- b"MIN_COMBINED")
- self.remove_dead_nodes(output_node_names)
- elif self.mode == "weights_rounded":
- self.output_graph = self.quantize_weights(self.input_graph, self.mode)
- self.remove_dead_nodes(output_node_names)
- else:
- print("Bad mode - " + self.mode + ".")
- return self.output_graph
-
- def round_nodes_recursively(self, current_node):
- """The entry point for simple rounding quantization."""
- if (current_node.name in self.already_visited
- ) and self.already_visited[current_node.name]:
- return
- self.already_visited[current_node.name] = True
- for input_node_name in current_node.input:
- input_node_name = node_name_from_input(input_node_name)
- input_node = self.nodes_map[input_node_name]
- self.round_nodes_recursively(input_node)
- nodes_to_quantize = ["Conv2D", "BiasAdd", "MatMul"]
- if any(current_node.op in s for s in nodes_to_quantize):
- new_node = node_def_pb2.NodeDef()
- new_node.CopyFrom(current_node)
- new_node.name = current_node.name + "_original"
- self.add_output_graph_node(new_node)
- levels = 1 << FLAGS.bitdepth
- constant_name = current_node.name + "_round_depth"
- constant_tensor = constant_op.constant(
- levels, dtype=dtypes.int32, name=constant_name)
- constant_node = constant_tensor.op.node_def
- self.add_output_graph_node(constant_node)
- quantize_node = node_def_pb2.NodeDef()
- quantize_node.op = "RoundToSteps"
- quantize_node.name = current_node.name
- quantize_node.input.extend([current_node.name + "_original"])
- quantize_node.input.extend([constant_node.name])
- self.add_output_graph_node(quantize_node)
- else:
- new_node = node_def_pb2.NodeDef()
- new_node.CopyFrom(current_node)
- self.add_output_graph_node(new_node)
-
- def quantize_nodes_recursively(self, current_node):
- """The entry point for quantizing nodes to eight bit and back."""
- if self.already_visited[current_node.name]:
- return
- self.already_visited[current_node.name] = True
- for input_node_name in current_node.input:
- input_node_name = node_name_from_input(input_node_name)
- input_node = self.nodes_map[input_node_name]
- self.quantize_nodes_recursively(input_node)
- nodes_to_quantize = ["Conv2D", "BiasAdd", "MatMul"]
- if any(current_node.op in s for s in nodes_to_quantize):
- for input_name in current_node.input:
- input_name = node_name_from_input(input_name)
- input_node = self.nodes_map[input_name]
- self.quantize_node(input_node)
- self.quantize_node(current_node)
- else:
- new_node = node_def_pb2.NodeDef()
- new_node.CopyFrom(current_node)
- self.add_output_graph_node(new_node)
-
- def quantize_node(self, input_node):
- """Handles quantizing a single node."""
- input_name = input_node.name
- if input_name in self.already_quantized:
- return
- self.already_quantized[input_name] = True
- original_input_name = input_name + "_original"
- reshape_name = input_name + "_reshape"
- reshape_dims_name = input_name + "_reshape_dims"
- max_name = input_name + "_max"
- min_name = input_name + "_min"
- dims_name = input_name + "_dims"
- quantize_name = input_name + "_quantize"
- dequantize_name = input_name
- original_input_node = node_def_pb2.NodeDef()
- original_input_node.CopyFrom(input_node)
- original_input_node.name = original_input_name
- self.add_output_graph_node(original_input_node)
- reshape_dims_node = create_constant_node(reshape_dims_name, -1,
- dtypes.int32, [1])
- self.add_output_graph_node(reshape_dims_node)
- reshape_node = create_node("Reshape", reshape_name,
- [original_input_name, reshape_dims_name])
- set_attr_dtype(reshape_node, "T", dtypes.float32)
- self.add_output_graph_node(reshape_node)
- dims_node = create_constant_node(dims_name, 0, dtypes.int32, [1])
- self.add_output_graph_node(dims_node)
- max_node = create_node("Max", max_name, [reshape_name, dims_name])
- set_attr_dtype(max_node, "T", dtypes.float32)
- set_attr_bool(max_node, "keep_dims", False)
- self.add_output_graph_node(max_node)
- min_node = create_node("Min", min_name, [reshape_name, dims_name])
- set_attr_dtype(min_node, "T", dtypes.float32)
- set_attr_bool(min_node, "keep_dims", False)
- self.add_output_graph_node(min_node)
- quantize_node = create_node("Quantize", quantize_name,
- [original_input_name, min_name, max_name])
- set_attr_dtype(quantize_node, "T", dtypes.quint8)
- set_attr_string(quantize_node, "mode", b"MIN_FIRST")
- self.add_output_graph_node(quantize_node)
- dequantize_node = create_node("Dequantize", dequantize_name,
- [quantize_name, min_name, max_name])
- set_attr_dtype(dequantize_node, "T", dtypes.quint8)
- set_attr_string(dequantize_node, "mode", b"MIN_FIRST")
- self.add_output_graph_node(dequantize_node)
-
- def should_merge_with_fake_quant_node(self):
- """Should the current node merge with self.state.output_node_stack[-1]?"""
- if not self.state.output_node_stack:
- return False
- top = self.state.output_node_stack[-1]
- return top[1] == 0 and top[0].op in ["FakeQuantWithMinMaxVars"]
-
- def should_quantize_const(self, node):
- if not self.state.output_node_stack:
- return False
- top = self.state.output_node_stack[-1]
- if not top[2]:
- return False
- dtype = dtypes.as_dtype(node.attr["dtype"].type)
- assert dtype == dtypes.float32, (
- "Failed to quantized constant %s of type %s" % (node.name, dtype))
- return True
-
- def eightbitize_nodes_recursively(self, current_node):
- """The entry point for transforming a graph into full eight bit."""
- if current_node.name in self.state.already_visited:
- if (self.should_merge_with_fake_quant_node() or
- current_node.name in self.state.merged_with_fake_quant):
- raise ValueError("Unsupported graph structure: output of node %s "
- "is processed by a FakeQuant* node and should have "
- "no other outputs.", current_node.name)
- return
- self.state.already_visited[current_node.name] = True
-
- for i, input_node_name in enumerate(current_node.input):
- quantize_input = False
- if current_node.op in ("MatMul", "Conv2D", "BiasAdd", "MaxPool",
- "AvgPool", "Relu", "Relu6",
- "BatchNormWithGlobalNormalization"):
- quantize_input = True
- elif current_node.op == "Concat" and i > 0:
- quantize_input = (
- dtypes.as_dtype(current_node.attr["T"].type) == dtypes.float32)
- elif current_node.op == "Reshape" and i == 0:
- quantize_input = (
- dtypes.as_dtype(current_node.attr["T"].type) == dtypes.float32)
-
- self.state.output_node_stack.append((current_node, i, quantize_input))
-
- input_node_name = node_name_from_input(input_node_name)
- input_node = self.nodes_map[input_node_name]
- self.eightbitize_nodes_recursively(input_node)
-
- self.state.output_node_stack.pop()
-
- if current_node.op == "MatMul":
- self.eightbitize_mat_mul_node(current_node)
- elif current_node.op == "Conv2D":
- self.eightbitize_conv_node(current_node)
- elif current_node.op == "BiasAdd":
- self.eightbitize_bias_add_node(current_node)
- elif current_node.op == "MaxPool" or current_node.op == "AvgPool":
- self.eightbitize_single_input_tensor_node(current_node,
- self.add_pool_function)
- elif current_node.op == "Relu" or current_node.op == "Relu6":
- self.eightbitize_single_input_tensor_node(current_node,
- self.add_relu_function)
- elif (current_node.op == "Concat" and
- dtypes.as_dtype(current_node.attr["T"].type) == dtypes.float32):
- self.eightbitize_concat_node(current_node)
- elif current_node.op == "BatchNormWithGlobalNormalization":
- self.eightbitize_batch_norm_node(current_node)
- elif (current_node.op == "Reshape" and
- dtypes.as_dtype(current_node.attr["T"].type) == dtypes.float32):
- self.eightbitize_reshape_node(current_node)
- elif (self.input_range and
- current_node.op in ("Placeholder", "PlaceholderV2")):
- self.eightbitize_placeholder_node(current_node)
- elif current_node.op == "FakeQuantWithMinMaxVars":
- # It will have been merged into the underlying node.
- pass
- elif current_node.op == "Const":
- if self.should_quantize_const(current_node):
- for n in quantize_weight_eightbit(current_node, b"MIN_FIRST"):
- self.add_output_graph_node(n)
- else:
- new_node = node_def_pb2.NodeDef()
- new_node.CopyFrom(current_node)
- self.add_output_graph_node(new_node)
-
- ###################################################################
- # Note: if more cases are added here, you may need to update the op
- # name lists in the loop over children at the start of the function.
- ###################################################################
- else:
- new_node = node_def_pb2.NodeDef()
- new_node.CopyFrom(current_node)
- self.add_output_graph_node(new_node)
-
- if (self.should_merge_with_fake_quant_node() and
- current_node.name not in self.state.merged_with_fake_quant):
- raise ValueError(
- "FakeQuant* node %s failed to merge with node %s of type %s" %
- (self.state.output_node_stack[-1][0], current_node.name,
- current_node.op))
-
- def add_eightbit_prologue_nodes(self, original_node):
- """Adds input conversion nodes to handle quantizing the underlying node."""
- namespace_prefix = original_node.name + "_eightbit"
- reshape_dims_name, reduction_dims_name = self.add_common_quantization_nodes(
- namespace_prefix)
- input_names = []
- min_max_names = []
- for original_input_name in original_node.input:
- quantize_input_name, min_input_name, max_input_name = (
- self.eightbitize_input_to_node(namespace_prefix, original_input_name,
- reshape_dims_name,
- reduction_dims_name))
- input_names.append(quantize_input_name)
- min_max_names.append(min_input_name)
- min_max_names.append(max_input_name)
- all_input_names = []
- all_input_names.extend(input_names)
- all_input_names.extend(min_max_names)
- return all_input_names
-
- def add_common_quantization_nodes(self, namespace_prefix):
- """Builds constant nodes needed for quantization of inputs."""
- reshape_dims_name = namespace_prefix + "_reshape_dims"
- reduction_dims_name = namespace_prefix + "_reduction_dims"
-
- reshape_dims_node = create_constant_node(reshape_dims_name, -1,
- dtypes.int32, [1])
- self.add_output_graph_node(reshape_dims_node)
- reduction_dims_node = create_constant_node(reduction_dims_name, 0,
- dtypes.int32, [1])
- self.add_output_graph_node(reduction_dims_node)
- return reshape_dims_name, reduction_dims_name
-
- def eightbitize_input_to_node(self, namespace_prefix, original_input_name,
- reshape_dims_name, reduction_dims_name):
- """Takes one float input to an op, and converts it to quantized form."""
- unique_input_name = unique_node_name_from_input(original_input_name)
- reshape_input_name = namespace_prefix + "_reshape_" + unique_input_name
- min_input_name = namespace_prefix + "_min_" + unique_input_name
- max_input_name = namespace_prefix + "_max_" + unique_input_name
- quantize_input_name = namespace_prefix + "_quantize_" + unique_input_name
- reshape_input_node = create_node("Reshape", reshape_input_name,
- [original_input_name, reshape_dims_name])
- set_attr_dtype(reshape_input_node, "T", dtypes.float32)
- self.add_output_graph_node(reshape_input_node)
- min_input_node = create_node("Min", min_input_name,
- [reshape_input_name, reduction_dims_name])
- set_attr_dtype(min_input_node, "T", dtypes.float32)
- set_attr_bool(min_input_node, "keep_dims", False)
- self.add_output_graph_node(min_input_node)
- max_input_node = create_node("Max", max_input_name,
- [reshape_input_name, reduction_dims_name])
- set_attr_dtype(max_input_node, "T", dtypes.float32)
- set_attr_bool(max_input_node, "keep_dims", False)
- self.add_output_graph_node(max_input_node)
- quantize_input_node = create_node(
- "QuantizeV2", quantize_input_name,
- [original_input_name, min_input_name, max_input_name])
- set_attr_dtype(quantize_input_node, "T", dtypes.quint8)
- set_attr_string(quantize_input_node, "mode", b"MIN_FIRST")
- self.add_output_graph_node(quantize_input_node)
- min_output_name = quantize_input_name + ":1"
- max_output_name = quantize_input_name + ":2"
- return quantize_input_name, min_output_name, max_output_name
-
- def add_quantize_down_nodes(self, original_node, quantized_output_name):
- quantized_outputs = [
- quantized_output_name, quantized_output_name + ":1",
- quantized_output_name + ":2"
- ]
- min_max_inputs = None
- if self.should_merge_with_fake_quant_node():
- # Use the inputs to the FakeQuantWithMinMaxVars node as the inputs to
- # Requantize.
- fake_quant_node = self.state.output_node_stack[-1][0]
- min_max_inputs = [fake_quant_node.input[1], fake_quant_node.input[2]]
- assert original_node.name not in self.state.merged_with_fake_quant
- self.state.merged_with_fake_quant[original_node.name] = True
- elif self.fallback_quantization_range:
- min_max_inputs = [
- "fallback_quantization_min_value:0",
- "fallback_quantization_max_value:0"
- ]
- else:
- # Add a RequantizationRange node for finding the min and max values.
- requant_range_node = create_node(
- "RequantizationRange", original_node.name + "_eightbit_requant_range",
- quantized_outputs)
- set_attr_dtype(requant_range_node, "Tinput", dtypes.qint32)
- self.add_output_graph_node(requant_range_node)
- min_max_inputs = [
- requant_range_node.name + ":0", requant_range_node.name + ":1"
- ]
- requantize_node = create_node("Requantize",
- original_node.name + "_eightbit_requantize",
- quantized_outputs + min_max_inputs)
- set_attr_dtype(requantize_node, "Tinput", dtypes.qint32)
- set_attr_dtype(requantize_node, "out_type", dtypes.quint8)
- self.add_output_graph_node(requantize_node)
- return requantize_node.name
-
- def add_dequantize_result_node(self,
- quantized_output_name,
- original_node_name,
- min_tensor_index=1):
- min_max_inputs = [
- "%s:%s" % (quantized_output_name, min_tensor_index),
- "%s:%s" % (quantized_output_name, (min_tensor_index + 1))
- ]
- dequantize_name = original_node_name
- if self.should_merge_with_fake_quant_node():
- fake_quant_node = self.state.output_node_stack[-1][0]
- if original_node_name not in self.state.merged_with_fake_quant:
- min_max_inputs = [fake_quant_node.input[1], fake_quant_node.input[2]]
- self.state.merged_with_fake_quant[original_node_name] = True
- dequantize_name = fake_quant_node.name
-
- dequantize_node = create_node(
- "Dequantize", dequantize_name,
- [quantized_output_name, min_max_inputs[0], min_max_inputs[1]])
- set_attr_dtype(dequantize_node, "T", dtypes.quint8)
- set_attr_string(dequantize_node, "mode", b"MIN_FIRST")
- self.add_output_graph_node(dequantize_node)
-
- def eightbitize_mat_mul_node(self, original_node):
- """Replaces a MatMul node with the eight bit equivalent sub-graph."""
- quantized_mat_mul_name = original_node.name + "_eightbit_quantized_mat_mul"
- all_input_names = self.add_eightbit_prologue_nodes(original_node)
- quantized_mat_mul_node = create_node("QuantizedMatMul",
- quantized_mat_mul_name,
- all_input_names)
- set_attr_dtype(quantized_mat_mul_node, "T1", dtypes.quint8)
- set_attr_dtype(quantized_mat_mul_node, "T2", dtypes.quint8)
- set_attr_dtype(quantized_mat_mul_node, "Toutput", dtypes.qint32)
- copy_attr(quantized_mat_mul_node, "transpose_a",
- original_node.attr["transpose_a"])
- copy_attr(quantized_mat_mul_node, "transpose_b",
- original_node.attr["transpose_b"])
- self.add_output_graph_node(quantized_mat_mul_node)
- quantize_down_name = self.add_quantize_down_nodes(original_node,
- quantized_mat_mul_name)
- self.add_dequantize_result_node(quantize_down_name, original_node.name)
-
- def eightbitize_conv_node(self, original_node):
- """Replaces a Conv2D node with the eight bit equivalent sub-graph."""
- all_input_names = self.add_eightbit_prologue_nodes(original_node)
- quantized_conv_name = original_node.name + "_eightbit_quantized_conv"
- quantized_conv_node = create_node("QuantizedConv2D", quantized_conv_name,
- all_input_names)
- copy_attr(quantized_conv_node, "strides", original_node.attr["strides"])
- copy_attr(quantized_conv_node, "padding", original_node.attr["padding"])
- set_attr_dtype(quantized_conv_node, "Tinput", dtypes.quint8)
- set_attr_dtype(quantized_conv_node, "Tfilter", dtypes.quint8)
- set_attr_dtype(quantized_conv_node, "out_type", dtypes.qint32)
- self.add_output_graph_node(quantized_conv_node)
- quantize_down_name = self.add_quantize_down_nodes(original_node,
- quantized_conv_name)
- self.add_dequantize_result_node(quantize_down_name, original_node.name)
-
- def eightbitize_bias_add_node(self, original_node):
- """Replaces a BiasAdd node with the eight bit equivalent sub-graph."""
- quantized_bias_add_name = (
- original_node.name + "_eightbit_quantized_bias_add")
- all_input_names = self.add_eightbit_prologue_nodes(original_node)
- quantized_bias_add_node = create_node("QuantizedBiasAdd",
- quantized_bias_add_name,
- all_input_names)
- set_attr_dtype(quantized_bias_add_node, "T1", dtypes.quint8)
- set_attr_dtype(quantized_bias_add_node, "T2", dtypes.quint8)
- set_attr_dtype(quantized_bias_add_node, "out_type", dtypes.qint32)
- self.add_output_graph_node(quantized_bias_add_node)
- quantize_down_name = self.add_quantize_down_nodes(original_node,
- quantized_bias_add_name)
- self.add_dequantize_result_node(quantize_down_name, original_node.name)
-
- def eightbitize_single_input_tensor_node(self, original_node,
- add_op_function):
- """Replaces a single-tensor node with the eight bit equivalent sub-graph.
-
- Converts a node like this:
-
- Shape(f) Input(f)
- | |
- +--------v v
- Operation
- |
- v
- (f)
-
- Into a quantized equivalent:
-
- Input(f) ReshapeDims
- +------v v-------------+
- | Reshape
- | |
- | | ReductionDims
- | +-----+ |
- | | +---c---------+
- | v v v v-------+
- | Min Max
- | +----+ |
- v v v--------+
- Quantize
- |
- v
- QuantizedOperation
- | | |
- v v v
- Dequantize
- |
- v
- (f)
-
-
- Args:
- original_node: Float node to be converted.
- add_op_function: Function to create the actual node.
-
- Returns:
- Subgraph representing the quantized version of the original node.
-
- """
- quantized_op_name = original_node.name + "_eightbit_quantized"
- quantized_op_type = "Quantized" + original_node.op
- all_input_names = self.add_eightbit_prologue_nodes(original_node)
- quantized_op_node = create_node(quantized_op_type, quantized_op_name,
- all_input_names)
- add_op_function(original_node, quantized_op_node)
- self.add_output_graph_node(quantized_op_node)
- self.add_dequantize_result_node(quantized_op_name, original_node.name)
-
- def add_pool_function(self, original_node, quantized_op_node):
- set_attr_dtype(quantized_op_node, "T", dtypes.quint8)
- copy_attr(quantized_op_node, "ksize", original_node.attr["ksize"])
- copy_attr(quantized_op_node, "strides", original_node.attr["strides"])
- copy_attr(quantized_op_node, "padding", original_node.attr["padding"])
-
- def add_relu_function(self, unused_arg_node, quantized_op_node):
- set_attr_dtype(quantized_op_node, "Tinput", dtypes.quint8)
-
- def eightbitize_concat_node(self, original_node):
- """Replaces a Concat node with the eight bit equivalent sub-graph.
-
- Converts a node like this:
-
- Shape(f) Input0(f) Input1(f)
- | | |
- +--------v v v----------+
- Concat
- |
- v
- (f)
-
- Into a quantized equivalent:
-
- Shape(f) Input0(f) ReshapeDims Input1(f)
- | +------v v--------------+------------------v v------+
- | | Reshape Reshape |
- | | | | |
- | | | ReductionDims | |
- | | +------+ | +--------+ |
- | | | +---c---------+-----------c-----+ | |
- | | +v v v v-------+---------v v v v+ |
- | | Min Max Min Max |
- | | +----+ | | +-----+ |
- | v v v--------+ +----------v v v
- | Quantize Quantize
- | +------------------+ +----------------------+
- +-------------------------------+ | |
- v v v
- QuantizedConcat
- | | |
- v v v
- Dequantize
- |
- v
- (f)
- Args:
- original_node: Float node to be converted.
-
- Returns:
- Subgraph representing the quantized version of the original node.
-
- """
- namespace_prefix = original_node.name + "_eightbit"
- quantized_concat_name = namespace_prefix + "_quantized_concat"
- reshape_dims_name, reduction_dims_name = self.add_common_quantization_nodes(
- namespace_prefix)
- shape_input_name = original_node.input[0]
- original_inputs = original_node.input[1:]
- input_names = []
- min_names = []
- max_names = []
- for original_input_name in original_inputs:
- quantize_input_name, min_input_name, max_input_name = (
- self.eightbitize_input_to_node(namespace_prefix, original_input_name,
- reshape_dims_name,
- reduction_dims_name))
- input_names.append(quantize_input_name)
- min_names.append(min_input_name)
- max_names.append(max_input_name)
- all_input_names = [shape_input_name]
- all_input_names.extend(input_names)
- all_input_names.extend(min_names)
- all_input_names.extend(max_names)
- quantized_concat_node = create_node("QuantizedConcat",
- quantized_concat_name, all_input_names)
- set_attr_int(quantized_concat_node, "N", len(original_inputs))
- set_attr_dtype(quantized_concat_node, "T", dtypes.quint8)
- self.add_output_graph_node(quantized_concat_node)
- self.add_dequantize_result_node(quantized_concat_name, original_node.name)
-
- def eightbitize_placeholder_node(self, current_node):
- """Replaces a placeholder node with a quint8 placeholder node+dequantize."""
- name = current_node.name
-
- # Convert the placeholder into a quantized type.
- output_node = node_def_pb2.NodeDef()
- output_node.CopyFrom(current_node)
- set_attr_dtype(output_node, "dtype", dtypes.quint8)
- output_node.name += "_original_input"
- self.add_output_graph_node(output_node)
-
- # Add a dequantize to convert back to float.
- dequantize_node = create_node("Dequantize", name, [
- output_node.name, "quantized_input_min_value",
- "quantized_input_max_value"
- ])
- set_attr_dtype(dequantize_node, "T", dtypes.quint8)
- set_attr_string(dequantize_node, "mode", b"MIN_FIRST")
- self.add_output_graph_node(dequantize_node)
-
- # For the descent over the graph to work, the dequantize node must be named
- # current_node.name. However, for the feeding of the graph to work, the
- # placeholder must have the name current_node.name; so record a final set
- # of renames to apply after all processing has been done.
- self.final_node_renames[output_node.name] = name
- self.final_node_renames[dequantize_node.name] = name + "_dequantize"
-
- def eightbitize_reshape_node(self, original_node):
- """Replaces a Reshape node with the eight bit equivalent sub-graph.
-
- Args:
- original_node: Float node to be converted.
-
- Returns:
- Subgraph representing the quantized version of the original node.
-
- """
- namespace_prefix = original_node.name + "_eightbit"
- quantized_reshape_name = namespace_prefix + "_quantized_reshape"
- reshape_dims_name, reduction_dims_name = self.add_common_quantization_nodes(
- namespace_prefix)
- shape_input_name = original_node.input[1]
- quantize_input_name, min_input_name, max_input_name = (
- self.eightbitize_input_to_node(namespace_prefix, original_node.input[0],
- reshape_dims_name, reduction_dims_name))
- quantized_reshape_node = create_node(
- "QuantizedReshape", quantized_reshape_name,
- [quantize_input_name, shape_input_name, min_input_name, max_input_name])
- set_attr_dtype(quantized_reshape_node, "T", dtypes.quint8)
- self.add_output_graph_node(quantized_reshape_node)
- self.add_dequantize_result_node(quantized_reshape_name, original_node.name)
-
- def eightbitize_batch_norm_node(self, original_node):
- """Replaces a MatMul node with the eight bit equivalent sub-graph."""
- namespace_prefix = original_node.name + "_eightbit"
- original_input_name = original_node.input[0]
- original_mean_name = original_node.input[1]
- original_variance_name = original_node.input[2]
- original_beta_name = original_node.input[3]
- original_gamma_name = original_node.input[4]
- quantized_batch_norm_name = namespace_prefix + "_quantized_batch_norm"
-
- reshape_dims_name, reduction_dims_name = self.add_common_quantization_nodes(
- namespace_prefix)
- quantize_input_name, min_input_name, max_input_name = (
- self.eightbitize_input_to_node(namespace_prefix, original_input_name,
- reshape_dims_name, reduction_dims_name))
- quantize_mean_name, min_mean_name, max_mean_name = (
- self.eightbitize_input_to_node(namespace_prefix, original_mean_name,
- reshape_dims_name, reduction_dims_name))
- quantize_variance_name, min_variance_name, max_variance_name = (
- self.eightbitize_input_to_node(namespace_prefix, original_variance_name,
- reshape_dims_name, reduction_dims_name))
- quantize_beta_name, min_beta_name, max_beta_name = (
- self.eightbitize_input_to_node(namespace_prefix, original_beta_name,
- reshape_dims_name, reduction_dims_name))
- quantize_gamma_name, min_gamma_name, max_gamma_name = (
- self.eightbitize_input_to_node(namespace_prefix, original_gamma_name,
- reshape_dims_name, reduction_dims_name))
- quantized_batch_norm_node = create_node(
- "QuantizedBatchNormWithGlobalNormalization", quantized_batch_norm_name,
- [
- quantize_input_name, min_input_name, max_input_name,
- quantize_mean_name, min_mean_name, max_mean_name,
- quantize_variance_name, min_variance_name, max_variance_name,
- quantize_beta_name, min_beta_name, max_beta_name,
- quantize_gamma_name, min_gamma_name, max_gamma_name
- ])
- set_attr_dtype(quantized_batch_norm_node, "Tinput", dtypes.quint8)
- set_attr_dtype(quantized_batch_norm_node, "out_type", dtypes.qint32)
- copy_attr(quantized_batch_norm_node, "scale_after_normalization",
- original_node.attr["scale_after_normalization"])
- copy_attr(quantized_batch_norm_node, "variance_epsilon",
- original_node.attr["variance_epsilon"])
- self.add_output_graph_node(quantized_batch_norm_node)
- quantize_down_name = self.add_quantize_down_nodes(original_node,
- quantized_batch_norm_name)
- self.add_dequantize_result_node(quantize_down_name, original_node.name)
-
- def add_output_graph_node(self, output_node):
- """Inserts one node into the new graph."""
- self.output_graph.node.extend([output_node])
-
- def remove_redundant_quantization(self, old_graph):
- """Removes unneeded pairs of quantize/dequantize ops from the graph.
-
- This is a bit of a tricky function, because it's attempting to spot the
- pattern of dequantizing from eight-bit up to float, and then immediately
- quantizing back down to eight bits again, that's introduced by previous
- passes that do 'key-hole' conversions of individual nodes but have to
- convert back to float to match the previous output interface, since they
- don't know that the next op can handle quantized tensors.
- It works by:
- - Looking for Quantize nodes.
- - Checking to see if their first input is a Dequantize node.
- - Seeing if their min/max inputs come from Min/Max nodes.
- - Making sure those Min/Max nodes are being fed from the same Dequantize.
- - Or that the Min is indirectly being fed from the same Dequantize as Max.
- - Making sure the Dequantize is going through a Reshape (which we add
- during the previous pass when we create the quantize sub-graph).
- - Looking for the dims Const op for the Min/Max dims.
- If all of these conditions are met, then it's a sub-graph pattern that
- we know how to optimize out (and is likely the common one we've introduced).
- We then rewire the graph to skip it entirely, and then rely on the dead node
- removal pass to get rid of any nodes that are no longer needed.
-
- Args:
- old_graph: The model we'll be stripping redundant nodes from.
-
- Returns:
- A graph with the unnecessary nodes removed.
-
- Raises:
- ValueError: Two nodes with the same name were found in the graph.
- """
- old_nodes_map = self.create_nodes_map(old_graph)
- self.output_graph = graph_pb2.GraphDef()
- inputs_to_rename = {}
- # We go through all the nodes, looking for any that match the patterns we
- # know how to optimize away.
- for node in old_graph.node:
- # We always start with a Quantize node, and examine its inputs to see if
- # they are in a form that can be removed.
- if node.op not in ["Quantize", "QuantizeV2"]:
- continue
- dequantize_node_name = node_name_from_input(node.input[0])
- if dequantize_node_name not in old_nodes_map:
- raise ValueError("Input node name '" + dequantize_node_name +
- "' not found in node '" + node.name + "'")
- dequantize_node = old_nodes_map[dequantize_node_name]
- # Do we have a Dequantize feeding in, with the same type as the Quantize?
- if dequantize_node.op != "Dequantize":
- continue
- if node.attr["T"] != dequantize_node.attr["T"]:
- continue
- # Now look at the other inputs, and ensure they're Min/Max nodes.
- min_node_name = node_name_from_input(node.input[1])
- max_node_name = node_name_from_input(node.input[2])
- min_node = old_nodes_map[min_node_name]
- max_node = old_nodes_map[max_node_name]
- is_min_right_type = (min_node.op in ["Min", "Dequantize"])
- is_max_right_type = (max_node.op in ["Max", "Dequantize"])
- if not is_min_right_type or not is_max_right_type:
- print("Didn't find expected types on inputs : %s, %s." % (min_node.op,
- max_node.op))
- continue
- min_node_input_name = node_name_from_input(min_node.input[0])
- max_node_input_name = node_name_from_input(max_node.input[0])
- # There are two different patterns for Min nodes we can recognize, one
- # where the input comes directly from the same one as the Max, and
- # another where we run it through another Min first, so check for both.
- is_same_input = False
- if min_node_input_name == max_node_input_name:
- is_same_input = True
- else:
- first_min_node_input = old_nodes_map[min_node_input_name]
- if first_min_node_input.op == "Concat":
- second_min_node_name = node_name_from_input(
- first_min_node_input.input[1])
- second_min_node = old_nodes_map[second_min_node_name]
- if second_min_node.op == "Min":
- second_min_node_input_name = node_name_from_input(
- second_min_node.input[0])
- is_same_input = (second_min_node_input_name == max_node_input_name)
- if not is_same_input:
- print("Different min/max inputs: " + min_node_input_name)
- continue
- # We recognize this pattern, so mark the graph edges to be rewired to
- # route around it entirely, since we know it's a no-op.
- dequantize_source_name = node_name_from_input(dequantize_node.input[0])
- node_tensor_name = ensure_tensor_name_has_port(node.name)
- min_tensor_name = node.name + ":1"
- max_tensor_name = node.name + ":2"
- inputs_to_rename[node_tensor_name] = dequantize_source_name
- inputs_to_rename[min_tensor_name] = dequantize_node.input[1]
- inputs_to_rename[max_tensor_name] = dequantize_node.input[2]
- # Finally we apply all the rewiring we've marked to the graph.
- for node in old_graph.node:
- for index, input_full_name in enumerate(node.input):
- input_name = ensure_tensor_name_has_port(input_full_name)
- if input_name in inputs_to_rename:
- node.input[index] = inputs_to_rename[input_name]
- self.add_output_graph_node(node)
- return self.output_graph
-
- def apply_final_node_renames(self):
- """Applies node renames in self.final_node_renames to self.output_graph."""
- old_graph = self.output_graph
- self.output_graph = graph_pb2.GraphDef()
- for node in old_graph.node:
- node.name = self.final_node_renames.get(node.name, node.name)
- for index, input_name in enumerate(node.input):
- node_name = node_name_from_input(input_name)
- input_full_name = ensure_tensor_name_has_port(input_name)
- if node_name in self.final_node_renames:
- node.input[index] = "%s%s" % (self.final_node_renames[node_name],
- input_full_name[len(node_name):])
- self.add_output_graph_node(node)
- return self.output_graph
-
- def remove_dead_nodes(self, output_names):
- """Removes nodes that are no longer needed for inference from the graph."""
- old_output_graph = self.output_graph
- self.output_graph = graph_util.extract_sub_graph(old_output_graph,
- output_names)
-
- def quantize_weights(self, input_graph, quantization_mode):
- """Quantize float Const ops.
-
- There are two modes of operations, both replace float Const ops with
- quantized values.
- 1. If quantization_mode is "weights_rounded", this function replaces float
- Const ops with quantized float Const ops - same as the original op, but
- float values being mapped to the center of one of 1<<FLAGS.bitdepth buckets.
- This does not change the raw model size, but compression algorithms such as
- zip (as used for compressing apks) or bzip2 will achieve a very good
- compression ratio.
- 2. For other quantization modes ("MIN_COMBINED" or "MIN_FIRST"), float
- Const ops are quantized and replaced by a tuple of four ops to perform
- the dequantization at runtime:
- * eight-bit Const (bucket indices, same shape as original float Const op
- * two float Const ops (min and max value of original float Const op)
- * Dequantize op to convert the eight-bit consts to float tensors.
- The quantization mode is important because we see accuracy problems when
- quantizing weights for different situations depending on the algorithm
- used. We haven't figured out exactly what the underlying cause is yet,
- unfortunately.
-
- Args:
- input_graph: A GraphDef of the model containing float Const ops.
- quantization_mode: How to quantize and dequantize the values.
-
- Returns:
- A GraphDef of the converted graph.
-
- Raises:
- ValueError: If quantization_mode is unsupported.
- """
- output_graph = graph_pb2.GraphDef()
- for input_node in input_graph.node:
- should_quantize = False
- if input_node.op == "Const":
- dtype = dtypes.as_dtype(input_node.attr["dtype"].type)
- if dtype == dtypes.float32:
- should_quantize = True
- if should_quantize:
- if quantization_mode == "weights_rounded":
- output_graph.node.extend(quantize_weight_rounded(input_node))
- elif quantization_mode in (b"MIN_COMBINED", b"MIN_FIRST"):
- output_graph.node.extend(
- quantize_weight_eightbit(input_node, quantization_mode))
- else:
- raise ValueError("Unsupported quantization mode %s." %
- quantization_mode)
- else:
- output_node = node_def_pb2.NodeDef()
- output_node.CopyFrom(input_node)
- output_graph.node.extend([output_node])
- return output_graph
-
- def set_input_graph(self, new_input_graph):
- self.input_graph = new_input_graph
- self.nodes_map = self.create_nodes_map(self.input_graph)
-
-
-def main(unused_args):
- if not gfile.Exists(FLAGS.input):
- print("Input graph file '" + FLAGS.input + "' does not exist!")
- return -1
-
- known_modes = [
- "round", "quantize", "eightbit", "weights", "test", "weights_rounded"
- ]
- if not any(FLAGS.mode in s for s in known_modes):
- print("mode is '" + FLAGS.mode + "', not in " + ", ".join(known_modes) +
- ".")
- return -1
-
- tf_graph = graph_pb2.GraphDef()
- with gfile.Open(FLAGS.input, "rb") as f:
- data = f.read()
- tf_graph.ParseFromString(data)
-
- graph = ops.Graph()
- with graph.as_default():
- importer.import_graph_def(tf_graph, input_map={}, name="")
-
- quantized_input_range = None
- if FLAGS.quantized_input:
- quantized_input_range = [
- FLAGS.quantized_input_min, FLAGS.quantized_input_max
- ]
-
- fallback_quantization_range = None
- if (FLAGS.quantized_fallback_min is not None or
- FLAGS.quantized_fallback_max is not None):
- assert FLAGS.quantized_fallback_min is not None
- assert FLAGS.quantized_fallback_max is not None
- fallback_quantization_range = [
- FLAGS.quantized_fallback_min, FLAGS.quantized_fallback_max
- ]
-
- rewriter = GraphRewriter(tf_graph, FLAGS.mode, quantized_input_range,
- fallback_quantization_range)
-
- output_graph = rewriter.rewrite(FLAGS.output_node_names.split(","))
-
- f = gfile.FastGFile(FLAGS.output, "wb")
- f.write(output_graph.SerializeToString())
-
- return 0
-
-
-if __name__ == "__main__":
- app.run()
diff --git a/tensorflow/tools/quantization/quantize_graph_test.py b/tensorflow/tools/quantization/quantize_graph_test.py
deleted file mode 100644
index 92bb5127da..0000000000
--- a/tensorflow/tools/quantization/quantize_graph_test.py
+++ /dev/null
@@ -1,966 +0,0 @@
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests the graph quantization script.
-
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import sys
-import numpy as np
-
-from tensorflow.core.framework import graph_pb2
-from tensorflow.python.client import session
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import graph_util
-from tensorflow.python.framework import importer
-from tensorflow.python.framework import ops as ops_lib
-from tensorflow.python.platform import flags as flags_lib
-from tensorflow.python.platform import test
-from tensorflow.python.platform import tf_logging
-from tensorflow.tools.quantization import quantize_graph
-
-flags = flags_lib
-FLAGS = flags.FLAGS
-
-
-def run_graph_def(graph_def, input_map, outputs):
- graph = ops_lib.Graph()
- with graph.as_default():
- importer.import_graph_def(graph_def, input_map={}, name="")
- with session.Session(graph=graph) as sess:
- results = sess.run(outputs, feed_dict=input_map)
- return results
-
-
-def test_mat_mul(m, n, k, a, b):
- """Tests a MatMul replacement."""
- a_constant_name = "a_constant"
- b_constant_name = "b_constant"
- mat_mul_name = "mat_mul"
-
- float_graph_def = graph_pb2.GraphDef()
- a_constant = quantize_graph.create_constant_node(
- a_constant_name, value=a, dtype=dtypes.float32, shape=[m, k])
- float_graph_def.node.extend([a_constant])
- b_constant = quantize_graph.create_constant_node(
- b_constant_name, value=b, dtype=dtypes.float32, shape=[k, n])
- float_graph_def.node.extend([b_constant])
- mat_mul_node = quantize_graph.create_node("MatMul", mat_mul_name,
- [a_constant_name, b_constant_name])
- quantize_graph.set_attr_dtype(mat_mul_node, "T", dtypes.float32)
- quantize_graph.set_attr_bool(mat_mul_node, "transpose_a", False)
- quantize_graph.set_attr_bool(mat_mul_node, "transpose_b", False)
- float_graph_def.node.extend([mat_mul_node])
-
- test_graph(float_graph_def, {}, [mat_mul_name])
-
-
-def test_conv(depth, image_width, image_height, image_batch_count, filter_size,
- filter_count, stride, padding, input_values, filter_values):
- """Tests a Conv replacement."""
- input_constant_name = "input_constant"
- filter_constant_name = "filter_constant"
- conv_name = "conv"
-
- float_graph_def = graph_pb2.GraphDef()
- input_constant = quantize_graph.create_constant_node(
- input_constant_name,
- value=input_values,
- dtype=dtypes.float32,
- shape=[image_batch_count, image_height, image_width, depth])
- float_graph_def.node.extend([input_constant])
- filter_constant = quantize_graph.create_constant_node(
- filter_constant_name,
- value=filter_values,
- dtype=dtypes.float32,
- shape=[filter_size, filter_size, depth, filter_count])
- float_graph_def.node.extend([filter_constant])
- conv_node = quantize_graph.create_node(
- "Conv2D", conv_name, [input_constant_name, filter_constant_name])
- quantize_graph.set_attr_dtype(conv_node, "T", dtypes.float32)
- quantize_graph.set_attr_int_list(conv_node, "strides", [1, stride, stride, 1])
- quantize_graph.set_attr_string(conv_node, "padding", padding)
- float_graph_def.node.extend([conv_node])
-
- test_graph(float_graph_def, {}, [conv_name])
-
-
-def are_tensors_near(a, b, tolerance):
- """Tests whether two tensors are nearly identical.
-
- This is a specialized comparison function designed to help debug problems with
- quantization. It prints out information about the differences between tensors
- on failure, paying special attention to possible biases by looking at the mean
- and absolute average errors.
-
- Args:
- a: First comparison tensor.
- b: Second comparison tensor.
- tolerance: Float value indicating how large an error between values is ok.
-
- Returns:
- Boolean indicating whether the two inputs were close enough.
- """
- flat_a = a.flatten()
- flat_b = b.flatten()
- if len(flat_a) != len(flat_b):
- tf_logging.info("Tensors are different sizes: " + str(len(flat_a)) + " vs "
- + str(len(flat_b)))
- return False
- value_count = len(flat_a)
- how_many_different = 0
- total_difference = 0
- total_abs_difference = 0
- for index in range(value_count):
- a_value = flat_a[index]
- b_value = flat_b[index]
- difference = a_value - b_value
- total_difference += difference
- total_abs_difference += abs(difference)
- if abs(difference) > tolerance:
- how_many_different += 1
- mean_difference = total_difference / value_count
- mean_abs_difference = total_abs_difference / value_count
- proportion_different = (how_many_different * 1.0) / value_count
- if how_many_different == 0:
- return True
- else:
- tf_logging.info("Tensors have {0} different values ({1}%), with mean"
- " difference {2} and mean absolute difference {3}".format(
- how_many_different, proportion_different * 100,
- mean_difference, mean_abs_difference))
- return False
-
-
-def get_top_value(input_values):
- max_value = None
- max_index = None
- for index, value in enumerate(input_values.flatten()):
- if max_value is None or value > max:
- max_value = value
- max_index = index
- return max_index, max_value
-
-
-def test_graph(float_graph_def, input_map, output_names, log_graph=False):
- """Runs the float graph through the rewriter and tests the results."""
- float_results = run_graph_def(
- float_graph_def, input_map,
- [output_name + ":0" for output_name in output_names])
- # TODO(petewarden): round test is currently failing because there is no
- # RoundToSteps op available.
- # round_rewriter = quantize_graph.GraphRewriter(float_graph_def, "round")
- # round_graph_def = round_rewriter.rewrite(output_name)
- # round_results = run_graph_def(round_graph_def, input_map,
- # [output_name + ":0"])
- # assert are_tensors_near(expected, round_results[0], 1.0)
- #
- # TODO(petewarden): Add test for "quantize" mode.
-
- eightbit_rewriter = quantize_graph.GraphRewriter(
- float_graph_def, "eightbit", quantized_input_range=None)
- eightbit_graph_def = eightbit_rewriter.rewrite(output_names)
- eightbit_results = run_graph_def(
- eightbit_graph_def, input_map,
- [output_name + ":0" for output_name in output_names])
- for expected, result in zip(float_results, eightbit_results):
- assert are_tensors_near(expected, result, 1.0)
-
- if log_graph:
- tf_logging.info("8bit:\n%s", str(eightbit_graph_def))
-
- # Test the weights_rounded mode. This uses the default bit_depth.
- weights_rounded_rewriter = quantize_graph.GraphRewriter(
- float_graph_def, "weights_rounded", quantized_input_range=None)
- weights_rounded_graph_def = weights_rounded_rewriter.rewrite(output_names)
- weights_rounded_results = run_graph_def(
- weights_rounded_graph_def, input_map,
- [output_name + ":0" for output_name in output_names])
- for expected, result in zip(float_results, weights_rounded_results):
- assert are_tensors_near(expected, result, 1.0)
-
-
-class QuantizeGraphTest(test.TestCase):
-
- def test_negative_const_problem(self):
- shape_constant_name = "shape_constant"
- shape_constant = quantize_graph.create_constant_node(
- shape_constant_name, value=-0.8, dtype=dtypes.float32, shape=[1])
- quantization_result = quantize_graph.quantize_weight_eightbit(
- shape_constant, b"MIN_COMBINED")
- self.assertEqual(4, len(quantization_result))
-
- def test_odd_padding_problem(self):
- """Tests one error case we ran into in a real graph."""
- test_conv(1, 4, 4, 1, 3, 1, 2, b"SAME",
- [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
- [1, 2, 3, 4, 5, 6, 7, 8, 9])
-
- def test_mat_mul_tiny(self):
- # These tests are added to test the generate case where
- # min(matrix) == max(matrix), which used to cause problems.
- test_mat_mul(1, 1, 1, [2], [3])
- test_mat_mul(1, 2, 1, [1], [2, 3])
- test_mat_mul(1, 1, 2, [1, 1], [1, 1])
- test_mat_mul(1, 1, 2, [0, 0], [1, 1])
- # The general case.
- test_mat_mul(1, 1, 2, [1, 2], [1, 2])
-
- def test_mat_mul_small(self):
- test_mat_mul(2, 4, 3, [1, 2, 3, 4, 5, 6],
- [7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18])
-
- def test_conv(self):
- test_conv(1, 4, 3, 1, 3, 1, 1, b"SAME",
- [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
- [1, 4, 7, 2, 5, 8, 3, 6, 9])
-
- def test_reshape(self):
- """Tests that MatMul->Reshape->MatMul avoids extra quantize/dequantize."""
-
- def make_matmul(name, a, b):
- n = quantize_graph.create_node("MatMul", name, [a.name, b.name])
- quantize_graph.set_attr_dtype(n, "T", dtypes.float32)
- quantize_graph.set_attr_bool(n, "transpose_a", False)
- quantize_graph.set_attr_bool(n, "transpose_b", False)
- return n
-
- # matmul_1 = input*weight_1
- input_node = quantize_graph.create_constant_node(
- "input", value=[0, 1, 2, 3], dtype=dtypes.float32, shape=[4, 1])
- weight_1_node = quantize_graph.create_constant_node(
- "weight_1",
- value=[.5, .6, .7, .8, .9],
- dtype=dtypes.float32,
- shape=[1, 5])
- matmul_1_node = make_matmul("matmul_1", input_node, weight_1_node)
-
- # Reshape 4x5 to 10x2.
- new_shape_node = quantize_graph.create_constant_node(
- "new_shape_node", value=[10, 2], dtype=dtypes.int32, shape=[2])
- reshape_node = quantize_graph.create_node(
- "Reshape", "reshape", [matmul_1_node.name, new_shape_node.name])
- quantize_graph.set_attr_dtype(reshape_node, "T", dtypes.float32)
-
- # matmul_2_node = reshape*weight_2
- weight_2_node = quantize_graph.create_constant_node(
- "weight_2", value=[1.5, 2.5], dtype=dtypes.float32, shape=[2, 1])
- matmul_2_node = make_matmul("matmul_2", reshape_node, weight_2_node)
-
- g = graph_pb2.GraphDef()
- g.node.extend([
- input_node, weight_1_node, matmul_1_node, new_shape_node, reshape_node,
- weight_2_node, matmul_2_node
- ])
-
- # Test the graph
- test_graph(g, {}, ["matmul_2"])
-
- # Verify there is only one Quantize and one Requantize op.
- eightbit_rewriter = quantize_graph.GraphRewriter(
- g, "eightbit", quantized_input_range=None)
- eightbit_graph_def = eightbit_rewriter.rewrite(["matmul_2"])
-
- ops = [node.op for node in eightbit_graph_def.node]
- # No quantize since all inputs are const and can be quantized up-front.
- self.assertEqual(0, ops.count("QuantizeV2") + ops.count("Quantize"))
- self.assertEqual(1, ops.count("QuantizedReshape"))
-
- # One dequantize at the end.
- self.assertEqual(1, ops.count("Dequantize"))
-
- def test_quantize_array(self):
- # Test invalid parameters (empty array, or 0 buckets.
- self.assertRaises(ValueError, quantize_graph.quantize_array, np.array([]),
- 2)
- self.assertRaises(ValueError, quantize_graph.quantize_array,
- np.array([1, 2]), 0)
- # Test input array of length 1.
- arr = np.array([1])
- qarr = quantize_graph.quantize_array(arr, 1)
- self.assertEqual(arr, qarr)
- qarr = quantize_graph.quantize_array(arr, 2)
- self.assertEqual(arr, qarr)
- # Test input array with all elements equal.
- arr = np.array([1, 1, 1])
- qarr = quantize_graph.quantize_array(arr, 10)
- self.assertTrue((np.array([1, 1, 1]) == qarr).all())
- # Test "normal" input arrays.
- arr = np.array([0, 0.3, 0.6, 1])
- qarr = quantize_graph.quantize_array(arr, 1)
- self.assertTrue((np.array([0.5, 0.5, 0.5, 0.5]) == qarr).all())
- qarr = quantize_graph.quantize_array(arr, 2)
- self.assertTrue((np.array([0.25, 0.25, 0.75, 0.75]) == qarr).all())
- qarr = quantize_graph.quantize_array(arr.reshape((2, 2)), 2)
- self.assertTrue((np.array([[0.25, 0.25], [0.75, 0.75]]) == qarr).all())
-
- def test_non_float_concat(self):
- concat_dim = quantize_graph.create_constant_node(
- "concat_dim", value=0, dtype=dtypes.int32, shape=[])
- a = quantize_graph.create_constant_node(
- "a",
- value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
- dtype=dtypes.int32,
- shape=[2, 2, 3])
- b = quantize_graph.create_constant_node(
- "b",
- value=[13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24],
- dtype=dtypes.int32,
- shape=[2, 2, 3])
- concat = quantize_graph.create_node("Concat", "concat",
- [concat_dim.name, a.name, b.name])
- quantize_graph.set_attr_int(concat, "N", 2)
- quantize_graph.set_attr_dtype(concat, "T", dtypes.int32)
-
- g = graph_pb2.GraphDef()
- g.node.extend([concat_dim, a, b, concat])
- test_graph(g, {}, [concat.name])
-
- def test_non_float_reshape(self):
- a = quantize_graph.create_constant_node(
- "a",
- value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
- dtype=dtypes.int32,
- shape=[2, 2, 3])
- shape = quantize_graph.create_constant_node(
- "shape", value=[12], dtype=dtypes.int32, shape=[1])
- reshape = quantize_graph.create_node("Reshape", "reshape",
- [a.name, shape.name])
- quantize_graph.set_attr_dtype(reshape, "T", dtypes.int32)
-
- g = graph_pb2.GraphDef()
- g.node.extend([a, shape, reshape])
- test_graph(g, {}, [reshape.name])
-
- def test_concat(self):
- shape_constant_name = "shape_constant"
- a_constant_name = "a_constant"
- b_constant_name = "b_constant"
- concat_name = "concat"
-
- float_graph_def = graph_pb2.GraphDef()
- shape_constant = quantize_graph.create_constant_node(
- shape_constant_name, value=0, dtype=dtypes.int32, shape=[])
- float_graph_def.node.extend([shape_constant])
- a_constant = quantize_graph.create_constant_node(
- a_constant_name,
- value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
- dtype=dtypes.float32,
- shape=[2, 2, 3])
- float_graph_def.node.extend([a_constant])
- b_constant = quantize_graph.create_constant_node(
- b_constant_name,
- value=[13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24],
- dtype=dtypes.float32,
- shape=[2, 2, 3])
- float_graph_def.node.extend([b_constant])
- concat_node = quantize_graph.create_node(
- "Concat", concat_name,
- [shape_constant_name, a_constant_name, b_constant_name])
- quantize_graph.set_attr_int(concat_node, "N", 2)
- quantize_graph.set_attr_dtype(concat_node, "T", dtypes.float32)
- float_graph_def.node.extend([concat_node])
-
- test_graph(float_graph_def, {}, [concat_name])
-
- # Verify the concat is quantized.
- eightbit_rewriter = quantize_graph.GraphRewriter(
- float_graph_def, "eightbit", quantized_input_range=None)
- eightbit_graph_def = eightbit_rewriter.rewrite([concat_name])
-
- ops = [node.op for node in eightbit_graph_def.node]
- self.assertEqual(1, ops.count("QuantizedConcat"))
-
- def test_multiple_outputs(self):
- input_constant_name = "input_constant"
- split_constant_name = "split_constant"
- split_name = "split"
- concat_constant_name = "concat_constant"
- concat_name = "concat"
-
- float_graph_def = graph_pb2.GraphDef()
- input_constant = quantize_graph.create_constant_node(
- input_constant_name,
- value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
- dtype=dtypes.float32,
- shape=[2, 6])
- float_graph_def.node.extend([input_constant])
- split_constant = quantize_graph.create_constant_node(
- split_constant_name, value=1, dtype=dtypes.int32, shape=[])
- float_graph_def.node.extend([split_constant])
- split_node = quantize_graph.create_node(
- "Split", split_name, [split_constant_name, input_constant_name])
- quantize_graph.set_attr_int(split_node, "num_split", 2)
- quantize_graph.set_attr_dtype(split_node, "T", dtypes.float32)
- float_graph_def.node.extend([split_node])
- concat_constant = quantize_graph.create_constant_node(
- concat_constant_name, value=1, dtype=dtypes.int32, shape=[])
- float_graph_def.node.extend([concat_constant])
- concat_node = quantize_graph.create_node(
- "Concat", concat_name,
- [concat_constant_name, split_name + ":0", split_name + ":1"])
- quantize_graph.set_attr_int(concat_node, "N", 2)
- quantize_graph.set_attr_dtype(concat_node, "T", dtypes.float32)
- float_graph_def.node.extend([concat_node])
-
- test_graph(float_graph_def, {}, [concat_name])
-
- def test_node_name_from_input(self):
- self.assertEqual("SomeName",
- quantize_graph.node_name_from_input("^SomeName:2"))
-
- def test_unique_node_name_from_input(self):
- self.assertEqual("__hat__SomeName__port__2",
- quantize_graph.unique_node_name_from_input("^SomeName:2"))
-
- def test_identity(self):
- input_constant_name = "input_constant"
- identity_name = "identity"
- float_graph_def = graph_pb2.GraphDef()
- input_constant = quantize_graph.create_constant_node(
- input_constant_name,
- value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
- dtype=dtypes.float32,
- shape=[2, 6])
- float_graph_def.node.extend([input_constant])
- identity_node = quantize_graph.create_node("Identity", identity_name,
- [input_constant_name])
- quantize_graph.set_attr_dtype(identity_node, "T", dtypes.float32)
- float_graph_def.node.extend([identity_node])
-
- mul_name = "mul"
- mul_node = quantize_graph.create_node("Mul", mul_name,
- [identity_name, identity_name])
- quantize_graph.set_attr_dtype(mul_node, "T", dtypes.float32)
- float_graph_def.node.extend([mul_node])
-
- test_graph(float_graph_def, {}, [mul_name])
-
- def test_keep_control_edges(self):
- no_op_name = "no_op"
- a_constant_name = "a_constant"
- b_constant_name = "b_constant"
- a_check_name = "a_check"
- b_check_name = "b_check"
- a_identity_name = "a_identity"
- b_identity_name = "b_identity"
- add_name = "add"
- graph_def = graph_pb2.GraphDef()
- no_op = quantize_graph.create_node("NoOp", no_op_name, [])
- graph_def.node.extend([no_op])
- a_constant = quantize_graph.create_constant_node(
- a_constant_name, value=1, dtype=dtypes.float32, shape=[])
- graph_def.node.extend([a_constant])
- a_check_node = quantize_graph.create_node("CheckNumerics", a_check_name,
- [a_constant_name])
- graph_def.node.extend([a_check_node])
- a_identity_node = quantize_graph.create_node(
- "Identity", a_identity_name,
- [a_constant_name, "^" + a_check_name, "^" + no_op_name])
- graph_def.node.extend([a_identity_node])
- b_constant = quantize_graph.create_constant_node(
- b_constant_name, value=1, dtype=dtypes.float32, shape=[])
- graph_def.node.extend([b_constant])
- b_check_node = quantize_graph.create_node("CheckNumerics", b_check_name,
- [b_constant_name])
- graph_def.node.extend([b_check_node])
- b_identity_node = quantize_graph.create_node(
- "Identity", b_identity_name, [b_constant_name, "^" + b_check_name])
- graph_def.node.extend([b_identity_node])
- add_node = quantize_graph.create_node("Add", add_name,
- [a_identity_name, b_identity_name])
- quantize_graph.set_attr_dtype(add_node, "T", dtypes.float32)
- graph_def.node.extend([add_node])
-
- expected_output = graph_pb2.GraphDef()
- no_op = quantize_graph.create_node("NoOp", no_op_name, [])
- expected_output.node.extend([no_op])
- a_constant = quantize_graph.create_constant_node(
- a_constant_name, value=1, dtype=dtypes.float32, shape=[])
- expected_output.node.extend([a_constant])
- a_identity_node = quantize_graph.create_node(
- "Identity", a_identity_name, [a_constant_name, "^" + no_op_name])
- expected_output.node.extend([a_identity_node])
- b_constant = quantize_graph.create_constant_node(
- b_constant_name, value=1, dtype=dtypes.float32, shape=[])
- expected_output.node.extend([b_constant])
- add_node = quantize_graph.create_node("Add", add_name,
- [a_identity_name, b_constant_name])
- quantize_graph.set_attr_dtype(add_node, "T", dtypes.float32)
- expected_output.node.extend([add_node])
- expected_output.versions.CopyFrom(graph_def.versions)
- expected_output.library.CopyFrom(graph_def.library)
-
- output = graph_util.remove_training_nodes(graph_def)
- stripped_output = graph_util.extract_sub_graph(output, [add_name])
- self.assertProtoEquals(expected_output, stripped_output)
-
- def test_batch_norm(self):
- input_constant_name = "input_constant"
- mean_constant_name = "mean_constant"
- variance_constant_name = "variance_constant"
- beta_constant_name = "beta_constant"
- gamma_constant_name = "gamma_constant"
- batch_norm_name = "batch_norm"
- float_graph_def = graph_pb2.GraphDef()
- input_constant = quantize_graph.create_constant_node(
- input_constant_name,
- value=[1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6],
- dtype=dtypes.float32,
- shape=[1, 1, 6, 2])
- float_graph_def.node.extend([input_constant])
- mean_constant = quantize_graph.create_constant_node(
- mean_constant_name, value=[10, 20], dtype=dtypes.float32, shape=[2])
- float_graph_def.node.extend([mean_constant])
- variance_constant = quantize_graph.create_constant_node(
- variance_constant_name,
- value=[0.25, 0.5],
- dtype=dtypes.float32,
- shape=[2])
- float_graph_def.node.extend([variance_constant])
- beta_constant = quantize_graph.create_constant_node(
- beta_constant_name, value=[0.1, 0.6], dtype=dtypes.float32, shape=[2])
- float_graph_def.node.extend([beta_constant])
- gamma_constant = quantize_graph.create_constant_node(
- gamma_constant_name, value=[0, 0], dtype=dtypes.float32, shape=[2])
- float_graph_def.node.extend([gamma_constant])
- batch_norm_node = quantize_graph.create_node(
- "BatchNormWithGlobalNormalization", batch_norm_name, [
- input_constant_name, mean_constant_name, variance_constant_name,
- beta_constant_name, gamma_constant_name
- ])
- quantize_graph.set_attr_dtype(batch_norm_node, "T", dtypes.float32)
- quantize_graph.set_attr_bool(batch_norm_node, "scale_after_normalization",
- False)
- quantize_graph.set_attr_float(batch_norm_node, "variance_epsilon", 0.001)
- float_graph_def.node.extend([batch_norm_node])
- test_graph(float_graph_def, {}, [batch_norm_name])
-
- def test_max_pool(self):
- input_constant_name = "input_constant"
- max_pool_name = "max_pool"
- float_graph_def = graph_pb2.GraphDef()
- input_constant = quantize_graph.create_constant_node(
- input_constant_name,
- value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
- dtype=dtypes.float32,
- shape=[1, 2, 6, 1])
- float_graph_def.node.extend([input_constant])
- max_pool_node = quantize_graph.create_node("MaxPool", max_pool_name,
- [input_constant_name])
- quantize_graph.set_attr_int_list(max_pool_node, "ksize", [1, 2, 2, 1])
- quantize_graph.set_attr_int_list(max_pool_node, "strides", [1, 1, 1, 1])
- quantize_graph.set_attr_string(max_pool_node, "padding", b"SAME")
- float_graph_def.node.extend([max_pool_node])
- test_graph(float_graph_def, {}, [max_pool_name])
-
- def test_avg_pool(self):
- input_constant_name = "input_constant"
- avg_pool_name = "avg_pool"
- float_graph_def = graph_pb2.GraphDef()
- input_constant = quantize_graph.create_constant_node(
- input_constant_name,
- value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
- dtype=dtypes.float32,
- shape=[1, 2, 6, 1])
- float_graph_def.node.extend([input_constant])
- avg_pool_node = quantize_graph.create_node("AvgPool", avg_pool_name,
- [input_constant_name])
- quantize_graph.set_attr_dtype(avg_pool_node, "T", dtypes.float32)
- quantize_graph.set_attr_int_list(avg_pool_node, "ksize", [1, 2, 2, 1])
- quantize_graph.set_attr_int_list(avg_pool_node, "strides", [1, 1, 1, 1])
- quantize_graph.set_attr_string(avg_pool_node, "padding", b"SAME")
- float_graph_def.node.extend([avg_pool_node])
- test_graph(float_graph_def, {}, [avg_pool_name])
-
- def test_relu(self):
- input_constant_name = "input_constant"
- relu_name = "relu"
- float_graph_def = graph_pb2.GraphDef()
- input_constant = quantize_graph.create_constant_node(
- input_constant_name,
- value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
- dtype=dtypes.float32,
- shape=[1, 2, 6, 1])
- float_graph_def.node.extend([input_constant])
- relu_node = quantize_graph.create_node("Relu", relu_name,
- [input_constant_name])
- quantize_graph.set_attr_dtype(relu_node, "T", dtypes.float32)
- float_graph_def.node.extend([relu_node])
- test_graph(float_graph_def, {}, [relu_name])
-
- def test_relu_w_fake_quant_w_min_max_vars(self):
- input_node = quantize_graph.create_constant_node(
- "input",
- value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
- dtype=dtypes.float32,
- shape=[1, 2, 6, 1])
- relu_node = quantize_graph.create_node("Relu", "relu", [input_node.name])
- quantize_graph.set_attr_dtype(relu_node, "T", dtypes.float32)
-
- min_node = quantize_graph.create_constant_node(
- "min_bias_add", value=0, dtype=dtypes.float32, shape=[])
- max_node = quantize_graph.create_constant_node(
- "max_bias_add", value=12, dtype=dtypes.float32, shape=[])
- fake_quant_node = quantize_graph.create_node(
- "FakeQuantWithMinMaxVars", "fake_quant",
- [relu_node.name, min_node.name, max_node.name])
-
- float_graph_def = graph_pb2.GraphDef()
- float_graph_def.node.extend(
- [input_node, relu_node, min_node, max_node, fake_quant_node])
- test_graph(float_graph_def, {}, [fake_quant_node.name], log_graph=True)
-
- # Verify there is only one Quantize and one Requantize op.
- eightbit_rewriter = quantize_graph.GraphRewriter(
- float_graph_def, "eightbit", quantized_input_range=None)
- eightbit_graph_def = eightbit_rewriter.rewrite([fake_quant_node.name])
-
- ops = [node.op for node in eightbit_graph_def.node]
- # No quantize since all inputs are const and can be quantized up-front.
- self.assertEqual(0, ops.count("QuantizeV2") + ops.count("Quantize"))
-
- # One dequantize at the end.
- self.assertEqual(1, ops.count("Dequantize"))
-
- def test_relu6(self):
- input_constant_name = "input_constant"
- relu6_name = "relu6"
- float_graph_def = graph_pb2.GraphDef()
- input_constant = quantize_graph.create_constant_node(
- input_constant_name,
- value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
- dtype=dtypes.float32,
- shape=[1, 2, 6, 1])
- float_graph_def.node.extend([input_constant])
- relu6_node = quantize_graph.create_node("Relu6", relu6_name,
- [input_constant_name])
- quantize_graph.set_attr_dtype(relu6_node, "T", dtypes.float32)
- float_graph_def.node.extend([relu6_node])
- test_graph(float_graph_def, {}, [relu6_name])
-
- def test_bias_add(self):
- input_constant_name = "input_constant"
- offset_constant_name = "offset_constant"
- bias_add_name = "bias_add"
- float_graph_def = graph_pb2.GraphDef()
- input_constant = quantize_graph.create_constant_node(
- input_constant_name,
- value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
- dtype=dtypes.float32,
- shape=[1, 1, 2, 6])
- float_graph_def.node.extend([input_constant])
- offset_constant = quantize_graph.create_constant_node(
- offset_constant_name,
- value=[1, 2, 3, 4, 5, 6],
- dtype=dtypes.float32,
- shape=[6])
- float_graph_def.node.extend([offset_constant])
- bias_add_node = quantize_graph.create_node(
- "BiasAdd", bias_add_name, [input_constant_name, offset_constant_name])
- quantize_graph.set_attr_dtype(bias_add_node, "T", dtypes.float32)
- float_graph_def.node.extend([bias_add_node])
- test_graph(float_graph_def, {}, [bias_add_name])
-
- def test_quantized_input_range_errors(self):
- with self.assertRaises(ValueError):
- # Invalid mode.
- quantize_graph.GraphRewriter(graph_pb2.GraphDef(), "weights_rounded",
- [0, 1])
- with self.assertRaises(ValueError):
- # Invalid range.
- quantize_graph.GraphRewriter(graph_pb2.GraphDef(), "eightbit", [0, -1])
-
- def test_quantized_input_range_bias_add(self):
- input_shape = [1, 1, 2, 6]
- input_n = quantize_graph.create_node("Placeholder", "input", [])
- quantize_graph.set_attr_dtype(input_n, "dtype", dtypes.float32)
- quantize_graph.set_attr_shape(input_n, "shape", input_shape)
- offset_n = quantize_graph.create_constant_node(
- "offset", value=[1, 2, 3, 4, 5, 6], dtype=dtypes.float32, shape=[6])
- bias_add_n = quantize_graph.create_node("BiasAdd", "bias_add",
- [input_n.name, offset_n.name])
- quantize_graph.set_attr_dtype(bias_add_n, "T", dtypes.float32)
-
- float_graph_def = graph_pb2.GraphDef()
- float_graph_def.node.extend([input_n, offset_n, bias_add_n])
-
- input_map = {
- input_n.name + ":0":
- np.reshape([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], input_shape)
- }
- self._RunTestsForQuantizedInputRange(float_graph_def, input_map,
- [bias_add_n.name], [-1, 20.])
- self._RunTestsForQuantizedInputRange(float_graph_def, input_map,
- [bias_add_n.name], [0, 12.])
-
- def test_quantized_input_range_mat_mul(self):
- shapes = [[3, 2], [2, 4]]
- inputs = []
- for i, shape in enumerate(shapes):
- node = quantize_graph.create_node("Placeholder", "input_%s" % i, [])
- quantize_graph.set_attr_dtype(node, "dtype", dtypes.float32)
- quantize_graph.set_attr_shape(node, "shape", shape)
- inputs.append(node)
- mat_mul_node = quantize_graph.create_node("MatMul", "mat_mul",
- [n.name for n in inputs])
- quantize_graph.set_attr_dtype(mat_mul_node, "T", dtypes.float32)
-
- float_graph_def = graph_pb2.GraphDef()
- float_graph_def.node.extend(inputs + [mat_mul_node])
-
- input_map = {
- inputs[0].name + ":0":
- np.reshape([1, 2, 3, 4, 5, 6], shapes[0]),
- inputs[1].name + ":0":
- np.reshape([.8, .7, .6, .5, .4, .3, .2, .1], shapes[1])
- }
- self._RunTestsForQuantizedInputRange(float_graph_def, input_map,
- [mat_mul_node.name], [-1, 20.])
- self._RunTestsForQuantizedInputRange(float_graph_def, input_map,
- [mat_mul_node.name], [0, 6.])
-
- def _RunTestsForQuantizedInputRange(self, float_graph_def, input_map,
- output_names, input_range):
- if sys.version_info[0] == 3:
- # uint8->quint8 conversion for numpy is not working currently.
- return
-
- quantized_input_map = {}
- for k, v in input_map.items():
- arr = [
- int(
- round((n - input_range[0]) * 255 / (input_range[1] - input_range[
- 0]))) for n in v.flat
- ]
- arr = np.array(arr, np.uint8)
- arr = arr.reshape(v.shape)
- arr = arr.astype(dtypes.quint8.as_numpy_dtype)
- quantized_input_map[k] = arr
- output_tensors = [output_name + ":0" for output_name in output_names]
- float_results = run_graph_def(float_graph_def, input_map, output_tensors)
-
- # Quantize treating the input as quantized in range <input_range>.
- rewriter = quantize_graph.GraphRewriter(float_graph_def, "eightbit",
- input_range)
- graph_def = rewriter.rewrite(output_names)
- results = run_graph_def(graph_def, quantized_input_map, output_tensors)
- for expected, result in zip(float_results, results):
- assert are_tensors_near(expected, result, .5)
- ops = [node.op for node in graph_def.node]
- self.assertEqual(0, ops.count("QuantizeV2") + ops.count("Quantize"))
- self.assertEqual(len(output_names), ops.count("Dequantize"))
-
- # Quantize without treating input as quantized.
- rewriter = quantize_graph.GraphRewriter(
- float_graph_def, "eightbit", quantized_input_range=None)
- graph_def = rewriter.rewrite(output_names)
- results = run_graph_def(graph_def, input_map, output_tensors)
- for expected, result in zip(float_results, results):
- assert are_tensors_near(expected, result, .5)
- ops = [node.op for node in graph_def.node]
- self.assertEqual(
- len(input_map), ops.count("QuantizeV2") + ops.count("Quantize"))
- self.assertEqual(len(output_names), ops.count("Dequantize"))
-
- def test_bias_add_w_fake_quant_w_min_max_vars(self):
- input_node = quantize_graph.create_constant_node(
- "input",
- value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
- dtype=dtypes.float32,
- shape=[1, 1, 2, 5])
- offset_node = quantize_graph.create_constant_node(
- "offset", value=[1, 2, 3, 4, 5], dtype=dtypes.float32, shape=[5])
- bias_add_node = quantize_graph.create_node(
- "BiasAdd", "bias_add", [input_node.name, offset_node.name])
- quantize_graph.set_attr_dtype(bias_add_node, "T", dtypes.float32)
-
- min_node = quantize_graph.create_constant_node(
- "min_bias_add", value=-.5, dtype=dtypes.float32, shape=[])
- max_node = quantize_graph.create_constant_node(
- "max_bias_add", value=15.5, dtype=dtypes.float32, shape=[])
- fake_quant_node = quantize_graph.create_node(
- "FakeQuantWithMinMaxVars", "fake_quant",
- [bias_add_node.name, min_node.name, max_node.name])
-
- float_graph_def = graph_pb2.GraphDef()
- float_graph_def.node.extend([
- input_node, offset_node, bias_add_node, min_node, max_node,
- fake_quant_node
- ])
- test_graph(float_graph_def, {}, [fake_quant_node.name], log_graph=True)
-
- # Verify there is only one Quantize and one Requantize op.
- # Pass in fallback_quantization_range, although it will have no effect
- # because the FakeQuantWithMinMaxVars are used instead.
- eightbit_rewriter = quantize_graph.GraphRewriter(
- float_graph_def,
- "eightbit",
- quantized_input_range=None,
- fallback_quantization_range=[-100, 100])
- eightbit_graph_def = eightbit_rewriter.rewrite([fake_quant_node.name])
-
- ops = [node.op for node in eightbit_graph_def.node]
- node_names = [node.name for node in eightbit_graph_def.node]
- # No quantize since all inputs are const and can be quantized up-front.
- self.assertEqual(0, ops.count("QuantizeV2") + ops.count("Quantize"))
-
- # One dequantize at the end.
- self.assertEqual(1, ops.count("Dequantize"))
-
- # The fallback constants are not in the graph.
- self.assertEqual(0, node_names.count("fallback_quantization_min_value"))
- self.assertEqual(0, node_names.count("fallback_quantization_max_value"))
-
- def test_bias_add_w_fallback_min_max_vars(self):
- input_node = quantize_graph.create_constant_node(
- "input",
- value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
- dtype=dtypes.float32,
- shape=[1, 1, 2, 5])
- offset_node = quantize_graph.create_constant_node(
- "offset", value=[1, 2, 3, 4, 5], dtype=dtypes.float32, shape=[5])
- bias_add_node = quantize_graph.create_node(
- "BiasAdd", "bias_add", [input_node.name, offset_node.name])
- quantize_graph.set_attr_dtype(bias_add_node, "T", dtypes.float32)
-
- float_graph_def = graph_pb2.GraphDef()
- float_graph_def.node.extend([input_node, offset_node, bias_add_node])
- test_graph(float_graph_def, {}, [bias_add_node.name], log_graph=True)
-
- # Verify there is only one Quantize, one Requantize op, and no
- # RequantizationRange op.
- eightbit_rewriter = quantize_graph.GraphRewriter(
- float_graph_def,
- "eightbit",
- quantized_input_range=None,
- fallback_quantization_range=[-.5, 15.5])
- eightbit_graph_def = eightbit_rewriter.rewrite([bias_add_node.name])
-
- ops = [node.op for node in eightbit_graph_def.node]
- node_names = [node.name for node in eightbit_graph_def.node]
- # No quantize since all inputs are const and can be quantized up-front.
- self.assertEqual(0, ops.count("QuantizeV2") + ops.count("Quantize"))
-
- # One dequantize at the end.
- self.assertEqual(1, ops.count("Dequantize"))
-
- # No RequantizationRange
- self.assertEqual(0, ops.count("RequantizationRange"))
-
- # The fallback constants are in the graph.
- self.assertEqual(1, node_names.count("fallback_quantization_min_value"))
- self.assertEqual(1, node_names.count("fallback_quantization_max_value"))
-
- def test_remove_redundant_quantization(self):
- a_constant_name = "a_constant"
- a_constant_min_name = "a_constant_min"
- a_constant_max_name = "a_constant_max"
- a_dequantize_name = "a_dequantize"
- a_quantize_name = "a_quantize"
- b_constant_name = "b_constant"
- b_constant_min_name = "b_constant_min"
- b_constant_max_name = "b_constant_max"
- b_dequantize_name = "b_dequantize"
- b_quantize_name = "b_quantize"
- mat_mul_name = "mat_mul"
- graph_def = graph_pb2.GraphDef()
- a_constant = quantize_graph.create_constant_node(
- a_constant_name, value=(0,), dtype=dtypes.quint8, shape=[])
- graph_def.node.extend([a_constant])
- a_constant_min = quantize_graph.create_constant_node(
- a_constant_min_name, value=2, dtype=dtypes.float32, shape=[])
- graph_def.node.extend([a_constant_min])
- a_constant_max = quantize_graph.create_constant_node(
- a_constant_max_name, value=2, dtype=dtypes.float32, shape=[])
- graph_def.node.extend([a_constant_max])
- a_dequantize_node = quantize_graph.create_node(
- "Dequantize", a_dequantize_name,
- [a_constant_name, a_constant_min_name, a_constant_max_name])
- quantize_graph.set_attr_dtype(a_dequantize_node, "T", dtypes.uint8)
- graph_def.node.extend([a_dequantize_node])
- a_quantize_node = quantize_graph.create_node(
- "QuantizeV2", a_quantize_name,
- [a_dequantize_name, a_dequantize_name + ":1", a_dequantize_name + ":2"])
- quantize_graph.set_attr_dtype(a_quantize_node, "T", dtypes.uint8)
- graph_def.node.extend([a_quantize_node])
- b_constant = quantize_graph.create_constant_node(
- b_constant_name, value=(0,), dtype=dtypes.quint8, shape=[])
- graph_def.node.extend([b_constant])
- b_constant_min = quantize_graph.create_constant_node(
- b_constant_min_name, value=3, dtype=dtypes.float32, shape=[])
- graph_def.node.extend([b_constant_min])
- b_constant_max = quantize_graph.create_constant_node(
- b_constant_max_name, value=3, dtype=dtypes.float32, shape=[])
- graph_def.node.extend([b_constant_max])
- b_dequantize_node = quantize_graph.create_node(
- "Dequantize", b_dequantize_name,
- [b_constant_name, b_constant_min_name, b_constant_max_name])
- quantize_graph.set_attr_dtype(b_dequantize_node, "T", dtypes.uint8)
- graph_def.node.extend([b_dequantize_node])
- b_quantize_node = quantize_graph.create_node(
- "QuantizeV2", b_quantize_name,
- [b_dequantize_name, b_dequantize_name + ":1", b_dequantize_name + ":2"])
- quantize_graph.set_attr_dtype(b_quantize_node, "T", dtypes.uint8)
- graph_def.node.extend([b_quantize_node])
- mat_mul_node = quantize_graph.create_node("QuantizedMatMul", mat_mul_name, [
- a_quantize_name, b_quantize_name, a_quantize_name + ":1",
- a_quantize_name + ":2", b_quantize_name + ":1", b_quantize_name + ":2"
- ])
- quantize_graph.set_attr_dtype(mat_mul_node, "T1", dtypes.uint8)
- quantize_graph.set_attr_dtype(mat_mul_node, "T2", dtypes.int32)
- graph_def.node.extend([mat_mul_node])
-
- expected_output = graph_pb2.GraphDef()
- a_constant = quantize_graph.create_constant_node(
- a_constant_name, value=(0,), dtype=dtypes.quint8, shape=[])
- expected_output.node.extend([a_constant])
- a_constant_min = quantize_graph.create_constant_node(
- a_constant_min_name, value=2, dtype=dtypes.float32, shape=[])
- expected_output.node.extend([a_constant_min])
- a_constant_max = quantize_graph.create_constant_node(
- a_constant_max_name, value=2, dtype=dtypes.float32, shape=[])
- expected_output.node.extend([a_constant_max])
- b_constant = quantize_graph.create_constant_node(
- b_constant_name, value=(0,), dtype=dtypes.quint8, shape=[])
- expected_output.node.extend([b_constant])
- b_constant_min = quantize_graph.create_constant_node(
- b_constant_min_name, value=3, dtype=dtypes.float32, shape=[])
- expected_output.node.extend([b_constant_min])
- b_constant_max = quantize_graph.create_constant_node(
- b_constant_max_name, value=3, dtype=dtypes.float32, shape=[])
- expected_output.node.extend([b_constant_max])
- mat_mul_node = quantize_graph.create_node("QuantizedMatMul", mat_mul_name, [
- a_constant_name, b_constant_name, a_constant_min_name,
- a_constant_max_name, b_constant_min_name, b_constant_max_name
- ])
- quantize_graph.set_attr_dtype(mat_mul_node, "T1", dtypes.uint8)
- quantize_graph.set_attr_dtype(mat_mul_node, "T2", dtypes.int32)
- expected_output.node.extend([mat_mul_node])
- expected_output.versions.CopyFrom(graph_def.versions)
- expected_output.library.CopyFrom(graph_def.library)
-
- rewriter = quantize_graph.GraphRewriter(
- graph_def, [mat_mul_name], quantized_input_range=None)
- output = rewriter.remove_redundant_quantization(graph_def)
- stripped_output = graph_util.extract_sub_graph(output, [mat_mul_name])
- self.assertProtoEquals(expected_output, stripped_output)
-
-
-if __name__ == "__main__":
- test.main()