aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
diff options
context:
space:
mode:
authorGravatar David Majnemer <majnemer@google.com>2018-08-28 13:56:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-28 14:01:14 -0700
commit96de2f020130fc0afbcc5087b520aa46ced5b5af (patch)
tree759a2b4d58448a1cd816a3b84979dc6d40d36467 /tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
parent9c1f14322484e44a93b77619ffd2e24b9b7a9b1d (diff)
[XLA] Implement kIota for CPU & GPU, extend it w/ broadcast semantics
This extends the Iota HLO to have a broadcast field. This allows for higher rank kIota operations. PiperOrigin-RevId: 210600435
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h20
1 files changed, 16 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index b6566ebefe..f682e69ee9 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -21,7 +21,9 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/core/lib/core/casts.h"
@@ -2493,11 +2495,21 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
std::is_same<NativeT, float>::value ||
std::is_same<NativeT, int32>::value ||
std::is_same<NativeT, uint32>::value>::type* = nullptr>
- Status HandleIota(HloInstruction* iota) {
- auto result = absl::make_unique<Literal>(iota->shape());
- auto data = result->data<ReturnT>();
+ Status HandleIota(HloInstruction* instruction) {
+ auto* iota = Cast<HloIotaInstruction>(instruction);
+ std::vector<NativeT> data(iota->shape().dimensions(iota->iota_dimension()));
std::iota(data.begin(), data.end(), 0);
- parent_->evaluated_[iota] = std::move(result);
+ auto result = LiteralUtil::CreateR1<NativeT>(data);
+
+ if (ShapeUtil::Rank(iota->shape()) > 1) {
+ TF_ASSIGN_OR_RETURN(
+ parent_->evaluated_[iota],
+ result->Broadcast(iota->shape(), {iota->iota_dimension()}));
+ } else {
+ TF_RET_CHECK(ShapeUtil::Rank(iota->shape()) == 1);
+ parent_->evaluated_[iota] = std::move(result);
+ }
+
return Status::OK();
}
template <typename NativeT,