diff options
Diffstat (limited to 'tensorflow/core/kernels/data/dataset_utils.h')
-rw-r--r-- | tensorflow/core/kernels/data/dataset_utils.h | 20 |
1 files changed, 20 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/data/dataset_utils.h b/tensorflow/core/kernels/data/dataset_utils.h index 6ec1350cd4..d777062293 100644 --- a/tensorflow/core/kernels/data/dataset_utils.h +++ b/tensorflow/core/kernels/data/dataset_utils.h @@ -22,6 +22,26 @@ limitations under the License. namespace tensorflow { namespace data { +// This method is used to determine whether we can short-circuit the evaluation +// of the user-defined function `func`. Short-circuting is possible if every +// function output corresponds to one of its inputs (e.g. `f(x) = x`, `f(x,y) = +// (y,x)`, or `f(x) = (x,x)`). +// +// If short-circuiting is possible, the method stores the mapping from output +// indices to input indices in `indices`. Otherwise, `indices` will be empty. +// +// Returns non-ok status if analysis of the function fails. +// +// TODO(jsimsa): Extend this to support constants as well. +Status ComputeShortCircuitIndices(OpKernelContext* ctx, + const NameAttrList& func, + std::vector<int>* indices); + +// Given a vector that maps output indices to input indices, return a vector +// that identifies for which output indices can we move the input (assuming +// output indices are processed left to right). +std::vector<bool> ComputeMoveVector(const std::vector<int>& indices); + Status MakeIteratorFromInputElement( IteratorContext* ctx, const std::vector<Tensor>& input_element, int64 thread_index, CapturedFunction* captured_func, StringPiece prefix, |