aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
diff options
context:
space:
mode:
authorGravatar mbhuiyan <mohammad.ashraf.bhuiyan@intel.com>2018-05-16 10:49:29 -0700
committerGravatar mbhuiyan <mohammad.ashraf.bhuiyan@intel.com>2018-05-16 10:49:29 -0700
commit2acf23109aabb2952ce73dee89fe1e63b0e80961 (patch)
tree54724426fcf6d8d9a5dab57862ae749997dc5fd5 /tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
parent7a667f694fc25691d1093019a6fe4e0cd32fd344 (diff)
parent383e6d48dfd5037bcb5d56937366f1ba12b9a67d (diff)
resolving the conflict while merging master
Diffstat (limited to 'tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc')
-rw-r--r--tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc54
1 files changed, 38 insertions, 16 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
index 8db4a0650d..5cdfc110af 100644
--- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "llvm/IR/Module.h"
#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
+#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
#include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@@ -541,7 +542,7 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot,
hlo_module_config_(hlo_module_config),
target_machine_features_(target_machine_features) {}
-/* static */ tensorflow::Status DotOpEmitter::EmitDotOperation(
+/* static */ Status DotOpEmitter::EmitDotOperation(
const HloInstruction& dot, const llvm_ir::IrArray& target_array,
const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array,
const llvm_ir::IrArray* addend_array,
@@ -690,7 +691,7 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() {
return true;
}
-tensorflow::Status DotOpEmitter::Emit() {
+Status DotOpEmitter::Emit() {
// The dot operation performs a sum of products over dimension 0 of the left
// hand side operand and dimension 1 of the right hand side operand.
//
@@ -734,7 +735,7 @@ tensorflow::Status DotOpEmitter::Emit() {
CHECK_EQ(addend_array_, nullptr);
- if (PotentiallyImplementedAsEigenDot(dot_)) {
+ if (PotentiallyImplementedAsEigenDot(dot_, target_machine_features_)) {
return EmitCallToRuntime();
}
@@ -868,10 +869,10 @@ tensorflow::Status DotOpEmitter::Emit() {
// loop.
ir_builder_->SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock());
- return tensorflow::Status::OK();
+ return Status::OK();
}
-tensorflow::Status DotOpEmitter::EmitScalarDot() {
+Status DotOpEmitter::EmitScalarDot() {
// A scalar dot is just a scalar multiply.
llvm::Value* result;
llvm::Value* lhs_value =
@@ -896,10 +897,10 @@ tensorflow::Status DotOpEmitter::EmitScalarDot() {
result = ir_builder_->CreateFMul(lhs_value, rhs_value);
}
target_array_.EmitWriteArrayElement(/*index=*/{}, result, ir_builder_);
- return tensorflow::Status::OK();
+ return Status::OK();
}
-tensorflow::Status DotOpEmitter::EmitCallToRuntime() {
+Status DotOpEmitter::EmitCallToRuntime() {
// The signature of the Eigen runtime matmul function is:
//
// (void)(void* run_options, float* out, float* lhs, float* rhs,
@@ -1001,7 +1002,7 @@ tensorflow::Status DotOpEmitter::EmitCallToRuntime() {
ir_builder_->getInt64(mat_mult_dims.k),
ir_builder_->getInt32(transpose_lhs),
ir_builder_->getInt32(transpose_rhs)});
- return tensorflow::Status::OK();
+ return Status::OK();
}
DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const {
@@ -1058,19 +1059,39 @@ static bool IsRank2WithNoPadding(const Shape& shape) {
// In a gemm operation where output = lhs * rhs, check whether the given shapes
// are valid for the operation.
-static bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
- const Shape& output_shape) {
+static bool AreValidGemmShapes(
+ const Shape& lhs_shape, const Shape& rhs_shape, const Shape& output_shape,
+ const TargetMachineFeatures& target_machine_features) {
// The inputs and the output must
// 1) be matrices with no padding, and
// 2) have an allowed element type.
PrimitiveType output_primitive_type = output_shape.element_type();
- return (output_primitive_type == F64 || output_primitive_type == F32 ||
- output_primitive_type == F16) &&
- IsRank2WithNoPadding(lhs_shape) && IsRank2WithNoPadding(rhs_shape) &&
- IsRank2WithNoPadding(output_shape);
+ if (!(output_primitive_type == F64 || output_primitive_type == F32 ||
+ output_primitive_type == F16)) {
+ return false;
+ }
+
+ if (!(IsRank2WithNoPadding(lhs_shape) && IsRank2WithNoPadding(rhs_shape) &&
+ IsRank2WithNoPadding(output_shape))) {
+ return false;
+ }
+
+ auto is_aligned = [&](const Shape& shape) {
+ return GetMinimumAlignmentForArray(shape, target_machine_features) >=
+ TargetMachineFeatures::kEigenExpectedTensorAlignment;
+ };
+
+ if (!is_aligned(lhs_shape) || !is_aligned(rhs_shape) ||
+ !is_aligned(output_shape)) {
+ return false;
+ }
+
+ return true;
}
-bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) {
+bool PotentiallyImplementedAsEigenDot(
+ const HloInstruction& hlo,
+ const TargetMachineFeatures& target_machine_features) {
// For certain types of Dot, we can call Eigen
if (hlo.opcode() == HloOpcode::kDot) {
const Shape& lhs_shape = hlo.operand(0)->shape();
@@ -1087,7 +1108,8 @@ bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) {
// If gemm can accept the operand shapes, use it rather than a custom
// kernel.
- if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape())) {
+ if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape(),
+ target_machine_features)) {
const DotDimensionNumbers& dim_numbers = hlo.dot_dimension_numbers();
// The size of the reduction dimension should match. The shape inference
// guarantees this invariant, so the check here is for programming