diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-08-03 16:48:44 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-08-03 16:52:15 -0700 |
commit | 57626dd38a7867b76c44f3933e7810190174a2ee (patch) | |
tree | cc9d6833e5107fe6fb915ca28c597b53c4b0bb3e /tensorflow/c/c_api.cc | |
parent | 11d900686fd6aa65c26d75ac92e11b437ef4c48c (diff) |
Allow specifying colocation constraints through TF_SetAttr*
Before this change, colocation constraint semantics were not
well-defined when they were set using TF_ColocateWith and
TF_SetAttrStringList and/or TF_SetAttrValueProto. One could get
an exception if multiple methods were used.
After this change all changes to colocation attribute (i.e. _class)
are executed on TF_OperationDescription.colocation_constraints leading
to consistent semantics.
PiperOrigin-RevId: 164202666
Diffstat (limited to 'tensorflow/c/c_api.cc')
-rw-r--r-- | tensorflow/c/c_api.cc | 54 |
1 files changed, 40 insertions, 14 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 663fec56f1..e3c4bb02d2 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -966,7 +966,7 @@ void TF_AddControlInput(TF_OperationDescription* desc, TF_Operation* input) { } void TF_ColocateWith(TF_OperationDescription* desc, TF_Operation* op) { - desc->colocation_constraints.emplace_back( + desc->colocation_constraints.emplace( StrCat(tensorflow::kColocationGroupPrefix, op->node.name())); } @@ -979,12 +979,20 @@ void TF_SetAttrString(TF_OperationDescription* desc, const char* attr_name, void TF_SetAttrStringList(TF_OperationDescription* desc, const char* attr_name, const void* const* values, const size_t* lengths, int num_values) { - std::vector<tensorflow::StringPiece> v; - v.reserve(num_values); - for (int i = 0; i < num_values; ++i) { - v.emplace_back(static_cast<const char*>(values[i]), lengths[i]); + if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) { + desc->colocation_constraints.clear(); + for (int i = 0; i < num_values; ++i) { + desc->colocation_constraints.emplace(static_cast<const char*>(values[i]), + lengths[i]); + } + } else { + std::vector<tensorflow::StringPiece> v; + v.reserve(num_values); + for (int i = 0; i < num_values; ++i) { + v.emplace_back(static_cast<const char*>(values[i]), lengths[i]); + } + desc->node_builder.Attr(attr_name, v); } - desc->node_builder.Attr(attr_name, v); } void TF_SetAttrInt(TF_OperationDescription* desc, const char* attr_name, @@ -1143,12 +1151,28 @@ void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name, const void* proto, size_t proto_len, TF_Status* status) { tensorflow::AttrValue attr_value; - if (attr_value.ParseFromArray(proto, proto_len)) { - desc->node_builder.Attr(attr_name, attr_value); - status->status = Status::OK(); - } else { + if (!attr_value.ParseFromArray(proto, proto_len)) { status->status = InvalidArgument("Unparseable AttrValue proto"); + return; } + + if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) { + if (attr_value.value_case() != tensorflow::AttrValue::kList && + attr_value.value_case() != tensorflow::AttrValue::VALUE_NOT_SET) { + status->status = + InvalidArgument("Expected \"list\" field for \"", + tensorflow::kColocationAttrName, "\" attribute"); + return; + } + desc->colocation_constraints.clear(); + for (const tensorflow::string& location : attr_value.list().s()) { + desc->colocation_constraints.insert(location); + } + } else { + desc->node_builder.Attr(attr_name, attr_value); + } + + status->status = Status::OK(); } static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc, @@ -1160,10 +1184,12 @@ static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc, status->status = InvalidArgument("Duplicate node name in graph: '", desc->node_builder.node_name(), "'"); } else { - std::sort(desc->colocation_constraints.begin(), - desc->colocation_constraints.end()); - desc->node_builder.Attr(tensorflow::kColocationAttrName, - desc->colocation_constraints); + if (!desc->colocation_constraints.empty()) { + desc->node_builder.Attr( + tensorflow::kColocationAttrName, + std::vector<tensorflow::string>(desc->colocation_constraints.begin(), + desc->colocation_constraints.end())); + } status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret); if (status->status.ok()) { |