aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/kernels/partitioned_function_ops.cc17
1 files changed, 10 insertions, 7 deletions
diff --git a/tensorflow/core/kernels/partitioned_function_ops.cc b/tensorflow/core/kernels/partitioned_function_ops.cc
index 70e2e055c5..8db78f9784 100644
--- a/tensorflow/core/kernels/partitioned_function_ops.cc
+++ b/tensorflow/core/kernels/partitioned_function_ops.cc
@@ -261,12 +261,6 @@ class PartitionedCallOp : public AsyncOpKernel {
// device, and
// (3) records which `Arg` and `Retval` nodes live in host memory.
Status UpdateArgAndRetMetadata(const string& device, Graph* subgraph) {
- if (arg_and_ret_indices_.find(device) != arg_and_ret_indices_.end()) {
- // This function has already been partitioned, albeit for a different
- // function library.
- return Status::OK();
- }
-
ArgAndRetIndices indices;
std::vector<int>* arg_indices = &indices.first;
std::vector<int>* ret_indices = &indices.second;
@@ -274,6 +268,8 @@ class PartitionedCallOp : public AsyncOpKernel {
std::vector<std::pair<Node*, int>> ret_nodes;
const AttrValue* attr_value;
+ // Find the Arg and Retval nodes, along with their corresponding indices
+ // in the original function.
for (Node* node : subgraph->op_nodes()) {
string node_type = node->type_string();
if (node_type == FunctionLibraryDefinition::kArgOp) {
@@ -289,6 +285,8 @@ class PartitionedCallOp : public AsyncOpKernel {
}
}
+ // Rewrite the indices of the Arg and Retval nodes for this function
+ // to range from 0 to the number of Arg nodes, Retval nodes, respectively.
auto sort_by_index = [](std::pair<Node*, int> one,
std::pair<Node*, int> two) -> bool {
return one.second < two.second;
@@ -318,7 +316,12 @@ class PartitionedCallOp : public AsyncOpKernel {
arg_and_ret_alloc_attrs_[device].second.push_back(alloc_attr);
}
- arg_and_ret_indices_.emplace(device, indices);
+ // If this kernel execution corresponds to a StatefulPartitionedCallOp,
+ // `arg_and_ret_indices_` might have been populated by a previous
+ // invocation.
+ if (arg_and_ret_indices_.find(device) == arg_and_ret_indices_.end()) {
+ arg_and_ret_indices_.emplace(device, indices);
+ }
return Status::OK();
}