diff options
author | Suharsh Sivakumar <suharshs@google.com> | 2017-03-29 18:17:36 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-29 19:31:58 -0700 |
commit | 20c86952373f61417a4a0829c57ff9640d91c177 (patch) | |
tree | bebd119ef41ad22f7bc38b7cf3d089b1c843d0f0 /tensorflow/core/graph/quantize_training.cc | |
parent | 3e239dc36147ef2d730ff4de50de59d9acfe0181 (diff) |
Make quantize training rewriter add variables to save and restore ops if a save op exists. This allows checkpointing of variables created by the rewriter.
Change: 151656461
Diffstat (limited to 'tensorflow/core/graph/quantize_training.cc')
-rw-r--r-- | tensorflow/core/graph/quantize_training.cc | 226 |
1 files changed, 213 insertions, 13 deletions
diff --git a/tensorflow/core/graph/quantize_training.cc b/tensorflow/core/graph/quantize_training.cc index 48915393ce..63294c695e 100644 --- a/tensorflow/core/graph/quantize_training.cc +++ b/tensorflow/core/graph/quantize_training.cc @@ -46,7 +46,7 @@ const std::unordered_set<string, StringPiece::Hasher> nodes_to_rewrite{ // Contains necessary parameters to convert an edge. struct EdgeToConvert { - // Edge is not owned here. + // edge is not owned here. const Edge* edge; int32 num_bits; bool signed_input; @@ -135,6 +135,195 @@ bool FindType(const Graph* graph, const Node* node, bool* signed_input, return true; } +// Find the Save op and inputs. +Status FindSaveOp(const Graph* graph, Node** save_op, + std::vector<const Edge*>* in_edges, bool* found) { + *found = false; + for (Node* node : graph->nodes()) { + if (node->type_string() == "SaveV2") { + // We found multiple save ops. + if (*found) { + return errors::InvalidArgument("Input graph has multiple SaveV2 ops."); + } + *save_op = node; + *found = true; + TF_RETURN_IF_ERROR(node->input_edges(in_edges)); + } + } + return Status::OK(); +} + +Node* FindRestoreAllOp(const Graph* graph, StringPiece save_prefix) { + for (Node* node : graph->nodes()) { + // The restore_all op should have the same prefix of the save_op. + if (node->name() == strings::StrCat(save_prefix, "/restore_all")) { + return node; + } + } + return nullptr; +} + +// Strips the last "/suffix" from a name. +// We use this to construct the name of restore ops in the same way they are +// constructed by the Saver. +StringPiece GetNodeNamePrefix(const Node* node) { + StringPiece name = node->name(); + return name.substr(0, name.rfind('/')); +} + +void FillStringTensor(Tensor* dst, const Tensor& src) { + auto dst_flat = dst->flat<string>(); + auto src_flat = src.flat<string>(); + for (int i = 0; i < src.NumElements(); i++) { + dst_flat(i) = src_flat(i); + } +} + +// Add the added_variables as an inputs to the Save op. +// We change the inputs of the SaveV2 op to include the names of the added +// variables. We also add the variables as inputs to the save op. +Status ConnectVariablesToSaveOp(Graph* graph, Node* save_op, + const std::vector<const Edge*>& in_edges, + const std::vector<Node*>& added_variables) { + Node* tensor_names_op = in_edges[1]->src(); + Node* shape_and_slices_op = in_edges[2]->src(); + + // Get the tensor_names and shape_and_slices tensors from the const op. + Tensor tensor_names; + Tensor shape_and_slices; + TF_RETURN_IF_ERROR( + GetNodeAttr(AttrSlice(tensor_names_op->def()), "value", &tensor_names)); + TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(shape_and_slices_op->def()), "value", + &shape_and_slices)); + + int tn_size = tensor_names.NumElements(); + int var_size = added_variables.size(); + + // Create a new save_op that has inputs to all the new variables. + NodeBuilder save_op_builder = + NodeBuilder(save_op->name(), save_op->type_string()); + // The first three inputs are prefix, tensor_names, and shapes_and_slices. + for (int i = 0; i < 3; i++) { + save_op_builder = save_op_builder.Input(in_edges[i]->src()); + } + std::vector<NodeBuilder::NodeOut> var_nodeouts; + var_nodeouts.reserve(tn_size + var_size); + // The rest of the inputs need to be used the construct the tensor list arg. + for (int i = 3; i < in_edges.size(); i++) { + var_nodeouts.emplace_back(in_edges[i]->src()); + } + + // Add the new values to the tensors and the op input. + Tensor new_tensor_names(DT_STRING, TensorShape({tn_size + var_size})); + Tensor new_shape_and_slices(DT_STRING, TensorShape({tn_size + var_size})); + FillStringTensor(&new_tensor_names, tensor_names); + FillStringTensor(&new_shape_and_slices, shape_and_slices); + for (int i = 0; i < var_size; i++) { + Node* var = added_variables[i]; + new_tensor_names.flat<string>()(tn_size + i) = var->name(); + new_shape_and_slices.flat<string>()(tn_size + i) = ""; + var_nodeouts.emplace_back(var); + } + save_op_builder = save_op_builder.Input(var_nodeouts); + + // Clear the old attr for the two constants and add the new ones. + tensor_names_op->ClearAttr("value"); + shape_and_slices_op->ClearAttr("value"); + tensor_names_op->AddAttr("value", new_tensor_names); + shape_and_slices_op->AddAttr("value", new_shape_and_slices); + + // Remove the old save_op and add the new one. + Node* new_save_op; + TF_RETURN_IF_ERROR(save_op_builder.Finalize(graph, &new_save_op)); + // Add outputs to the new_save_op, all outputs are control edges. + for (const Edge* edge : save_op->out_edges()) { + graph->AddControlEdge(new_save_op, edge->dst()); + } + graph->RemoveNode(save_op); + + return Status::OK(); +} + +// Add a restore subgraph for each variable and connect to the restore_all op. +// For each variable we add the following subgraph: +// Assign----restore_all +// / \ +// RestoreV2 Variable +Status AddRestoreVariableSubgraphs(Graph* graph, Node* save_op, + const std::vector<const Edge*>& in_edges, + const std::vector<Node*>& variables) { + Node* prefix_op = in_edges[0]->src(); + StringPiece name_prefix = GetNodeNamePrefix(save_op); + Node* restore_all = FindRestoreAllOp(graph, name_prefix); + if (restore_all == nullptr) { + return errors::InvalidArgument("graph has SaveOp, but no restore_all NoOp"); + } + const string restore_op_name = strings::StrCat(name_prefix, "/RestoreV2"); + const string assign_op_name = strings::StrCat(name_prefix, "/Assign"); + for (Node* var : variables) { + string new_restore_op_name = graph->NewName(restore_op_name); + string new_assign_op_name = graph->NewName(assign_op_name); + string tensor_names_op_name = + strings::StrCat(new_restore_op_name, "/tensor_names"); + string shape_and_slices_op_name = + strings::StrCat(new_restore_op_name, "/shape_and_slices"); + + // Construct the tensor_names input with the variable name. + Node* tensor_names; + Tensor tensor_names_val(DT_STRING, TensorShape({1})); + tensor_names_val.flat<string>()(0) = var->name(); + TF_RETURN_IF_ERROR(NodeBuilder(tensor_names_op_name, "Const") + .Attr("dtype", DT_STRING) + .Attr("value", tensor_names_val) + .Finalize(graph, &tensor_names)); + + // Construct the shape_and_slices input with empty string. + Node* shape_and_slices; + Tensor shape_and_slices_val(DT_STRING, TensorShape({1})); + shape_and_slices_val.flat<string>()(0) = ""; + TF_RETURN_IF_ERROR(NodeBuilder(shape_and_slices_op_name, "Const") + .Attr("dtype", DT_STRING) + .Attr("value", shape_and_slices_val) + .Finalize(graph, &shape_and_slices)); + + // Build the new Restore op for this variable. + Node* restore_op; + TF_RETURN_IF_ERROR(NodeBuilder(new_restore_op_name, "RestoreV2") + .Input(prefix_op) + .Input(tensor_names) + .Input(shape_and_slices) + .Attr("dtypes", {DT_FLOAT}) + .Finalize(graph, &restore_op)); + + // Create Assign op, attaching the variable and Restore op to it. + Node* assign_op; + TF_RETURN_IF_ERROR(NodeBuilder(new_assign_op_name, "Assign") + .Input(var) + .Input(restore_op) + .Finalize(graph, &assign_op)); + + // Add a control edge from the assign op to restore_all op. + graph->AddControlEdge(assign_op, restore_all); + } + return Status::OK(); +} + +// Adds new variables to save and restore ops matching the Save and Restore +// graphs created in tensorflow/python/training/saver.py. +Status AddSaveAndRestore(Graph* graph, const std::vector<Node*>& variables) { + Node* save_op; + std::vector<const Edge*> in_edges; + bool found = false; + TF_RETURN_IF_ERROR(FindSaveOp(graph, &save_op, &in_edges, &found)); + if (found) { + TF_RETURN_IF_ERROR( + AddRestoreVariableSubgraphs(graph, save_op, in_edges, variables)); + TF_RETURN_IF_ERROR( + ConnectVariablesToSaveOp(graph, save_op, in_edges, variables)); + } + return Status::OK(); +} + // Sets output to the Node that computes reduction axes corresponding to all // dimensions of input and return. Status MakeReductionAxes(Graph* graph, string name_prefix, Node* input, @@ -225,12 +414,16 @@ Status MakeExponentialMovingAverage(Graph* graph, string name_prefix, // | \ / // +----------- assign Status MakeInitializedEMAVariable(Graph* graph, const string& name, Node* decay, - Node* init_val, Node** var) { + Node* init_val, + std::vector<Node*>* added_variables, + Node** var) { + // TODO(suharshs): Update this to use ResourceVariables when they are ready. TF_RETURN_IF_ERROR( NodeBuilder(strings::StrCat(name, "/Variable"), "VariableV2") .Attr("shape", TensorShape()) .Attr("dtype", DT_FLOAT) .Finalize(graph, var)); + added_variables->push_back(*var); Node* is_initialized; TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat(name, "/IsInitialized"), @@ -264,7 +457,8 @@ Status MakeInitializedEMAVariable(Graph* graph, const string& name, Node* decay, // Computes the min and max EMA of input and stores them in min_var and max_var. Status MakeEMAMinMaxVars(Graph* graph, const string& name_prefix, Node* input, - Node** min_var, Node** max_var) { + std::vector<Node*>* added_variables, Node** min_var, + Node** max_var) { // TODO(suharshs): The decay will be constant, so we could make only one for // all quantize_and_dequantize ops to share, this would have to live outside // this function. @@ -292,17 +486,18 @@ Status MakeEMAMinMaxVars(Graph* graph, const string& name_prefix, Node* input, .Input(input) .Input(reduction_axes) .Finalize(graph, &max)); - TF_RETURN_IF_ERROR( - MakeInitializedEMAVariable(graph, min_name, decay, min, min_var)); - TF_RETURN_IF_ERROR( - MakeInitializedEMAVariable(graph, max_name, decay, max, max_var)); + TF_RETURN_IF_ERROR(MakeInitializedEMAVariable(graph, min_name, decay, min, + added_variables, min_var)); + TF_RETURN_IF_ERROR(MakeInitializedEMAVariable(graph, max_name, decay, max, + added_variables, max_var)); return Status::OK(); } // Makes an input min and max constant if the range is given. Otherwise, makes // min and max variables that are updated by an EMA. Status MakeInputMinMax(Graph* graph, const string& name_prefix, - const EdgeToConvert& edge, Node** input_min, + const EdgeToConvert& edge, + std::vector<Node*>* added_variables, Node** input_min, Node** input_max) { if (edge.range_given) { // Make constant nodes for the input_min and input_max if the range is @@ -324,7 +519,8 @@ Status MakeInputMinMax(Graph* graph, const string& name_prefix, } else { // If the range is not given, estimate the range with EMA variables. TF_RETURN_IF_ERROR(MakeEMAMinMaxVars(graph, name_prefix, edge.edge->src(), - input_min, input_max)); + added_variables, input_min, + input_max)); } return Status::OK(); @@ -334,11 +530,12 @@ Status MakeInputMinMax(Graph* graph, const string& name_prefix, // The result is stored in convert_node. Status MakeQuantizeAndDequantizeV2(Graph* graph, const string& name_prefix, const EdgeToConvert& edge, + std::vector<Node*>* added_variables, Node** convert_node) { Node* input_min; Node* input_max; - TF_RETURN_IF_ERROR( - MakeInputMinMax(graph, name_prefix, edge, &input_min, &input_max)); + TF_RETURN_IF_ERROR(MakeInputMinMax(graph, name_prefix, edge, added_variables, + &input_min, &input_max)); string quant_name = strings::StrCat(name_prefix, "/QuantizeAndDequantizeV2"); TF_RETURN_IF_ERROR(NodeBuilder(quant_name, "QuantizeAndDequantizeV2") @@ -358,14 +555,15 @@ Status ProcessTargetEdges(Graph* graph, // Remember previously converted ops to avoid duplicated conversion on the // same input. std::unordered_map<string, Node*, StringPiece::Hasher> name_index; + std::vector<Node*> added_variables; for (const EdgeToConvert edge : target_edges) { Node* convert_node; string name_prefix = edge.edge->src()->name(); auto iter = name_index.find(name_prefix); if (iter == name_index.end()) { - TF_RETURN_IF_ERROR( - MakeQuantizeAndDequantizeV2(graph, name_prefix, edge, &convert_node)); + TF_RETURN_IF_ERROR(MakeQuantizeAndDequantizeV2( + graph, name_prefix, edge, &added_variables, &convert_node)); name_index[name_prefix] = convert_node; } else { convert_node = iter->second; @@ -375,6 +573,8 @@ Status ProcessTargetEdges(Graph* graph, graph->RemoveEdge(edge.edge); } + TF_RETURN_IF_ERROR(AddSaveAndRestore(graph, added_variables)); + return Status::OK(); } |