diff options
Diffstat (limited to 'tensorflow/core/kernels/data/dataset_utils.cc')
-rw-r--r-- | tensorflow/core/kernels/data/dataset_utils.cc | 47 |
1 files changed, 47 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/data/dataset_utils.cc b/tensorflow/core/kernels/data/dataset_utils.cc index e10833f525..a40f7f2146 100644 --- a/tensorflow/core/kernels/data/dataset_utils.cc +++ b/tensorflow/core/kernels/data/dataset_utils.cc @@ -15,10 +15,57 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/gtl/cleanup.h" namespace tensorflow { namespace data { +Status ComputeShortCircuitIndices(OpKernelContext* ctx, + const NameAttrList& func, + std::vector<int>* indices) { + FunctionLibraryRuntime::Handle fn_handle; + TF_RETURN_IF_ERROR(ctx->function_library()->Instantiate( + func.name(), AttrSlice(&func.attr()), &fn_handle)); + auto cleanup = gtl::MakeCleanup([ctx, fn_handle]() { + Status s = ctx->function_library()->ReleaseHandle(fn_handle); + if (!s.ok()) { + LOG(WARNING) << "Failed to release handle: " << s.error_message(); + } + }); + + const FunctionBody* fn_body = + ctx->function_library()->GetFunctionBody(fn_handle); + indices->resize(fn_body->ret_nodes.size()); + for (size_t i = 0; i < fn_body->ret_nodes.size(); ++i) { + Node* ret_node = fn_body->ret_nodes[i]; + Node* ret_input_node; + TF_RETURN_IF_ERROR(ret_node->input_node(0, &ret_input_node)); + if (ret_input_node->def().op() == FunctionLibraryDefinition::kArgOp) { + TF_RETURN_IF_ERROR( + GetNodeAttr(ret_input_node->def(), "index", &((*indices)[i]))); + } else { + indices->clear(); + break; + } + } + return Status::OK(); +} + +std::vector<bool> ComputeMoveVector(const std::vector<int>& indices) { + std::map<int, int> last_use; + for (size_t i = 0; i < indices.size(); ++i) { + last_use[indices[i]] = i; + } + std::vector<bool> can_move; + can_move.resize(indices.size()); + for (size_t i = 0; i < indices.size(); ++i) { + can_move[i] = last_use[indices[i]] == i; + } + return can_move; +} + Status MakeIteratorFromInputElement( IteratorContext* ctx, const std::vector<Tensor>& input_element, int64 thread_index, CapturedFunction* captured_func, StringPiece prefix, |