path: root/tensorflow/core/graph/quantize_training.cc
diff options
authorGravatar Suharsh Sivakumar <suharshs@google.com>2017-03-29 18:17:36 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-29 19:31:58 -0700
commit20c86952373f61417a4a0829c57ff9640d91c177 (patch)
treebebd119ef41ad22f7bc38b7cf3d089b1c843d0f0 /tensorflow/core/graph/quantize_training.cc
parent3e239dc36147ef2d730ff4de50de59d9acfe0181 (diff)
Make quantize training rewriter add variables to save and restore ops if a save op exists. This allows checkpointing of variables created by the rewriter.
Change: 151656461
Diffstat (limited to 'tensorflow/core/graph/quantize_training.cc')
1 files changed, 213 insertions, 13 deletions
diff --git a/tensorflow/core/graph/quantize_training.cc b/tensorflow/core/graph/quantize_training.cc
index 48915393ce..63294c695e 100644
--- a/tensorflow/core/graph/quantize_training.cc
+++ b/tensorflow/core/graph/quantize_training.cc
@@ -46,7 +46,7 @@ const std::unordered_set<string, StringPiece::Hasher> nodes_to_rewrite{
// Contains necessary parameters to convert an edge.
struct EdgeToConvert {
- // Edge is not owned here.
+ // edge is not owned here.
const Edge* edge;
int32 num_bits;
bool signed_input;
@@ -135,6 +135,195 @@ bool FindType(const Graph* graph, const Node* node, bool* signed_input,
return true;
+// Find the Save op and inputs.
+Status FindSaveOp(const Graph* graph, Node** save_op,
+ std::vector<const Edge*>* in_edges, bool* found) {
+ *found = false;
+ for (Node* node : graph->nodes()) {
+ if (node->type_string() == "SaveV2") {
+ // We found multiple save ops.
+ if (*found) {
+ return errors::InvalidArgument("Input graph has multiple SaveV2 ops.");
+ }
+ *save_op = node;
+ *found = true;
+ TF_RETURN_IF_ERROR(node->input_edges(in_edges));
+ }
+ }
+ return Status::OK();
+Node* FindRestoreAllOp(const Graph* graph, StringPiece save_prefix) {
+ for (Node* node : graph->nodes()) {
+ // The restore_all op should have the same prefix of the save_op.
+ if (node->name() == strings::StrCat(save_prefix, "/restore_all")) {
+ return node;
+ }
+ }
+ return nullptr;
+// Strips the last "/suffix" from a name.
+// We use this to construct the name of restore ops in the same way they are
+// constructed by the Saver.
+StringPiece GetNodeNamePrefix(const Node* node) {
+ StringPiece name = node->name();
+ return name.substr(0, name.rfind('/'));
+void FillStringTensor(Tensor* dst, const Tensor& src) {
+ auto dst_flat = dst->flat<string>();
+ auto src_flat = src.flat<string>();
+ for (int i = 0; i < src.NumElements(); i++) {
+ dst_flat(i) = src_flat(i);
+ }
+// Add the added_variables as an inputs to the Save op.
+// We change the inputs of the SaveV2 op to include the names of the added
+// variables. We also add the variables as inputs to the save op.
+Status ConnectVariablesToSaveOp(Graph* graph, Node* save_op,
+ const std::vector<const Edge*>& in_edges,
+ const std::vector<Node*>& added_variables) {
+ Node* tensor_names_op = in_edges[1]->src();
+ Node* shape_and_slices_op = in_edges[2]->src();
+ // Get the tensor_names and shape_and_slices tensors from the const op.
+ Tensor tensor_names;
+ Tensor shape_and_slices;
+ GetNodeAttr(AttrSlice(tensor_names_op->def()), "value", &tensor_names));
+ TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(shape_and_slices_op->def()), "value",
+ &shape_and_slices));
+ int tn_size = tensor_names.NumElements();
+ int var_size = added_variables.size();
+ // Create a new save_op that has inputs to all the new variables.
+ NodeBuilder save_op_builder =
+ NodeBuilder(save_op->name(), save_op->type_string());
+ // The first three inputs are prefix, tensor_names, and shapes_and_slices.
+ for (int i = 0; i < 3; i++) {
+ save_op_builder = save_op_builder.Input(in_edges[i]->src());
+ }
+ std::vector<NodeBuilder::NodeOut> var_nodeouts;
+ var_nodeouts.reserve(tn_size + var_size);
+ // The rest of the inputs need to be used the construct the tensor list arg.
+ for (int i = 3; i < in_edges.size(); i++) {
+ var_nodeouts.emplace_back(in_edges[i]->src());
+ }
+ // Add the new values to the tensors and the op input.
+ Tensor new_tensor_names(DT_STRING, TensorShape({tn_size + var_size}));
+ Tensor new_shape_and_slices(DT_STRING, TensorShape({tn_size + var_size}));
+ FillStringTensor(&new_tensor_names, tensor_names);
+ FillStringTensor(&new_shape_and_slices, shape_and_slices);
+ for (int i = 0; i < var_size; i++) {
+ Node* var = added_variables[i];
+ new_tensor_names.flat<string>()(tn_size + i) = var->name();
+ new_shape_and_slices.flat<string>()(tn_size + i) = "";
+ var_nodeouts.emplace_back(var);
+ }
+ save_op_builder = save_op_builder.Input(var_nodeouts);
+ // Clear the old attr for the two constants and add the new ones.
+ tensor_names_op->ClearAttr("value");
+ shape_and_slices_op->ClearAttr("value");
+ tensor_names_op->AddAttr("value", new_tensor_names);
+ shape_and_slices_op->AddAttr("value", new_shape_and_slices);
+ // Remove the old save_op and add the new one.
+ Node* new_save_op;
+ TF_RETURN_IF_ERROR(save_op_builder.Finalize(graph, &new_save_op));
+ // Add outputs to the new_save_op, all outputs are control edges.
+ for (const Edge* edge : save_op->out_edges()) {
+ graph->AddControlEdge(new_save_op, edge->dst());
+ }
+ graph->RemoveNode(save_op);
+ return Status::OK();
+// Add a restore subgraph for each variable and connect to the restore_all op.
+// For each variable we add the following subgraph:
+// Assign----restore_all
+// / \
+// RestoreV2 Variable
+Status AddRestoreVariableSubgraphs(Graph* graph, Node* save_op,
+ const std::vector<const Edge*>& in_edges,
+ const std::vector<Node*>& variables) {
+ Node* prefix_op = in_edges[0]->src();
+ StringPiece name_prefix = GetNodeNamePrefix(save_op);
+ Node* restore_all = FindRestoreAllOp(graph, name_prefix);
+ if (restore_all == nullptr) {
+ return errors::InvalidArgument("graph has SaveOp, but no restore_all NoOp");
+ }
+ const string restore_op_name = strings::StrCat(name_prefix, "/RestoreV2");
+ const string assign_op_name = strings::StrCat(name_prefix, "/Assign");
+ for (Node* var : variables) {
+ string new_restore_op_name = graph->NewName(restore_op_name);
+ string new_assign_op_name = graph->NewName(assign_op_name);
+ string tensor_names_op_name =
+ strings::StrCat(new_restore_op_name, "/tensor_names");
+ string shape_and_slices_op_name =
+ strings::StrCat(new_restore_op_name, "/shape_and_slices");
+ // Construct the tensor_names input with the variable name.
+ Node* tensor_names;
+ Tensor tensor_names_val(DT_STRING, TensorShape({1}));
+ tensor_names_val.flat<string>()(0) = var->name();
+ TF_RETURN_IF_ERROR(NodeBuilder(tensor_names_op_name, "Const")
+ .Attr("dtype", DT_STRING)
+ .Attr("value", tensor_names_val)
+ .Finalize(graph, &tensor_names));
+ // Construct the shape_and_slices input with empty string.
+ Node* shape_and_slices;
+ Tensor shape_and_slices_val(DT_STRING, TensorShape({1}));
+ shape_and_slices_val.flat<string>()(0) = "";
+ TF_RETURN_IF_ERROR(NodeBuilder(shape_and_slices_op_name, "Const")
+ .Attr("dtype", DT_STRING)
+ .Attr("value", shape_and_slices_val)
+ .Finalize(graph, &shape_and_slices));
+ // Build the new Restore op for this variable.
+ Node* restore_op;
+ TF_RETURN_IF_ERROR(NodeBuilder(new_restore_op_name, "RestoreV2")
+ .Input(prefix_op)
+ .Input(tensor_names)
+ .Input(shape_and_slices)
+ .Attr("dtypes", {DT_FLOAT})
+ .Finalize(graph, &restore_op));
+ // Create Assign op, attaching the variable and Restore op to it.
+ Node* assign_op;
+ TF_RETURN_IF_ERROR(NodeBuilder(new_assign_op_name, "Assign")
+ .Input(var)
+ .Input(restore_op)
+ .Finalize(graph, &assign_op));
+ // Add a control edge from the assign op to restore_all op.
+ graph->AddControlEdge(assign_op, restore_all);
+ }
+ return Status::OK();
+// Adds new variables to save and restore ops matching the Save and Restore
+// graphs created in tensorflow/python/training/saver.py.
+Status AddSaveAndRestore(Graph* graph, const std::vector<Node*>& variables) {
+ Node* save_op;
+ std::vector<const Edge*> in_edges;
+ bool found = false;
+ TF_RETURN_IF_ERROR(FindSaveOp(graph, &save_op, &in_edges, &found));
+ if (found) {
+ AddRestoreVariableSubgraphs(graph, save_op, in_edges, variables));
+ ConnectVariablesToSaveOp(graph, save_op, in_edges, variables));
+ }
+ return Status::OK();
// Sets output to the Node that computes reduction axes corresponding to all
// dimensions of input and return.
Status MakeReductionAxes(Graph* graph, string name_prefix, Node* input,
@@ -225,12 +414,16 @@ Status MakeExponentialMovingAverage(Graph* graph, string name_prefix,
// | \ /
// +----------- assign
Status MakeInitializedEMAVariable(Graph* graph, const string& name, Node* decay,
- Node* init_val, Node** var) {
+ Node* init_val,
+ std::vector<Node*>* added_variables,
+ Node** var) {
+ // TODO(suharshs): Update this to use ResourceVariables when they are ready.
NodeBuilder(strings::StrCat(name, "/Variable"), "VariableV2")
.Attr("shape", TensorShape())
.Attr("dtype", DT_FLOAT)
.Finalize(graph, var));
+ added_variables->push_back(*var);
Node* is_initialized;
TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat(name, "/IsInitialized"),
@@ -264,7 +457,8 @@ Status MakeInitializedEMAVariable(Graph* graph, const string& name, Node* decay,
// Computes the min and max EMA of input and stores them in min_var and max_var.
Status MakeEMAMinMaxVars(Graph* graph, const string& name_prefix, Node* input,
- Node** min_var, Node** max_var) {
+ std::vector<Node*>* added_variables, Node** min_var,
+ Node** max_var) {
// TODO(suharshs): The decay will be constant, so we could make only one for
// all quantize_and_dequantize ops to share, this would have to live outside
// this function.
@@ -292,17 +486,18 @@ Status MakeEMAMinMaxVars(Graph* graph, const string& name_prefix, Node* input,
.Finalize(graph, &max));
- MakeInitializedEMAVariable(graph, min_name, decay, min, min_var));
- MakeInitializedEMAVariable(graph, max_name, decay, max, max_var));
+ TF_RETURN_IF_ERROR(MakeInitializedEMAVariable(graph, min_name, decay, min,
+ added_variables, min_var));
+ TF_RETURN_IF_ERROR(MakeInitializedEMAVariable(graph, max_name, decay, max,
+ added_variables, max_var));
return Status::OK();
// Makes an input min and max constant if the range is given. Otherwise, makes
// min and max variables that are updated by an EMA.
Status MakeInputMinMax(Graph* graph, const string& name_prefix,
- const EdgeToConvert& edge, Node** input_min,
+ const EdgeToConvert& edge,
+ std::vector<Node*>* added_variables, Node** input_min,
Node** input_max) {
if (edge.range_given) {
// Make constant nodes for the input_min and input_max if the range is
@@ -324,7 +519,8 @@ Status MakeInputMinMax(Graph* graph, const string& name_prefix,
} else {
// If the range is not given, estimate the range with EMA variables.
TF_RETURN_IF_ERROR(MakeEMAMinMaxVars(graph, name_prefix, edge.edge->src(),
- input_min, input_max));
+ added_variables, input_min,
+ input_max));
return Status::OK();
@@ -334,11 +530,12 @@ Status MakeInputMinMax(Graph* graph, const string& name_prefix,
// The result is stored in convert_node.
Status MakeQuantizeAndDequantizeV2(Graph* graph, const string& name_prefix,
const EdgeToConvert& edge,
+ std::vector<Node*>* added_variables,
Node** convert_node) {
Node* input_min;
Node* input_max;
- MakeInputMinMax(graph, name_prefix, edge, &input_min, &input_max));
+ TF_RETURN_IF_ERROR(MakeInputMinMax(graph, name_prefix, edge, added_variables,
+ &input_min, &input_max));
string quant_name = strings::StrCat(name_prefix, "/QuantizeAndDequantizeV2");
TF_RETURN_IF_ERROR(NodeBuilder(quant_name, "QuantizeAndDequantizeV2")
@@ -358,14 +555,15 @@ Status ProcessTargetEdges(Graph* graph,
// Remember previously converted ops to avoid duplicated conversion on the
// same input.
std::unordered_map<string, Node*, StringPiece::Hasher> name_index;
+ std::vector<Node*> added_variables;
for (const EdgeToConvert edge : target_edges) {
Node* convert_node;
string name_prefix = edge.edge->src()->name();
auto iter = name_index.find(name_prefix);
if (iter == name_index.end()) {
- MakeQuantizeAndDequantizeV2(graph, name_prefix, edge, &convert_node));
+ TF_RETURN_IF_ERROR(MakeQuantizeAndDequantizeV2(
+ graph, name_prefix, edge, &added_variables, &convert_node));
name_index[name_prefix] = convert_node;
} else {
convert_node = iter->second;
@@ -375,6 +573,8 @@ Status ProcessTargetEdges(Graph* graph,
+ TF_RETURN_IF_ERROR(AddSaveAndRestore(graph, added_variables));
return Status::OK();