diff options
author | Sanjoy Das <sanjoy@google.com> | 2018-05-28 22:16:46 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-28 22:19:19 -0700 |
commit | e62f3e5ff68aad1ddef2b581b98a90125e740ddd (patch) | |
tree | 607fda609038bd9b064784f97988e5f2fa3c232d /tensorflow/compiler/xla/service/indexed_array_analysis.h | |
parent | b05a6b5c4cb685b19b8c09693d40d4743af79dea (diff) |
Make IndexedArrayAnalysis behave well around StatusOr
PiperOrigin-RevId: 198348355
Diffstat (limited to 'tensorflow/compiler/xla/service/indexed_array_analysis.h')
-rw-r--r-- | tensorflow/compiler/xla/service/indexed_array_analysis.h | 36 |
1 files changed, 22 insertions, 14 deletions
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h index 8c1f616fab..561832ab59 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.h +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h @@ -220,7 +220,7 @@ class IndexedArrayAnalysis { // NB! By inspecting the implementation, you may be able to infer a stronger // caching guarantee than what is mentioned above. Nevertheless, what is // stated above is the contract. - Array* GetArrayFor(const HloInstruction* instr); + StatusOr<Array*> GetArrayFor(const HloInstruction* instr); // Pretty-prints the expression rooted at `root`. string ToString(Array* root, bool print_constants = false); @@ -228,18 +228,18 @@ class IndexedArrayAnalysis { private: // Helper function that ensures that every HLO instruction that is // transitively used by `root` has an entry in `cache_`. - void TraverseAndPopulateCache(const HloInstruction* root); + Status TraverseAndPopulateCache(const HloInstruction* root); // Creates an Array instance for `instr` under the assumption that all // operations of `instr` are present in `cache_`. - Array* ComputeArrayFor(const HloInstruction* instr); + StatusOr<Array*> ComputeArrayFor(const HloInstruction* instr); - Array* ComputeArrayForConstant(const Literal& literal); + StatusOr<Array*> ComputeArrayForConstant(const Literal& literal); - Array* ComputeArrayForGather(const Shape& shape, - const GatherDimensionNumbers& dim_numbers, - tensorflow::gtl::ArraySlice<int64> window_bounds, - Array* source, Array* indices); + StatusOr<Array*> ComputeArrayForGather( + const Shape& shape, const GatherDimensionNumbers& dim_numbers, + tensorflow::gtl::ArraySlice<int64> window_bounds, Array* source, + Array* indices); // This tries to fold a ScalarIndexedArray which has another // ScalarIndexedArray as a source into a ScalarIndexedArray that instead has a @@ -262,16 +262,16 @@ class IndexedArrayAnalysis { // // I2 = [I0[i] for i in I1] // G1 = [Arr[i] for i in I2] - ScalarIndexedArray* FoldGatherOfGather( + StatusOr<ScalarIndexedArray*> FoldGatherOfGather( ScalarIndexedArray* source, Array* indices, int64 source_dim, tensorflow::gtl::ArraySlice<int64> output_dims, Shape shape); - Array* ComputeArrayForReshape(const Shape& shape, Array* operand); + StatusOr<Array*> ComputeArrayForReshape(const Shape& shape, Array* operand); - Array* ComputeArrayForElementwiseBinaryOp(const HloInstruction* instr, - Array* lhs, Array* rhs); - Array* ComputeArrayForElementwiseUnaryOp(const HloInstruction* instr, - Array* operand); + StatusOr<Array*> ComputeArrayForElementwiseBinaryOp( + const HloInstruction* instr, Array* lhs, Array* rhs); + StatusOr<Array*> ComputeArrayForElementwiseUnaryOp( + const HloInstruction* instr, Array* operand); template <typename T, typename... Args> T* Construct(Args&&... args) { @@ -299,6 +299,14 @@ class IndexedArrayAnalysis { return owned_literals_.back().get(); } + StatusOr<Literal*> TakeOwnership( + StatusOr<std::unique_ptr<Literal>> literal_or_error) { + TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal, + std::move(literal_or_error)); + owned_literals_.push_back(std::move(literal)); + return owned_literals_.back().get(); + } + std::vector<std::unique_ptr<Array>> owned_tensors_; std::vector<std::unique_ptr<Literal>> owned_literals_; tensorflow::gtl::FlatMap<const HloInstruction*, Array*> cache_; |