aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/data/dataset_utils.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/data/dataset_utils.h')
-rw-r--r--tensorflow/core/kernels/data/dataset_utils.h20
1 files changed, 20 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/data/dataset_utils.h b/tensorflow/core/kernels/data/dataset_utils.h
index 6ec1350cd4..d777062293 100644
--- a/tensorflow/core/kernels/data/dataset_utils.h
+++ b/tensorflow/core/kernels/data/dataset_utils.h
@@ -22,6 +22,26 @@ limitations under the License.
namespace tensorflow {
namespace data {
+// This method is used to determine whether we can short-circuit the evaluation
+// of the user-defined function `func`. Short-circuting is possible if every
+// function output corresponds to one of its inputs (e.g. `f(x) = x`, `f(x,y) =
+// (y,x)`, or `f(x) = (x,x)`).
+//
+// If short-circuiting is possible, the method stores the mapping from output
+// indices to input indices in `indices`. Otherwise, `indices` will be empty.
+//
+// Returns non-ok status if analysis of the function fails.
+//
+// TODO(jsimsa): Extend this to support constants as well.
+Status ComputeShortCircuitIndices(OpKernelContext* ctx,
+ const NameAttrList& func,
+ std::vector<int>* indices);
+
+// Given a vector that maps output indices to input indices, return a vector
+// that identifies for which output indices can we move the input (assuming
+// output indices are processed left to right).
+std::vector<bool> ComputeMoveVector(const std::vector<int>& indices);
+
Status MakeIteratorFromInputElement(
IteratorContext* ctx, const std::vector<Tensor>& input_element,
int64 thread_index, CapturedFunction* captured_func, StringPiece prefix,