aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/framework/cc_op_gen.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/cc/framework/cc_op_gen.cc')
-rw-r--r--tensorflow/cc/framework/cc_op_gen.cc71
1 files changed, 61 insertions, 10 deletions
diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc
index d6a4f141b6..dfdef88945 100644
--- a/tensorflow/cc/framework/cc_op_gen.cc
+++ b/tensorflow/cc/framework/cc_op_gen.cc
@@ -273,6 +273,12 @@ string PrintAttrValue(const string& op, const AttrValue& attr_value) {
return "<Unknown AttrValue type>"; // Prevent missing return warning
}
+bool IsEmptyList(const AttrValue::ListValue& list) {
+ return list.s_size() == 0 && list.i_size() == 0 && list.f_size() == 0 &&
+ list.b_size() == 0 && list.type_size() == 0 &&
+ list.shape_size() == 0 && list.tensor_size() == 0;
+}
+
string ToCamelCase(const string& str) {
string result;
const char joiner = '_';
@@ -297,9 +303,9 @@ string ToCamelCase(const string& str) {
// indicate whether to treat the type as const when accepting the C++ type as an
// argument to a function.
std::pair<const char*, bool> AttrTypeName(StringPiece attr_type) {
- static const std::unordered_map<StringPiece, std::pair<const char*, bool>,
- StringPieceHasher>
- attr_type_map{
+ static const auto* attr_type_map =
+ new std::unordered_map<StringPiece, std::pair<const char*, bool>,
+ StringPieceHasher>{
{"string", {"StringPiece", false}},
{"list(string)", {"gtl::ArraySlice<string>", true}},
{"int", {"int64", false}},
@@ -317,14 +323,34 @@ std::pair<const char*, bool> AttrTypeName(StringPiece attr_type) {
{"func", {"NameAttrList", true}},
};
- auto entry = attr_type_map.find(attr_type);
- if (entry == attr_type_map.end()) {
+ auto entry = attr_type_map->find(attr_type);
+ if (entry == attr_type_map->end()) {
LOG(FATAL) << "Unsupported Attr type: " << attr_type;
return {"", false};
}
return entry->second;
}
+const char* ListElementTypeName(StringPiece attr_type) {
+ static const auto* attr_list_type_map =
+ new std::unordered_map<StringPiece, const char*, StringPieceHasher>{
+ {"list(string)", "string"},
+ {"list(int)", "int"},
+ {"list(float)", "float"},
+ {"list(bool)", "bool"},
+ {"list(type)", "DataType"},
+ {"list(shape)", "PartialTensorShape"},
+ {"list(tensor)", "TensorProto"},
+ };
+
+ auto entry = attr_list_type_map->find(attr_type);
+ if (entry == attr_list_type_map->end()) {
+ LOG(FATAL) << "Unsupported or non-list Attr type: " << attr_type;
+ return "";
+ }
+ return entry->second;
+}
+
bool IsCPPKeyword(StringPiece name) {
static const std::unordered_set<StringPiece, StringPieceHasher>
// Keywords obtained from http://en.cppreference.com/w/cpp/keyword
@@ -668,6 +694,7 @@ OpInfo::OpInfo(const OpDef& graph_op_def, const ApiDef& api_def,
string OpInfo::GetOpAttrStruct() const {
string struct_fields;
string setters;
+ string defaults_static_storage;
for (int i = 0; i < graph_op_def.attr_size(); ++i) {
const auto& attr(graph_op_def.attr(i));
@@ -705,11 +732,32 @@ string OpInfo::GetOpAttrStruct() const {
"_ = x;\n");
strings::StrAppend(&setters, " return ret;\n }\n\n");
- strings::StrAppend(
- &struct_fields, " ", attr_type_name, " ", api_def_attr.rename_to(),
- "_ = ",
- PrintAttrValue(graph_op_def.name(), api_def_attr.default_value()),
- ";\n");
+ string field_initiliazer;
+ auto& default_value = api_def_attr.default_value();
+ if (default_value.value_case() == AttrValue::kList &&
+ !IsEmptyList(default_value.list())) {
+ // Non-empty lists need static storage for their defaults. Define a
+ // function with static local variable that stores the array.
+ strings::StrAppend(&defaults_static_storage, " static ",
+ attr_type_name, " Default_", api_def_attr.rename_to(),
+ "() {\n");
+ strings::StrAppend(
+ &defaults_static_storage, " static const ",
+ ListElementTypeName(attr.type()), " kStorage[] = ",
+ PrintAttrValue(graph_op_def.name(), api_def_attr.default_value()),
+ ";\n");
+ strings::StrAppend(&defaults_static_storage, " return ",
+ attr_type_name, "(kStorage);\n }\n");
+ // Set the field_initializer to call the defined function.
+ strings::StrAppend(&field_initiliazer, "Default_",
+ api_def_attr.rename_to(), "()");
+ } else {
+ field_initiliazer =
+ PrintAttrValue(graph_op_def.name(), api_def_attr.default_value());
+ }
+ strings::StrAppend(&struct_fields, " ", attr_type_name, " ",
+ api_def_attr.rename_to(), "_ = ", field_initiliazer,
+ ";\n");
}
if (struct_fields.empty()) {
@@ -721,6 +769,9 @@ string OpInfo::GetOpAttrStruct() const {
string struct_decl = MakeComment(attrs_comment, " ");
strings::StrAppend(&struct_decl, " struct Attrs {\n");
strings::StrAppend(&struct_decl, setters, struct_fields);
+ if (!defaults_static_storage.empty()) {
+ strings::StrAppend(&struct_decl, " private:\n", defaults_static_storage);
+ }
strings::StrAppend(&struct_decl, " };\n");
return struct_decl;