diff options
author | 2017-07-26 21:53:47 -0700 | |
---|---|---|
committer | 2017-07-26 21:57:20 -0700 | |
commit | a52470172324e77dff9f548cb40e45d6a3a156b5 (patch) | |
tree | e273c480fe6861bd4af5fe0cd0d3e7f856039e90 | |
parent | a49fe0366880ec51f452bf3106d342bb586d5e93 (diff) |
Sets the incarnation number even when the attribute is set.
PiperOrigin-RevId: 163299121
-rw-r--r-- | tensorflow/core/graph/graph_partition.cc | 9 | ||||
-rw-r--r-- | tensorflow/core/graph/graph_partition_test.cc | 35 |
2 files changed, 42 insertions, 2 deletions
diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc index 750e18a9ca..bf8dcb2fcf 100644 --- a/tensorflow/core/graph/graph_partition.cc +++ b/tensorflow/core/graph/graph_partition.cc @@ -909,8 +909,13 @@ void SetIncarnation(const PartitionOptions& opts, NodeDef* ndef) { // No known send_device. The runtime will detect it later. return; } - int64 incarnation = opts.get_incarnation(send_device); - AddNodeAttr("send_device_incarnation", incarnation, ndef); + int64 incarnation = PartitionOptions::kIllegalIncarnation; + if (!GetNodeAttr(*ndef, "send_device_incarnation", &incarnation).ok() || + (incarnation == PartitionOptions::kIllegalIncarnation)) { + incarnation = opts.get_incarnation(send_device); + SetAttrValue(incarnation, + &((*ndef->mutable_attr())["send_device_incarnation"])); + } } // Sets attribute send_device_incarnation of all Send/Recv nodes in diff --git a/tensorflow/core/graph/graph_partition_test.cc b/tensorflow/core/graph/graph_partition_test.cc index 9c49b0b67b..3c12ed2689 100644 --- a/tensorflow/core/graph/graph_partition_test.cc +++ b/tensorflow/core/graph/graph_partition_test.cc @@ -445,6 +445,41 @@ TEST_F(GraphPartitionTest, Functions) { ExpectFunctions(partitions_[b].library(), {"XTimesTwo", "XTimesFour"}); } +TEST_F(GraphPartitionTest, SetIncarnation) { + GraphDef gdef; + const char* const kSendRecvAttrs = R"proto( + attr { key: 'T' value { type: DT_FLOAT } } + attr { key: 'client_terminated' value { b: false } } + attr { key: 'recv_device' value { s: 'B' } } + attr { key: 'send_device' value { s: 'A' } } + attr { key: 'send_device_incarnation' value { i: 0 } } + attr { key: 'tensor_name' value { s: 'test' } } +)proto"; + CHECK(protobuf::TextFormat::ParseFromString( + StrCat("node { name: 'A/Pi' op: 'Const' ", + " attr { key: 'dtype' value { type: DT_FLOAT } } ", + " attr { key: 'value' value { tensor { ", + " dtype: DT_FLOAT tensor_shape {} float_val: 3.14 } } } }", + "node { name: 'A' op: '_Send' input: 'A/Pi' ", kSendRecvAttrs, "}", + "node { name: 'B' op: '_Recv' ", kSendRecvAttrs, + " attr { key: 'tensor_type' value { type:DT_FLOAT}}}"), + &gdef)); + gdef.mutable_versions()->set_producer(TF_GRAPH_DEF_VERSION); + Partition(gdef, &partitions_); + EXPECT_EQ(2, partitions_.size()); + + for (const auto& kv : partitions_) { + const GraphDef& gdef = kv.second; + for (const NodeDef& ndef : gdef.node()) { + if (ndef.name() == "A" || ndef.name() == "B") { + int64 val; + TF_CHECK_OK(GetNodeAttr(ndef, "send_device_incarnation", &val)); + EXPECT_EQ(val, 100); // Send device is "A". + } + } + } +} + TEST(TopologicalSortNodesWithTimePriorityTest, NoDependencies) { // Create placeholders, shuffle them so the order in the graph is not strictly // increasing. |