diff options
author | 2018-08-28 13:56:25 -0700 | |
---|---|---|
committer | 2018-08-28 14:01:14 -0700 | |
commit | 96de2f020130fc0afbcc5087b520aa46ced5b5af (patch) | |
tree | 759a2b4d58448a1cd816a3b84979dc6d40d36467 /tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h | |
parent | 9c1f14322484e44a93b77619ffd2e24b9b7a9b1d (diff) |
[XLA] Implement kIota for CPU & GPU, extend it w/ broadcast semantics
This extends the Iota HLO to have a broadcast field. This allows for higher
rank kIota operations.
PiperOrigin-RevId: 210600435
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h | 20 |
1 files changed, 16 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index b6566ebefe..f682e69ee9 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -21,7 +21,9 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/core/lib/core/casts.h" @@ -2493,11 +2495,21 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::is_same<NativeT, float>::value || std::is_same<NativeT, int32>::value || std::is_same<NativeT, uint32>::value>::type* = nullptr> - Status HandleIota(HloInstruction* iota) { - auto result = absl::make_unique<Literal>(iota->shape()); - auto data = result->data<ReturnT>(); + Status HandleIota(HloInstruction* instruction) { + auto* iota = Cast<HloIotaInstruction>(instruction); + std::vector<NativeT> data(iota->shape().dimensions(iota->iota_dimension())); std::iota(data.begin(), data.end(), 0); - parent_->evaluated_[iota] = std::move(result); + auto result = LiteralUtil::CreateR1<NativeT>(data); + + if (ShapeUtil::Rank(iota->shape()) > 1) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[iota], + result->Broadcast(iota->shape(), {iota->iota_dimension()})); + } else { + TF_RET_CHECK(ShapeUtil::Rank(iota->shape()) == 1); + parent_->evaluated_[iota] = std::move(result); + } + return Status::OK(); } template <typename NativeT, |