diff options
author | Sanjoy Das <sanjoy@google.com> | 2018-07-09 16:54:51 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-09 16:59:17 -0700 |
commit | 3493816ae9bf83e6a1a7639ebce3a05b944e2cd2 (patch) | |
tree | e9df013e59437a172efa0e3ba70f7b6c58f4d232 /tensorflow/compiler/xla/service/hlo_evaluator.cc | |
parent | 1e0e804ca57791b48d394ad6f7fb536774e8c220 (diff) |
Teach the indexed array analysis about dot operations
PiperOrigin-RevId: 203855406
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_evaluator.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_evaluator.cc | 18 |
1 files changed, 18 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index f68b4ca353..9dd56fcbf7 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -330,6 +330,24 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseUnaryOp( return result; } +StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateDotOp( + const DotDimensionNumbers& dim_numbers, const Literal& lhs, + const Literal& rhs) { + std::unique_ptr<HloInstruction> lhs_instr = + HloInstruction::CreateConstant(lhs.CloneToUnique()); + std::unique_ptr<HloInstruction> rhs_instr = + HloInstruction::CreateConstant(rhs.CloneToUnique()); + + TF_ASSIGN_OR_RETURN( + Shape dot_shape, + ShapeInference::InferDotOpShape(lhs.shape(), rhs.shape(), dim_numbers)); + + std::unique_ptr<HloInstruction> cloned_instruction = + HloInstruction::CreateDot(dot_shape, lhs_instr.get(), rhs_instr.get(), + dim_numbers); + return Evaluate(cloned_instruction.get()); +} + Status HloEvaluator::HandleParameter(HloInstruction* parameter) { CHECK_LT(parameter->parameter_number(), arg_literals_.size()); const Literal* input_literal = arg_literals_[parameter->parameter_number()]; |