aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_evaluator.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_evaluator.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.h9
1 files changed, 9 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h
index 2ad56080d8..a4c37ef328 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.h
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -115,6 +116,10 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
StatusOr<std::unique_ptr<Literal>> EvaluateElementwiseUnaryOp(
HloOpcode opcode, const Literal& operand);
+ StatusOr<std::unique_ptr<Literal>> EvaluateDotOp(
+ const DotDimensionNumbers& dim_numbers, const Literal& lhs,
+ const Literal& rhs);
+
protected:
// Make HloEvaluatorTypedVisitor a friend because it is logically part of this
// class.
@@ -172,10 +177,14 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
Status HandleSelect(HloInstruction* select) override;
+ Status HandleTupleSelect(HloInstruction* tuple_select) override;
+
Status HandleBroadcast(HloInstruction* broadcast) override;
Status HandleAfterAll(HloInstruction* token) override;
+ Status HandleSort(HloInstruction* sort) override;
+
// Returns the already-evaluated literal result for the instruction.
// A Constant instruction is considered evaluated and its literal will be
// returned directly without looking up the cache.