diff options
author | 2018-05-25 17:46:19 -0700 | |
---|---|---|
committer | 2018-05-25 17:48:50 -0700 | |
commit | b4423efd55c5e463dd70d6975aa3a9d0f260011b (patch) | |
tree | 44b03f53e867a1732224d2e65a37075b2ed742d5 /tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h | |
parent | a6eb244b2b8ee4d9592a705c4bc0771e4d708565 (diff) |
Add a type-erased broadcast implementation to xla::Literal
And use this in HLO evaluator. Since broadcast only moves bytes around we don't
need a type specialized implementation.
I'll use this in a later change.
PiperOrigin-RevId: 198128524
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h | 30 |
1 files changed, 0 insertions, 30 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index e37d651c95..82ee77e1ae 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -161,36 +161,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return HandleRound<ReturnT>(round); } - Status HandleBroadcast(HloInstruction* broadcast) override { - const Literal& operand_to_broadcast = - parent_->GetEvaluatedLiteralFor(broadcast->operand(0)); - std::vector<int64> broadcast_indices( - ShapeUtil::Rank(broadcast->operand(0)->shape()), 0); - - TF_RET_CHECK(broadcast->dimensions().size() == - ShapeUtil::Rank(operand_to_broadcast.shape())) - << "broadcast dimensions is of size: " << broadcast->dimensions().size() - << " and rank of operand_to_broadcast is: " - << ShapeUtil::Rank(operand_to_broadcast.shape()); - // Checks that operand's dimensions are the same as the broadcast's - // dimensions along the dimensions to be broadcasted. - for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { - TF_RET_CHECK(broadcast->shape().dimensions(broadcast->dimensions(i)) == - operand_to_broadcast.shape().dimensions(i)); - } - - auto output = MakeUnique<Literal>(broadcast->shape()); - TF_RETURN_IF_ERROR(output->Populate<ReturnT>( - [&](tensorflow::gtl::ArraySlice<int64> multi_index) { - for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { - broadcast_indices[i] = multi_index[broadcast->dimensions(i)]; - } - return operand_to_broadcast.Get<ReturnT>(broadcast_indices); - })); - parent_->evaluated_[broadcast] = std::move(output); - return Status::OK(); - } - template < typename NativeT, typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr> |