aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_evaluator.cc
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-07-09 16:54:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-09 16:59:17 -0700
commit3493816ae9bf83e6a1a7639ebce3a05b944e2cd2 (patch)
treee9df013e59437a172efa0e3ba70f7b6c58f4d232 /tensorflow/compiler/xla/service/hlo_evaluator.cc
parent1e0e804ca57791b48d394ad6f7fb536774e8c220 (diff)
Teach the indexed array analysis about dot operations
PiperOrigin-RevId: 203855406
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_evaluator.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc18
1 files changed, 18 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index f68b4ca353..9dd56fcbf7 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -330,6 +330,24 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseUnaryOp(
return result;
}
+StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateDotOp(
+ const DotDimensionNumbers& dim_numbers, const Literal& lhs,
+ const Literal& rhs) {
+ std::unique_ptr<HloInstruction> lhs_instr =
+ HloInstruction::CreateConstant(lhs.CloneToUnique());
+ std::unique_ptr<HloInstruction> rhs_instr =
+ HloInstruction::CreateConstant(rhs.CloneToUnique());
+
+ TF_ASSIGN_OR_RETURN(
+ Shape dot_shape,
+ ShapeInference::InferDotOpShape(lhs.shape(), rhs.shape(), dim_numbers));
+
+ std::unique_ptr<HloInstruction> cloned_instruction =
+ HloInstruction::CreateDot(dot_shape, lhs_instr.get(), rhs_instr.get(),
+ dim_numbers);
+ return Evaluate(cloned_instruction.get());
+}
+
Status HloEvaluator::HandleParameter(HloInstruction* parameter) {
CHECK_LT(parameter->parameter_number(), arg_literals_.size());
const Literal* input_literal = arg_literals_[parameter->parameter_number()];