diff options
-rw-r--r-- | tensorflow/core/kernels/partitioned_function_ops.cc | 17 |
1 files changed, 10 insertions, 7 deletions
diff --git a/tensorflow/core/kernels/partitioned_function_ops.cc b/tensorflow/core/kernels/partitioned_function_ops.cc index 70e2e055c5..8db78f9784 100644 --- a/tensorflow/core/kernels/partitioned_function_ops.cc +++ b/tensorflow/core/kernels/partitioned_function_ops.cc @@ -261,12 +261,6 @@ class PartitionedCallOp : public AsyncOpKernel { // device, and // (3) records which `Arg` and `Retval` nodes live in host memory. Status UpdateArgAndRetMetadata(const string& device, Graph* subgraph) { - if (arg_and_ret_indices_.find(device) != arg_and_ret_indices_.end()) { - // This function has already been partitioned, albeit for a different - // function library. - return Status::OK(); - } - ArgAndRetIndices indices; std::vector<int>* arg_indices = &indices.first; std::vector<int>* ret_indices = &indices.second; @@ -274,6 +268,8 @@ class PartitionedCallOp : public AsyncOpKernel { std::vector<std::pair<Node*, int>> ret_nodes; const AttrValue* attr_value; + // Find the Arg and Retval nodes, along with their corresponding indices + // in the original function. for (Node* node : subgraph->op_nodes()) { string node_type = node->type_string(); if (node_type == FunctionLibraryDefinition::kArgOp) { @@ -289,6 +285,8 @@ class PartitionedCallOp : public AsyncOpKernel { } } + // Rewrite the indices of the Arg and Retval nodes for this function + // to range from 0 to the number of Arg nodes, Retval nodes, respectively. auto sort_by_index = [](std::pair<Node*, int> one, std::pair<Node*, int> two) -> bool { return one.second < two.second; @@ -318,7 +316,12 @@ class PartitionedCallOp : public AsyncOpKernel { arg_and_ret_alloc_attrs_[device].second.push_back(alloc_attr); } - arg_and_ret_indices_.emplace(device, indices); + // If this kernel execution corresponds to a StatefulPartitionedCallOp, + // `arg_and_ret_indices_` might have been populated by a previous + // invocation. + if (arg_and_ret_indices_.find(device) == arg_and_ret_indices_.end()) { + arg_and_ret_indices_.emplace(device, indices); + } return Status::OK(); } |