[KERNEL_GEN] Add a pattern for hlo.dyn_broadcast->linalg to enable is_inf kernel.
PiperOrigin-RevId: 351179620
This commit is contained in:
parent
ecf1bf5132
commit
180f917446
1
BUILD
1
BUILD
|
@ -668,6 +668,7 @@ cc_library(
|
||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
"@llvm-project//mlir:SCFDialect",
|
"@llvm-project//mlir:SCFDialect",
|
||||||
"@llvm-project//mlir:StandardOps",
|
"@llvm-project//mlir:StandardOps",
|
||||||
|
"@llvm-project//mlir:TensorDialect",
|
||||||
"@llvm-project//mlir:Transforms",
|
"@llvm-project//mlir:Transforms",
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
|
|
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||||
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
||||||
#include "mlir/Dialect/SCF/SCF.h"
|
#include "mlir/Dialect/SCF/SCF.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/AffineExpr.h"
|
#include "mlir/IR/AffineExpr.h"
|
||||||
#include "mlir/IR/Attributes.h"
|
#include "mlir/IR/Attributes.h"
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
|
@ -437,6 +438,55 @@ class HloBroadcastInDimConverter
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class HloDynamicBroadcastInDimConverter
|
||||||
|
: public OpConversionPattern<mhlo::DynamicBroadcastInDimOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern<mhlo::DynamicBroadcastInDimOp>::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
mhlo::DynamicBroadcastInDimOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
|
// Convert only if the producer is an HLO constant. Ideally the pattern
|
||||||
|
// (`mhlo.constant` -> `mhlo.dynamic_broadcast_in_dim`) should be converted
|
||||||
|
// to an Tensor-dialect op similar to TF ConstantLikeOp.
|
||||||
|
if (!op.operand().getDefiningOp<mhlo::ConstOp>()) return failure();
|
||||||
|
|
||||||
|
mhlo::DynamicBroadcastInDimOp::Adaptor adaptor(op);
|
||||||
|
Value operand = adaptor.operand();
|
||||||
|
auto operand_type = operand.getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!operand_type || operand_type.getRank() != 0) return failure();
|
||||||
|
|
||||||
|
Value shape = adaptor.output_dimensions();
|
||||||
|
auto shape_type = shape.getType().cast<RankedTensorType>();
|
||||||
|
int64_t result_rank = shape_type.getDimSize(0);
|
||||||
|
|
||||||
|
SmallVector<Value, 2> dyn_dims;
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
for (int i = 0; i < result_rank; ++i) {
|
||||||
|
Value index = rewriter.create<ConstantIndexOp>(loc, i);
|
||||||
|
dyn_dims.push_back(rewriter.create<tensor::ExtractOp>(loc, shape, index));
|
||||||
|
}
|
||||||
|
auto result_type = op.getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
|
int64_t nloops = result_type.getRank();
|
||||||
|
Value init = rewriter.create<linalg::InitTensorOp>(
|
||||||
|
loc, dyn_dims, result_type.getShape(), result_type.getElementType());
|
||||||
|
Operation* generic = rewriter.create<linalg::GenericOp>(
|
||||||
|
loc, TypeRange{init.getType()}, ValueRange{operand},
|
||||||
|
/*outputBuffers=*/ValueRange{init},
|
||||||
|
llvm::makeArrayRef(
|
||||||
|
{AffineMap::get(/*dimCount=*/nloops, /*symbolCount=*/0, {},
|
||||||
|
rewriter.getContext()),
|
||||||
|
rewriter.getMultiDimIdentityMap(nloops)}),
|
||||||
|
GetNParallelLoopsAttrs(nloops),
|
||||||
|
[&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
|
||||||
|
nested_builder.create<linalg::YieldOp>(loc, *args.begin());
|
||||||
|
});
|
||||||
|
rewriter.replaceOp(op, generic->getResults());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
class LhloBroadcastInDimConverter
|
class LhloBroadcastInDimConverter
|
||||||
: public OpConversionPattern<lmhlo::BroadcastInDimOp> {
|
: public OpConversionPattern<lmhlo::BroadcastInDimOp> {
|
||||||
public:
|
public:
|
||||||
|
@ -1067,7 +1117,7 @@ struct HloLegalizeToLinalgPass
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
ConversionTarget target(getContext());
|
ConversionTarget target(getContext());
|
||||||
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
|
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
|
||||||
scf::SCFDialect>();
|
tensor::TensorDialect, scf::SCFDialect>();
|
||||||
|
|
||||||
auto func = getFunction();
|
auto func = getFunction();
|
||||||
mhlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns);
|
mhlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns);
|
||||||
|
@ -1091,8 +1141,8 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
OwningRewritePatternList* patterns) {
|
OwningRewritePatternList* patterns) {
|
||||||
patterns
|
patterns
|
||||||
->insert<BroadcastConverter<mhlo::BroadcastOp, false>,
|
->insert<BroadcastConverter<mhlo::BroadcastOp, false>,
|
||||||
ConstConverter<mhlo::ConstOp>, HloBroadcastInDimConverter,
|
ConstConverter<mhlo::ConstOp>, HloDynamicBroadcastInDimConverter,
|
||||||
IotaConverter<mhlo::IotaOp, false>,
|
HloBroadcastInDimConverter, IotaConverter<mhlo::IotaOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::AbsOp, false>,
|
PointwiseToLinalgConverter<mhlo::AbsOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::AddOp, false>,
|
PointwiseToLinalgConverter<mhlo::AddOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::AndOp, false>,
|
PointwiseToLinalgConverter<mhlo::AndOp, false>,
|
||||||
|
|
|
@ -808,3 +808,25 @@ func @integer_pow(%lhs: tensor<2x2xi32>,
|
||||||
tensor<2x2xi32>) -> tensor<2x2xi32>
|
tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||||
return %0 : tensor<2x2xi32>
|
return %0 : tensor<2x2xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0) -> ()>
|
||||||
|
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0) -> (d0)>
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @dynamic_broadcast_in_dim(
|
||||||
|
// CHECK-SAME: [[SHAPE:%.*]]: tensor<1xindex>
|
||||||
|
func @dynamic_broadcast_in_dim(%shape: tensor<1xindex>) -> tensor<?xf32> {
|
||||||
|
%cst = mhlo.constant dense<0x7F800000> : tensor<f32>
|
||||||
|
%result = "mhlo.dynamic_broadcast_in_dim"(%cst, %shape) {
|
||||||
|
broadcast_dimensions = dense<> : tensor<0xi64>
|
||||||
|
} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
|
return %result : tensor<?xf32>
|
||||||
|
}
|
||||||
|
// CHECK: [[CST:%.*]] = constant
|
||||||
|
// CHECK: [[INIT:%.*]] = linalg.init_tensor
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
||||||
|
// CHECK-SAME: ins([[CST]] : tensor<f32>) outs([[INIT]] : tensor<?xf32>)
|
||||||
|
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32):
|
||||||
|
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
|
||||||
|
|
Loading…
Reference in New Issue