diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_evaluator.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_evaluator.cc | 254 |
1 files changed, 163 insertions, 91 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index ebe7428052..1b3babc214 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -26,20 +26,21 @@ limitations under the License. #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_query.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/bitmap.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" @@ -53,9 +54,7 @@ std::unique_ptr<Literal> ElementWiseUnaryOp( const Literal& operand) { DCHECK(ShapeUtil::SameDimensions(shape, operand.shape())); - auto result = MakeUnique<Literal>(); - *result->mutable_shape() = shape; - LiteralUtil::Reserve(ShapeUtil::ElementsIn(shape), result.get()); + auto result = LiteralUtil::CreateFromShape(shape); std::vector<int64> multi_index(ShapeUtil::Rank(result->shape()), 0); do { @@ -74,9 +73,7 @@ std::unique_ptr<Literal> ElementWiseBinaryOp( DCHECK(ShapeUtil::SameDimensions(shape, rhs.shape())); DCHECK(ShapeUtil::SameDimensions(lhs.shape(), rhs.shape())); - auto result = MakeUnique<Literal>(); - *result->mutable_shape() = shape; - LiteralUtil::Reserve(ShapeUtil::ElementsIn(shape), result.get()); + auto result = LiteralUtil::CreateFromShape(shape); std::vector<int64> multi_index(ShapeUtil::Rank(result->shape()), 0); do { @@ -99,9 +96,7 @@ std::unique_ptr<Literal> ElementWiseTernaryOp( DCHECK(ShapeUtil::SameDimensions(lhs.shape(), rhs.shape())); DCHECK(ShapeUtil::SameDimensions(rhs.shape(), ehs.shape())); - auto result = MakeUnique<Literal>(); - *result->mutable_shape() = shape; - LiteralUtil::Reserve(ShapeUtil::ElementsIn(shape), result.get()); + auto result = LiteralUtil::CreateFromShape(shape); std::vector<int64> multi_index(ShapeUtil::Rank(result->shape()), 0); do { @@ -130,29 +125,130 @@ NativeT AbsoluteVal(NativeT value) { return std::abs(value); } -template <typename NativeT> -StatusOr<std::unique_ptr<Literal>> EvaluateOpForLiteralInternal( +} // namespace + +Status HloEvaluator::DefaultAction(HloInstruction* hlo) { + VLOG(2) << "Handle instruction: " << hlo->ToString(); + Shape shape = hlo->shape(); + TF_CHECK_OK(ShapeUtil::ValidateShape(shape)); + + TF_ASSIGN_OR_RETURN(evaluated_[hlo], EvaluateBasedOnType(hlo)); + return Status::OK(); +} + +Status HloEvaluator::HandleParameter(HloInstruction* parameter) { + VLOG(2) << "HandleParameter: " << parameter->ToString(); + const Literal* input_literal = arg_literals_[parameter->parameter_number()]; + VLOG(2) << "Parameter evaluated to: " + << LiteralUtil::ToString(*input_literal); + CHECK(ShapeUtil::Equal(parameter->shape(), input_literal->shape())); + + evaluated_[parameter] = MakeUnique<Literal>(*input_literal); + return Status::OK(); +} + +Status HloEvaluator::HandleConstant(HloInstruction* constant, + const Literal& literal) { + VLOG(2) << "HandleConstant: " << constant->ToString(); + CHECK(ShapeUtil::Equal(constant->shape(), literal.shape())); + + evaluated_[constant] = MakeUnique<Literal>(literal); + return Status::OK(); +} + +StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate( + HloComputation* computation, + tensorflow::gtl::ArraySlice<const Literal*> args) { + arg_literals_ = args; + TF_RETURN_IF_ERROR(computation->Accept(this)); + return std::move(FindOrDie(evaluated_, computation->root_instruction())); +} + +StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate( + HloInstruction* instruction, + tensorflow::gtl::ArraySlice<const Literal*> args) { + DCHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction)); + Shape shape = instruction->shape(); + TF_CHECK_OK(ShapeUtil::ValidateShape(shape)); + + arg_literals_ = args; + + // Evaluate operands of Parameter type against the input literals which caches + // the evaluated literal results. + for (const auto operand : instruction->operands()) { + if (operand->opcode() == HloOpcode::kParameter) { + TF_CHECK_OK(HandleParameter(operand)); + } else if (operand->opcode() == HloOpcode::kConstant) { + evaluated_[operand] = MakeUnique<Literal>(operand->literal()); + } + } + + TF_RETURN_IF_ERROR(instruction->Visit(this)); + return std::move(FindOrDie(evaluated_, instruction)); +} + +StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateBasedOnType( HloInstruction* instruction) { - DCHECK(hlo_query::AllOperandsAreConstants(*instruction)); + Shape shape = instruction->shape(); + TF_CHECK_OK(ShapeUtil::ValidateShape(shape)); + + switch (shape.element_type()) { + case PRED: + return EvaluateSameTypedElementwise<bool>(instruction); + case U8: + return EvaluateSameTypedElementwise<uint8>(instruction); + case U16: + return Unimplemented("unhandled primitive type: %s.", + PrimitiveType_Name(U16).c_str()); + case U32: + return EvaluateSameTypedElementwise<uint32>(instruction); + case U64: + return EvaluateSameTypedElementwise<uint64>(instruction); + case S8: + return EvaluateSameTypedElementwise<int8>(instruction); + case S16: + return Unimplemented("unhandled primitive type: %s.", + PrimitiveType_Name(S16).c_str()); + case S32: + return EvaluateSameTypedElementwise<int32>(instruction); + case S64: + return EvaluateSameTypedElementwise<int64>(instruction); + case F16: + return Unimplemented("unhandled primitive type: %s.", + PrimitiveType_Name(F16).c_str()); + case F32: + return EvaluateSameTypedElementwise<float>(instruction); + case F64: + return EvaluateSameTypedElementwise<double>(instruction); + default: + return Unimplemented("unhandled primitive type: %s.", + PrimitiveType_Name(shape.element_type()).c_str()); + } +} +template <typename NativeT> +StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateSameTypedElementwise( + HloInstruction* instruction) { const std::vector<HloInstruction*>& operands = instruction->operands(); HloOpcode opcode = instruction->opcode(); const Shape& shape = instruction->shape(); switch (opcode) { // TODO(b/35950897): many of the stl function used here are not overloaded - // for all XLA primitive types. + // for every XLA primitive types. + // Unary element-wise ops. + // case HloOpcode::kAbs: CHECK_EQ(operands.size(), 1); return ElementWiseUnaryOp<NativeT>( shape, [](NativeT operand) { return AbsoluteVal(operand); }, - operands[0]->literal()); + GetEvaluatedLiteralFor(operands[0])); case HloOpcode::kCeil: CHECK_EQ(operands.size(), 1); return ElementWiseUnaryOp<NativeT>( shape, [](NativeT operand) { return std::ceil(operand); }, - operands[0]->literal()); + GetEvaluatedLiteralFor(operands[0])); case HloOpcode::kConvert: CHECK_EQ(operands.size(), 1); // TODO(b/35950897): implement Convert. @@ -162,37 +258,37 @@ StatusOr<std::unique_ptr<Literal>> EvaluateOpForLiteralInternal( CHECK_EQ(operands.size(), 1); return ElementWiseUnaryOp<NativeT>( shape, [](NativeT operand) { return operand; }, - operands[0]->literal()); + GetEvaluatedLiteralFor(operands[0])); case HloOpcode::kExp: CHECK_EQ(operands.size(), 1); return ElementWiseUnaryOp<NativeT>( shape, [](NativeT operand) { return std::exp(operand); }, - operands[0]->literal()); + GetEvaluatedLiteralFor(operands[0])); case HloOpcode::kFloor: CHECK_EQ(operands.size(), 1); return ElementWiseUnaryOp<NativeT>( shape, [](NativeT operand) { return std::floor(operand); }, - operands[0]->literal()); + GetEvaluatedLiteralFor(operands[0])); case HloOpcode::kIsFinite: CHECK_EQ(operands.size(), 1); return ElementWiseUnaryOp<NativeT>( shape, [](NativeT operand) { return std::isfinite(operand); }, - operands[0]->literal()); + GetEvaluatedLiteralFor(operands[0])); case HloOpcode::kLog: CHECK_EQ(operands.size(), 1); return ElementWiseUnaryOp<NativeT>( shape, [](NativeT operand) { return std::log(operand); }, - operands[0]->literal()); + GetEvaluatedLiteralFor(operands[0])); case HloOpcode::kLogicalNot: CHECK_EQ(operands.size(), 1); return ElementWiseUnaryOp<NativeT>( shape, [](NativeT operand) { return !operand; }, - operands[0]->literal()); + GetEvaluatedLiteralFor(operands[0])); case HloOpcode::kNegate: CHECK_EQ(operands.size(), 1); return ElementWiseUnaryOp<NativeT>( shape, [](NativeT operand) { return -operand; }, - operands[0]->literal()); + GetEvaluatedLiteralFor(operands[0])); case HloOpcode::kSign: CHECK_EQ(operands.size(), 1); CHECK(primitive_util::IsIntegralType(shape.element_type())); @@ -201,95 +297,113 @@ StatusOr<std::unique_ptr<Literal>> EvaluateOpForLiteralInternal( return (NativeT(0) < operand) - (operand < NativeT(0)); }, - operands[0]->literal()); + GetEvaluatedLiteralFor(operands[0])); case HloOpcode::kTanh: CHECK_EQ(operands.size(), 1); return ElementWiseUnaryOp<NativeT>( shape, [](NativeT operand) { return std::tanh(operand); }, - operands[0]->literal()); + GetEvaluatedLiteralFor(operands[0])); // Binary element-wise ops. + // case HloOpcode::kAdd: CHECK_EQ(operands.size(), 2); return ElementWiseBinaryOp<NativeT>( shape, [](NativeT lhs, NativeT rhs) { return lhs + rhs; }, - operands[0]->literal(), operands[1]->literal()); + GetEvaluatedLiteralFor(operands[0]), + GetEvaluatedLiteralFor(operands[1])); case HloOpcode::kDivide: CHECK_EQ(operands.size(), 2); return ElementWiseBinaryOp<NativeT>( shape, [](NativeT lhs, NativeT rhs) { return lhs / rhs; }, - operands[0]->literal(), operands[1]->literal()); + GetEvaluatedLiteralFor(operands[0]), + GetEvaluatedLiteralFor(operands[1])); case HloOpcode::kMultiply: CHECK_EQ(operands.size(), 2); return ElementWiseBinaryOp<NativeT>( shape, [](NativeT lhs, NativeT rhs) { return lhs * rhs; }, - operands[0]->literal(), operands[1]->literal()); + GetEvaluatedLiteralFor(operands[0]), + GetEvaluatedLiteralFor(operands[1])); case HloOpcode::kSubtract: CHECK_EQ(operands.size(), 2); return ElementWiseBinaryOp<NativeT>( shape, [](NativeT lhs, NativeT rhs) { return lhs - rhs; }, - operands[0]->literal(), operands[1]->literal()); + GetEvaluatedLiteralFor(operands[0]), + GetEvaluatedLiteralFor(operands[1])); case HloOpcode::kEq: CHECK_EQ(operands.size(), 2); return ElementWiseBinaryOp<bool>( shape, [](NativeT lhs, NativeT rhs) { return lhs == rhs; }, - operands[0]->literal(), operands[1]->literal()); + GetEvaluatedLiteralFor(operands[0]), + GetEvaluatedLiteralFor(operands[1])); case HloOpcode::kGe: CHECK_EQ(operands.size(), 2); return ElementWiseBinaryOp<bool>( shape, [](NativeT lhs, NativeT rhs) { return lhs >= rhs; }, - operands[0]->literal(), operands[1]->literal()); + GetEvaluatedLiteralFor(operands[0]), + GetEvaluatedLiteralFor(operands[1])); case HloOpcode::kGt: CHECK_EQ(operands.size(), 2); return ElementWiseBinaryOp<bool>( shape, [](NativeT lhs, NativeT rhs) { return lhs > rhs; }, - operands[0]->literal(), operands[1]->literal()); + GetEvaluatedLiteralFor(operands[0]), + GetEvaluatedLiteralFor(operands[1])); case HloOpcode::kLe: CHECK_EQ(operands.size(), 2); return ElementWiseBinaryOp<bool>( shape, [](NativeT lhs, NativeT rhs) { return lhs <= rhs; }, - operands[0]->literal(), operands[1]->literal()); + GetEvaluatedLiteralFor(operands[0]), + GetEvaluatedLiteralFor(operands[1])); case HloOpcode::kLt: CHECK_EQ(operands.size(), 2); return ElementWiseBinaryOp<bool>( shape, [](NativeT lhs, NativeT rhs) { return lhs < rhs; }, - operands[0]->literal(), operands[1]->literal()); + GetEvaluatedLiteralFor(operands[0]), + GetEvaluatedLiteralFor(operands[1])); case HloOpcode::kNe: CHECK_EQ(operands.size(), 2); return ElementWiseBinaryOp<bool>( shape, [](NativeT lhs, NativeT rhs) { return lhs != rhs; }, - operands[0]->literal(), operands[1]->literal()); + GetEvaluatedLiteralFor(operands[0]), + GetEvaluatedLiteralFor(operands[1])); case HloOpcode::kMaximum: CHECK_EQ(operands.size(), 2); return ElementWiseBinaryOp<NativeT>( shape, [](NativeT lhs, NativeT rhs) { return std::max(lhs, rhs); }, - operands[0]->literal(), operands[1]->literal()); + GetEvaluatedLiteralFor(operands[0]), + GetEvaluatedLiteralFor(operands[1])); case HloOpcode::kMinimum: CHECK_EQ(operands.size(), 2); return ElementWiseBinaryOp<NativeT>( shape, [](NativeT lhs, NativeT rhs) { return std::min(lhs, rhs); }, - operands[0]->literal(), operands[1]->literal()); + GetEvaluatedLiteralFor(operands[0]), + GetEvaluatedLiteralFor(operands[1])); case HloOpcode::kPower: CHECK_EQ(operands.size(), 2); return ElementWiseBinaryOp<NativeT>( shape, [](NativeT lhs, NativeT rhs) { return std::pow(lhs, rhs); }, - operands[0]->literal(), operands[1]->literal()); + GetEvaluatedLiteralFor(operands[0]), + GetEvaluatedLiteralFor(operands[1])); case HloOpcode::kRemainder: CHECK_EQ(operands.size(), 2); return ElementWiseBinaryOp<NativeT>( shape, [](NativeT lhs, NativeT rhs) { return std::remainder(lhs, rhs); }, - operands[0]->literal(), operands[1]->literal()); + GetEvaluatedLiteralFor(operands[0]), + GetEvaluatedLiteralFor(operands[1])); case HloOpcode::kLogicalAnd: CHECK_EQ(operands.size(), 2); return ElementWiseBinaryOp<NativeT>( shape, [](NativeT lhs, NativeT rhs) { return lhs && rhs; }, - operands[0]->literal(), operands[1]->literal()); + GetEvaluatedLiteralFor(operands[0]), + GetEvaluatedLiteralFor(operands[1])); case HloOpcode::kLogicalOr: CHECK_EQ(operands.size(), 2); return ElementWiseBinaryOp<NativeT>( shape, [](NativeT lhs, NativeT rhs) { return lhs || rhs; }, - operands[0]->literal(), operands[1]->literal()); + GetEvaluatedLiteralFor(operands[0]), + GetEvaluatedLiteralFor(operands[1])); // Ternary element-wise ops. + // case HloOpcode::kClamp: { CHECK_EQ(operands.size(), 3); std::function<NativeT(NativeT, NativeT, NativeT)> clamp_op = @@ -297,8 +411,9 @@ StatusOr<std::unique_ptr<Literal>> EvaluateOpForLiteralInternal( return std::max(low, std::min(value, high)); }; return ElementWiseTernaryOp<NativeT, NativeT, NativeT, NativeT>( - shape, std::move(clamp_op), operands[0]->literal(), - operands[1]->literal(), operands[2]->literal()); + shape, std::move(clamp_op), GetEvaluatedLiteralFor(operands[0]), + GetEvaluatedLiteralFor(operands[1]), + GetEvaluatedLiteralFor(operands[2])); } break; case HloOpcode::kSelect: { CHECK_EQ(operands.size(), 3); @@ -311,8 +426,9 @@ StatusOr<std::unique_ptr<Literal>> EvaluateOpForLiteralInternal( return on_false; }; return ElementWiseTernaryOp<NativeT, bool, NativeT, NativeT>( - shape, std::move(select_op), operands[0]->literal(), - operands[1]->literal(), operands[2]->literal()); + shape, std::move(select_op), GetEvaluatedLiteralFor(operands[0]), + GetEvaluatedLiteralFor(operands[1]), + GetEvaluatedLiteralFor(operands[2])); } break; default: return Unimplemented("unhandled HLO ops for HloEvaluator: %s.", @@ -320,48 +436,4 @@ StatusOr<std::unique_ptr<Literal>> EvaluateOpForLiteralInternal( } } -} // namespace - -/* static */ StatusOr<std::unique_ptr<Literal>> -HloEvaluator::EvaluateOpForLiteral(HloInstruction* instruction) { - DCHECK(hlo_query::AllOperandsAreConstants(*instruction)); - - Shape shape = instruction->shape(); - TF_CHECK_OK(ShapeUtil::ValidateShape(shape)); - - // REVIEW QUESTION: other than a few operations, do we need to handle the - // general case of operands being of different types in the context of the - // evaluator? - - switch (shape.element_type()) { - case PRED: - return EvaluateOpForLiteralInternal<bool>(instruction); - case U8: - return EvaluateOpForLiteralInternal<uint8>(instruction); - case U16: - LOG(FATAL) << "U16/uint16 is unimplemented."; - case U32: - return EvaluateOpForLiteralInternal<uint32>(instruction); - case U64: - return EvaluateOpForLiteralInternal<uint64>(instruction); - case S8: - return EvaluateOpForLiteralInternal<int8>(instruction); - case S16: - LOG(FATAL) << "S16/int16 is unimplemented."; - case S32: - return EvaluateOpForLiteralInternal<int32>(instruction); - case S64: - return EvaluateOpForLiteralInternal<int64>(instruction); - case F16: - LOG(FATAL) << "F16 is unimplemented."; - case F32: - return EvaluateOpForLiteralInternal<float>(instruction); - case F64: - return EvaluateOpForLiteralInternal<double>(instruction); - default: - return Unimplemented("unhandled primitive type: %s.", - PrimitiveType_Name(shape.element_type()).c_str()); - } -} - } // namespace xla |