diff options
author | 2017-06-07 15:56:12 -0700 | |
---|---|---|
committer | 2017-06-07 16:01:11 -0700 | |
commit | beeaade460a125975b6fe34d23ff0465183f8b4a (patch) | |
tree | 705f461333574cd5513076e5111c6ca4fd9655e5 /tensorflow/compiler/xla/service/hlo_evaluator.h | |
parent | b94540e6f7ea130674b8122ec192c3d9a07a6752 (diff) |
Resubmit a reverted change. Original description:
[XLA] Enable HloEvaluator for constant folding, also merged a few operations
from hlo_constant_folding to hlo_evaluator.
Additionally:
- In ShapeUtil::ForEachIndex:
* fix a bug where visitor is called when the shape has zero elements (e.g., F32{1,0})
* added test case for ForEachIndex.
- In HloEvaluator:
* Instead of copying and caching a Constant instruction, return the literal directly if the instruction is constant.
* Fix an issue where TUPLE and OPAQUE primitives are not keyed in the templated typed_visitor.
* Use (fixed) LiteralUtil::Populate to populate resulting literal, fixes the preexisting bug in the evaluator where R0 and shape with zero size dimensions are not handled.
* Refactor ElementWiseUnaryOp and HandleCompare to be templatized on the operand's type.
* Refactor IsFinite to be top level since it is only applicable to floats and the return type is always boolean.
* Change from std::remainder to std::fmod for kRemainder to be compliant with existing XLA behavior.
* Change from std::max and std::min to std::fmax and std::fmin to handle NaNs.
* Minor comments fix.
PiperOrigin-RevId: 158330052
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_evaluator.h')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_evaluator.h | 46 |
1 files changed, 40 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index 50cb32eb85..e6798a35a0 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -57,21 +57,32 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // Evaluates a single HLO instruction and an array of pointers to literals. // Return the evaluated result as literal if successful. // Precondition: - // 1. argument literals are corresponds to the input instruction's - // parameters in their post-orderring. + // 1. argument literals correspond to the input instruction's parameters in + // their post-ordering. // 2. the instruction's operands must be of either Parameter or Constant type. // TODO(b/35950897): implement more ops other than element-wise ops. StatusOr<std::unique_ptr<Literal>> Evaluate( HloInstruction* instruction, tensorflow::gtl::ArraySlice<const Literal*> arg_literals); + // Evaluates a single HLO instruction with constant operands. + // Returns the evaluated result as literal if successful. + // Precondition: + // 1. all operands of the input instruction are constants. + // 2. the instruction is not a Parameter operation. + StatusOr<std::unique_ptr<Literal>> Evaluate(HloInstruction* instruction); + + // Same as Evaluate, except returning nullptr on error. + std::unique_ptr<Literal> TryEvaluate(HloInstruction* instruction); + protected: // Templated DfsHloVisitor. Typically ReturnT here indicates the resulting - // literal type of each evaluated Handle* method of a TypedVisitor. One - // exception to this is HandleCompare, where the resulting literal type is + // literal type of each evaluated Handle* method of a TypedVisitor. + // There are however a few notable exceptions to this is rule, notably: + // - HandleCompare and HandleIsFinite: where the resulting literal type is // always boolean. - // Note the forward declaration here is necessary to enable TypedVisitor to - // access parent members. + // These operations are handled outside of the parent HloEvaluator handlers + // instead of from within TypedVisitor. template <typename ReturnT> class TypedVisitor; @@ -81,15 +92,38 @@ class HloEvaluator : public DfsHloVisitorWithDefault { return hlo->Visit(typed_visitors_.at(hlo->shape().element_type()).get()); } + // Operations that are type-agnostic. + // Status HandleParameter(HloInstruction* parameter) override; Status HandleConstant(HloInstruction* constant, const Literal& literal) override; + Status HandleConcatenate( + HloInstruction* concatenate, + tensorflow::gtl::ArraySlice<HloInstruction*> operands) override; + + Status HandleReshape(HloInstruction* reshape) override; + + Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override; + + Status HandleTranspose(HloInstruction* transpose) override; + + Status HandleIsFinite(HloInstruction* is_finite, + HloInstruction* operand) override; + + Status HandleCompare(HloInstruction* compare, HloOpcode opcode, + HloInstruction* lhs, HloInstruction* rhs) override; + private: // Returns the already-evaluated literal result for the instruction. + // A Constant instruction is considered evaluated and its literal will be + // returned directly without looking up the cache. // Crash with log if the given instruction has not been evaluated previously. const Literal& GetEvaluatedLiteralFor(const HloInstruction* hlo) { + if (hlo->IsConstant()) { + return hlo->literal(); + } auto it = evaluated_.find(hlo); CHECK(it != evaluated_.end()) << "could not find evaluated value for: " << hlo->ToString(); |