diff options
author | Jacques Pienaar <jpienaar@google.com> | 2018-02-28 00:07:55 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-02-28 00:11:59 -0800 |
commit | 503d9b522e28272e032bc45a10e3c0f21398a16e (patch) | |
tree | 1c635afe83ac48ed0180f5975e1fa75dbf022124 /tensorflow/compiler/xla/service/hlo_evaluator.h | |
parent | c38a16dbcc5de5fa5579a3e48ec12be316a2cb3f (diff) |
[XLA:Evaluator] Handle while loop.
* Add while loop support to HloEvaluator;
* Add a max_loop_iteration argument to the interpreter's constructor to limit
the number of loop iterations that will be evaluated (or no bound if -1).
Maintain current constant propagation behavior by setting limit to 0 for evaluators used for CP.
PiperOrigin-RevId: 187287574
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_evaluator.h')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_evaluator.h | 10 |
1 files changed, 9 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index fc82011630..8a27cf9a3a 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -36,7 +36,10 @@ namespace xla { // This class is not thread-safe. class HloEvaluator : public DfsHloVisitorWithDefault { public: - HloEvaluator(); + // Only evaluate up to max_loop_iterations per while-loop execution if + // specified. + explicit HloEvaluator(int64 max_loop_iterations = -1); + // Evaluates an HLO module and an array of pointers to literals. // Returns the evaluated result as a literal if successful. // Precondition: The indices of arg_literals correspond to the parameter @@ -157,6 +160,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault { Status HandleCall(HloInstruction* call) override; + Status HandleWhile(HloInstruction* while_hlo) override; + private: // Returns the already-evaluated literal result for the instruction. // A Constant instruction is considered evaluated and its literal will be @@ -194,6 +199,9 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // Must be cleared for each evaluation. std::vector<const Literal*> arg_literals_; + // Max loop iterations to execute with no maximum if negative. + int64 max_loop_iterations_; + TF_DISALLOW_COPY_AND_ASSIGN(HloEvaluator); }; |