/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_H_ #include #include #include #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_module.h" namespace xla { // An abstraction representing a ordered set of HLO module built to run // concurrently across different devices. class HloModuleGroup { public: // Construct an empty module group. explicit HloModuleGroup(absl::string_view name) : name_(name) {} // Construct a module group containing a single module. HloModuleGroup(absl::string_view name, std::unique_ptr module); // Construct a module group containing any number of modules. HloModuleGroup(absl::string_view name, absl::Span> modules); // Returns the modules contained in the group. const std::vector& modules() const { return module_ptrs_; } // Returns a module at a particular index. HloModule& module(int index) const { return *module_ptrs_.at(index); } // Add a module to the back of vector of modules in the group. void push_back(std::unique_ptr module); // Moves all modules from the group into the returned vector. After this // method runs, the module group will be empty. std::vector> ConsumeModules(); string name() const { return name_; } string ToString() const; // Serialize the module group to/from a proto. HloModuleGroupProto ToProto() const; static StatusOr CreateFromProto( const HloModuleGroupProto& proto, absl::Span module_configs); private: string name_; // Vector of modules as std::unique_ptrs. std::vector> modules_; // Vector of modules as normal pointers. This vector is kept in sync with // modules_ as modules are added to the group with push_back. std::vector module_ptrs_; }; std::ostream& operator<<(std::ostream& out, const HloModuleGroup& group); } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_H_