aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-07-26 21:53:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-26 21:57:20 -0700
commita52470172324e77dff9f548cb40e45d6a3a156b5 (patch)
treee273c480fe6861bd4af5fe0cd0d3e7f856039e90
parenta49fe0366880ec51f452bf3106d342bb586d5e93 (diff)
Sets the incarnation number even when the attribute is set.
PiperOrigin-RevId: 163299121
-rw-r--r--tensorflow/core/graph/graph_partition.cc9
-rw-r--r--tensorflow/core/graph/graph_partition_test.cc35
2 files changed, 42 insertions, 2 deletions
diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc
index 750e18a9ca..bf8dcb2fcf 100644
--- a/tensorflow/core/graph/graph_partition.cc
+++ b/tensorflow/core/graph/graph_partition.cc
@@ -909,8 +909,13 @@ void SetIncarnation(const PartitionOptions& opts, NodeDef* ndef) {
// No known send_device. The runtime will detect it later.
return;
}
- int64 incarnation = opts.get_incarnation(send_device);
- AddNodeAttr("send_device_incarnation", incarnation, ndef);
+ int64 incarnation = PartitionOptions::kIllegalIncarnation;
+ if (!GetNodeAttr(*ndef, "send_device_incarnation", &incarnation).ok() ||
+ (incarnation == PartitionOptions::kIllegalIncarnation)) {
+ incarnation = opts.get_incarnation(send_device);
+ SetAttrValue(incarnation,
+ &((*ndef->mutable_attr())["send_device_incarnation"]));
+ }
}
// Sets attribute send_device_incarnation of all Send/Recv nodes in
diff --git a/tensorflow/core/graph/graph_partition_test.cc b/tensorflow/core/graph/graph_partition_test.cc
index 9c49b0b67b..3c12ed2689 100644
--- a/tensorflow/core/graph/graph_partition_test.cc
+++ b/tensorflow/core/graph/graph_partition_test.cc
@@ -445,6 +445,41 @@ TEST_F(GraphPartitionTest, Functions) {
ExpectFunctions(partitions_[b].library(), {"XTimesTwo", "XTimesFour"});
}
+TEST_F(GraphPartitionTest, SetIncarnation) {
+ GraphDef gdef;
+ const char* const kSendRecvAttrs = R"proto(
+ attr { key: 'T' value { type: DT_FLOAT } }
+ attr { key: 'client_terminated' value { b: false } }
+ attr { key: 'recv_device' value { s: 'B' } }
+ attr { key: 'send_device' value { s: 'A' } }
+ attr { key: 'send_device_incarnation' value { i: 0 } }
+ attr { key: 'tensor_name' value { s: 'test' } }
+)proto";
+ CHECK(protobuf::TextFormat::ParseFromString(
+ StrCat("node { name: 'A/Pi' op: 'Const' ",
+ " attr { key: 'dtype' value { type: DT_FLOAT } } ",
+ " attr { key: 'value' value { tensor { ",
+ " dtype: DT_FLOAT tensor_shape {} float_val: 3.14 } } } }",
+ "node { name: 'A' op: '_Send' input: 'A/Pi' ", kSendRecvAttrs, "}",
+ "node { name: 'B' op: '_Recv' ", kSendRecvAttrs,
+ " attr { key: 'tensor_type' value { type:DT_FLOAT}}}"),
+ &gdef));
+ gdef.mutable_versions()->set_producer(TF_GRAPH_DEF_VERSION);
+ Partition(gdef, &partitions_);
+ EXPECT_EQ(2, partitions_.size());
+
+ for (const auto& kv : partitions_) {
+ const GraphDef& gdef = kv.second;
+ for (const NodeDef& ndef : gdef.node()) {
+ if (ndef.name() == "A" || ndef.name() == "B") {
+ int64 val;
+ TF_CHECK_OK(GetNodeAttr(ndef, "send_device_incarnation", &val));
+ EXPECT_EQ(val, 100); // Send device is "A".
+ }
+ }
+ }
+}
+
TEST(TopologicalSortNodesWithTimePriorityTest, NoDependencies) {
// Create placeholders, shuffle them so the order in the graph is not strictly
// increasing.