From f41959ccb2d9d4c722fe8fc3351401d53bcf4900 Mon Sep 17 00:00:00 2001 From: Manjunath Kudlur Date: Fri, 6 Nov 2015 16:27:58 -0800 Subject: TensorFlow: Initial commit of TensorFlow library. TensorFlow is an open source software library for numerical computation using data flow graphs. Base CL: 107276108 --- tensorflow/core/framework/attr_value_util.cc | 382 +++++++++++++++++++++++++++ 1 file changed, 382 insertions(+) create mode 100644 tensorflow/core/framework/attr_value_util.cc (limited to 'tensorflow/core/framework/attr_value_util.cc') diff --git a/tensorflow/core/framework/attr_value_util.cc b/tensorflow/core/framework/attr_value_util.cc new file mode 100644 index 0000000000..400ef118b8 --- /dev/null +++ b/tensorflow/core/framework/attr_value_util.cc @@ -0,0 +1,382 @@ +#include "tensorflow/core/framework/attr_value_util.h" + +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/regexp.h" + +namespace tensorflow { + +namespace { + +string SummarizeString(const string& str) { + return strings::StrCat("\"", str_util::CEscape(str), "\""); +} + +string SummarizeShape(const TensorShapeProto& proto) { + TensorShape shape(proto); + return shape.ShortDebugString(); +} + +string SummarizeTensor(const TensorProto& tensor_proto) { + Tensor t; + if (!t.FromProto(tensor_proto)) { + return strings::StrCat(""); + } + return t.DebugString(); +} + +} // namespace + +string SummarizeAttrValue(const AttrValue& attr_value) { + switch (attr_value.value_case()) { + case AttrValue::kS: + return SummarizeString(attr_value.s()); + case AttrValue::kI: + return strings::StrCat(attr_value.i()); + case AttrValue::kF: + return strings::StrCat(attr_value.f()); + case AttrValue::kB: + return attr_value.b() ? "true" : "false"; + case AttrValue::kType: + return DataType_Name(attr_value.type()); + case AttrValue::kShape: + return SummarizeShape(attr_value.shape()); + case AttrValue::kTensor: + return SummarizeTensor(attr_value.tensor()); + case AttrValue::kList: { + string ret = "["; + if (attr_value.list().s_size() > 0) { + for (int i = 0; i < attr_value.list().s_size(); ++i) { + if (i > 0) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, SummarizeString(attr_value.list().s(i))); + } + } else if (attr_value.list().i_size() > 0) { + for (int i = 0; i < attr_value.list().i_size(); ++i) { + if (i > 0) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, attr_value.list().i(i)); + } + } else if (attr_value.list().f_size() > 0) { + for (int i = 0; i < attr_value.list().f_size(); ++i) { + if (i > 0) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, attr_value.list().f(i)); + } + } else if (attr_value.list().b_size() > 0) { + for (int i = 0; i < attr_value.list().b_size(); ++i) { + if (i > 0) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, attr_value.list().b(i) ? "true" : "false"); + } + } else if (attr_value.list().type_size() > 0) { + for (int i = 0; i < attr_value.list().type_size(); ++i) { + if (i > 0) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, DataType_Name(attr_value.list().type(i))); + } + } else if (attr_value.list().shape_size() > 0) { + for (int i = 0; i < attr_value.list().shape_size(); ++i) { + if (i > 0) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, SummarizeShape(attr_value.list().shape(i))); + } + } else if (attr_value.list().tensor_size() > 0) { + for (int i = 0; i < attr_value.list().tensor_size(); ++i) { + if (i > 0) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, + SummarizeTensor(attr_value.list().tensor(i))); + } + } + strings::StrAppend(&ret, "]"); + return ret; + } + case AttrValue::kFunc: { + std::vector entries; + for (auto p : attr_value.func().attr()) { + entries.push_back( + strings::StrCat(p.first, "=", SummarizeAttrValue(p.second))); + } + sort(entries.begin(), entries.end()); + return strings::StrCat(attr_value.func().name(), "[", + str_util::Join(entries, ", "), "]"); + } + case AttrValue::kPlaceholder: + return strings::StrCat("$", attr_value.placeholder()); + case AttrValue::VALUE_NOT_SET: + return ""; + } + return ""; // Prevent missing return warning +} + +Status AttrValueHasType(const AttrValue& attr_value, StringPiece type) { + int num_set = 0; + +#define VALIDATE_FIELD(name, type_string, oneof_case) \ + do { \ + if (attr_value.has_list()) { \ + if (attr_value.list().name##_size() > 0) { \ + if (type != "list(" type_string ")") { \ + return errors::InvalidArgument( \ + "AttrValue had value with type list(" type_string ") when ", \ + type, " expected"); \ + } \ + ++num_set; \ + } \ + } else if (attr_value.value_case() == AttrValue::oneof_case) { \ + if (type != type_string) { \ + return errors::InvalidArgument( \ + "AttrValue had value with type " type_string " when ", type, \ + " expected"); \ + } \ + ++num_set; \ + } \ + } while (false) + + VALIDATE_FIELD(s, "string", kS); + VALIDATE_FIELD(i, "int", kI); + VALIDATE_FIELD(f, "float", kF); + VALIDATE_FIELD(b, "bool", kB); + VALIDATE_FIELD(type, "type", kType); + VALIDATE_FIELD(shape, "shape", kShape); + VALIDATE_FIELD(tensor, "tensor", kTensor); + +#undef VALIDATE_FIELD + + if (attr_value.value_case() == AttrValue::kFunc) { + if (type != "func") { + return errors::InvalidArgument( + "AttrValue had value with type 'func' when ", type, " expected"); + } + ++num_set; + } + + if (attr_value.value_case() == AttrValue::kPlaceholder) { + return errors::InvalidArgument( + "AttrValue had value with unexpected type 'placeholder"); + } + + // If the attr type is 'list', we expect attr_value.has_list() to be true. + // However, proto3's attr_value.has_list() can be false when set to an empty + // list. So we simply check if has_list is false and some other field in + // attr_value is set to flag the error. + if (StringPiece(type).starts_with("list(") && !attr_value.has_list()) { + if (num_set) { + return errors::InvalidArgument( + "AttrValue missing value with expected type ", type); + } else { + // Indicate that we have a list, but an empty one. + ++num_set; + } + } + + // Okay to have an empty list, but not to be missing a non-list value. + if (num_set == 0 && !StringPiece(type).starts_with("list(")) { + return errors::InvalidArgument( + "AttrValue missing value with expected type ", type); + } + + // Ref types and DT_INVALID are illegal. + if (type == "type") { + if (IsRefType(attr_value.type())) { + return errors::InvalidArgument( + "AttrValue must not have reference type value of ", + DataTypeString(attr_value.type())); + } + if (attr_value.type() == DT_INVALID) { + return errors::InvalidArgument("AttrValue has invalid DataType"); + } + } else if (type == "list(type)") { + for (auto as_int : attr_value.list().type()) { + const DataType dtype = static_cast(as_int); + if (IsRefType(dtype)) { + return errors::InvalidArgument( + "AttrValue must not have reference type value of ", + DataTypeString(dtype)); + } + if (dtype == DT_INVALID) { + return errors::InvalidArgument("AttrValue contains invalid DataType"); + } + } + } + + return Status::OK(); +} + +bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out) { + // Parse type. + string field_name; + bool is_list = type.Consume("list("); + if (type.Consume("string")) { + field_name = "s"; + } else if (type.Consume("int")) { + field_name = "i"; + } else if (type.Consume("float")) { + field_name = "f"; + } else if (type.Consume("bool")) { + field_name = "b"; + } else if (type.Consume("type")) { + field_name = "type"; + } else if (type.Consume("shape")) { + field_name = "shape"; + } else if (type.Consume("tensor")) { + field_name = "tensor"; + } else if (type.Consume("func")) { + field_name = "func"; + } else if (type.Consume("placeholder")) { + field_name = "placeholder"; + } else { + return false; + } + if (is_list && !type.Consume(")")) { + return false; + } + + // Construct a valid text proto message to parse. + string to_parse; + if (is_list) { + // TextFormat parser considers "i: 7" to be the same as "i: [7]", + // but we only want to allow list values with []. + if (!RE2::FullMatch(ToRegexpStringPiece(text), "\\s*\\[.*\\]\\s*")) { + return false; + } + if (RE2::FullMatch(ToRegexpStringPiece(text), "\\s*\\[\\s*\\]\\s*")) { + // User wrote "[]", so return empty list without invoking the TextFormat + // parse which returns an error for "i: []". + out->Clear(); + out->mutable_list(); + return true; + } + to_parse = strings::StrCat("list { ", field_name, ": ", text, " }"); + } else { + to_parse = strings::StrCat(field_name, ": ", text); + } + + // Parse if we can. + return protobuf::TextFormat::ParseFromString(to_parse, out); +} + +#define DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD) \ + void SetAttrValue(ARG_TYPE value, AttrValue* out) { out->set_##FIELD(value); } + +#define DEFINE_SET_ATTR_VALUE_LIST(ARG_TYPE, FIELD) \ + void SetAttrValue(ARG_TYPE value, AttrValue* out) { \ + out->mutable_list(); /* create list() even if value empty */ \ + for (const auto& v : value) { \ + out->mutable_list()->add_##FIELD(v); \ + } \ + } + +#define DEFINE_SET_ATTR_VALUE_BOTH(ARG_TYPE, FIELD) \ + DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD) \ + DEFINE_SET_ATTR_VALUE_LIST(gtl::ArraySlice, FIELD) + +DEFINE_SET_ATTR_VALUE_ONE(const string&, s) +DEFINE_SET_ATTR_VALUE_LIST(gtl::ArraySlice, s) +DEFINE_SET_ATTR_VALUE_BOTH(const char*, s) +DEFINE_SET_ATTR_VALUE_BOTH(int64, i) +DEFINE_SET_ATTR_VALUE_BOTH(int32, i) +DEFINE_SET_ATTR_VALUE_BOTH(float, f) +DEFINE_SET_ATTR_VALUE_BOTH(double, f) +DEFINE_SET_ATTR_VALUE_BOTH(bool, b) +DEFINE_SET_ATTR_VALUE_LIST(const std::vector&, b) +DEFINE_SET_ATTR_VALUE_LIST(std::initializer_list, b) +DEFINE_SET_ATTR_VALUE_BOTH(DataType, type) + +void SetAttrValue(StringPiece value, AttrValue* out) { + out->set_s(value.data(), value.size()); +} + +void SetAttrValue(const TensorShape& value, AttrValue* out) { + value.AsProto(out->mutable_shape()); +} + +void SetAttrValue(const gtl::ArraySlice value, AttrValue* out) { + out->mutable_list(); // Create list() even if value empty. + for (const auto& v : value) { + v.AsProto(out->mutable_list()->add_shape()); + } +} + +void SetAttrValue(const Tensor& value, AttrValue* out) { + if (value.NumElements() > 1) { + value.AsProtoTensorContent(out->mutable_tensor()); + } else { + value.AsProtoField(out->mutable_tensor()); + } +} + +void SetAttrValue(const gtl::ArraySlice value, AttrValue* out) { + out->mutable_list(); // Create list() even if value empty. + for (const auto& v : value) { + if (v.NumElements() > 1) { + v.AsProtoTensorContent(out->mutable_list()->add_tensor()); + } else { + v.AsProtoField(out->mutable_list()->add_tensor()); + } + } +} + +void SetAttrValue(const TensorProto& value, AttrValue* out) { + *out->mutable_tensor() = value; +} + +void SetAttrValue(const gtl::ArraySlice value, AttrValue* out) { + out->mutable_list(); // Create list() even if value empty. + for (const auto& v : value) { + *out->mutable_list()->add_tensor() = v; + } +} + +void SetAttrValue(const NameAttrList& value, AttrValue* out) { + *out->mutable_func() = value; +} + +bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b) { + string a_str, b_str; + a.SerializeToString(&a_str); + b.SerializeToString(&b_str); + // Note: it should be safe to compare proto serializations of the attr + // values since at most one field should be set in each (indeed, it + // must be the same field if they are to compare equal). + // Exception: there are multiple equivalent representations of + // TensorProtos. So a return value of true implies a == b, but not the + // converse. + return a_str == b_str; +} + +bool HasPlaceHolder(const AttrValue& val) { + switch (val.value_case()) { + case AttrValue::kFunc: + for (const auto& p : val.func().attr()) { + if (HasPlaceHolder(p.second)) { + return true; + } + } + break; + case AttrValue::kPlaceholder: + return true; + default: + break; + } + return false; +} + +bool SubstitutePlaceholders(SubstituteFunc substitute, AttrValue* value) { + switch (value->value_case()) { + case AttrValue::kFunc: + for (auto& p : *(value->mutable_func()->mutable_attr())) { + if (!SubstitutePlaceholders(substitute, &p.second)) { + return false; + } + } + break; + case AttrValue::kPlaceholder: + return substitute(value->placeholder(), value); + case AttrValue::VALUE_NOT_SET: + return false; + default: + break; + } + return true; +} + +} // namespace tensorflow -- cgit v1.2.3