aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-05-25 17:46:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-25 17:48:50 -0700
commitb4423efd55c5e463dd70d6975aa3a9d0f260011b (patch)
tree44b03f53e867a1732224d2e65a37075b2ed742d5 /tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
parenta6eb244b2b8ee4d9592a705c4bc0771e4d708565 (diff)
Add a type-erased broadcast implementation to xla::Literal
And use this in HLO evaluator. Since broadcast only moves bytes around we don't need a type specialized implementation. I'll use this in a later change. PiperOrigin-RevId: 198128524
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h30
1 files changed, 0 insertions, 30 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index e37d651c95..82ee77e1ae 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -161,36 +161,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return HandleRound<ReturnT>(round);
}
- Status HandleBroadcast(HloInstruction* broadcast) override {
- const Literal& operand_to_broadcast =
- parent_->GetEvaluatedLiteralFor(broadcast->operand(0));
- std::vector<int64> broadcast_indices(
- ShapeUtil::Rank(broadcast->operand(0)->shape()), 0);
-
- TF_RET_CHECK(broadcast->dimensions().size() ==
- ShapeUtil::Rank(operand_to_broadcast.shape()))
- << "broadcast dimensions is of size: " << broadcast->dimensions().size()
- << " and rank of operand_to_broadcast is: "
- << ShapeUtil::Rank(operand_to_broadcast.shape());
- // Checks that operand's dimensions are the same as the broadcast's
- // dimensions along the dimensions to be broadcasted.
- for (int64 i = 0; i < broadcast->dimensions().size(); ++i) {
- TF_RET_CHECK(broadcast->shape().dimensions(broadcast->dimensions(i)) ==
- operand_to_broadcast.shape().dimensions(i));
- }
-
- auto output = MakeUnique<Literal>(broadcast->shape());
- TF_RETURN_IF_ERROR(output->Populate<ReturnT>(
- [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
- for (int64 i = 0; i < broadcast->dimensions().size(); ++i) {
- broadcast_indices[i] = multi_index[broadcast->dimensions(i)];
- }
- return operand_to_broadcast.Get<ReturnT>(broadcast_indices);
- }));
- parent_->evaluated_[broadcast] = std::move(output);
- return Status::OK();
- }
-
template <
typename NativeT,
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>