aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_evaluator.h
diff options
context:
space:
mode:
authorGravatar Kay Zhu <kayzhu@google.com>2017-06-07 15:56:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-07 16:01:11 -0700
commitbeeaade460a125975b6fe34d23ff0465183f8b4a (patch)
tree705f461333574cd5513076e5111c6ca4fd9655e5 /tensorflow/compiler/xla/service/hlo_evaluator.h
parentb94540e6f7ea130674b8122ec192c3d9a07a6752 (diff)
Resubmit a reverted change. Original description:
[XLA] Enable HloEvaluator for constant folding, also merged a few operations from hlo_constant_folding to hlo_evaluator. Additionally: - In ShapeUtil::ForEachIndex: * fix a bug where visitor is called when the shape has zero elements (e.g., F32{1,0}) * added test case for ForEachIndex. - In HloEvaluator: * Instead of copying and caching a Constant instruction, return the literal directly if the instruction is constant. * Fix an issue where TUPLE and OPAQUE primitives are not keyed in the templated typed_visitor. * Use (fixed) LiteralUtil::Populate to populate resulting literal, fixes the preexisting bug in the evaluator where R0 and shape with zero size dimensions are not handled. * Refactor ElementWiseUnaryOp and HandleCompare to be templatized on the operand's type. * Refactor IsFinite to be top level since it is only applicable to floats and the return type is always boolean. * Change from std::remainder to std::fmod for kRemainder to be compliant with existing XLA behavior. * Change from std::max and std::min to std::fmax and std::fmin to handle NaNs. * Minor comments fix. PiperOrigin-RevId: 158330052
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_evaluator.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.h46
1 files changed, 40 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h
index 50cb32eb85..e6798a35a0 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.h
@@ -57,21 +57,32 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
// Evaluates a single HLO instruction and an array of pointers to literals.
// Return the evaluated result as literal if successful.
// Precondition:
- // 1. argument literals are corresponds to the input instruction's
- // parameters in their post-orderring.
+ // 1. argument literals correspond to the input instruction's parameters in
+ // their post-ordering.
// 2. the instruction's operands must be of either Parameter or Constant type.
// TODO(b/35950897): implement more ops other than element-wise ops.
StatusOr<std::unique_ptr<Literal>> Evaluate(
HloInstruction* instruction,
tensorflow::gtl::ArraySlice<const Literal*> arg_literals);
+ // Evaluates a single HLO instruction with constant operands.
+ // Returns the evaluated result as literal if successful.
+ // Precondition:
+ // 1. all operands of the input instruction are constants.
+ // 2. the instruction is not a Parameter operation.
+ StatusOr<std::unique_ptr<Literal>> Evaluate(HloInstruction* instruction);
+
+ // Same as Evaluate, except returning nullptr on error.
+ std::unique_ptr<Literal> TryEvaluate(HloInstruction* instruction);
+
protected:
// Templated DfsHloVisitor. Typically ReturnT here indicates the resulting
- // literal type of each evaluated Handle* method of a TypedVisitor. One
- // exception to this is HandleCompare, where the resulting literal type is
+ // literal type of each evaluated Handle* method of a TypedVisitor.
+ // There are however a few notable exceptions to this is rule, notably:
+ // - HandleCompare and HandleIsFinite: where the resulting literal type is
// always boolean.
- // Note the forward declaration here is necessary to enable TypedVisitor to
- // access parent members.
+ // These operations are handled outside of the parent HloEvaluator handlers
+ // instead of from within TypedVisitor.
template <typename ReturnT>
class TypedVisitor;
@@ -81,15 +92,38 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
return hlo->Visit(typed_visitors_.at(hlo->shape().element_type()).get());
}
+ // Operations that are type-agnostic.
+ //
Status HandleParameter(HloInstruction* parameter) override;
Status HandleConstant(HloInstruction* constant,
const Literal& literal) override;
+ Status HandleConcatenate(
+ HloInstruction* concatenate,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands) override;
+
+ Status HandleReshape(HloInstruction* reshape) override;
+
+ Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override;
+
+ Status HandleTranspose(HloInstruction* transpose) override;
+
+ Status HandleIsFinite(HloInstruction* is_finite,
+ HloInstruction* operand) override;
+
+ Status HandleCompare(HloInstruction* compare, HloOpcode opcode,
+ HloInstruction* lhs, HloInstruction* rhs) override;
+
private:
// Returns the already-evaluated literal result for the instruction.
+ // A Constant instruction is considered evaluated and its literal will be
+ // returned directly without looking up the cache.
// Crash with log if the given instruction has not been evaluated previously.
const Literal& GetEvaluatedLiteralFor(const HloInstruction* hlo) {
+ if (hlo->IsConstant()) {
+ return hlo->literal();
+ }
auto it = evaluated_.find(hlo);
CHECK(it != evaluated_.end())
<< "could not find evaluated value for: " << hlo->ToString();