aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_evaluator.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_evaluator.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc232
1 files changed, 218 insertions, 14 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index deb7f28d84..51353eea6e 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/index_util.h"
#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
@@ -135,7 +136,6 @@ StatusOr<std::unique_ptr<Literal>> Compare<complex64>(
} // namespace
-
HloEvaluator::HloEvaluator(int64 max_loop_iterations)
: max_loop_iterations_(max_loop_iterations) {
typed_visitors_[PRED] = MakeUnique<HloEvaluatorTypedVisitor<bool>>(this);
@@ -330,6 +330,24 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseUnaryOp(
return result;
}
+StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateDotOp(
+ const DotDimensionNumbers& dim_numbers, const Literal& lhs,
+ const Literal& rhs) {
+ std::unique_ptr<HloInstruction> lhs_instr =
+ HloInstruction::CreateConstant(lhs.CloneToUnique());
+ std::unique_ptr<HloInstruction> rhs_instr =
+ HloInstruction::CreateConstant(rhs.CloneToUnique());
+
+ TF_ASSIGN_OR_RETURN(
+ Shape dot_shape,
+ ShapeInference::InferDotOpShape(lhs.shape(), rhs.shape(), dim_numbers));
+
+ std::unique_ptr<HloInstruction> cloned_instruction =
+ HloInstruction::CreateDot(dot_shape, lhs_instr.get(), rhs_instr.get(),
+ dim_numbers);
+ return Evaluate(cloned_instruction.get());
+}
+
Status HloEvaluator::HandleParameter(HloInstruction* parameter) {
CHECK_LT(parameter->parameter_number(), arg_literals_.size());
const Literal* input_literal = arg_literals_[parameter->parameter_number()];
@@ -382,7 +400,7 @@ Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) {
ShapeUtil::GetDimension(operand_shape, concat_dim);
}
- auto result_literal = Literal::CreateFromDimensions(
+ auto result_literal = LiteralUtil::CreateFromDimensions(
reference_shape.element_type(), concat_dimensions);
DimensionVector source_indices(rank, 0);
DimensionVector dest_indices(concat_dimensions.size(), 0);
@@ -533,7 +551,7 @@ Status HloEvaluator::HandleTuple(HloInstruction* tuple) {
operand_literals.push_back(&GetEvaluatedLiteralFor(operand));
}
- evaluated_[tuple] = Literal::MakeTuple(operand_literals);
+ evaluated_[tuple] = LiteralUtil::MakeTuple(operand_literals);
return Status::OK();
}
@@ -757,6 +775,12 @@ class OutputWindowIndexToInputIndex {
return ArraySlice<int64>(input_index_);
}
+ // Returns for a given 'input_dim' the corresponding output dimension index,
+ // or -1 if 'input_dim' is an elided window dimension.
+ int64 input_dim_value_to_output_index(int64 input_dim) {
+ return input_dim_value_to_output_index_[input_dim];
+ }
+
private:
// Propagates window dimensions from the output index to input_index_ by
// mutating input_index_ in place.
@@ -774,7 +798,7 @@ class OutputWindowIndexToInputIndex {
// input_dim_value_to_index_vector_[i] tells us how to compute dimension i of
// the input index from the output index. See
- // PropagateOutputIndexToInputIndex.
+ // PropagateOutputIndexWindowDimsToInputIndex.
std::vector<int64> input_dim_value_to_output_index_;
// The result computed by this functor. operator() returns an ArraySlice into
@@ -827,6 +851,8 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) {
// corresponding index in the input shape.
std::vector<int64> input_index(operand.shape().dimensions_size());
std::vector<int64> output_index(gather->shape().dimensions_size());
+ std::vector<int64> input_gather_index_clamped(
+ operand.shape().dimensions_size());
OutputGatherIndexToInputIndex output_gather_index_to_input_index(
&gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(),
@@ -848,14 +874,26 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) {
output_index[i] = output_gather_index[i] + output_window_index[i];
DCHECK_LT(output_index[i], shape.dimensions(i));
}
+ for (int i = 0, e = input_gather_index.size(); i < e; i++) {
+ int64 output_dim =
+ output_window_index_to_input_index.input_dim_value_to_output_index(i);
+ // If 'output_dim' is -1, it means 'i' is an elided window dim. This means
+ // we set the iteration index to 0, so for the purpose of the following
+ // calculations we can consider the output dimension size to be 1.
+ int64 output_dim_size =
+ output_dim == -1 ? 1 : shape.dimensions(output_dim);
+ // Clamp the gather index so that the gather region fits in the operand.
+ // input_gather_index_clamped[i] = clamp(input_gather_index[i], 0,
+ // operand_shape.dimensions(i) -
+ // output_dim_size);
+ input_gather_index_clamped[i] =
+ std::min(operand_shape.dimensions(i) - output_dim_size,
+ std::max(0LL, input_gather_index[i]));
+ }
for (int i = 0, e = input_index.size(); i < e; i++) {
- // TODO(b/74360564): We should implement whatever out of bounds behavior
- // we decide for dynamic-slice here as well.
- input_index[i] = (input_gather_index[i] + input_window_index[i]) %
- operand_shape.dimensions(i);
- if (input_index[i] < 0) {
- input_index[i] += operand_shape.dimensions(i);
- }
+ input_index[i] = input_gather_index_clamped[i] + input_window_index[i];
+ DCHECK_GE(input_index[i], 0);
+ DCHECK_LT(input_index[i], operand_shape.dimensions(i));
}
TF_RETURN_IF_ERROR(
result->CopyElementFrom(operand, input_index, output_index));
@@ -903,7 +941,7 @@ Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) {
}
Status HloEvaluator::HandleAfterAll(HloInstruction* token) {
- evaluated_[token] = Literal::CreateToken();
+ evaluated_[token] = LiteralUtil::CreateToken();
return Status::OK();
}
@@ -1024,8 +1062,6 @@ Status HloEvaluator::HandleSelect(HloInstruction* select) {
const auto& on_false = GetEvaluatedLiteralFor(select->operand(2));
// If predicate is of scalar type, no element-wise selection would be needed.
- // This would also handle output array of tuple types as the DefaultAction
- // would go through the HloEvaluatorTypedVisitor which doesn't handle tuples.
if (ShapeUtil::IsScalar(pred.shape())) {
if (pred.Get<bool>({})) {
evaluated_[select] = on_true.CloneToUnique();
@@ -1038,6 +1074,19 @@ Status HloEvaluator::HandleSelect(HloInstruction* select) {
return DefaultAction(select);
}
+Status HloEvaluator::HandleTupleSelect(HloInstruction* tuple_select) {
+ const auto& pred = GetEvaluatedLiteralFor(tuple_select->operand(0));
+ const auto& on_true = GetEvaluatedLiteralFor(tuple_select->operand(1));
+ const auto& on_false = GetEvaluatedLiteralFor(tuple_select->operand(2));
+
+ if (pred.Get<bool>({})) {
+ evaluated_[tuple_select] = on_true.CloneToUnique();
+ } else {
+ evaluated_[tuple_select] = on_false.CloneToUnique();
+ }
+ return Status::OK();
+}
+
Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) {
HloComputation* cond_comp = while_hlo->while_condition();
HloComputation* body_comp = while_hlo->while_body();
@@ -1068,6 +1117,161 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) {
return Status::OK();
}
+// Key-value sort is a special snowflake: it's templated on two different
+// element types, one for the keys, and one for the values. Jump through some
+// hoops to make this work.
+namespace {
+template <typename KeyType, typename ValueType>
+StatusOr<std::unique_ptr<Literal>> EvaluateSortInternal(
+ HloInstruction* sort, const Literal& keys_literal,
+ const Literal& values_literal) {
+ auto rank = ShapeUtil::Rank(keys_literal.shape());
+ TF_RET_CHECK(
+ ShapeUtil::SameDimensions(keys_literal.shape(), values_literal.shape()))
+ << "Sort keys and values must have the same dimensions";
+ TF_RET_CHECK(rank > 0 && rank <= 2)
+ << "Sort is only supported for rank-1 and rank-2 shapes, rank is: "
+ << rank;
+ TF_RET_CHECK(sort->operand_count() == 2) << "Expected key-value sort";
+ // We need to sort and array of keys and an array of values, where the
+ // sorted order of the values is determined by the keys. The simplest(?)
+ // way to do this is to go to an array-of-pairs representation, sort the
+ // array using the keys, and then go back to pair-of-arrays.
+ VLOG(3) << "HandleSort keys_literal: " << keys_literal.ToString();
+ VLOG(3) << "HandleSort values_literal: " << values_literal.ToString();
+
+ auto sort_r1 = [](const Literal& keys_literal,
+ const Literal& values_literal) {
+ const auto& keys_data = keys_literal.data<KeyType>();
+ const auto& values_data = values_literal.data<ValueType>();
+
+ using kv_pair = std::pair<KeyType, ValueType>;
+ std::vector<kv_pair> key_value_vector;
+ CHECK_EQ(keys_data.size(), values_data.size());
+ key_value_vector.reserve(keys_data.size());
+ for (int i = 0; i < keys_data.size(); ++i) {
+ key_value_vector.push_back(std::make_pair(keys_data[i], values_data[i]));
+ }
+ std::sort(key_value_vector.begin(), key_value_vector.end(),
+ [](const kv_pair& a, const kv_pair& b) {
+ return SafeLess<KeyType>(a.first, b.first);
+ });
+ std::vector<KeyType> result_keys;
+ std::vector<ValueType> result_values;
+ for (const auto& key_value : key_value_vector) {
+ result_keys.push_back(key_value.first);
+ result_values.push_back(key_value.second);
+ }
+ auto result_keys_literal = MakeUnique<Literal>(keys_literal.shape());
+ result_keys_literal->PopulateR1(
+ tensorflow::gtl::ArraySlice<KeyType>(result_keys));
+ auto result_values_literal = MakeUnique<Literal>(values_literal.shape());
+ result_values_literal->PopulateR1(
+ tensorflow::gtl::ArraySlice<ValueType>(result_values));
+ return std::make_pair(std::move(result_keys_literal),
+ std::move(result_values_literal));
+ };
+
+ std::unique_ptr<Literal> result_tuple;
+ if (rank == 1) {
+ auto result_pair = sort_r1(keys_literal, values_literal);
+ result_tuple = LiteralUtil::MakeTuple(
+ {result_pair.first.get(), result_pair.second.get()});
+ } else {
+ // For R2 sort, the desired semantics are to sort each matrix row
+ // independently.
+ auto keys_result_literal = MakeUnique<Literal>(keys_literal.shape());
+ auto values_result_literal = MakeUnique<Literal>(values_literal.shape());
+ int64 r1_length = keys_literal.shape().dimensions(1);
+ for (int64 row = 0; row < keys_literal.shape().dimensions(0); ++row) {
+ TF_ASSIGN_OR_RETURN(auto keys_r1_slice,
+ keys_literal.Slice({row, 0}, {row + 1, r1_length})
+ ->Reshape({r1_length}));
+ TF_ASSIGN_OR_RETURN(auto values_r1_slice,
+ values_literal.Slice({row, 0}, {row + 1, r1_length})
+ ->Reshape({r1_length}));
+ auto r1_result_pair = sort_r1(*keys_r1_slice, *values_r1_slice);
+ TF_ASSIGN_OR_RETURN(auto sorted_keys,
+ r1_result_pair.first->Reshape({1, r1_length}));
+ TF_ASSIGN_OR_RETURN(auto sorted_values,
+ r1_result_pair.second->Reshape({1, r1_length}));
+ TF_RETURN_IF_ERROR(keys_result_literal->CopySliceFrom(
+ *sorted_keys, {0, 0}, {row, 0}, {1, r1_length}));
+ TF_RETURN_IF_ERROR(values_result_literal->CopySliceFrom(
+ *sorted_values, {0, 0}, {row, 0}, {1, r1_length}));
+ }
+ result_tuple = LiteralUtil::MakeTuple(
+ {keys_result_literal.get(), values_result_literal.get()});
+ }
+
+ VLOG(3) << "HandleSort result_tuple: " << result_tuple->ToString();
+ return std::move(result_tuple);
+}
+
+template <typename KeyType>
+StatusOr<std::unique_ptr<Literal>> EvaluateSortCurried(
+ HloInstruction* sort, const Literal& keys_literal,
+ const Literal& values_literal) {
+ switch (sort->operand(1)->shape().element_type()) {
+ case F32:
+ return EvaluateSortInternal<KeyType, float>(sort, keys_literal,
+ values_literal);
+ case U32:
+ return EvaluateSortInternal<KeyType, uint32>(sort, keys_literal,
+ values_literal);
+ case S32:
+ return EvaluateSortInternal<KeyType, int32>(sort, keys_literal,
+ values_literal);
+ case BF16:
+ return EvaluateSortInternal<KeyType, bfloat16>(sort, keys_literal,
+ values_literal);
+ default:
+ return InvalidArgument("Unsupported type for Sort");
+ }
+}
+
+StatusOr<std::unique_ptr<Literal>> EvaluateSort(HloInstruction* sort,
+ const Literal& keys_literal,
+ const Literal& values_literal) {
+ switch (sort->operand(0)->shape().element_type()) {
+ case F32:
+ return EvaluateSortCurried<float>(sort, keys_literal, values_literal);
+ case U32:
+ return EvaluateSortCurried<uint32>(sort, keys_literal, values_literal);
+ case S32:
+ return EvaluateSortCurried<int32>(sort, keys_literal, values_literal);
+ case BF16:
+ return EvaluateSortCurried<bfloat16>(sort, keys_literal, values_literal);
+ default:
+ return InvalidArgument("Unsupported type for Sort");
+ }
+}
+} // namespace
+
+Status HloEvaluator::HandleSort(HloInstruction* sort) {
+ const int64 sort_dim = sort->dimensions(0);
+ const int64 rank = ShapeUtil::Rank(sort->operand(0)->shape());
+ if (sort_dim != rank - 1) {
+ return Unimplemented(
+ "Trying to support along dimension %lld, which is not the last "
+ "dimension",
+ sort_dim);
+ }
+
+ if (!ShapeUtil::IsTuple(sort->shape())) {
+ return DefaultAction(sort);
+ } else {
+ auto result = EvaluateSort(sort, GetEvaluatedLiteralFor(sort->operand(0)),
+ GetEvaluatedLiteralFor(sort->operand(1)));
+ if (result.ok()) {
+ evaluated_[sort] = std::move(result.ValueOrDie());
+ return Status::OK();
+ } else {
+ return result.status();
+ }
+ }
+}
+
Status HloEvaluator::Preprocess(HloInstruction* hlo) {
VLOG(2) << "About to visit HLO: " << hlo->ToString();
return Status::OK();