Add support for lowering mhlo.iota to Linalg.
PiperOrigin-RevId: 322799853
This commit is contained in:
parent
4251630426
commit
8f262ae8f5
|
@ -640,25 +640,25 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class IotaConverter : public OpConversionPattern<lmhlo::IotaOp> {
|
template <typename OpTy, bool isLHLO = true>
|
||||||
|
class IotaConverter : public OpConversionPattern<OpTy> {
|
||||||
public:
|
public:
|
||||||
using OpConversionPattern<lmhlo::IotaOp>::OpConversionPattern;
|
using OpConversionPattern<OpTy>::OpConversionPattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(
|
LogicalResult matchAndRewrite(
|
||||||
lmhlo::IotaOp iotaOp, ArrayRef<Value> args,
|
OpTy iotaOp, ArrayRef<Value> args,
|
||||||
ConversionPatternRewriter& rewriter) const final {
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
auto resultMemrefType =
|
ShapedType resultShapedType = getHloOpResultType<isLHLO>(iotaOp);
|
||||||
iotaOp.getOperand().getType().dyn_cast<MemRefType>();
|
if (!resultShapedType) return failure();
|
||||||
if (!resultMemrefType) return failure();
|
|
||||||
|
|
||||||
auto resultElementType = resultMemrefType.getElementType();
|
auto resultElementType = resultShapedType.getElementType();
|
||||||
if (!resultElementType.isSignlessIntOrFloat()) return failure();
|
if (!resultElementType.isSignlessIntOrFloat()) return failure();
|
||||||
|
|
||||||
// Construct the indexing maps needed for linalg.generic ops.
|
// Construct the indexing maps needed for linalg.generic ops.
|
||||||
unsigned nloops = resultMemrefType.getRank();
|
unsigned nloops = resultShapedType.getRank();
|
||||||
|
|
||||||
rewriter.create<linalg::IndexedGenericOp>(
|
auto linalgOp = rewriter.create<linalg::IndexedGenericOp>(
|
||||||
iotaOp.getLoc(), ArrayRef<Type>{}, args,
|
iotaOp.getLoc(), isLHLO ? ArrayRef<Type>{} : resultShapedType, args,
|
||||||
0, // args_in
|
0, // args_in
|
||||||
1, // args_out
|
1, // args_out
|
||||||
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
|
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
|
||||||
|
@ -669,14 +669,16 @@ class IotaConverter : public OpConversionPattern<lmhlo::IotaOp> {
|
||||||
nestedLoc, ivs[iotaOp.iota_dimension().getZExtValue()],
|
nestedLoc, ivs[iotaOp.iota_dimension().getZExtValue()],
|
||||||
nestedBuilder.getIntegerType(
|
nestedBuilder.getIntegerType(
|
||||||
resultElementType.getIntOrFloatBitWidth()));
|
resultElementType.getIntOrFloatBitWidth()));
|
||||||
if (resultElementType.isa<FloatType>()) {
|
if (resultElementType.template isa<FloatType>()) {
|
||||||
castOp = nestedBuilder.create<SIToFPOp>(nestedLoc, castOp,
|
castOp = nestedBuilder.create<SIToFPOp>(nestedLoc, castOp,
|
||||||
resultElementType);
|
resultElementType);
|
||||||
}
|
}
|
||||||
nestedBuilder.create<linalg::YieldOp>(nestedLoc, castOp);
|
nestedBuilder.create<linalg::YieldOp>(nestedLoc, castOp);
|
||||||
});
|
});
|
||||||
|
if (isLHLO)
|
||||||
rewriter.replaceOp(iotaOp, llvm::None);
|
rewriter.replaceOp(iotaOp, llvm::None);
|
||||||
|
else
|
||||||
|
rewriter.replaceOp(iotaOp, linalgOp.output_tensors());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -768,7 +770,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
patterns->insert<BroadcastConverter<lmhlo::BroadcastOp>,
|
patterns->insert<BroadcastConverter<lmhlo::BroadcastOp>,
|
||||||
ConstConverter,
|
ConstConverter,
|
||||||
ConvToLinalgConverter,
|
ConvToLinalgConverter,
|
||||||
IotaConverter,
|
IotaConverter<lmhlo::IotaOp>,
|
||||||
LhloBroadcastInDimConverter,
|
LhloBroadcastInDimConverter,
|
||||||
PointwiseToLinalgConverter<lmhlo::AbsOp>,
|
PointwiseToLinalgConverter<lmhlo::AbsOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::AddOp>,
|
PointwiseToLinalgConverter<lmhlo::AddOp>,
|
||||||
|
@ -870,36 +872,37 @@ namespace mhlo {
|
||||||
|
|
||||||
void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
OwningRewritePatternList* patterns) {
|
OwningRewritePatternList* patterns) {
|
||||||
patterns->insert<BroadcastConverter<mhlo::BroadcastOp, false>,
|
patterns
|
||||||
HloBroadcastInDimConverter,
|
->insert<BroadcastConverter<mhlo::BroadcastOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::AbsOp, false>,
|
HloBroadcastInDimConverter, IotaConverter<mhlo::IotaOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::AddOp, false>,
|
PointwiseToLinalgConverter<mhlo::AbsOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::AndOp, false>,
|
PointwiseToLinalgConverter<mhlo::AddOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::CeilOp, false>,
|
PointwiseToLinalgConverter<mhlo::AndOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::CompareOp, false>,
|
PointwiseToLinalgConverter<mhlo::CeilOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::ComplexOp, false>,
|
PointwiseToLinalgConverter<mhlo::CompareOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::ConvertOp, false>,
|
PointwiseToLinalgConverter<mhlo::ComplexOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::CopyOp, false>,
|
PointwiseToLinalgConverter<mhlo::ConvertOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::CosOp, false>,
|
PointwiseToLinalgConverter<mhlo::CopyOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::DivOp, false>,
|
PointwiseToLinalgConverter<mhlo::CosOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::ExpOp, false>,
|
PointwiseToLinalgConverter<mhlo::DivOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::ImagOp, false>,
|
PointwiseToLinalgConverter<mhlo::ExpOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::LogOp, false>,
|
PointwiseToLinalgConverter<mhlo::ImagOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::MaxOp, false>,
|
PointwiseToLinalgConverter<mhlo::LogOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::MinOp, false>,
|
PointwiseToLinalgConverter<mhlo::MaxOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::MulOp, false>,
|
PointwiseToLinalgConverter<mhlo::MinOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::NegOp, false>,
|
PointwiseToLinalgConverter<mhlo::MulOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::RealOp, false>,
|
PointwiseToLinalgConverter<mhlo::NegOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::RemOp, false>,
|
PointwiseToLinalgConverter<mhlo::RealOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
|
PointwiseToLinalgConverter<mhlo::RemOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::SelectOp, false>,
|
PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::SinOp, false>,
|
PointwiseToLinalgConverter<mhlo::SelectOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::SqrtOp, false>,
|
PointwiseToLinalgConverter<mhlo::SinOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::SubOp, false>,
|
PointwiseToLinalgConverter<mhlo::SqrtOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::TanhOp, false>,
|
PointwiseToLinalgConverter<mhlo::SubOp, false>,
|
||||||
ReshapeOpConverter<mhlo::ReshapeOp, false>,
|
PointwiseToLinalgConverter<mhlo::TanhOp, false>,
|
||||||
ReverseConverter<mhlo::ReverseOp, false>,
|
ReshapeOpConverter<mhlo::ReshapeOp, false>,
|
||||||
TransposeConverter<mhlo::TransposeOp, false>>(context);
|
ReverseConverter<mhlo::ReverseOp, false>,
|
||||||
|
TransposeConverter<mhlo::TransposeOp, false>>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() {
|
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() {
|
||||||
|
|
|
@ -557,3 +557,18 @@ func @reverse(%input: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||||
}
|
}
|
||||||
// CHECK: linalg.generic
|
// CHECK: linalg.generic
|
||||||
// CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
// CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||||
|
// CHECK-LABEL: func @iota
|
||||||
|
func @iota() -> tensor<7x10xf32> {
|
||||||
|
%result = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> (tensor<7x10xf32>)
|
||||||
|
return %result : tensor<7x10xf32>
|
||||||
|
}
|
||||||
|
// CHECK: linalg.indexed_generic
|
||||||
|
// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]]
|
||||||
|
// CHECK-NEXT: ^bb0(%[[D0:.*]]: index, %[[D1:.*]]: index):
|
||||||
|
// CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[D1]] : index to i32
|
||||||
|
// CHECK-NEXT: %[[FLOAT_CAST:.*]] = sitofp %[[INT_CAST]] : i32 to f32
|
||||||
|
// CHECK-NEXT: linalg.yield %[[FLOAT_CAST]] : f32
|
||||||
|
|
Loading…
Reference in New Issue