aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Justin Lebar <jlebar@google.com>2018-10-09 17:19:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 17:23:55 -0700
commit5be479930d3dcfa3edb863703b1d73b89d45f03c (patch)
tree4d89676c2a1b6ddf0cc1da3873536d6471b50321
parent9bd459e4ceba14f9bb1af98d52a109325de952e8 (diff)
[XLA:GPU] Use CudnnConvKind in more places.
No functional change. PiperOrigin-RevId: 216451881
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc99
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc84
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_insertion.cc31
4 files changed, 116 insertions, 99 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 0144d59097..62da43d68a 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -591,6 +591,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:hlo_casting_utils",
"//tensorflow/compiler/xla/service:hlo_creation_utils",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/compiler/xla/service:shape_inference",
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
index 89dd1bb272..a809c22b33 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
@@ -312,11 +312,12 @@ StatusOr<CudnnConvParams> GetCudnnConvParams(
TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config,
conv->backend_config<CudnnConvBackendConfig>());
- const auto& target = conv->custom_call_target();
+ TF_ASSIGN_OR_RETURN(CudnnConvKind kind, GetCudnnConvKind(conv));
const auto& lhs_shape = conv->operand(0)->shape();
const auto& rhs_shape = conv->operand(1)->shape();
const auto& conv_result_shape = conv->shape().tuple_shapes(0);
+ params.kind = kind;
params.window = &conv->window();
params.dnums = &conv->convolution_dimension_numbers();
params.feature_group_count = conv->feature_group_count();
@@ -324,55 +325,55 @@ StatusOr<CudnnConvParams> GetCudnnConvParams(
backend_config.algorithm(), backend_config.tensor_ops_enabled()));
params.conv_result_scale = backend_config.conv_result_scale();
- if (target == kCudnnConvForwardCallTarget) {
- params.kind = CudnnConvKind::kForward;
- params.input_shape = &lhs_shape;
- params.filter_shape = &rhs_shape;
- params.output_shape = &conv_result_shape;
- params.input_buf = operand_buffers[0];
- params.filter_buf = operand_buffers[1];
- params.output_buf = result_buffer;
- } else if (target == kCudnnConvBackwardInputCallTarget) {
- params.kind = CudnnConvKind::kBackwardInput;
- params.input_shape = &conv_result_shape;
- params.filter_shape = &rhs_shape;
- params.output_shape = &lhs_shape;
- params.input_buf = result_buffer;
- params.filter_buf = operand_buffers[1];
- params.output_buf = operand_buffers[0];
- } else if (target == kCudnnConvBackwardFilterCallTarget) {
- params.kind = CudnnConvKind::kBackwardFilter;
- params.input_shape = &lhs_shape;
- params.filter_shape = &conv_result_shape;
- params.output_shape = &rhs_shape;
- params.input_buf = operand_buffers[0];
- params.filter_buf = result_buffer;
- params.output_buf = operand_buffers[1];
- } else if (target == kCudnnConvBiasActivationForwardCallTarget) {
- params.kind = CudnnConvKind::kForwardActivation;
- params.input_shape = &lhs_shape;
- params.filter_shape = &rhs_shape;
- params.output_shape = &conv_result_shape;
- params.fusion.emplace();
- auto& fusion = *params.fusion;
- if (backend_config.activation_mode() <
- static_cast<int64>(se::dnn::ActivationMode::kNumActivationModes)) {
- fusion.mode = static_cast<se::dnn::ActivationMode>(
- backend_config.activation_mode());
- } else {
- return InternalError("Bad activation mode: %s",
- backend_config.ShortDebugString());
- }
- fusion.side_input_scale = backend_config.side_input_scale();
- params.input_buf = operand_buffers[0];
- params.filter_buf = operand_buffers[1];
- params.output_buf = result_buffer;
- params.fusion->bias_buf = operand_buffers[2];
- if (operand_buffers.size() >= 4) {
- params.fusion->side_input_buf = operand_buffers[3];
+ switch (kind) {
+ case CudnnConvKind::kForward:
+ params.input_shape = &lhs_shape;
+ params.filter_shape = &rhs_shape;
+ params.output_shape = &conv_result_shape;
+ params.input_buf = operand_buffers[0];
+ params.filter_buf = operand_buffers[1];
+ params.output_buf = result_buffer;
+ break;
+ case CudnnConvKind::kBackwardInput:
+ params.input_shape = &conv_result_shape;
+ params.filter_shape = &rhs_shape;
+ params.output_shape = &lhs_shape;
+ params.input_buf = result_buffer;
+ params.filter_buf = operand_buffers[1];
+ params.output_buf = operand_buffers[0];
+ break;
+ case CudnnConvKind::kBackwardFilter:
+ params.input_shape = &lhs_shape;
+ params.filter_shape = &conv_result_shape;
+ params.output_shape = &rhs_shape;
+ params.input_buf = operand_buffers[0];
+ params.filter_buf = result_buffer;
+ params.output_buf = operand_buffers[1];
+ break;
+ case CudnnConvKind::kForwardActivation: {
+ params.kind = CudnnConvKind::kForwardActivation;
+ params.input_shape = &lhs_shape;
+ params.filter_shape = &rhs_shape;
+ params.output_shape = &conv_result_shape;
+ params.fusion.emplace();
+ auto& fusion = *params.fusion;
+ if (backend_config.activation_mode() <
+ static_cast<int64>(se::dnn::ActivationMode::kNumActivationModes)) {
+ fusion.mode = static_cast<se::dnn::ActivationMode>(
+ backend_config.activation_mode());
+ } else {
+ return InternalError("Bad activation mode: %s",
+ backend_config.ShortDebugString());
+ }
+ fusion.side_input_scale = backend_config.side_input_scale();
+ params.input_buf = operand_buffers[0];
+ params.filter_buf = operand_buffers[1];
+ params.output_buf = result_buffer;
+ params.fusion->bias_buf = operand_buffers[2];
+ if (operand_buffers.size() >= 4) {
+ params.fusion->side_input_buf = operand_buffers[3];
+ }
}
- } else {
- return InternalError("Unexpected custom call target: %s", target);
}
return params;
}
diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
index e3869b5c36..8f1f5a7bf5 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
@@ -105,38 +105,45 @@ static HloInstruction* PadInstruction(HloInstruction* instr,
// Pads the input/output feature dimensions of the given cudnn convolution
// custom-call to be multiples of kDesiredNumFeaturesFactor.
-static StatusOr<bool> PadFeaturesDims(HloInstruction* conv) {
+static StatusOr<bool> PadFeaturesDims(HloCustomCallInstruction* conv) {
CHECK_EQ(0, conv->shape().tuple_shapes(1).dimensions(0))
<< "conv must use 0 scratch bytes, i.e. this pass must be run "
"before CudnnConvolutionAlgorithmPicker.";
- const auto& target = conv->custom_call_target();
+ TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(conv));
const auto& dnums = conv->convolution_dimension_numbers();
auto* lhs = conv->mutable_operand(0);
auto* rhs = conv->mutable_operand(1);
const Shape& result_shape = conv->shape().tuple_shapes(0);
Shape new_lhs_shape = [&] {
- if (target == kCudnnConvForwardCallTarget ||
- target == kCudnnConvBackwardFilterCallTarget) {
- // LHS is "input".
- return PadShape(lhs->shape(), {dnums.input_feature_dimension()});
+ switch (kind) {
+ case CudnnConvKind::kForward:
+ case CudnnConvKind::kBackwardFilter:
+ // LHS is "input".
+ return PadShape(lhs->shape(), {dnums.input_feature_dimension()});
+ case CudnnConvKind::kBackwardInput:
+ // LHS is "output".
+ return PadShape(lhs->shape(), {dnums.output_feature_dimension()});
+ case CudnnConvKind::kForwardActivation:
+ LOG(FATAL) << "Not yet implemented.";
}
- CHECK_EQ(target, kCudnnConvBackwardInputCallTarget);
- // LHS is "output".
- return PadShape(lhs->shape(), {dnums.output_feature_dimension()});
}();
Shape new_rhs_shape = [&] {
- if (target == kCudnnConvForwardCallTarget ||
- target == kCudnnConvBackwardInputCallTarget) {
- // RHS is "filter".
- return PadShape(rhs->shape(), {dnums.kernel_input_feature_dimension(),
- dnums.kernel_output_feature_dimension()});
+ switch (kind) {
+ case CudnnConvKind::kForward:
+ case CudnnConvKind::kBackwardInput:
+ // RHS is "filter".
+ return PadShape(rhs->shape(),
+ {dnums.kernel_input_feature_dimension(),
+ dnums.kernel_output_feature_dimension()});
+ case CudnnConvKind::kBackwardFilter:
+ // RHS is "output".
+ return PadShape(rhs->shape(), {dnums.output_feature_dimension()});
+ case CudnnConvKind::kForwardActivation:
+ LOG(FATAL) << "Not yet implemented.";
}
- CHECK_EQ(target, kCudnnConvBackwardFilterCallTarget);
- // RHS is "output".
- return PadShape(rhs->shape(), {dnums.output_feature_dimension()});
}();
if (ShapeUtil::Equal(lhs->shape(), new_lhs_shape) &&
@@ -146,18 +153,21 @@ static StatusOr<bool> PadFeaturesDims(HloInstruction* conv) {
}
Shape new_result_shape = [&] {
- if (target == kCudnnConvForwardCallTarget) {
- // Result is "output".
- return PadShape(result_shape, {dnums.output_feature_dimension()});
+ switch (kind) {
+ case CudnnConvKind::kForward:
+ // Result is "output".
+ return PadShape(result_shape, {dnums.output_feature_dimension()});
+ case CudnnConvKind::kBackwardInput:
+ // Result is "input".
+ return PadShape(result_shape, {dnums.input_feature_dimension()});
+ case CudnnConvKind::kBackwardFilter:
+ // Result is "filter".
+ return PadShape(result_shape,
+ {dnums.kernel_input_feature_dimension(),
+ dnums.kernel_output_feature_dimension()});
+ case CudnnConvKind::kForwardActivation:
+ LOG(FATAL) << "Not yet implemented.";
}
- if (target == kCudnnConvBackwardInputCallTarget) {
- // Result is "input".
- return PadShape(result_shape, {dnums.input_feature_dimension()});
- }
- CHECK_EQ(target, kCudnnConvBackwardFilterCallTarget);
- // Result is "filter".
- return PadShape(result_shape, {dnums.kernel_input_feature_dimension(),
- dnums.kernel_output_feature_dimension()});
}();
// Check that padding wouldn't increase the total bytes read/written by this
@@ -223,16 +233,20 @@ static StatusOr<bool> PadFeaturesDims(HloInstruction* conv) {
return true;
}
-static std::vector<HloInstruction*> GetRelevantConvs(HloComputation* comp) {
- std::vector<HloInstruction*> convs;
+static std::vector<HloCustomCallInstruction*> GetRelevantConvs(
+ HloComputation* comp) {
+ std::vector<HloCustomCallInstruction*> convs;
for (HloInstruction* instr : comp->instructions()) {
- if (IsCustomCallToDnnConvolution(*instr) &&
- instr->operand(0)->shape().element_type() == F16 &&
+ if (!IsCustomCallToDnnConvolution(*instr)) {
+ continue;
+ }
+ auto* custom_call = Cast<HloCustomCallInstruction>(instr);
+ if (custom_call->operand(0)->shape().element_type() == F16 &&
// TODO(timshen): Disable for fused conv for now. Implement it if it's
// needed.
- Cast<HloCustomCallInstruction>(instr)->custom_call_target() !=
+ custom_call->custom_call_target() !=
kCudnnConvBiasActivationForwardCallTarget) {
- convs.push_back(instr);
+ convs.push_back(custom_call);
}
}
return convs;
@@ -241,7 +255,7 @@ static std::vector<HloInstruction*> GetRelevantConvs(HloComputation* comp) {
StatusOr<bool> PadForTensorCores::Run(HloModule* module) {
bool changed = false;
for (HloComputation* comp : module->MakeNonfusionComputations()) {
- for (HloInstruction* conv : GetRelevantConvs(comp)) {
+ for (HloCustomCallInstruction* conv : GetRelevantConvs(comp)) {
TF_ASSIGN_OR_RETURN(bool result, PadFeaturesDims(conv));
changed |= result;
}
diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
index b42a19e3a2..ae7abca7c6 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/util.h"
@@ -378,25 +379,25 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution(
StatusOr<bool> PadInsertion::RunOnComputation(HloComputation* computation) {
bool changed = false;
- std::vector<HloInstruction*> convs;
+ std::vector<HloCustomCallInstruction*> convs;
for (auto* instr : computation->instructions()) {
if (IsCustomCallToDnnConvolution(*instr)) {
- convs.push_back(instr);
+ convs.push_back(Cast<HloCustomCallInstruction>(instr));
}
}
- for (HloInstruction* instruction : convs) {
- const auto& target = instruction->custom_call_target();
- if (target == kCudnnConvForwardCallTarget ||
- target == kCudnnConvBiasActivationForwardCallTarget) {
- changed |= CanonicalizeForwardConvolution(instruction);
- } else if (target == kCudnnConvBackwardFilterCallTarget) {
- changed |= CanonicalizeBackwardFilterConvolution(instruction);
- } else if (target == kCudnnConvBackwardInputCallTarget) {
- changed |= CanonicalizeBackwardInputConvolution(instruction);
- } else {
- LOG(FATAL) << "Unknown custom call target for cudnn conv: "
- << instruction->ToString();
- }
+ for (HloCustomCallInstruction* instruction : convs) {
+ TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(instruction));
+ changed |= [&] {
+ switch (kind) {
+ case CudnnConvKind::kForward:
+ case CudnnConvKind::kForwardActivation:
+ return CanonicalizeForwardConvolution(instruction);
+ case CudnnConvKind::kBackwardInput:
+ return CanonicalizeBackwardInputConvolution(instruction);
+ case CudnnConvKind::kBackwardFilter:
+ return CanonicalizeBackwardFilterConvolution(instruction);
+ }
+ }();
}
return changed;
}