/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include #include #include #include "tensorflow/c/checkpoint_reader.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/graph/subgraph.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/public/session.h" #include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h" #include "tensorflow/tools/graph_transforms/transform_utils.h" namespace tensorflow { using str_util::Join; using str_util::Split; using str_util::StringReplace; using strings::StrCat; namespace graph_transforms { // Sparsify Tensor of shape [N, 1]. Return the indices and values vectors for // non-zero tensor content. Status SparsifyWeights(const Tensor& tensor, Tensor* indices_tensor, Tensor* values_tensor) { if (tensor.dims() != 2 || tensor.dim_size(1) != 1) { return tensorflow::errors::FailedPrecondition( "Transform only applicable to subgraph with 'Const' with " "tensor of shape [N, 1]. But instead get shape ", tensor.shape().DebugString(), "."); } auto flat = tensor.flat(); std::vector indices; std::vector values; for (int64 i = 0; i < flat.size(); i++) { float val = flat(i); if (std::abs(val) >= 1.0e-5) { indices.push_back(i); values.push_back(val); } } // During model initialization, InitializeTableOp makes use of // KeyValueTensorIterator, which does not accept empty keys or values. // Consequently, adding a dummy pair of indices and values as a walkaround. if (indices.empty() || values.empty()) { indices.push_back(0); values.push_back(0); } *indices_tensor = Tensor(DataTypeToEnum::value, {static_cast(indices.size())}); std::copy_n(indices.begin(), indices.size(), indices_tensor->flat().data()); *values_tensor = Tensor(DataTypeToEnum::value, {static_cast(values.size())}); std::copy_n(values.begin(), values.size(), values_tensor->flat().data()); return Status::OK(); } void CreateConstNode(const Tensor& tensor, const string& name, NodeDef* node_def) { node_def->set_op("Const"); node_def->set_name(name); SetNodeTensorAttr("value", tensor, node_def); } string GetMonolithicTensorKey(const string& tensor_slice_name) { std::vector names = Split(tensor_slice_name, "/"); if (str_util::StartsWith(names[names.size() - 1], "part_")) { CHECK_GE(names.size(), 2); names.pop_back(); } return Join(names, "/"); } Status ObtainTensorSlice(const GraphDef& input_graph_def, const string& target_name, string* shape_slice_string) { string restore_node_name; for (const auto& node : input_graph_def.node()) { std::vector node_name_parts = Split(node.name(), "/"); if (node_name_parts.size() == 2 && str_util::StartsWith(node_name_parts[0], "save") && str_util::StartsWith(node_name_parts[1], "Assign") && node.input(0) == target_name) { restore_node_name = node.input(1); break; } } std::vector restore_node_parts = Split(restore_node_name, ":"); CHECK_LE(restore_node_parts.size(), 2); string tensor_names_node; string shape_and_slices_node; for (const auto& node : input_graph_def.node()) { if ((node.name() == restore_node_parts[0]) && (node.op() == "RestoreV2")) { tensor_names_node = node.input(1); shape_and_slices_node = node.input(2); break; } } int offset = -1; for (const auto& node : input_graph_def.node()) { if (node.name() == tensor_names_node) { Tensor tensor_names_tensor; TF_RETURN_IF_ERROR(GetNodeAttr(node, "value", &tensor_names_tensor)); const auto& tensor_names_value = tensor_names_tensor.flat(); for (int i = 0; i < tensor_names_value.size(); i++) { if (tensor_names_value(i) == GetMonolithicTensorKey(target_name)) { offset = i; break; } } } } if (offset == -1) { return errors::Internal("Unable to find RestoreV2 entry for variable: ", target_name); } for (const auto& node : input_graph_def.node()) { if (node.name() == shape_and_slices_node) { Tensor shape_and_slices_tensor; TF_RETURN_IF_ERROR(GetNodeAttr(node, "value", &shape_and_slices_tensor)); const auto& shape_and_slices_value = shape_and_slices_tensor.flat(); *shape_slice_string = shape_and_slices_value(offset); return Status::OK(); } } return errors::Internal("Unable to find slice for variable: ", target_name); } Status ReadTensorFromCheckpoint( const string& tensor_name, const std::unique_ptr& ckpt_reader, const string& shape_and_slice, Tensor* tensor) { if (ckpt_reader) { TensorShape parsed_full_shape; TensorSlice parsed_slice; TensorShape parsed_slice_shape; bool get_slice = false; if (!shape_and_slice.empty()) { TF_RETURN_IF_ERROR( checkpoint::ParseShapeAndSlice(shape_and_slice, &parsed_full_shape, &parsed_slice, &parsed_slice_shape)); get_slice = (parsed_full_shape != parsed_slice_shape); } if (get_slice) { TF_RETURN_IF_ERROR(ckpt_reader->LookupSlice( GetMonolithicTensorKey(tensor_name), parsed_slice, tensor)); } else { TF_RETURN_IF_ERROR( ckpt_reader->Lookup(GetMonolithicTensorKey(tensor_name), tensor)); } return Status::OK(); } return errors::Internal("Checkpoint reader was not initialized. "); } Status InitializeCheckpointReader(const TransformFuncContext& context, std::unique_ptr* ckpt_reader) { if (context.params.count("input_checkpoint")) { const string input_checkpoint = context.params.at("input_checkpoint")[0]; ckpt_reader->reset(new BundleReader(Env::Default(), input_checkpoint)); TF_RETURN_IF_ERROR((*ckpt_reader)->status()); } return Status::OK(); } Status ObtainVariableInfo( const GraphDef& input_graph_def, std::unique_ptr >* shapes_and_slices) { shapes_and_slices->reset(new std::unordered_map()); for (const auto& node : input_graph_def.node()) { if ((node.op() == "Variable") || (node.op() == "VariableV2")) { string s; TF_RETURN_IF_ERROR(ObtainTensorSlice(input_graph_def, node.name(), &s)); (**shapes_and_slices)[node.name()] = s; } } return Status::OK(); } Status RemoveInputAtIndex(NodeDef* n, int index) { for (int i = index; i < n->input_size() - 1; i++) { n->mutable_input()->SwapElements(i, i + 1); } n->mutable_input()->RemoveLast(); return Status::OK(); } Status RemoveNodeAtIndex(GraphDef* g, int index) { for (int i = index; i < g->node_size() - 1; i++) { g->mutable_node()->SwapElements(i, i + 1); } g->mutable_node()->RemoveLast(); return Status::OK(); } Status SparsifyGatherInternal( const GraphDef& input_graph_def, const std::unique_ptr >& shapes_and_slices, const TransformFuncContext& context, const OpTypePattern& pattern, const std::unique_ptr& ckpt_reader, GraphDef* output_graph_def) { string group_init_node = "group_deps"; if (context.params.count("group_init_node")) { group_init_node = context.params.at("group_init_node")[0]; } GraphDef current_graph_def = input_graph_def; bool any_match_found = false; // Populate references. std::unordered_map refs; for (const auto& node : current_graph_def.node()) { for (const auto& input : node.input()) { auto parsed_input = StringReplace(input, "^", "", true); refs[parsed_input] += 1; } } // The subgraphs may have overlapping components, therefore GraphMatcher // doesn't return all subgraphs in one round -- this has to be multi-round // update. do { any_match_found = false; GraphDef replaced_graph_def = current_graph_def; std::vector init_table_node_names; std::vector removed_node_names; TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( current_graph_def, pattern, [&ckpt_reader, &any_match_found, &init_table_node_names, &shapes_and_slices, &removed_node_names, &refs](const NodeMatch& match, const std::set& input_nodes, const std::set& output_nodes, std::vector* new_nodes) { any_match_found = true; // The captured subgraph should be of the following pattern: // Const --> Identity --> Gather --> ... // ^ // | // (ids) // // After transform, it becomes: // --> NoOp(group_deps) // | // Const --> InitializeTable --> HashTable // ^ | // | | // Const ------------- | // v // (ids) ---> LookupTableFind <--- Const(default) // | // v // ... // clang-format off // For each subgraph, do the following // 1. Sparsify the `Const`, creating two `Const`, for hashtable // key/val. // 2. Create a `InitializeTable` op connecting to the above 2 `Const`. // 3. Create a `HashTable` op connecting to `InitializeTable` op. // 4. Replace the `Gather` with a `LookupTableFind` op. // 5. Connect the `LookupTableFind` with // a. `HashTable` // b. `Gather`'s ids input // c. a `default_val` arg, valued at 0 // clang-format on const NodeDef& gather_node = match.node; // GatherV2 adds an "axis" parameter. sparsify_gather only supports // axis 0 gathers. if (gather_node.op() == "GatherV2") { // Per the OpTypePattern, the 3rd input to Gather must be a Const. const NodeDef& axis_node = match.inputs[2].node; Tensor axis_t; TF_RETURN_IF_ERROR(GetNodeAttr(axis_node, "value", &axis_t)); int64 axis = 0; if (axis_t.dtype() == DT_INT32) { axis = axis_t.scalar()(); } else if (axis_t.dtype() == DT_INT64) { axis = axis_t.scalar()(); } else { return tensorflow::errors::FailedPrecondition( "Gather axis was not int32 or int64."); } if (axis != 0) { return tensorflow::errors::FailedPrecondition( "Transform only applicable to subgraph with GatherV2 over " "axis 0. Found axis ", axis, "."); } } const NodeDef& weights_node = match.inputs[0].inputs[0].node; DataType data_type; TF_RETURN_IF_ERROR(GetNodeAttr(weights_node, "dtype", &data_type)); if (data_type != DT_FLOAT) { return tensorflow::errors::FailedPrecondition( "Transform only applicable to subgraph with 'Const'," "'Variable', or 'VariableV2' of dtype " "'DT_FLOAT'. Found '" + weights_node.op() + "' with name '", weights_node.name(), "' and dtype '", data_type, "'."); } Tensor weight; if (weights_node.op() == "Const") { weight = GetNodeTensorAttr(weights_node, "value"); } else { TF_RETURN_IF_ERROR(ReadTensorFromCheckpoint( weights_node.name(), ckpt_reader, (*shapes_and_slices)[weights_node.name()], &weight)); } // Add both both weight and identity node names. removed_node_names.push_back(weights_node.name()); removed_node_names.push_back(match.inputs[0].node.name()); for (auto input_node : match.inputs[0].node.input()) { auto parsed_input = StringReplace(input_node, "^", "", true); refs[parsed_input]--; } Tensor indices_tensor; Tensor values_tensor; TF_RETURN_IF_ERROR( SparsifyWeights(weight, &indices_tensor, &values_tensor)); // indices and values of sparsified `Const` DataType key_dtype = DT_INT64; NodeDef indices_node; CreateConstNode(indices_tensor, StrCat(weights_node.name(), "/indices"), &indices_node); SetNodeAttr("dtype", key_dtype, &indices_node); NodeDef values_node; CreateConstNode(values_tensor, StrCat(weights_node.name(), "/values"), &values_node); SetNodeAttr("dtype", data_type, &values_node); // HashTable node NodeDef hashtable_node; hashtable_node.set_op("HashTable"); hashtable_node.set_name(StrCat(weights_node.name(), "/HashTable")); SetNodeAttr("key_dtype", key_dtype, &hashtable_node); SetNodeAttr("value_dtype", data_type, &hashtable_node); // InitializeTable node NodeDef init_table_node; init_table_node.set_op("InitializeTable"); init_table_node.set_name( StrCat(weights_node.name(), "/InitializeTable")); SetNodeAttr("Tkey", key_dtype, &init_table_node); SetNodeAttr("Tval", data_type, &init_table_node); init_table_node_names.push_back(init_table_node.name()); // LookupTableFind node NodeDef lookup_node; lookup_node.set_op("LookupTableFind"); lookup_node.set_name(StrCat(gather_node.name(), "/LookupTableFind")); SetNodeAttr("Tin", key_dtype, &lookup_node); SetNodeAttr("Tout", data_type, &lookup_node); // Default return value of hashtable lookup Tensor zero_tensor(data_type, TensorShape({})); zero_tensor.flat()(0) = 0.0; NodeDef default_value_node; CreateConstNode(zero_tensor, StrCat(gather_node.name(), "/Const"), &default_value_node); SetNodeAttr("dtype", data_type, &default_value_node); // ExpandDims argument Tensor dim_idx(DT_INT32, TensorShape({})); dim_idx.flat()(0) = -1; NodeDef dim_idx_node; dim_idx_node.set_op("Const"); dim_idx_node.set_name( StrCat(gather_node.name(), "/ExpandDims/Const")); SetNodeAttr("value", dim_idx, &dim_idx_node); SetNodeAttr("dtype", DT_INT32, &dim_idx_node); // ExpandDims node NodeDef expand_dims_node; expand_dims_node.set_op("ExpandDims"); // Reuse gather_node's name so not to change dependent's inputs expand_dims_node.set_name(gather_node.name()); SetNodeAttr("T", data_type, &expand_dims_node); // Connect nodes AddNodeInput(hashtable_node.name(), &init_table_node); refs[hashtable_node.name()]++; AddNodeInput(indices_node.name(), &init_table_node); refs[indices_node.name()]++; AddNodeInput(values_node.name(), &init_table_node); refs[values_node.name()]++; AddNodeInput(hashtable_node.name(), &lookup_node); refs[hashtable_node.name()]++; AddNodeInput(gather_node.input(1), &lookup_node); refs[gather_node.input(1)]++; AddNodeInput(default_value_node.name(), &lookup_node); refs[default_value_node.name()]++; AddNodeInput(lookup_node.name(), &expand_dims_node); refs[lookup_node.name()]++; AddNodeInput(dim_idx_node.name(), &expand_dims_node); refs[dim_idx_node.name()]++; // Copy 'ids' input of original 'Gather' new_nodes->push_back(match.inputs[1].node); new_nodes->push_back(indices_node); new_nodes->push_back(values_node); new_nodes->push_back(hashtable_node); new_nodes->push_back(init_table_node); new_nodes->push_back(lookup_node); new_nodes->push_back(default_value_node); new_nodes->push_back(dim_idx_node); new_nodes->push_back(expand_dims_node); return Status::OK(); }, {true}, &replaced_graph_def)); NodeDef* init_op = nullptr; for (int i = 0; i < replaced_graph_def.node_size(); i++) { if (replaced_graph_def.node(i).name() == group_init_node && replaced_graph_def.node(i).op() == "NoOp") { init_op = replaced_graph_def.mutable_node(i); break; } } if (!init_op) { // Init node init_op = replaced_graph_def.mutable_node()->Add(); init_op->set_op("NoOp"); init_op->set_name(group_init_node); } for (const string& name : init_table_node_names) { // Add control dependence from init_table_node to group_deps_node AddNodeInput(StrCat("^", name), init_op); refs[name]++; } // Erase inputs and outputs as they are not considered for deletion. for (const auto& output : context.output_names) { refs.erase(output); } for (const auto& input : context.input_names) { refs.erase(input); } // Add nodes with a reference count of 0 for deletion. for (auto entry : refs) { if (entry.second == 0) { removed_node_names.push_back(entry.first); } } while (!removed_node_names.empty()) { auto name = removed_node_names.back(); removed_node_names.pop_back(); int i = 0; while (i < replaced_graph_def.node_size()) { // Revisit this to see if we can safely remove RestoreV2 nodes. if ((replaced_graph_def.node(i).name() == name) && (replaced_graph_def.node(i).op() != "RestoreV2")) { for (const auto& input : replaced_graph_def.node(i).input()) { auto parsed_input = StringReplace(input, "^", "", true); refs[parsed_input] -= 1; if (refs[parsed_input] == 0) { removed_node_names.push_back(parsed_input); } } TF_RETURN_IF_ERROR(RemoveNodeAtIndex(&replaced_graph_def, i)); continue; } int j = 0; bool deleted_inputs = false; while (j < replaced_graph_def.node(i).input_size()) { if (replaced_graph_def.node(i).input(j) == name || replaced_graph_def.node(i).input(j) == ("^" + name)) { TF_RETURN_IF_ERROR( RemoveInputAtIndex(replaced_graph_def.mutable_node(i), j)); deleted_inputs = true; continue; } j++; } if (deleted_inputs) { if (replaced_graph_def.node(i).op() == "ConcatV2") { if (replaced_graph_def.node(i).input_size() > 2) { SetNodeAttr("N", replaced_graph_def.node(i).input_size() - 1, replaced_graph_def.mutable_node(i)); } else if (replaced_graph_def.node(i).input_size() == 2) { if (refs[replaced_graph_def.node(i).input(1)] != 1) { return errors::Internal( "Expect axis tensor of ConcatV2 node to only be referenced " "once."); } refs[replaced_graph_def.node(i).input(1)] -= 1; removed_node_names.push_back(replaced_graph_def.node(i).input(1)); replaced_graph_def.mutable_node(i)->mutable_input()->RemoveLast(); replaced_graph_def.mutable_node(i)->mutable_attr()->erase("N"); replaced_graph_def.mutable_node(i)->set_op("Identity"); } else { return errors::Internal( "ConcatV2 should have at least two elements"); } } if ((replaced_graph_def.node(i).op() == "Assign" || replaced_graph_def.node(i).op() == "Reshape" || replaced_graph_def.node(i).op() == "Equal" || replaced_graph_def.node(i).op() == "Mean" || replaced_graph_def.node(i).op() == "ScalarSummary") && replaced_graph_def.node(i).input_size() == 1) { removed_node_names.push_back(replaced_graph_def.node(i).name()); } if (!replaced_graph_def.node(i).input_size()) { removed_node_names.push_back(replaced_graph_def.node(i).name()); } } i++; } } current_graph_def = replaced_graph_def; } while (any_match_found); *output_graph_def = current_graph_def; return Status::OK(); } Status SparsifyGather(const GraphDef& input_graph_def, const TransformFuncContext& context, GraphDef* output_graph_def) { // clang-format off const OpTypePattern gather_pattern = {"Gather", { {"Identity", { {"Const|Variable|VariableV2"} } }, {"*"}, } }; const OpTypePattern gather_v2_pattern = {"GatherV2", { {"Identity", { {"Const|Variable|VariableV2"} } }, {"*"}, // GatherV2's axis must be constant. {"Const"}, } }; // clang-format on GraphDef cleaned_input_graph_def; RemoveAttributes(input_graph_def, {"_output_shapes"}, &cleaned_input_graph_def); GraphDef temp_output; std::unique_ptr ckpt_reader; TF_RETURN_IF_ERROR(InitializeCheckpointReader(context, &ckpt_reader)); std::unique_ptr > shapes_and_slices; TF_RETURN_IF_ERROR( ObtainVariableInfo(cleaned_input_graph_def, &shapes_and_slices)); TF_RETURN_IF_ERROR(SparsifyGatherInternal( cleaned_input_graph_def, shapes_and_slices, context, gather_pattern, ckpt_reader, &temp_output)); TF_RETURN_IF_ERROR(SparsifyGatherInternal(temp_output, shapes_and_slices, context, gather_v2_pattern, ckpt_reader, output_graph_def)); return Status::OK(); } REGISTER_GRAPH_TRANSFORM("sparsify_gather", SparsifyGather); } // namespace graph_transforms } // namespace tensorflow