diff options
author | Ilya Biryukov <ibiryukov@google.com> | 2018-06-11 14:16:30 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-11 14:19:38 -0700 |
commit | 0912bc8cc7f491cdcc5b8a74600292c6e810247b (patch) | |
tree | 13c0ff0c4d4d33784408550a4f4db75a76446552 /tensorflow/cc | |
parent | 21aa82e1a12eb53fe4c94006f957c1adab9aa662 (diff) |
Fix 'cc_op_gen' to use static storage for constant arrays.
Previously, the generate would emit code like this:
struct Attrs {
ArraySlice<int> dilations_ = {1, 1, 1, 1};
};
This code is incorrect, since the array slice references a temporary object
that dies after initialization finishes.
After this change change the generator will produce static functions to
initialize the values:
struct Attrs {
ArraySlice<int> dilations_ = Default_dilations();
private:
ArraySlice<int> Default_dilations() {
static int kStorage[] = {1, 1, 1, 1};
return ArraySlice<int>(kStorage);
}
};
Presumably, it used to work because all compilers chose to use static storage
in those cases anyway. However, new versions of clang tend to miscompile this
code, causing test failures. (This error was found when trying to upgrade our
clang revision from r328903 to r331746).
PiperOrigin-RevId: 200110952
Diffstat (limited to 'tensorflow/cc')
-rw-r--r-- | tensorflow/cc/framework/cc_op_gen.cc | 71 |
1 files changed, 61 insertions, 10 deletions
diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index d6a4f141b6..dfdef88945 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -273,6 +273,12 @@ string PrintAttrValue(const string& op, const AttrValue& attr_value) { return "<Unknown AttrValue type>"; // Prevent missing return warning } +bool IsEmptyList(const AttrValue::ListValue& list) { + return list.s_size() == 0 && list.i_size() == 0 && list.f_size() == 0 && + list.b_size() == 0 && list.type_size() == 0 && + list.shape_size() == 0 && list.tensor_size() == 0; +} + string ToCamelCase(const string& str) { string result; const char joiner = '_'; @@ -297,9 +303,9 @@ string ToCamelCase(const string& str) { // indicate whether to treat the type as const when accepting the C++ type as an // argument to a function. std::pair<const char*, bool> AttrTypeName(StringPiece attr_type) { - static const std::unordered_map<StringPiece, std::pair<const char*, bool>, - StringPieceHasher> - attr_type_map{ + static const auto* attr_type_map = + new std::unordered_map<StringPiece, std::pair<const char*, bool>, + StringPieceHasher>{ {"string", {"StringPiece", false}}, {"list(string)", {"gtl::ArraySlice<string>", true}}, {"int", {"int64", false}}, @@ -317,14 +323,34 @@ std::pair<const char*, bool> AttrTypeName(StringPiece attr_type) { {"func", {"NameAttrList", true}}, }; - auto entry = attr_type_map.find(attr_type); - if (entry == attr_type_map.end()) { + auto entry = attr_type_map->find(attr_type); + if (entry == attr_type_map->end()) { LOG(FATAL) << "Unsupported Attr type: " << attr_type; return {"", false}; } return entry->second; } +const char* ListElementTypeName(StringPiece attr_type) { + static const auto* attr_list_type_map = + new std::unordered_map<StringPiece, const char*, StringPieceHasher>{ + {"list(string)", "string"}, + {"list(int)", "int"}, + {"list(float)", "float"}, + {"list(bool)", "bool"}, + {"list(type)", "DataType"}, + {"list(shape)", "PartialTensorShape"}, + {"list(tensor)", "TensorProto"}, + }; + + auto entry = attr_list_type_map->find(attr_type); + if (entry == attr_list_type_map->end()) { + LOG(FATAL) << "Unsupported or non-list Attr type: " << attr_type; + return ""; + } + return entry->second; +} + bool IsCPPKeyword(StringPiece name) { static const std::unordered_set<StringPiece, StringPieceHasher> // Keywords obtained from http://en.cppreference.com/w/cpp/keyword @@ -668,6 +694,7 @@ OpInfo::OpInfo(const OpDef& graph_op_def, const ApiDef& api_def, string OpInfo::GetOpAttrStruct() const { string struct_fields; string setters; + string defaults_static_storage; for (int i = 0; i < graph_op_def.attr_size(); ++i) { const auto& attr(graph_op_def.attr(i)); @@ -705,11 +732,32 @@ string OpInfo::GetOpAttrStruct() const { "_ = x;\n"); strings::StrAppend(&setters, " return ret;\n }\n\n"); - strings::StrAppend( - &struct_fields, " ", attr_type_name, " ", api_def_attr.rename_to(), - "_ = ", - PrintAttrValue(graph_op_def.name(), api_def_attr.default_value()), - ";\n"); + string field_initiliazer; + auto& default_value = api_def_attr.default_value(); + if (default_value.value_case() == AttrValue::kList && + !IsEmptyList(default_value.list())) { + // Non-empty lists need static storage for their defaults. Define a + // function with static local variable that stores the array. + strings::StrAppend(&defaults_static_storage, " static ", + attr_type_name, " Default_", api_def_attr.rename_to(), + "() {\n"); + strings::StrAppend( + &defaults_static_storage, " static const ", + ListElementTypeName(attr.type()), " kStorage[] = ", + PrintAttrValue(graph_op_def.name(), api_def_attr.default_value()), + ";\n"); + strings::StrAppend(&defaults_static_storage, " return ", + attr_type_name, "(kStorage);\n }\n"); + // Set the field_initializer to call the defined function. + strings::StrAppend(&field_initiliazer, "Default_", + api_def_attr.rename_to(), "()"); + } else { + field_initiliazer = + PrintAttrValue(graph_op_def.name(), api_def_attr.default_value()); + } + strings::StrAppend(&struct_fields, " ", attr_type_name, " ", + api_def_attr.rename_to(), "_ = ", field_initiliazer, + ";\n"); } if (struct_fields.empty()) { @@ -721,6 +769,9 @@ string OpInfo::GetOpAttrStruct() const { string struct_decl = MakeComment(attrs_comment, " "); strings::StrAppend(&struct_decl, " struct Attrs {\n"); strings::StrAppend(&struct_decl, setters, struct_fields); + if (!defaults_static_storage.empty()) { + strings::StrAppend(&struct_decl, " private:\n", defaults_static_storage); + } strings::StrAppend(&struct_decl, " };\n"); return struct_decl; |