/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ syntax = "proto3"; import "tensorflow/compiler/xla/xla_data.proto"; import "tensorflow/compiler/xla/service/hlo.proto"; package xla; // Options for the HLO insert-reduce-precision-operations pass. message HloReducePrecisionOptions { // Where and when the reduce-precision operations will be added. enum Location { // Add reduce-precision operations to the inputs of selected instructions. // This is done before any optimization occurs. OP_INPUTS = 0; // Add reduce-precision operations to the outputs of selected instructions. // This is done before any optimization occurs. OP_OUTPUTS = 1; // After operation-fusion occurs, add reduce-precision operations to the // outputs of any selected instructions that have not been fused into // fusion instructions. UNFUSED_OP_OUTPUTS = 2; // After operation-fusion occurs, add reduce-precision operations to the // outputs of any fusion instructions that contain operations matching the // selection criteria. FUSION_INPUTS_BY_CONTENT = 3; // After operation-fusion occurs, add reduce-precision operations to the // outputs of any fusion instructions that contain operations matching the // selection criteria. FUSION_OUTPUTS_BY_CONTENT = 4; } Location location = 1; // Exponent and mantissa bit counts for the reduced precision. uint32 exponent_bits = 2; uint32 mantissa_bits = 3; // Operations matching these opcodes should be suffixed with reduce-precision // operations. repeated uint32 opcodes_to_suffix = 4; // Operations with names containing these substrings should be suffixed with // reduce-precision operations. repeated string opname_substrings_to_suffix = 5; } // Debugging options for XLA. These options may change at any time - there are // no guarantees about backward or forward compatibility for these fields. message DebugOptions { // HLO modules matching this regex will be dumped to a .dot file throughout // various stages in compilation (file names are LOG(INFO)'d). Set to ".*" to // dump *all* HLO modules. string xla_generate_hlo_graph = 1; // Show addresses of HLO ops in graph dump. bool xla_hlo_graph_addresses = 2; // Path to dump HLO graphs to. string xla_hlo_graph_path = 4; // Dump HLO graphs as TensorFlow GraphDefs. bool xla_hlo_dump_as_graphdef = 5; // HLO modules matching this regex will be dumped to LOG(INFO). Set to ".*" to // dump *all* HLO modules. string xla_log_hlo_text = 6; // Dump all HLO modules as text into the provided directory path. string xla_generate_hlo_text_to = 7; // Dump Hlo after all hlo passes are executed as proto binary into this // directory. string xla_dump_optimized_hlo_proto_to = 8; // Instrument the computation to collect per-HLO cycle counts. bool xla_hlo_profile = 9; // Dumps computations that XLA executes into the provided directory path. string xla_dump_computations_to = 10; // Dumps parameters and results of computations that XLA executes into the // provided directory path. string xla_dump_executions_to = 11; // List of HLO passes to disable. These names must exactly match the pass // names as specified by the HloPassInterface::name() method. repeated string xla_disable_hlo_passes = 30; // Numerical optimization level for the XLA compiler backend; the specific // interpretation of this value is left to the backends. int32 xla_backend_optimization_level = 31; // Embed the compiler IR as a string in the executable. bool xla_embed_ir_in_executable = 33; // Dump the compiler IR into this directory as individual files. string xla_dump_ir_to = 34; // Eliminate implicit broadcasts when lowering user computations to HLO // instructions; use explicit broadcast instead. bool xla_eliminate_hlo_implicit_broadcast = 35; // When generating calls to Eigen in the CPU backend, use multi-threaded Eigen // mode. bool xla_cpu_multi_thread_eigen = 60; // Path to directory with cuda/ptx tools and libraries. string xla_gpu_cuda_data_dir = 61; // Enable flush-to-zero semantics in the GPU backend. bool xla_gpu_ftz = 62; // Disable multi-streaming in the GPU backend. bool xla_gpu_disable_multi_streaming = 63; // If true, in LLVM-based backends, emit !alias.scope metadata in // generated IR. bool xla_llvm_enable_alias_scope_metadata = 70; // If true, in LLVM-based backends, emit !noalias metadata in the // generated IR. bool xla_llvm_enable_noalias_metadata = 71; // If true, in LLVM-based backends, emit !invariant.load metadata in // the generated IR. bool xla_llvm_enable_invariant_load_metadata = 72; // If true, a set of expensive LLVM optimization passes will not be run. bool xla_llvm_disable_expensive_passes = 73; // Options for inserting reduce-precision operations for numerical // experimentation. This is a repeated field, as we may want to have // multiple passes with different parameters. repeated HloReducePrecisionOptions hlo_reduce_precision_options = 80; // This is used by ClientLibraryTestBase::ComputeAndCompare*. If true, the // computation will run n! times with all permunations of layouts for the // output shape in rank n. For example, with a 3D shape, all permutations of // the set {0, 1, 2} are tried. bool xla_test_all_output_layouts = 90; // This is used by ClientLibraryTestBase::ComputeAndCompare*. If true, the // computation will run for all permunations of layouts of all input // arguments. For example, with 2 input arguments in 2D and 4D shapes, the // computation will run 2! * 4! times. bool xla_test_all_input_layouts = 91; // Assign colors based on sharding information when generating the Graphviz // HLO graph. bool xla_hlo_graph_sharding_color = 92; // Prefix the name scopes of the TF graph exports with "devX" device // assignments, if available. bool xla_hlo_tfgraph_device_scopes = 93; // If true, the GPU backend is free to use cudnn for HLO batch normalization // ops. bool xla_gpu_use_cudnn_batchnorm = 94; // Dump HLO before any hlo passes are executed as proto binary into this // directory. string xla_dump_unoptimized_hlo_proto_to = 95; // Dump HLO after each pass as an HloProto in binary file format into this // directory. string xla_dump_per_pass_hlo_proto_to = 96; // Generate calls to MKL-DNN in the CPU backend. bool xla_cpu_use_mkl_dnn = 97; // Maximum kernel unroll factor for the GPU backend. int32 xla_gpu_max_kernel_unroll_factor = 98; // When true, "unsafe" mathematical optimizations are enabled. These // transformations include but are not limited to: // // - Reducing the precision of operations (e.g. using an approximate sin // function, or transforming x/y into x * (1/y)). // - Assuming that operations never produce or consume NaN or +/- Inf. // - Assuming that +0 and -0 are indistinguishable. bool xla_cpu_enable_fast_math = 99; bool xla_gpu_enable_fast_math = 100; // Crashes the program when any kind of verification fails, instead of just // logging the failures. One example is cross checking of convolution results // among different algorithms. bool xla_gpu_crash_on_verification_failures = 101; // Force the host platform to pretend that there are these many host // "devices". All these devices are backed by the same threadpool. Defaults // to 1. // // Setting this to anything other than 1 can increase overhead from context // switching but we let the user override this behavior to help run tests on // the host that run models in parallel across multiple devices. int32 xla_force_host_platform_device_count = 102; // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. map xla_backend_extra_options = 500; } // These settings control how XLA compiles and/or runs code. Not all settings // will have an effect on every platform. // // When adding new fields, keep in mind that boolean fields default to false. message ExecutionOptions { // This optional field's layout is used as a hint when storing the output of // this computation. Subsequent transfers of this output array to the client // may be faster when using this layout. // // We use a Shape here to accommodate computations that return a tuple. Shape shape_with_output_layout = 2; // Used to seed random-number generators used in this computation. If this is // 0, we generate a seed ourselves. // // TODO(b/32083678): Changing the seed unnecessarily forces a recompilation. uint64 seed = 3; DebugOptions debug_options = 4; // This optional field specifies a particular set of devices to run the // computation on. The computation will be partitioned across these devices. // If not provided, the default device will be chosen. repeated DeviceHandle device_handles = 5; } message GetDeviceHandlesRequest { int64 device_count = 1; } message GetDeviceHandlesResponse { repeated DeviceHandle device_handles = 1; } message TransferToClientRequest { GlobalDataHandle data = 1; // This optional field directs the service to return the literal in this // layout. A shape is used to hold the layout to accommodate tuples. Shape shape_with_layout = 2; } message TransferToClientResponse { LiteralProto literal = 1; } message TransferToServerRequest { LiteralProto literal = 1; DeviceHandle device_handle = 2; } message TransferToServerResponse { GlobalDataHandle data = 1; } message TransferToInfeedRequest { LiteralProto literal = 1; int64 replica_id = 2; DeviceHandle device_handle = 3; } message TransferToInfeedResponse { } message TransferFromOutfeedRequest { // This optional field directs the service to return the literal in this // layout. A shape is used to hold the layout to accommodate tuples. Shape shape_with_layout = 1; int64 replica_id = 2; DeviceHandle device_handle = 3; } message TransferFromOutfeedResponse { LiteralProto literal = 1; } message ResetDeviceRequest { DeviceHandle device_handle = 1; } message ResetDeviceResponse { } message ComputationGraphStatsRequest { HloModuleProto computation = 1; DebugOptions debug_options = 2; } message ComputationStatsResponse { ComputationStats stats = 1; } message CreateChannelHandleRequest { ChannelHandle.ChannelType channel_type = 1; } message CreateChannelHandleResponse { ChannelHandle channel = 1; } message UnregisterRequest { GlobalDataHandle data = 1; } message UnregisterResponse { } message ExecuteGraphRequest { HloModuleProto computation = 1; repeated GlobalDataHandle arguments = 2; // Options that affect how XLA compiles and runs code to service this request. ExecutionOptions execution_options = 3; } message ExecuteGraphParallelRequest { repeated ExecuteGraphRequest requests = 1; } message ExecuteResponse { GlobalDataHandle output = 1; ExecutionProfile profile = 2; } message ExecuteParallelResponse { repeated ExecuteResponse responses = 1; } message WaitForExecutionRequest { ExecutionHandle execution = 1; } message WaitForExecutionResponse { GlobalDataHandle output = 1; ExecutionProfile profile = 2; } message ComputeConstantGraphRequest { HloModuleProto computation = 1; Layout output_layout = 2; } message ComputeConstantResponse { // A LiteralProto is returned directly for this request. LiteralProto literal = 1; } message DeconstructTupleRequest { GlobalDataHandle tuple_handle = 2; } message DeconstructTupleResponse { repeated GlobalDataHandle element_handles = 1; } message LoadDataRequest { // Describes the path of the ColumnIO tablet to load. string columnio_tablet_path = 1; // Describes the field to load within the ColumnIO tablet. string columnio_field = 2; // Individual element shape, excluding rows. Shape element_shape = 3; // Warning: ColumnIO does not support random-access, so use offset with // caution in performance-critical scenarios. int64 offset = 4; // Maximum number of elements (with shape element_shape) to load. int64 limit = 5; // If more than one item is requested (via limit > 1), then this request // attribute zips together the produced vectors. bool zip = 6; } message LoadDataResponse { GlobalDataHandle data = 1; Shape data_shape = 2; int64 available_rows = 3; int64 rows_loaded = 4; int64 nanoseconds = 5; } message GetShapeRequest { GlobalDataHandle data = 1; } message GetShapeResponse { Shape shape = 1; } message UnpackRequest { GlobalDataHandle data = 1; } message UnpackResponse { repeated GlobalDataHandle tied_data = 1; }