aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service
diff options
context:
space:
mode:
authorGravatar Yunxing Dai <yunxing@google.com>2018-08-21 17:08:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-21 17:12:22 -0700
commit3cb3a450ed845c4602080f43d7bb6cfade298a22 (patch)
tree21e5a40fcb72737bd2cb00829bd049c5173d20e0 /tensorflow/compiler/xla/service
parent95d718a8a41370f31ccb3b32aaac7fd00b0291e4 (diff)
[XLA] gtl::optional->absl::optional
PiperOrigin-RevId: 209686671
Diffstat (limited to 'tensorflow/compiler/xla/service')
-rw-r--r--tensorflow/compiler/xla/service/BUILD13
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc2
-rw-r--r--tensorflow/compiler/xla/service/batchnorm_expander.cc4
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_normalization_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_options.cc9
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_options.h5
-rw-r--r--tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc2
-rw-r--r--tensorflow/compiler/xla/service/cpu/dot_op_emitter.h2
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD4
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_thunk.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/fft_thunk.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_executable.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_executable.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc12
-rw-r--r--tensorflow/compiler/xla/service/gpu/nvptx_compiler.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h6
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc10
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h10
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc11
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h16
-rw-r--r--tensorflow/compiler/xla/service/hlo_lexer.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_matchers.h9
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_config.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_metadata.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_metadata.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.cc15
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.h4
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.cc11
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/BUILD3
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc8
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h4
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/sort_util.cc6
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/sort_util.h4
-rw-r--r--tensorflow/compiler/xla/service/local_service.cc8
-rw-r--r--tensorflow/compiler/xla/service/while_loop_analysis.cc4
-rw-r--r--tensorflow/compiler/xla/service/while_loop_analysis.h6
-rw-r--r--tensorflow/compiler/xla/service/while_loop_simplifier.cc5
44 files changed, 123 insertions, 119 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 88065c58ae..850e965a80 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -241,6 +241,7 @@ cc_library(
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -398,6 +399,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -1091,6 +1093,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -1248,6 +1251,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -1326,6 +1330,7 @@ cc_library(
"//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -1477,8 +1482,7 @@ cc_library(
deps = [
":hlo",
":hlo_evaluator",
- "//tensorflow/compiler/xla:literal",
- "//tensorflow/core:lib",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -1493,6 +1497,7 @@ cc_library(
":while_loop_analysis",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -2663,6 +2668,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_proto",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -2746,6 +2752,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:regexp_internal",
+ "@com_google_absl//absl/types:optional",
],
alwayslink = 1,
)
@@ -3107,6 +3114,7 @@ cc_library(
"//tensorflow/core:ptr_util",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -3173,6 +3181,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:regexp_internal",
+ "@com_google_absl//absl/types:optional",
],
)
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 1d26e30651..0a040b5d16 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/memory/memory.h"
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
@@ -43,7 +44,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc
index c4cd60c120..01931b2d02 100644
--- a/tensorflow/compiler/xla/service/batchnorm_expander.cc
+++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
@@ -35,7 +36,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -43,7 +43,7 @@ namespace xla {
namespace {
-using tensorflow::gtl::optional;
+using absl::optional;
// BatchNormExpanderVisitor traverses the HLO computation and rewrites BatchNorm
// operations into smaller operations.
diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc
index 7cf05ca443..a44756e136 100644
--- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc
@@ -236,7 +236,7 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) {
builder.AddInstruction(HloInstruction::CreateCrossReplicaSum(
ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b},
sum, /*replica_group_ids=*/{}, /*barrier=*/"",
- /*all_reduce_id=*/tensorflow::gtl::nullopt));
+ /*all_reduce_id=*/absl::nullopt));
HloInstruction* gte_a = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(f32_shape, crs, 0));
HloInstruction* gte_b = builder.AddInstruction(
diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
index f9f1f64998..303ceac2e0 100644
--- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
@@ -252,7 +252,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) {
builder.AddInstruction(HloInstruction::CreateCrossReplicaSum(
ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}, reduction,
/*replica_group_ids=*/{}, /*barrier=*/"",
- /*all_reduce_id=*/tensorflow::gtl::nullopt));
+ /*all_reduce_id=*/absl::nullopt));
HloInstruction* gte = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1));
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc
index aa872d5ec9..69acca86bf 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc
@@ -34,8 +34,8 @@ namespace cpu {
// instruction stream.
namespace {
-using ::tensorflow::gtl::nullopt;
-using ::tensorflow::gtl::optional;
+using ::absl::nullopt;
+using ::absl::optional;
using ShouldMakeOperandColMajorCache =
tensorflow::gtl::FlatMap<const HloInstruction*, bool>;
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.cc b/tensorflow/compiler/xla/service/cpu/cpu_options.cc
index 3ed7876715..b6039b465e 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_options.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_options.cc
@@ -45,8 +45,7 @@ bool VectorizedReduceDisabled(const HloModuleConfig& config) {
return extra_options_map.count(kXlaOptimizeForSizeCpuOption) > 0;
}
-tensorflow::gtl::optional<int64> LlvmIrGemvTilingFactor(
- const HloModuleConfig& config) {
+absl::optional<int64> LlvmIrGemvTilingFactor(const HloModuleConfig& config) {
const auto& extra_options_map =
config.debug_options().xla_backend_extra_options();
auto it = extra_options_map.find(kLlvmIrDotTilingFactor);
@@ -55,7 +54,7 @@ tensorflow::gtl::optional<int64> LlvmIrGemvTilingFactor(
tensorflow::strings::safe_strto64(it->second, &tiling_factor)) {
return tiling_factor;
}
- return tensorflow::gtl::nullopt;
+ return absl::nullopt;
}
bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config) {
@@ -71,13 +70,13 @@ static tensorflow::StringPiece RemoveSuffix(tensorflow::StringPiece str,
return str.substr(0, str.size() - suffix.size());
}
-tensorflow::gtl::optional<std::tuple<int64, int64, int64>> LlvmIrGemmTileSize(
+absl::optional<std::tuple<int64, int64, int64>> LlvmIrGemmTileSize(
const HloModuleConfig& config) {
const auto& extra_options_map =
config.debug_options().xla_backend_extra_options();
auto it = extra_options_map.find(kLlvmIrGemmTileSize);
if (it == extra_options_map.end()) {
- return tensorflow::gtl::nullopt;
+ return absl::nullopt;
}
std::vector<string> tile_components =
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.h b/tensorflow/compiler/xla/service/cpu/cpu_options.h
index 429b9e16cb..47c7eb13b6 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_options.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_options.h
@@ -27,9 +27,8 @@ namespace options {
bool OptimizeForSizeRequested(const HloModuleConfig& config);
bool VectorizedReduceDisabled(const HloModuleConfig& config);
bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config);
-tensorflow::gtl::optional<int64> LlvmIrGemvTilingFactor(
- const HloModuleConfig& config);
-tensorflow::gtl::optional<std::tuple<int64, int64, int64>> LlvmIrGemmTileSize(
+absl::optional<int64> LlvmIrGemvTilingFactor(const HloModuleConfig& config);
+absl::optional<std::tuple<int64, int64, int64>> LlvmIrGemmTileSize(
const HloModuleConfig& config);
} // namespace options
diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
index 414f185fdf..797392c265 100644
--- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
@@ -1620,7 +1620,7 @@ bool PotentiallyImplementedAsEigenDot(
// For vector-matrix dot products, it is always profitable to make the Rhs
// column major.
-tensorflow::gtl::optional<int64> ProfitableToMakeDotOperandColumnMajor(
+absl::optional<int64> ProfitableToMakeDotOperandColumnMajor(
const HloInstruction& hlo) {
if (hlo.opcode() == HloOpcode::kDot && hlo.shape().dimensions_size() == 2 &&
hlo.shape().dimensions(0) == 1) {
diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
index aeead3844b..05322faa75 100644
--- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
@@ -38,7 +38,7 @@ bool PotentiallyImplementedAsEigenDot(
// Returns the index for an operand to `hlo` that should ideally be column
// major. Returns nullopt if there is no such operand or if `hlo` is not a dot
// or a fusion containing a dot.
-tensorflow::gtl::optional<int64> ProfitableToMakeDotOperandColumnMajor(
+absl::optional<int64> ProfitableToMakeDotOperandColumnMajor(
const HloInstruction& hlo);
// Returns true to indicate that we can generate a tiled LLVM IR implementation
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc
index addb016b04..5ab0756219 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc
@@ -24,7 +24,7 @@ limitations under the License.
namespace xla {
namespace {
-using tensorflow::gtl::nullopt;
+using absl::nullopt;
class ElementalIrEmitterExecutionTest : public HloTestBase {
protected:
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index d5c4854c89..fbef487ac8 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -186,6 +186,7 @@ cc_library(
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/types:optional",
"@llvm//:core",
"@llvm//:support",
],
@@ -346,6 +347,7 @@ cc_library(
"//tensorflow/core/platform/default/build_config:stream_executor_cuda", # build_cleaner: keep
"//tensorflow/stream_executor",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -382,6 +384,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -684,6 +687,7 @@ cc_library(
"//tensorflow/core:regexp_internal",
"//tensorflow/core:stream_executor_no_cuda",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/types:optional",
"@llvm//:core",
],
alwayslink = True, # Contains compiler registration
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
index d76ca6698d..f7952787c1 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONVOLUTION_THUNK_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONVOLUTION_THUNK_H_
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
@@ -26,7 +27,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
index caeb89d78e..5a8fc76e85 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
@@ -14,12 +14,12 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h"
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h"
#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/mutex.h"
@@ -28,10 +28,10 @@ namespace xla {
namespace gpu {
namespace {
+using absl::optional;
using se::DeviceMemoryBase;
using se::dnn::AlgorithmConfig;
using se::dnn::AlgorithmDesc;
-using tensorflow::gtl::optional;
class ScratchAllocator : public se::ScratchAllocator {
public:
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
index 8b7749628a..472de2ff0f 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
@@ -16,12 +16,12 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.h b/tensorflow/compiler/xla/service/gpu/fft_thunk.h
index 8c53be5077..4adec7ee54 100644
--- a/tensorflow/compiler/xla/service/gpu/fft_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FFT_THUNK_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FFT_THUNK_H_
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
@@ -25,7 +26,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
index a1fbd8022d..88be63e267 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
@@ -112,7 +112,7 @@ Status GpuExecutable::ExecuteThunks(
//
// TODO(jlebar): Should we cache the results of HloInstruction::ToString(),
// since we expect it to be an expensive call?
- tensorflow::gtl::optional<ScopedAnnotation> op_annotation;
+ absl::optional<ScopedAnnotation> op_annotation;
if (top_level_annotation.IsEnabled()) {
op_annotation.emplace(
thunk->hlo_instruction() != nullptr
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h
index c7ce6d0acb..09a1d9c12b 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h
@@ -19,6 +19,7 @@ limitations under the License.
#include <memory>
#include <string>
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/executable.h"
@@ -35,7 +36,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 362fd5913a..bda2986202 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/container/inlined_vector.h"
#include "absl/memory/memory.h"
+#include "absl/types/optional.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Function.h"
@@ -79,7 +80,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/bits.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
@@ -88,11 +88,11 @@ namespace gpu {
namespace {
using absl::InlinedVector;
+using absl::nullopt;
+using absl::optional;
using llvm_ir::IrArray;
using llvm_ir::IrName;
using tensorflow::gtl::ArraySlice;
-using tensorflow::gtl::nullopt;
-using tensorflow::gtl::optional;
using tensorflow::strings::StrCat;
// If a dimensions is smaller than this, untiled transposition may be more
@@ -2098,9 +2098,9 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
TF_RETURN_IF_ERROR(llvm_ir::EmitSortInPlace(
dimension_to_sort, GetIrArray(*sort, *sort, keys_shape_index),
- values != nullptr ? tensorflow::gtl::make_optional<IrArray>(
+ values != nullptr ? absl::make_optional<IrArray>(
GetIrArray(*sort, *sort, values_shape_index))
- : tensorflow::gtl::nullopt,
+ : absl::nullopt,
IrName(sort), xor_mask, &b_, &launch_dimensions));
}
}
@@ -2308,7 +2308,7 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
for (const auto& kv : hlo_slices) {
buffers_needed.insert(kv.second.first.allocation());
}
- tensorflow::gtl::optional<const BufferAllocation*> temp_buffer;
+ absl::optional<const BufferAllocation*> temp_buffer;
for (const BufferAllocation& alloc : buffer_assn.Allocations()) {
if (alloc.IsPreallocatedTempBuffer()) {
if (!temp_buffer.has_value()) {
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h
index d4d2909f1b..08ef6ef56c 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h
@@ -20,13 +20,13 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/llvm_compiler.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index f9cd00bd78..c0109c30f7 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -19,11 +19,11 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/container/inlined_vector.h"
#include "absl/memory/memory.h"
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/core/lib/core/casts.h"
-#include "tensorflow/core/lib/gtl/optional.h"
namespace xla {
@@ -1672,8 +1672,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// 2. Using the selected index, scatter value from `source` to result. We
// do this by iterating through the window, and compare each index with
// the selected index.
- tensorflow::gtl::optional<ReturnT> selected_val;
- tensorflow::gtl::optional<std::vector<int64>> selected_index;
+ absl::optional<ReturnT> selected_val;
+ absl::optional<std::vector<int64>> selected_index;
IterateThroughWindow(
window_shape, window, operand_literal.shape(), source_index,
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index a4ea21c692..f8ade39e8c 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include <unordered_map>
#include <vector>
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
@@ -37,7 +38,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/map_util.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
@@ -47,10 +47,10 @@ limitations under the License.
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/regexp.h"
+using ::absl::nullopt;
+using ::absl::optional;
using ::tensorflow::Env;
using ::tensorflow::WriteStringToFile;
-using ::tensorflow::gtl::nullopt;
-using ::tensorflow::gtl::optional;
using ::tensorflow::io::JoinPath;
using ::tensorflow::str_util::Join;
using ::tensorflow::str_util::StringReplace;
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index cb2264d08d..8a9856c1da 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -296,7 +296,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
TF_RET_CHECK(proto.called_computation_ids_size() == 1)
<< "CrossReplicaSum should have 1 called computation but sees "
<< proto.called_computation_ids_size();
- tensorflow::gtl::optional<int64> all_reduce_id;
+ absl::optional<int64> all_reduce_id;
if (proto.all_reduce_id() > 0) {
all_reduce_id = proto.all_reduce_id();
}
@@ -666,7 +666,7 @@ HloInstruction::CreateCrossReplicaSum(
HloComputation* reduce_computation,
tensorflow::gtl::ArraySlice<int64> replica_group_ids,
tensorflow::StringPiece barrier,
- const tensorflow::gtl::optional<int64>& all_reduce_id) {
+ const absl::optional<int64>& all_reduce_id) {
return absl::make_unique<HloAllReduceInstruction>(
shape, operands, reduce_computation, replica_group_ids, barrier,
all_reduce_id);
@@ -1836,7 +1836,7 @@ string HloInstruction::ToString(const HloPrintOptions& options) const {
}
bool HloInstruction::IsElementwiseImpl(
- const tensorflow::gtl::optional<int64>& operand_idx) const {
+ const absl::optional<int64>& operand_idx) const {
switch (opcode_) {
// Unary elementwise operations.
case HloOpcode::kAbs:
@@ -2623,7 +2623,7 @@ bool HloInstruction::IsElementwiseBinary() const {
}
bool HloInstruction::IsElementwise() const {
- return IsElementwiseImpl(tensorflow::gtl::nullopt);
+ return IsElementwiseImpl(absl::nullopt);
}
bool HloInstruction::ImplicitlyBroadcastsOperand(int64 operand_idx) const {
@@ -3156,7 +3156,7 @@ void HloInstruction::set_cross_replica_sum_barrier(const string& barrier) {
barrier);
}
-tensorflow::gtl::optional<int64> HloInstruction::all_reduce_id() const {
+absl::optional<int64> HloInstruction::all_reduce_id() const {
return Cast<HloAllReduceInstruction>(this)->all_reduce_id();
}
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 41bb40b7bd..69397a4b37 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -449,7 +449,7 @@ class HloInstruction {
HloComputation* reduce_computation,
tensorflow::gtl::ArraySlice<int64> replica_group_ids,
tensorflow::StringPiece barrier,
- const tensorflow::gtl::optional<int64>& all_reduce_id);
+ const absl::optional<int64>& all_reduce_id);
// This op handles the communication of an Alltoall operation. On each core,
// the operands are N ops in the same shape, where N is the number of cores
@@ -1038,9 +1038,9 @@ class HloInstruction {
return sharding_ ? *sharding_ : default_;
}
// Returns the sharding unique device, if any.
- tensorflow::gtl::optional<int64> sharding_unique_device() const {
+ absl::optional<int64> sharding_unique_device() const {
if (sharding_ == nullptr) {
- return tensorflow::gtl::optional<int64>();
+ return absl::optional<int64>();
}
return sharding_->UniqueDevice();
}
@@ -1427,7 +1427,7 @@ class HloInstruction {
void set_cross_replica_sum_barrier(const string& barrier);
// Delegates to HloAllReduceInstruction::all_reduce_id.
- tensorflow::gtl::optional<int64> all_reduce_id() const;
+ absl::optional<int64> all_reduce_id() const;
// Returns data on the window in a windowed operation such as
// convolution.
@@ -1557,7 +1557,7 @@ class HloInstruction {
// NOTE: For all instructions other than kFusion, being elementwise on one of
// the operands is equivalent to being elementwise on all the operands.
virtual bool IsElementwiseImpl(
- const tensorflow::gtl::optional<int64>& operand_idx) const;
+ const absl::optional<int64>& operand_idx) const;
// Prints an instruction to a string.
//
// The canonical string representation needs to name operands and instruction
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index e91cabbb72..dbafa35b2a 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -301,8 +301,7 @@ HloAllReduceInstruction::HloAllReduceInstruction(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
HloComputation* reduce_computation,
tensorflow::gtl::ArraySlice<int64> replica_group_ids,
- tensorflow::StringPiece barrier,
- const tensorflow::gtl::optional<int64>& all_reduce_id)
+ tensorflow::StringPiece barrier, const absl::optional<int64>& all_reduce_id)
: HloInstruction(HloOpcode::kCrossReplicaSum, shape),
replica_group_ids_(replica_group_ids.begin(), replica_group_ids.end()),
cross_replica_sum_barrier_(barrier.begin(), barrier.end()),
@@ -702,7 +701,7 @@ HloInstructionProto HloMapInstruction::ToProto() const {
}
bool HloMapInstruction::IsElementwiseImpl(
- const tensorflow::gtl::optional<int64>& operand_idx) const {
+ const absl::optional<int64>& operand_idx) const {
if (!dimensions().empty()) {
// Check that the map is executed in elementwise compatible dimensions.
if (dimensions().size() != shape().dimensions_size()) {
@@ -815,7 +814,7 @@ HloInstructionProto HloConstantInstruction::ToProto() const {
}
bool HloConstantInstruction::IsElementwiseImpl(
- const tensorflow::gtl::optional<int64>& operand_idx) const {
+ const absl::optional<int64>& operand_idx) const {
return true;
}
@@ -955,7 +954,7 @@ HloInstructionProto HloFusionInstruction::ToProto() const {
}
bool HloFusionInstruction::IsElementwiseImpl(
- const tensorflow::gtl::optional<int64>& operand_idx) const {
+ const absl::optional<int64>& operand_idx) const {
if (!operand_idx.has_value()) {
for (auto* fused : fused_instructions()) {
if (fused->opcode() != HloOpcode::kParameter && !fused->IsElementwise()) {
@@ -1387,7 +1386,7 @@ std::vector<string> HloRngInstruction::ExtraAttributesToStringImpl(
}
bool HloRngInstruction::IsElementwiseImpl(
- const tensorflow::gtl::optional<int64>& operand_idx) const {
+ const absl::optional<int64>& operand_idx) const {
return true;
}
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index 1152fa83ed..93e4c21b2f 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -225,7 +225,7 @@ class HloAllReduceInstruction : public HloInstruction {
HloComputation* reduce_computation,
tensorflow::gtl::ArraySlice<int64> replica_group_ids,
tensorflow::StringPiece barrier,
- const tensorflow::gtl::optional<int64>& all_reduce_id);
+ const absl::optional<int64>& all_reduce_id);
// Returns the group ids of each replica for CrossReplicaSum op.
const std::vector<int64>& replica_group_ids() const {
@@ -241,9 +241,7 @@ class HloAllReduceInstruction : public HloInstruction {
cross_replica_sum_barrier_ = barrier;
}
- tensorflow::gtl::optional<int64> all_reduce_id() const {
- return all_reduce_id_;
- }
+ absl::optional<int64> all_reduce_id() const { return all_reduce_id_; }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
@@ -271,7 +269,7 @@ class HloAllReduceInstruction : public HloInstruction {
// For Allreduce nodes from different modules, if they have the same
// all_reduce_id, they will be 'Allreduce'd. If empty, Allreduce will not be
// applied cross modules.
- tensorflow::gtl::optional<int64> all_reduce_id_;
+ absl::optional<int64> all_reduce_id_;
};
class HloAllToAllInstruction : public HloInstruction {
@@ -508,7 +506,7 @@ class HloMapInstruction : public HloInstruction {
private:
bool IsElementwiseImpl(
- const tensorflow::gtl::optional<int64>& operand_idx) const override;
+ const absl::optional<int64>& operand_idx) const override;
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
@@ -601,7 +599,7 @@ class HloConstantInstruction : public HloInstruction {
private:
bool IsElementwiseImpl(
- const tensorflow::gtl::optional<int64>& operand_idx) const override;
+ const absl::optional<int64>& operand_idx) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
@@ -752,7 +750,7 @@ class HloFusionInstruction : public HloInstruction {
bool add_output = false);
bool IsElementwiseImpl(
- const tensorflow::gtl::optional<int64>& operand_idx) const override;
+ const absl::optional<int64>& operand_idx) const override;
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
@@ -781,7 +779,7 @@ class HloRngInstruction : public HloInstruction {
private:
bool IsElementwiseImpl(
- const tensorflow::gtl::optional<int64>& operand_idx) const override;
+ const absl::optional<int64>& operand_idx) const override;
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
diff --git a/tensorflow/compiler/xla/service/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc
index 8e0d38b6a6..4d54186e95 100644
--- a/tensorflow/compiler/xla/service/hlo_lexer.cc
+++ b/tensorflow/compiler/xla/service/hlo_lexer.cc
@@ -17,10 +17,10 @@ limitations under the License.
#include <unordered_map>
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/regexp.h"
diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h
index c577b4359a..0a442e77f0 100644
--- a/tensorflow/compiler/xla/service/hlo_matchers.h
+++ b/tensorflow/compiler/xla/service/hlo_matchers.h
@@ -16,10 +16,10 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/core/lib/gtl/optional.h"
namespace xla {
namespace testing {
@@ -120,8 +120,7 @@ class HloShapeAndLayoutMatcher
class HloShardingMatcher
: public ::testing::MatcherInterface<const HloInstruction*> {
public:
- explicit HloShardingMatcher(
- const tensorflow::gtl::optional<HloSharding>& sharding)
+ explicit HloShardingMatcher(const absl::optional<HloSharding>& sharding)
: sharding_(sharding) {}
bool MatchAndExplain(const HloInstruction* instruction,
@@ -129,7 +128,7 @@ class HloShardingMatcher
void DescribeTo(std::ostream* os) const override;
private:
- tensorflow::gtl::optional<HloSharding> sharding_;
+ absl::optional<HloSharding> sharding_;
};
// Matches a Dot HLO instruction with specific LHS and RHS contracting
@@ -337,7 +336,7 @@ inline ::testing::Matcher<const ::xla::HloInstruction*> Sharding(
// Verifies that no HloSharding is set for an HLO instruction.
inline ::testing::Matcher<const ::xla::HloInstruction*> NoSharding() {
return ::testing::MakeMatcher(
- new ::xla::testing::HloShardingMatcher(tensorflow::gtl::nullopt));
+ new ::xla::testing::HloShardingMatcher(absl::nullopt));
}
inline ::testing::Matcher<const ::xla::HloInstruction*> Dot(
diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h
index 01f3acd8fb..3f1e1cc73e 100644
--- a/tensorflow/compiler/xla/service/hlo_module_config.h
+++ b/tensorflow/compiler/xla/service/hlo_module_config.h
@@ -18,11 +18,11 @@ limitations under the License.
#include <string>
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla.pb.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/optional.h"
namespace xla {
@@ -104,7 +104,7 @@ class HloModuleConfig {
private:
// If you add new members, be sure to update compilation_cache_key.
- tensorflow::gtl::optional<ComputationLayout> entry_computation_layout_;
+ absl::optional<ComputationLayout> entry_computation_layout_;
// Whether this is a 'host module'.
bool is_host_module_ = false;
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
index beb4c4fb8a..f52a37bc74 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
@@ -271,15 +271,14 @@ int64 HloModuleGroupMetadata::GetModuleId(const HloModule* module) const {
LOG(FATAL) << "unknown module";
}
-tensorflow::gtl::optional<int64> HloModuleGroupMetadata::GetInstructionDevice(
+absl::optional<int64> HloModuleGroupMetadata::GetInstructionDevice(
const HloInstruction& instruction) const {
// The module group metadata can be created in both "single module, multiple
// devices" and "multiple modules, no explicit devices" fashions.
// The API returns an optional even though the current implementation always
// returns a device, to account for cases where we cannot guess a device.
// In such cases the VerifyChannelInstructions() will return proper errors.
- tensorflow::gtl::optional<int64> device =
- instruction.sharding_unique_device();
+ absl::optional<int64> device = instruction.sharding_unique_device();
if (!device) {
device = GetModuleId(instruction.parent()->parent());
}
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
index 924c8fda71..dead6d9c20 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
+++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <unordered_set>
#include <vector>
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@@ -29,7 +30,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -159,7 +159,7 @@ class HloModuleGroupMetadata {
// Retrieves the device an instruction is assigned to. Either from the
// sharding information, or from the ordinal of the module the instruction
// is in.
- tensorflow::gtl::optional<int64> GetInstructionDevice(
+ absl::optional<int64> GetInstructionDevice(
const HloInstruction& instruction) const;
// Returns the number of modules for devices (excluding the host module).
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index aafd0e4efd..44180a881e 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -33,8 +33,8 @@ namespace xla {
namespace {
+using ::absl::optional;
using ::tensorflow::StringPiece;
-using ::tensorflow::gtl::optional;
using ::tensorflow::str_util::Join;
using ::tensorflow::str_util::Split;
using ::tensorflow::str_util::SplitAndParseAsInts;
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc
index 0cba9ebbcb..903fbbec1a 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding.cc
@@ -244,16 +244,16 @@ StatusOr<HloSharding> HloSharding::GetTupleSharding(const Shape& shape) const {
return Tuple(ShapeTree<HloSharding>(shape, *this));
}
-tensorflow::gtl::optional<int64> HloSharding::UniqueDevice() const {
+absl::optional<int64> HloSharding::UniqueDevice() const {
if (IsTuple()) {
if (tuple_elements_.empty()) {
- return tensorflow::gtl::nullopt;
+ return absl::nullopt;
}
- tensorflow::gtl::optional<int64> unique_device;
+ absl::optional<int64> unique_device;
for (auto& tuple_sharding : tuple_elements_) {
auto device = tuple_sharding.UniqueDevice();
if (!device || (unique_device && *device != *unique_device)) {
- return tensorflow::gtl::nullopt;
+ return absl::nullopt;
}
unique_device = device;
}
@@ -262,7 +262,7 @@ tensorflow::gtl::optional<int64> HloSharding::UniqueDevice() const {
if (!replicated_ && maximal_) {
return static_cast<int64>(*tile_assignment_.begin());
}
- return tensorflow::gtl::nullopt;
+ return absl::nullopt;
}
int64 HloSharding::GetUniqueDevice() const {
@@ -439,14 +439,13 @@ HloSharding HloSharding::GetSubSharding(const Shape& shape,
: sub_shape_tree.element(ShapeIndex({}));
}
-tensorflow::gtl::optional<HloSharding> HloSharding::ExtractSingleSharding()
- const {
+absl::optional<HloSharding> HloSharding::ExtractSingleSharding() const {
if (!IsTuple()) {
return *this;
}
for (int64 i = 1; i < tuple_elements_.size(); ++i) {
if (tuple_elements_[0] != tuple_elements_[i]) {
- return tensorflow::gtl::optional<HloSharding>();
+ return absl::optional<HloSharding>();
}
}
return tuple_elements_.front();
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h
index 894783e5d1..4c64ac60c5 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.h
+++ b/tensorflow/compiler/xla/service/hlo_sharding.h
@@ -151,7 +151,7 @@ class HloSharding {
// span a single device, the return value will be empty.
// In order for a sharding to span a single device, every leaf sharding must
// be maximal and not replicated, and the used device must match.
- tensorflow::gtl::optional<int64> UniqueDevice() const;
+ absl::optional<int64> UniqueDevice() const;
// Retrieves the unique device or fails with a CHECK.
int64 GetUniqueDevice() const;
@@ -182,7 +182,7 @@ class HloSharding {
// be returned. If it is a tuple, and all the tuple elements are common, the
// common element will be returned. Otherwise the optional will contain no
// value.
- tensorflow::gtl::optional<HloSharding> ExtractSingleSharding() const;
+ absl::optional<HloSharding> ExtractSingleSharding() const;
bool operator==(const HloSharding& other) const {
return replicated_ == other.replicated_ && maximal_ == other.maximal_ &&
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
index 6f8df2694b..256c8e5573 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
@@ -14,13 +14,14 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
+
#include "absl/algorithm/container.h"
#include "absl/container/inlined_vector.h"
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/gtl/flatset.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
@@ -971,15 +972,15 @@ namespace {
// Returns the non-contracting non-batch dimension (as per `contracting_dims`
// and `batch_dims`) if there is exactly one, otherwise returns nullopt.
-gtl::optional<int64> GetOnlyNonContractingNonBatchDim(
+absl::optional<int64> GetOnlyNonContractingNonBatchDim(
int64 rank, ArraySlice<int64> contracting_dims,
ArraySlice<int64> batch_dims) {
- gtl::optional<int64> result;
+ absl::optional<int64> result;
for (int64 dim = 0; dim < rank; dim++) {
if (!ArrayContains(contracting_dims, dim) &&
!ArrayContains(batch_dims, dim)) {
if (result.has_value()) {
- return gtl::nullopt;
+ return absl::nullopt;
}
result = dim;
}
@@ -999,7 +1000,7 @@ bool CanFoldDotIntoIndexedArray(
tensorflow::StringPiece tag,
Analysis::ScalarIndexedConstantArray* indexed_array,
ArraySlice<int64> contracting_dims, ArraySlice<int64> batch_dims) {
- gtl::optional<int64> non_contracting_non_batch_dim =
+ absl::optional<int64> non_contracting_non_batch_dim =
GetOnlyNonContractingNonBatchDim(ShapeUtil::Rank(indexed_array->shape()),
contracting_dims, batch_dims);
if (!non_contracting_non_batch_dim.has_value()) {
diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD
index ce2d6678a5..539a9522c1 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/BUILD
+++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD
@@ -134,9 +134,7 @@ cc_library(
":llvm_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:lib",
"@llvm//:core",
@@ -194,6 +192,7 @@ cc_library(
"//tensorflow/compiler/xla/service/gpu:parallel_loop_emitter",
"//tensorflow/compiler/xla/service/gpu:partition_assignment",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/types:optional",
"@llvm//:core",
],
)
diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc
index 35b3941272..cb4d1db997 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc
@@ -55,10 +55,10 @@ Shape MergeDimensions(tensorflow::gtl::ArraySlice<size_t> segs,
}
} // namespace
-tensorflow::gtl::optional<std::vector<int64> > FindTranspose021(
- const Shape& a, const Shape& b) {
+absl::optional<std::vector<int64> > FindTranspose021(const Shape& a,
+ const Shape& b) {
if (!ShapeUtil::CompatibleIgnoringElementType(a, b)) {
- return tensorflow::gtl::nullopt;
+ return absl::nullopt;
}
std::vector<int64> perm(a.dimensions().size());
@@ -88,7 +88,7 @@ tensorflow::gtl::optional<std::vector<int64> > FindTranspose021(
return dims_021;
}
- return tensorflow::gtl::nullopt;
+ return absl::nullopt;
}
IrArray::Index GetUnreducedOutputIndex(
diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h
index ccb9b8ba3e..8bd06c42c3 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h
@@ -36,8 +36,8 @@ namespace llvm_ir {
// If `b` is a 0-2-1 transpose of `a` in 0-1-2, return the dimensions for the
// reduced shape of `b` or the 0-2-1 shape.
-tensorflow::gtl::optional<std::vector<int64> > FindTranspose021(const Shape& a,
- const Shape& b);
+absl::optional<std::vector<int64> > FindTranspose021(const Shape& a,
+ const Shape& b);
// Return the unreduced output index corresponding to the given reduced output
// index.
diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc
index e546f5cc4a..c333311a7e 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h"
// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
+#include "absl/types/optional.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Instructions.h"
@@ -30,7 +31,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -42,7 +42,7 @@ namespace {
void EmitCompareLoop(int64 dimension_to_sort, const IrArray::Index& keys_index,
const IrArray::Index& compare_keys_index,
const IrArray& keys_array,
- const tensorflow::gtl::optional<IrArray>& values_array,
+ const absl::optional<IrArray>& values_array,
llvm::IRBuilder<>* b) {
// if (is_smaller_index &&
// compare_keys[dimension_to_sort] < dimension_to_sort_bound)
@@ -87,7 +87,7 @@ void EmitCompareLoop(int64 dimension_to_sort, const IrArray::Index& keys_index,
} // namespace
Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array,
- const tensorflow::gtl::optional<IrArray>& values_array,
+ const absl::optional<IrArray>& values_array,
tensorflow::StringPiece name, llvm::Value* xor_mask,
llvm::IRBuilder<>* b,
const gpu::LaunchDimensions* launch_dimensions) {
diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h
index 8458744c6b..39fffea931 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h
@@ -16,12 +16,12 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_SORT_UTIL_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_SORT_UTIL_H_
+#include "absl/types/optional.h"
#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -31,7 +31,7 @@ namespace llvm_ir {
// implements the inner loop of BitonicSort. If 'launch_dimensions' is nullptr,
// the inner compare loop will not be parallelized.
Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array,
- const tensorflow::gtl::optional<IrArray>& values_array,
+ const absl::optional<IrArray>& values_array,
tensorflow::StringPiece name, llvm::Value* xor_mask,
llvm::IRBuilder<>* b,
const gpu::LaunchDimensions* launch_dimensions);
diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc
index 597a788c5d..b7cb782a7e 100644
--- a/tensorflow/compiler/xla/service/local_service.cc
+++ b/tensorflow/compiler/xla/service/local_service.cc
@@ -73,7 +73,7 @@ namespace {
// If the parameter number is invalid for this computation, nullopt is
// returned. When the return value has_value(), nullptr will never be
// the held value.
-tensorflow::gtl::optional<const OpMetadata*> ParameterMetadata(
+absl::optional<const OpMetadata*> ParameterMetadata(
const XlaComputation& computation, int parameter_number) {
for (const HloComputationProto& comp : computation.proto().computations()) {
if (comp.id() == computation.proto().entry_computation_id()) {
@@ -81,14 +81,14 @@ tensorflow::gtl::optional<const OpMetadata*> ParameterMetadata(
if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter) &&
instr.parameter_number() == parameter_number) {
if (!instr.has_metadata()) {
- return tensorflow::gtl::nullopt;
+ return absl::nullopt;
}
return &instr.metadata();
}
}
}
}
- return tensorflow::gtl::nullopt;
+ return absl::nullopt;
}
ExecutionOptions CreateExecutionOptions(
@@ -158,7 +158,7 @@ StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
TF_RETURN_IF_ERROR(
ShapeUtil::ValidateShapeWithOptionalLayout(argument_shape));
if (!ShapeUtil::Compatible(argument_shape, program_shape.parameters(i))) {
- tensorflow::gtl::optional<const OpMetadata*> metadata =
+ absl::optional<const OpMetadata*> metadata =
ParameterMetadata(computation, /*parameter_number=*/i);
auto metadata_string = [&metadata]() -> string {
if (!metadata.has_value()) {
diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.cc b/tensorflow/compiler/xla/service/while_loop_analysis.cc
index af2cb6dc2a..7e4ac92a7c 100644
--- a/tensorflow/compiler/xla/service/while_loop_analysis.cc
+++ b/tensorflow/compiler/xla/service/while_loop_analysis.cc
@@ -18,8 +18,8 @@ limitations under the License.
namespace xla {
-using tensorflow::gtl::nullopt;
-using tensorflow::gtl::optional;
+using absl::nullopt;
+using absl::optional;
// Finds and returns the non-constant operand in instr.
//
diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.h b/tensorflow/compiler/xla/service/while_loop_analysis.h
index bf59813e8c..bf497f4892 100644
--- a/tensorflow/compiler/xla/service/while_loop_analysis.h
+++ b/tensorflow/compiler/xla/service/while_loop_analysis.h
@@ -16,8 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ANALYSIS_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ANALYSIS_H_
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
-#include "tensorflow/core/lib/gtl/optional.h"
namespace xla {
@@ -25,8 +25,8 @@ namespace xla {
// nullopt otherwise. max_value_returned limits the number of steps that are
// evaluated while trying to brute force a loop trip count, trip counts larger
// than max_value_returned result in nullopt.
-tensorflow::gtl::optional<int64> ComputeWhileLoopTripCount(
- HloInstruction *while_op, int64 max_value_returned = 128);
+absl::optional<int64> ComputeWhileLoopTripCount(HloInstruction *while_op,
+ int64 max_value_returned = 128);
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc
index dd8697e680..a24e2b0116 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc
@@ -14,17 +14,16 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/call_inliner.h"
#include "tensorflow/compiler/xla/service/while_loop_analysis.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
-using tensorflow::gtl::nullopt;
-using tensorflow::gtl::optional;
+using absl::optional;
// Determines whether the given instruction is a send/recv node, or has a
// subcomputation which contains a send/recv node.