aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.cc111
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.h36
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis_test.cc12
3 files changed, 88 insertions, 71 deletions
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
index 5d870f9fc4..21af9a615c 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
@@ -33,8 +33,6 @@ using tensorflow::gtl::ArraySlice;
using tensorflow::str_util::Join;
} // namespace
-// TODO(sanjoy): Make this pass StatusOr safe.
-
string IndexedArrayAnalysis::ToString(Array* root, bool print_constants) {
switch (root->kind()) {
case Array::kUnknown: {
@@ -69,18 +67,18 @@ string IndexedArrayAnalysis::ToString(Array* root, bool print_constants) {
}
}
-Analysis::Array* IndexedArrayAnalysis::GetArrayFor(
+StatusOr<Analysis::Array*> IndexedArrayAnalysis::GetArrayFor(
const HloInstruction* instr) {
auto it = cache_.find(instr);
if (it != cache_.end()) {
return it->second;
}
- TraverseAndPopulateCache(instr);
+ TF_RETURN_IF_ERROR(TraverseAndPopulateCache(instr));
return FindOrDie(cache_, instr);
}
-void IndexedArrayAnalysis::TraverseAndPopulateCache(
+Status IndexedArrayAnalysis::TraverseAndPopulateCache(
const HloInstruction* root) {
// Depth first search over the DAG, invoking ComputeArrayFor in post order.
// The HLO instructions already in the cache are considered leaves.
@@ -116,32 +114,42 @@ void IndexedArrayAnalysis::TraverseAndPopulateCache(
case kVisited:
stack.pop_back();
- InsertOrDie(&cache_, instr, ComputeArrayFor(instr));
+ TF_ASSIGN_OR_RETURN(Array * array, ComputeArrayFor(instr));
+ InsertOrDie(&cache_, instr, array);
break;
}
} while (!stack.empty());
+
+ return Status::OK();
}
-Analysis::Array* IndexedArrayAnalysis::ComputeArrayFor(
+StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayFor(
const HloInstruction* instr) {
Array* computed_array;
if (instr->IsElementwise() && instr->operand_count() == 1) {
- computed_array = ComputeArrayForElementwiseUnaryOp(
- instr, FindOrDie(cache_, instr->operand(0)));
+ TF_ASSIGN_OR_RETURN(computed_array,
+ ComputeArrayForElementwiseUnaryOp(
+ instr, FindOrDie(cache_, instr->operand(0))));
} else if (instr->IsElementwise() && instr->operand_count() == 2) {
- computed_array = ComputeArrayForElementwiseBinaryOp(
- instr, FindOrDie(cache_, instr->operand(0)),
- FindOrDie(cache_, instr->operand(1)));
+ TF_ASSIGN_OR_RETURN(computed_array,
+ ComputeArrayForElementwiseBinaryOp(
+ instr, FindOrDie(cache_, instr->operand(0)),
+ FindOrDie(cache_, instr->operand(1))));
} else if (instr->opcode() == HloOpcode::kConstant) {
- computed_array = ComputeArrayForConstant(instr->literal());
+ TF_ASSIGN_OR_RETURN(computed_array,
+ ComputeArrayForConstant(instr->literal()));
} else if (instr->opcode() == HloOpcode::kGather) {
- computed_array = ComputeArrayForGather(
- instr->shape(), instr->gather_dimension_numbers(),
- instr->gather_window_bounds(), FindOrDie(cache_, instr->operand(0)),
- FindOrDie(cache_, instr->operand(1)));
+ TF_ASSIGN_OR_RETURN(
+ computed_array,
+ ComputeArrayForGather(instr->shape(), instr->gather_dimension_numbers(),
+ instr->gather_window_bounds(),
+ FindOrDie(cache_, instr->operand(0)),
+ FindOrDie(cache_, instr->operand(1))));
} else if (instr->opcode() == HloOpcode::kReshape) {
- computed_array = ComputeArrayForReshape(
- instr->shape(), FindOrDie(cache_, instr->operand(0)));
+ TF_ASSIGN_OR_RETURN(
+ computed_array,
+ ComputeArrayForReshape(instr->shape(),
+ FindOrDie(cache_, instr->operand(0))));
} else {
computed_array = nullptr;
}
@@ -153,12 +161,12 @@ Analysis::Array* IndexedArrayAnalysis::ComputeArrayFor(
return computed_array;
}
-Analysis::Array* IndexedArrayAnalysis::ComputeArrayForConstant(
+StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForConstant(
const Literal& literal) {
return Construct<ConstantArray>(&literal);
}
-ScalarIndexedArray* IndexedArrayAnalysis::FoldGatherOfGather(
+StatusOr<ScalarIndexedArray*> IndexedArrayAnalysis::FoldGatherOfGather(
ScalarIndexedArray* source, Array* indices, int64 source_dim,
tensorflow::gtl::ArraySlice<int64> output_dims, Shape shape) {
// We want to transform Gather(Gather(A, X), Y) => Gather(A, Gather(X, Y)).
@@ -224,7 +232,7 @@ ScalarIndexedArray* IndexedArrayAnalysis::FoldGatherOfGather(
std::move(shape));
}
-Analysis::Array* IndexedArrayAnalysis::ComputeArrayForGather(
+StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForGather(
const Shape& shape, const GatherDimensionNumbers& dim_numbers,
tensorflow::gtl::ArraySlice<int64> window_bounds, Array* source,
Array* indices) {
@@ -397,7 +405,7 @@ int64 FindSourcePositionForPassthroughResultDim(ArraySlice<int64> operand_shape,
}; // namespace
-Analysis::Array* IndexedArrayAnalysis::ComputeArrayForReshape(
+StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForReshape(
const Shape& shape, Array* operand) {
auto* scalar_indexed = dynamic_cast<ScalarIndexedConstantArray*>(operand);
if (!scalar_indexed) {
@@ -541,10 +549,12 @@ Analysis::Array* IndexedArrayAnalysis::ComputeArrayForReshape(
std::back_inserter(output_dims_for_new_scalar_indexed_node),
map_passthrough_operand_dim_to_result_dim);
- Array* new_scalar_indexed_source = ComputeArrayForConstant(
- *TakeOwnership(scalar_indexed->literal()
- .Reshape(new_scalar_indexed_source_shape)
- .ValueOrDie()));
+ TF_ASSIGN_OR_RETURN(const Literal* new_scalar_indexed_source_literal,
+ TakeOwnership(scalar_indexed->literal().Reshape(
+ new_scalar_indexed_source_shape)));
+ TF_ASSIGN_OR_RETURN(
+ Array * new_scalar_indexed_source,
+ ComputeArrayForConstant(*new_scalar_indexed_source_literal));
return ConstructScalarIndexedArray(
new_scalar_indexed_source, scalar_indexed->indices(),
@@ -552,7 +562,8 @@ Analysis::Array* IndexedArrayAnalysis::ComputeArrayForReshape(
output_dims_for_new_scalar_indexed_node, shape);
}
-Analysis::Array* IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(
+StatusOr<Analysis::Array*>
+IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(
const HloInstruction* instr, Array* lhs, Array* rhs) {
// Try to fold BinaryOp(Broadcast(Const0), ScalarIndexed(Const1, Indices))
// => ScalarIndexed(BinaryOp(Broadcast'(Const0), Const1), Indices)
@@ -642,28 +653,25 @@ Analysis::Array* IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(
// inner_broadcast_result is the Broadcast'(Const0) bit in
// BinaryOp(Broadcast'(Const0), Const1)
- std::unique_ptr<Literal> inner_broadcast_result =
- broadcast_const_operand->literal()
- .Broadcast(scalar_indexed_const->source()->shape(),
- new_inner_broadcast_dims)
- .ConsumeValueOrDie();
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<Literal> inner_broadcast_result,
+ broadcast_const_operand->literal().Broadcast(
+ scalar_indexed_const->source()->shape(), new_inner_broadcast_dims));
// literal_for_new_source is BinaryOp(Broadcast'(Const0), Const1)
const Literal* literal_for_new_source;
if (lhs_is_indexed) {
- literal_for_new_source =
- TakeOwnership(HloEvaluator{}
- .EvaluateElementwiseBinaryOp(
- instr->opcode(), scalar_indexed_const->literal(),
- *inner_broadcast_result)
- .ConsumeValueOrDie());
+ TF_ASSIGN_OR_RETURN(
+ literal_for_new_source,
+ TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp(
+ instr->opcode(), scalar_indexed_const->literal(),
+ *inner_broadcast_result)));
} else {
- literal_for_new_source =
- TakeOwnership(HloEvaluator{}
- .EvaluateElementwiseBinaryOp(
- instr->opcode(), *inner_broadcast_result,
- scalar_indexed_const->literal())
- .ConsumeValueOrDie());
+ TF_ASSIGN_OR_RETURN(
+ literal_for_new_source,
+ TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp(
+ instr->opcode(), *inner_broadcast_result,
+ scalar_indexed_const->literal())));
}
ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
@@ -675,7 +683,8 @@ Analysis::Array* IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(
scalar_indexed_const->shape());
}
-Analysis::Array* IndexedArrayAnalysis::ComputeArrayForElementwiseUnaryOp(
+StatusOr<Analysis::Array*>
+IndexedArrayAnalysis::ComputeArrayForElementwiseUnaryOp(
const HloInstruction* instr, Array* operand) {
auto* scalar_indexed_const =
dynamic_cast<ScalarIndexedConstantArray*>(operand);
@@ -686,11 +695,9 @@ Analysis::Array* IndexedArrayAnalysis::ComputeArrayForElementwiseUnaryOp(
// Fold UnaryOp(ScalarIndexed(Const, Indices))
// => ScalarIndexed(UnaryOp(Const), Indices)
- Literal* literal_for_new_source =
- TakeOwnership(HloEvaluator{}
- .EvaluateElementwiseUnaryOp(
- instr->opcode(), scalar_indexed_const->literal())
- .ConsumeValueOrDie());
+ TF_ASSIGN_OR_RETURN(Literal * literal_for_new_source,
+ TakeOwnership(HloEvaluator{}.EvaluateElementwiseUnaryOp(
+ instr->opcode(), scalar_indexed_const->literal())));
ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
return Construct<ScalarIndexedConstantArray>(
new_source, scalar_indexed_const->indices(),
@@ -712,7 +719,7 @@ StatusOr<bool> IndexedArrayAnalysisPrinterPass::Run(HloModule* module) {
IndexedArrayAnalysis analysis;
for (auto* computation : module->MakeNonfusionComputations()) {
for (auto* instr : computation->instructions()) {
- auto* t = analysis.GetArrayFor(instr);
+ TF_ASSIGN_OR_RETURN(Analysis::Array * t, analysis.GetArrayFor(instr));
if (!dynamic_cast<UnknownArray*>(t) && !dynamic_cast<ConstantArray*>(t)) {
VLOG(2) << instr->ToString() << " -> " << analysis.ToString(t);
}
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h
index 8c1f616fab..561832ab59 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.h
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h
@@ -220,7 +220,7 @@ class IndexedArrayAnalysis {
// NB! By inspecting the implementation, you may be able to infer a stronger
// caching guarantee than what is mentioned above. Nevertheless, what is
// stated above is the contract.
- Array* GetArrayFor(const HloInstruction* instr);
+ StatusOr<Array*> GetArrayFor(const HloInstruction* instr);
// Pretty-prints the expression rooted at `root`.
string ToString(Array* root, bool print_constants = false);
@@ -228,18 +228,18 @@ class IndexedArrayAnalysis {
private:
// Helper function that ensures that every HLO instruction that is
// transitively used by `root` has an entry in `cache_`.
- void TraverseAndPopulateCache(const HloInstruction* root);
+ Status TraverseAndPopulateCache(const HloInstruction* root);
// Creates an Array instance for `instr` under the assumption that all
// operations of `instr` are present in `cache_`.
- Array* ComputeArrayFor(const HloInstruction* instr);
+ StatusOr<Array*> ComputeArrayFor(const HloInstruction* instr);
- Array* ComputeArrayForConstant(const Literal& literal);
+ StatusOr<Array*> ComputeArrayForConstant(const Literal& literal);
- Array* ComputeArrayForGather(const Shape& shape,
- const GatherDimensionNumbers& dim_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds,
- Array* source, Array* indices);
+ StatusOr<Array*> ComputeArrayForGather(
+ const Shape& shape, const GatherDimensionNumbers& dim_numbers,
+ tensorflow::gtl::ArraySlice<int64> window_bounds, Array* source,
+ Array* indices);
// This tries to fold a ScalarIndexedArray which has another
// ScalarIndexedArray as a source into a ScalarIndexedArray that instead has a
@@ -262,16 +262,16 @@ class IndexedArrayAnalysis {
//
// I2 = [I0[i] for i in I1]
// G1 = [Arr[i] for i in I2]
- ScalarIndexedArray* FoldGatherOfGather(
+ StatusOr<ScalarIndexedArray*> FoldGatherOfGather(
ScalarIndexedArray* source, Array* indices, int64 source_dim,
tensorflow::gtl::ArraySlice<int64> output_dims, Shape shape);
- Array* ComputeArrayForReshape(const Shape& shape, Array* operand);
+ StatusOr<Array*> ComputeArrayForReshape(const Shape& shape, Array* operand);
- Array* ComputeArrayForElementwiseBinaryOp(const HloInstruction* instr,
- Array* lhs, Array* rhs);
- Array* ComputeArrayForElementwiseUnaryOp(const HloInstruction* instr,
- Array* operand);
+ StatusOr<Array*> ComputeArrayForElementwiseBinaryOp(
+ const HloInstruction* instr, Array* lhs, Array* rhs);
+ StatusOr<Array*> ComputeArrayForElementwiseUnaryOp(
+ const HloInstruction* instr, Array* operand);
template <typename T, typename... Args>
T* Construct(Args&&... args) {
@@ -299,6 +299,14 @@ class IndexedArrayAnalysis {
return owned_literals_.back().get();
}
+ StatusOr<Literal*> TakeOwnership(
+ StatusOr<std::unique_ptr<Literal>> literal_or_error) {
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
+ std::move(literal_or_error));
+ owned_literals_.push_back(std::move(literal));
+ return owned_literals_.back().get();
+ }
+
std::vector<std::unique_ptr<Array>> owned_tensors_;
std::vector<std::unique_ptr<Literal>> owned_literals_;
tensorflow::gtl::FlatMap<const HloInstruction*, Array*> cache_;
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
index 76e7e7086c..68f247bfc3 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
@@ -40,12 +40,14 @@ class IndexedArrayAnalysisTest : public HloVerifiedTestBase {
IndexedArrayAnalysis indexed_tensor_analysis;
ParseAndVerifyModule(hlo_text);
- string result = indexed_tensor_analysis.ToString(
+ TF_ASSERT_OK_AND_ASSIGN(
+ IndexedArrayAnalysis::Array* const array_result,
indexed_tensor_analysis.GetArrayFor(
- module().entry_computation()->root_instruction()),
- print_constants);
- LOG(INFO) << result;
- ASSERT_EQ(result, root_expression);
+ module().entry_computation()->root_instruction()));
+ string string_result =
+ indexed_tensor_analysis.ToString(array_result, print_constants);
+ LOG(INFO) << string_result;
+ ASSERT_EQ(string_result, root_expression);
}
};