aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_module_config.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_module_config.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_config.h92
1 files changed, 92 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h
new file mode 100644
index 0000000000..f081790869
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_module_config.h
@@ -0,0 +1,92 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_CONFIG_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_CONFIG_H_
+
+#include <string>
+
+#include "tensorflow/compiler/xla/service/computation_layout.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+namespace xla {
+
+// This class gathers all settings and values which affect the compiled
+// executable outside of the HLO code itself. This include layouts of inputs and
+// outputs to the module and settings such as HLO profiling. Together the
+// HloModule and HloModuleConfig unambiguously determine a particular
+// executable.
+class HloModuleConfig {
+ public:
+ explicit HloModuleConfig(const ProgramShape& program_shape);
+
+ // Return a reference to the layout of the entry computation.
+ const ComputationLayout& entry_computation_layout() const {
+ return entry_computation_layout_;
+ }
+ ComputationLayout* mutable_entry_computation_layout() {
+ return &entry_computation_layout_;
+ }
+
+ // Sets/returns whether to enable HLO-level profiling.
+ bool hlo_profiling_enabled() const { return hlo_profiling_enabled_; }
+ void enable_hlo_profiling(bool enabled) { hlo_profiling_enabled_ = enabled; }
+
+ bool has_hybrid_result() const { return has_hybrid_result_; }
+ void set_has_hybrid_result(bool has_hybrid_result) {
+ has_hybrid_result_ = has_hybrid_result;
+ }
+
+ // Sets/returns the module seed set during execution.
+ void set_seed(uint64 seed) { seed_ = seed; }
+ uint64 seed() const { return seed_; }
+
+ void set_replica_count(int64 replica_count) {
+ replica_count_ = replica_count;
+ }
+ int64 replica_count() const { return replica_count_; }
+
+ // Return a string which unambiguously represents all the fields of this data
+ // structure. Used for generating a cache key for storing the compiled
+ // executable.
+ string compilation_cache_key() const;
+
+ private:
+ ComputationLayout entry_computation_layout_;
+
+ // Whether to enable HLO-level profiling.
+ bool hlo_profiling_enabled_ = false;
+
+ // If this flag is true, the generated executable will return a ShapedBuffer
+ // holding the result of the computation. In a ShapedBuffer, tuples have their
+ // structure held in host memory and the element arrays (leaves of the tuple
+ // structure) stored in device memory. The ShapedBuffer is considered "hybrid"
+ // because its leaves are on device but its structure is stored on
+ // host. Otherwise, if this flag is false, the generated executable will
+ // return a DeviceMemoryBase where the result is held entirely in device
+ // memory.
+ bool has_hybrid_result_ = false;
+
+ // Module/graph-level seed handle.
+ uint64 seed_ = 0;
+
+ // The number of replicas to compile this binary for.
+ int64 replica_count_ = 1;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_CONFIG_H_