diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/BUILD')
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/BUILD | 33 |
1 files changed, 25 insertions, 8 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 64b9683628..51968d13d4 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -68,9 +68,7 @@ cc_library( # srcs = [ # "partition_assignment_test.cc", # ], -# tags = [ -# "requires-gpu-sm35", -# ], +# tags = tf_cuda_tests_tags(), # deps = [ # ":partition_assignment", # "//tensorflow/core:stream_executor_no_cuda", @@ -373,7 +371,6 @@ cc_library( hdrs = ["ir_emission_utils.h"], deps = [ ":backend_configs", - ":cudnn_convolution_runner", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", @@ -414,6 +411,8 @@ cc_library( srcs = ["cudnn_convolution_runner.cc"], hdrs = ["cudnn_convolution_runner.h"], deps = [ + ":backend_configs", + ":ir_emission_utils", ":stream_executor_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", @@ -422,8 +421,10 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", ], ) @@ -432,6 +433,7 @@ cc_library( srcs = ["cudnn_convolution_rewriter.cc"], hdrs = ["cudnn_convolution_rewriter.h"], deps = [ + ":backend_configs", ":ir_emission_utils", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:util", @@ -596,14 +598,11 @@ cc_library( hdrs = ["pad_for_tensor_cores.h"], deps = [ ":ir_emission_utils", - "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:hlo_creation_utils", + "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:hlo_pass", - "//tensorflow/compiler/xla/service:shape_inference", ], ) @@ -656,6 +655,7 @@ cc_library( deps = [ ":cudnn_convolution_algorithm_picker", ":cudnn_convolution_rewriter", + ":cudnn_fused_convolution_rewriter", ":fusion_merger", ":gpu_constants", ":gpu_copy_insertion", @@ -783,6 +783,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:computation_layout", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:layout_assignment", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", @@ -967,3 +968,19 @@ tf_cc_test( "@com_google_absl//absl/strings", ], ) + +cc_library( + name = "cudnn_fused_convolution_rewriter", + srcs = ["cudnn_fused_convolution_rewriter.cc"], + hdrs = ["cudnn_fused_convolution_rewriter.h"], + deps = [ + ":backend_configs", + ":ir_emission_utils", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/compiler/xla/service:pattern_matcher", + "//tensorflow/core:stream_executor_no_cuda", + ], +) |