aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/data/dataset_utils.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/data/dataset_utils.cc')
-rw-r--r--tensorflow/core/kernels/data/dataset_utils.cc47
1 files changed, 47 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/data/dataset_utils.cc b/tensorflow/core/kernels/data/dataset_utils.cc
index e10833f525..a40f7f2146 100644
--- a/tensorflow/core/kernels/data/dataset_utils.cc
+++ b/tensorflow/core/kernels/data/dataset_utils.cc
@@ -15,10 +15,57 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
namespace tensorflow {
namespace data {
+Status ComputeShortCircuitIndices(OpKernelContext* ctx,
+ const NameAttrList& func,
+ std::vector<int>* indices) {
+ FunctionLibraryRuntime::Handle fn_handle;
+ TF_RETURN_IF_ERROR(ctx->function_library()->Instantiate(
+ func.name(), AttrSlice(&func.attr()), &fn_handle));
+ auto cleanup = gtl::MakeCleanup([ctx, fn_handle]() {
+ Status s = ctx->function_library()->ReleaseHandle(fn_handle);
+ if (!s.ok()) {
+ LOG(WARNING) << "Failed to release handle: " << s.error_message();
+ }
+ });
+
+ const FunctionBody* fn_body =
+ ctx->function_library()->GetFunctionBody(fn_handle);
+ indices->resize(fn_body->ret_nodes.size());
+ for (size_t i = 0; i < fn_body->ret_nodes.size(); ++i) {
+ Node* ret_node = fn_body->ret_nodes[i];
+ Node* ret_input_node;
+ TF_RETURN_IF_ERROR(ret_node->input_node(0, &ret_input_node));
+ if (ret_input_node->def().op() == FunctionLibraryDefinition::kArgOp) {
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(ret_input_node->def(), "index", &((*indices)[i])));
+ } else {
+ indices->clear();
+ break;
+ }
+ }
+ return Status::OK();
+}
+
+std::vector<bool> ComputeMoveVector(const std::vector<int>& indices) {
+ std::map<int, int> last_use;
+ for (size_t i = 0; i < indices.size(); ++i) {
+ last_use[indices[i]] = i;
+ }
+ std::vector<bool> can_move;
+ can_move.resize(indices.size());
+ for (size_t i = 0; i < indices.size(); ++i) {
+ can_move[i] = last_use[indices[i]] == i;
+ }
+ return can_move;
+}
+
Status MakeIteratorFromInputElement(
IteratorContext* ctx, const std::vector<Tensor>& input_element,
int64 thread_index, CapturedFunction* captured_func, StringPiece prefix,