aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_evaluator.h
diff options
context:
space:
mode:
authorGravatar Jacques Pienaar <jpienaar@google.com>2018-02-28 00:07:55 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-28 00:11:59 -0800
commit503d9b522e28272e032bc45a10e3c0f21398a16e (patch)
tree1c635afe83ac48ed0180f5975e1fa75dbf022124 /tensorflow/compiler/xla/service/hlo_evaluator.h
parentc38a16dbcc5de5fa5579a3e48ec12be316a2cb3f (diff)
[XLA:Evaluator] Handle while loop.
* Add while loop support to HloEvaluator; * Add a max_loop_iteration argument to the interpreter's constructor to limit the number of loop iterations that will be evaluated (or no bound if -1). Maintain current constant propagation behavior by setting limit to 0 for evaluators used for CP. PiperOrigin-RevId: 187287574
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_evaluator.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.h10
1 files changed, 9 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h
index fc82011630..8a27cf9a3a 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.h
@@ -36,7 +36,10 @@ namespace xla {
// This class is not thread-safe.
class HloEvaluator : public DfsHloVisitorWithDefault {
public:
- HloEvaluator();
+ // Only evaluate up to max_loop_iterations per while-loop execution if
+ // specified.
+ explicit HloEvaluator(int64 max_loop_iterations = -1);
+
// Evaluates an HLO module and an array of pointers to literals.
// Returns the evaluated result as a literal if successful.
// Precondition: The indices of arg_literals correspond to the parameter
@@ -157,6 +160,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
Status HandleCall(HloInstruction* call) override;
+ Status HandleWhile(HloInstruction* while_hlo) override;
+
private:
// Returns the already-evaluated literal result for the instruction.
// A Constant instruction is considered evaluated and its literal will be
@@ -194,6 +199,9 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
// Must be cleared for each evaluation.
std::vector<const Literal*> arg_literals_;
+ // Max loop iterations to execute with no maximum if negative.
+ int64 max_loop_iterations_;
+
TF_DISALLOW_COPY_AND_ASSIGN(HloEvaluator);
};