Upstream mhlo.dot lowering to Linalg to MHLO repo.
We prototyped the lowering from mhlo.dot to linalg.matmul in IREE. Since Linalg now supports matmul in tensors world, we can move the lowering logic to tensors world, and upstream to legalize_to_linalg.cc. The patch lowers the mhlo.dot to the linalg.matmul/matvec/dot in tensors world. PiperOrigin-RevId: 351184911
This commit is contained in:
parent
180f917446
commit
8f58f844e5
|
@ -1014,6 +1014,123 @@ class SliceConverter : public OpConversionPattern<lmhlo::SliceOp> {
|
|||
}
|
||||
};
|
||||
|
||||
enum class DotOperationType {
|
||||
kVectorDot = 0,
|
||||
kMatrixVector = 1,
|
||||
kMatrixMatrix = 2,
|
||||
kUnsupported = 3
|
||||
};
|
||||
|
||||
DotOperationType GetDotOperationType(mhlo::DotOp dot_op) {
|
||||
ArrayRef<int64_t> lhs_shape =
|
||||
dot_op.lhs().getType().cast<ShapedType>().getShape();
|
||||
ArrayRef<int64_t> rhs_shape =
|
||||
dot_op.rhs().getType().cast<ShapedType>().getShape();
|
||||
auto shape_matches = [](int64_t a, int64_t b) {
|
||||
return a == ShapedType::kDynamicSize || b == ShapedType::kDynamicSize ||
|
||||
a == b;
|
||||
};
|
||||
if (lhs_shape.size() == 1 && rhs_shape.size() == 1 &&
|
||||
shape_matches(lhs_shape[0], rhs_shape[0])) {
|
||||
return DotOperationType::kVectorDot;
|
||||
}
|
||||
if (lhs_shape.size() == 2 && rhs_shape.size() == 1 &&
|
||||
shape_matches(lhs_shape[1], rhs_shape[0])) {
|
||||
return DotOperationType::kMatrixVector;
|
||||
}
|
||||
if (rhs_shape.size() == 2 && rhs_shape.size() == 2 &&
|
||||
shape_matches(lhs_shape[1], rhs_shape[0])) {
|
||||
return DotOperationType::kMatrixMatrix;
|
||||
}
|
||||
return DotOperationType::kUnsupported;
|
||||
}
|
||||
|
||||
SmallVector<Value, 8> GetDotOpInitTensorDynSizes(OpBuilder& b, Location loc,
|
||||
Value lhs, Value rhs,
|
||||
ShapedType result_type,
|
||||
DotOperationType type) {
|
||||
SmallVector<Value, 8> dyn_shape;
|
||||
switch (type) {
|
||||
case DotOperationType::kMatrixMatrix: {
|
||||
if (result_type.isDynamicDim(0))
|
||||
dyn_shape.push_back(b.create<DimOp>(loc, lhs, 0));
|
||||
if (result_type.isDynamicDim(1))
|
||||
dyn_shape.push_back(b.create<DimOp>(loc, rhs, 1));
|
||||
break;
|
||||
}
|
||||
case DotOperationType::kMatrixVector: {
|
||||
if (result_type.isDynamicDim(0))
|
||||
dyn_shape.push_back(b.create<DimOp>(loc, lhs, 0));
|
||||
break;
|
||||
}
|
||||
case DotOperationType::kVectorDot:
|
||||
case DotOperationType::kUnsupported:
|
||||
default: {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return dyn_shape;
|
||||
}
|
||||
|
||||
class DotOpOnTensorsConversion : public OpConversionPattern<mhlo::DotOp> {
|
||||
public:
|
||||
using OpConversionPattern<mhlo::DotOp>::OpConversionPattern;
|
||||
LogicalResult matchAndRewrite(
|
||||
mhlo::DotOp op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
if (!VerifyHloOpBufferOrTensorSemantics</*isLHLO=*/false>(op)) {
|
||||
return failure();
|
||||
}
|
||||
Location loc = op.getLoc();
|
||||
mhlo::DotOp::Adaptor adaptor(args);
|
||||
Type result_type = op.getResult().getType();
|
||||
auto shaped_type = result_type.cast<ShapedType>();
|
||||
DotOperationType op_type = GetDotOperationType(op);
|
||||
SmallVector<Value, 8> dyn_shape = GetDotOpInitTensorDynSizes(
|
||||
rewriter, loc, adaptor.lhs(), adaptor.rhs(), shaped_type, op_type);
|
||||
auto zero_attr = rewriter.getZeroAttr(shaped_type.getElementType());
|
||||
Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
|
||||
auto init_tensor = rewriter.create<DynamicTensorFromElementsOp>(
|
||||
loc, result_type, dyn_shape);
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
SmallVector<Type, 4> arg_types(shaped_type.getRank(),
|
||||
rewriter.getIndexType());
|
||||
Region& region = init_tensor.body();
|
||||
Block* block = rewriter.createBlock(®ion, region.begin(), arg_types);
|
||||
rewriter.setInsertionPointToEnd(block);
|
||||
rewriter.create<YieldOp>(loc, zero);
|
||||
}
|
||||
linalg::LinalgOp linalg_op;
|
||||
switch (op_type) {
|
||||
case DotOperationType::kMatrixMatrix: {
|
||||
linalg_op = rewriter.create<linalg::MatmulOp>(
|
||||
loc, TypeRange{result_type},
|
||||
ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{init_tensor});
|
||||
break;
|
||||
}
|
||||
case DotOperationType::kMatrixVector: {
|
||||
linalg_op = rewriter.create<linalg::MatvecOp>(
|
||||
loc, TypeRange{result_type},
|
||||
ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{init_tensor});
|
||||
break;
|
||||
}
|
||||
case DotOperationType::kVectorDot: {
|
||||
linalg_op = rewriter.create<linalg::DotOp>(
|
||||
loc, TypeRange{result_type},
|
||||
ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{init_tensor});
|
||||
break;
|
||||
}
|
||||
case DotOperationType::kUnsupported:
|
||||
default: {
|
||||
return op.emitError("unsupported dot operation type");
|
||||
}
|
||||
}
|
||||
rewriter.replaceOp(op, linalg_op->getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
||||
OwningRewritePatternList* patterns) {
|
||||
// clang-format off
|
||||
|
@ -1181,7 +1298,8 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
|||
PointwiseToLinalgConverter<mhlo::XorOp, false>,
|
||||
ReshapeOpConverter<mhlo::ReshapeOp, false>,
|
||||
ReverseConverter<mhlo::ReverseOp, false>,
|
||||
TransposeConverter<mhlo::TransposeOp, false>>(context);
|
||||
TransposeConverter<mhlo::TransposeOp, false>,
|
||||
DotOpOnTensorsConversion>(context);
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() {
|
||||
|
|
|
@ -830,3 +830,44 @@ func @dynamic_broadcast_in_dim(%shape: tensor<1xindex>) -> tensor<?xf32> {
|
|||
// CHECK-SAME: ins([[CST]] : tensor<f32>) outs([[INIT]] : tensor<?xf32>)
|
||||
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32):
|
||||
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
|
||||
|
||||
// -----
|
||||
|
||||
func @dot_matmul(%arg0: tensor<2x3xf32>,
|
||||
%arg1: tensor<3x?xf32>) -> tensor<2x?xf32> {
|
||||
%0 = "mhlo.dot"(%arg0, %arg1) : (tensor<2x3xf32>,
|
||||
tensor<3x?xf32>) -> tensor<2x?xf32>
|
||||
return %0 : tensor<2x?xf32>
|
||||
}
|
||||
// CHECK: func @dot_matmul(%[[ARG0:.*]]: tensor<2x3xf32>, %[[ARG1:.*]]: tensor<3x?xf32>)
|
||||
// CHECK: %[[INIT:.*]] = dynamic_tensor_from_elements
|
||||
// CHECK: linalg.matmul
|
||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xf32>, tensor<3x?xf32>)
|
||||
// CHECK-SAME: outs(%[[INIT]] : tensor<2x?xf32>)
|
||||
|
||||
// -----
|
||||
|
||||
func @dot_matvec(%arg0: tensor<?x3xf32>,
|
||||
%arg1: tensor<3xf32>) -> tensor<?xf32> {
|
||||
%0 = "mhlo.dot"(%arg0, %arg1) : (tensor<?x3xf32>,
|
||||
tensor<3xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
// CHECK: func @dot_matvec(%[[ARG0:.*]]: tensor<?x3xf32>, %[[ARG1:.*]]: tensor<3xf32>)
|
||||
// CHECK: %[[INIT:.*]] = dynamic_tensor_from_elements
|
||||
// CHECK: linalg.matvec
|
||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x3xf32>, tensor<3xf32>)
|
||||
// CHECK-SAME: outs(%[[INIT]] : tensor<?xf32>)
|
||||
|
||||
// -----
|
||||
|
||||
func @dot_dot(%arg0: tensor<?xf32>,
|
||||
%arg1: tensor<?xf32>) -> tensor<f32> {
|
||||
%0 = "mhlo.dot"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xf32>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
// CHECK: func @dot_dot(%[[ARG0:.*]]: tensor<?xf32>, %[[ARG1:.*]]: tensor<?xf32>)
|
||||
// CHECK: %[[INIT:.*]] = dynamic_tensor_from_elements
|
||||
// CHECK: linalg.dot
|
||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?xf32>, tensor<?xf32>)
|
||||
// CHECK-SAME: outs(%[[INIT]] : tensor<f32>)
|
||||
|
|
Loading…
Reference in New Issue