aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/indexed_array_analysis.h
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-05-28 22:16:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-28 22:19:19 -0700
commite62f3e5ff68aad1ddef2b581b98a90125e740ddd (patch)
tree607fda609038bd9b064784f97988e5f2fa3c232d /tensorflow/compiler/xla/service/indexed_array_analysis.h
parentb05a6b5c4cb685b19b8c09693d40d4743af79dea (diff)
Make IndexedArrayAnalysis behave well around StatusOr
PiperOrigin-RevId: 198348355
Diffstat (limited to 'tensorflow/compiler/xla/service/indexed_array_analysis.h')
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.h36
1 files changed, 22 insertions, 14 deletions
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h
index 8c1f616fab..561832ab59 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.h
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h
@@ -220,7 +220,7 @@ class IndexedArrayAnalysis {
// NB! By inspecting the implementation, you may be able to infer a stronger
// caching guarantee than what is mentioned above. Nevertheless, what is
// stated above is the contract.
- Array* GetArrayFor(const HloInstruction* instr);
+ StatusOr<Array*> GetArrayFor(const HloInstruction* instr);
// Pretty-prints the expression rooted at `root`.
string ToString(Array* root, bool print_constants = false);
@@ -228,18 +228,18 @@ class IndexedArrayAnalysis {
private:
// Helper function that ensures that every HLO instruction that is
// transitively used by `root` has an entry in `cache_`.
- void TraverseAndPopulateCache(const HloInstruction* root);
+ Status TraverseAndPopulateCache(const HloInstruction* root);
// Creates an Array instance for `instr` under the assumption that all
// operations of `instr` are present in `cache_`.
- Array* ComputeArrayFor(const HloInstruction* instr);
+ StatusOr<Array*> ComputeArrayFor(const HloInstruction* instr);
- Array* ComputeArrayForConstant(const Literal& literal);
+ StatusOr<Array*> ComputeArrayForConstant(const Literal& literal);
- Array* ComputeArrayForGather(const Shape& shape,
- const GatherDimensionNumbers& dim_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds,
- Array* source, Array* indices);
+ StatusOr<Array*> ComputeArrayForGather(
+ const Shape& shape, const GatherDimensionNumbers& dim_numbers,
+ tensorflow::gtl::ArraySlice<int64> window_bounds, Array* source,
+ Array* indices);
// This tries to fold a ScalarIndexedArray which has another
// ScalarIndexedArray as a source into a ScalarIndexedArray that instead has a
@@ -262,16 +262,16 @@ class IndexedArrayAnalysis {
//
// I2 = [I0[i] for i in I1]
// G1 = [Arr[i] for i in I2]
- ScalarIndexedArray* FoldGatherOfGather(
+ StatusOr<ScalarIndexedArray*> FoldGatherOfGather(
ScalarIndexedArray* source, Array* indices, int64 source_dim,
tensorflow::gtl::ArraySlice<int64> output_dims, Shape shape);
- Array* ComputeArrayForReshape(const Shape& shape, Array* operand);
+ StatusOr<Array*> ComputeArrayForReshape(const Shape& shape, Array* operand);
- Array* ComputeArrayForElementwiseBinaryOp(const HloInstruction* instr,
- Array* lhs, Array* rhs);
- Array* ComputeArrayForElementwiseUnaryOp(const HloInstruction* instr,
- Array* operand);
+ StatusOr<Array*> ComputeArrayForElementwiseBinaryOp(
+ const HloInstruction* instr, Array* lhs, Array* rhs);
+ StatusOr<Array*> ComputeArrayForElementwiseUnaryOp(
+ const HloInstruction* instr, Array* operand);
template <typename T, typename... Args>
T* Construct(Args&&... args) {
@@ -299,6 +299,14 @@ class IndexedArrayAnalysis {
return owned_literals_.back().get();
}
+ StatusOr<Literal*> TakeOwnership(
+ StatusOr<std::unique_ptr<Literal>> literal_or_error) {
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
+ std::move(literal_or_error));
+ owned_literals_.push_back(std::move(literal));
+ return owned_literals_.back().get();
+ }
+
std::vector<std::unique_ptr<Array>> owned_tensors_;
std::vector<std::unique_ptr<Literal>> owned_literals_;
tensorflow::gtl::FlatMap<const HloInstruction*, Array*> cache_;