diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_evaluator.h')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_evaluator.h | 9 |
1 files changed, 9 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index 2ad56080d8..a4c37ef328 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -115,6 +116,10 @@ class HloEvaluator : public DfsHloVisitorWithDefault { StatusOr<std::unique_ptr<Literal>> EvaluateElementwiseUnaryOp( HloOpcode opcode, const Literal& operand); + StatusOr<std::unique_ptr<Literal>> EvaluateDotOp( + const DotDimensionNumbers& dim_numbers, const Literal& lhs, + const Literal& rhs); + protected: // Make HloEvaluatorTypedVisitor a friend because it is logically part of this // class. @@ -172,10 +177,14 @@ class HloEvaluator : public DfsHloVisitorWithDefault { Status HandleSelect(HloInstruction* select) override; + Status HandleTupleSelect(HloInstruction* tuple_select) override; + Status HandleBroadcast(HloInstruction* broadcast) override; Status HandleAfterAll(HloInstruction* token) override; + Status HandleSort(HloInstruction* sort) override; + // Returns the already-evaluated literal result for the instruction. // A Constant instruction is considered evaluated and its literal will be // returned directly without looking up the cache. |