aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_api.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-08-03 16:48:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-03 16:52:15 -0700
commit57626dd38a7867b76c44f3933e7810190174a2ee (patch)
treecc9d6833e5107fe6fb915ca28c597b53c4b0bb3e /tensorflow/c/c_api.cc
parent11d900686fd6aa65c26d75ac92e11b437ef4c48c (diff)
Allow specifying colocation constraints through TF_SetAttr*
Before this change, colocation constraint semantics were not well-defined when they were set using TF_ColocateWith and TF_SetAttrStringList and/or TF_SetAttrValueProto. One could get an exception if multiple methods were used. After this change all changes to colocation attribute (i.e. _class) are executed on TF_OperationDescription.colocation_constraints leading to consistent semantics. PiperOrigin-RevId: 164202666
Diffstat (limited to 'tensorflow/c/c_api.cc')
-rw-r--r--tensorflow/c/c_api.cc54
1 files changed, 40 insertions, 14 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index 663fec56f1..e3c4bb02d2 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -966,7 +966,7 @@ void TF_AddControlInput(TF_OperationDescription* desc, TF_Operation* input) {
}
void TF_ColocateWith(TF_OperationDescription* desc, TF_Operation* op) {
- desc->colocation_constraints.emplace_back(
+ desc->colocation_constraints.emplace(
StrCat(tensorflow::kColocationGroupPrefix, op->node.name()));
}
@@ -979,12 +979,20 @@ void TF_SetAttrString(TF_OperationDescription* desc, const char* attr_name,
void TF_SetAttrStringList(TF_OperationDescription* desc, const char* attr_name,
const void* const* values, const size_t* lengths,
int num_values) {
- std::vector<tensorflow::StringPiece> v;
- v.reserve(num_values);
- for (int i = 0; i < num_values; ++i) {
- v.emplace_back(static_cast<const char*>(values[i]), lengths[i]);
+ if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) {
+ desc->colocation_constraints.clear();
+ for (int i = 0; i < num_values; ++i) {
+ desc->colocation_constraints.emplace(static_cast<const char*>(values[i]),
+ lengths[i]);
+ }
+ } else {
+ std::vector<tensorflow::StringPiece> v;
+ v.reserve(num_values);
+ for (int i = 0; i < num_values; ++i) {
+ v.emplace_back(static_cast<const char*>(values[i]), lengths[i]);
+ }
+ desc->node_builder.Attr(attr_name, v);
}
- desc->node_builder.Attr(attr_name, v);
}
void TF_SetAttrInt(TF_OperationDescription* desc, const char* attr_name,
@@ -1143,12 +1151,28 @@ void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name,
const void* proto, size_t proto_len,
TF_Status* status) {
tensorflow::AttrValue attr_value;
- if (attr_value.ParseFromArray(proto, proto_len)) {
- desc->node_builder.Attr(attr_name, attr_value);
- status->status = Status::OK();
- } else {
+ if (!attr_value.ParseFromArray(proto, proto_len)) {
status->status = InvalidArgument("Unparseable AttrValue proto");
+ return;
}
+
+ if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) {
+ if (attr_value.value_case() != tensorflow::AttrValue::kList &&
+ attr_value.value_case() != tensorflow::AttrValue::VALUE_NOT_SET) {
+ status->status =
+ InvalidArgument("Expected \"list\" field for \"",
+ tensorflow::kColocationAttrName, "\" attribute");
+ return;
+ }
+ desc->colocation_constraints.clear();
+ for (const tensorflow::string& location : attr_value.list().s()) {
+ desc->colocation_constraints.insert(location);
+ }
+ } else {
+ desc->node_builder.Attr(attr_name, attr_value);
+ }
+
+ status->status = Status::OK();
}
static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc,
@@ -1160,10 +1184,12 @@ static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc,
status->status = InvalidArgument("Duplicate node name in graph: '",
desc->node_builder.node_name(), "'");
} else {
- std::sort(desc->colocation_constraints.begin(),
- desc->colocation_constraints.end());
- desc->node_builder.Attr(tensorflow::kColocationAttrName,
- desc->colocation_constraints);
+ if (!desc->colocation_constraints.empty()) {
+ desc->node_builder.Attr(
+ tensorflow::kColocationAttrName,
+ std::vector<tensorflow::string>(desc->colocation_constraints.begin(),
+ desc->colocation_constraints.end()));
+ }
status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret);
if (status->status.ok()) {