[MLIR][KernelGen] Fix unranked codegeneration in kernel generator
PiperOrigin-RevId: 335847086
This commit is contained in:
parent
1f14527609
commit
3736c5542f
|
@ -15,9 +15,9 @@ limitations under the License.
|
||||||
|
|
||||||
include "mlir/Pass/PassBase.td"
|
include "mlir/Pass/PassBase.td"
|
||||||
|
|
||||||
def TestChloLegalizeToHloPass : Pass<"mhlo-test-chlo-legalize-to-hlo", "FuncOp"> {
|
def ChloLegalizeToHloPass : Pass<"chlo-legalize-to-hlo", "FuncOp"> {
|
||||||
let summary = "Test pass for applying chlo -> hlo legalization patterns.";
|
let summary = "Legalize CHLO to HLO.";
|
||||||
let constructor = "createTestChloLegalizeToHloPass()";
|
let constructor = "createChloLegalizeToHloPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
def HloLegalizeToLhloPass : Pass<"hlo-legalize-to-lhlo", "ModuleOp"> {
|
def HloLegalizeToLhloPass : Pass<"hlo-legalize-to-lhlo", "ModuleOp"> {
|
||||||
|
|
|
@ -44,6 +44,9 @@ std::unique_ptr<OperationPass<FuncOp>> createControlFlowToScfPass();
|
||||||
/// Lowers from HLO dialect to Standard dialect.
|
/// Lowers from HLO dialect to Standard dialect.
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToStdPass();
|
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToStdPass();
|
||||||
|
|
||||||
|
/// Lowers from the CHLO dialect to the HLO dialect.
|
||||||
|
std::unique_ptr<FunctionPass> createChloLegalizeToHloPass();
|
||||||
|
|
||||||
/// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
|
/// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
|
||||||
/// buffers if necessary. If `results_escape_functions` is set to true,
|
/// buffers if necessary. If `results_escape_functions` is set to true,
|
||||||
/// allocated buffers for function results will be returned and escape the
|
/// allocated buffers for function results will be returned and escape the
|
||||||
|
|
|
@ -22,7 +22,6 @@ limitations under the License.
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace mhlo {
|
namespace mhlo {
|
||||||
|
|
||||||
std::unique_ptr<Pass> createTestChloLegalizeToHloPass();
|
|
||||||
std::unique_ptr<FunctionPass> createTestInferShapedTypeMethodsPass();
|
std::unique_ptr<FunctionPass> createTestInferShapedTypeMethodsPass();
|
||||||
std::unique_ptr<Pass> createTestMaterializeBroadcastsPass();
|
std::unique_ptr<Pass> createTestMaterializeBroadcastsPass();
|
||||||
std::unique_ptr<Pass> createTestUnfuseBatchNormPass();
|
std::unique_ptr<Pass> createTestUnfuseBatchNormPass();
|
||||||
|
|
|
@ -27,8 +27,8 @@ namespace mhlo {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
struct TestChloLegalizeToHloPass
|
struct ChloLegalizeToHloPass
|
||||||
: public PassWrapper<TestChloLegalizeToHloPass, FunctionPass> {
|
: public PassWrapper<ChloLegalizeToHloPass, FunctionPass> {
|
||||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
registry.insert<mhlo::MhloDialect, shape::ShapeDialect, scf::SCFDialect>();
|
registry.insert<mhlo::MhloDialect, shape::ShapeDialect, scf::SCFDialect>();
|
||||||
}
|
}
|
||||||
|
@ -36,11 +36,12 @@ struct TestChloLegalizeToHloPass
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
ConversionTarget conversionTarget(getContext());
|
ConversionTarget conversionTarget(getContext());
|
||||||
OwningRewritePatternList conversionPatterns;
|
OwningRewritePatternList conversionPatterns;
|
||||||
|
|
||||||
conversionTarget.addIllegalDialect<chlo::HloClientDialect>();
|
conversionTarget.addIllegalDialect<chlo::HloClientDialect>();
|
||||||
|
|
||||||
// Consider the mhlo dialect legal for tests.
|
// Consider the mhlo dialect legal for tests.
|
||||||
conversionTarget.addLegalDialect<mhlo::MhloDialect>();
|
conversionTarget.addLegalDialect<mhlo::MhloDialect>();
|
||||||
// The conversion uses helpers from the Standard dialect.
|
|
||||||
|
// The conversion uses helpers from the standard dialect.
|
||||||
conversionTarget.addLegalDialect<mlir::StandardOpsDialect>();
|
conversionTarget.addLegalDialect<mlir::StandardOpsDialect>();
|
||||||
conversionTarget.addLegalDialect<mlir::shape::ShapeDialect>();
|
conversionTarget.addLegalDialect<mlir::shape::ShapeDialect>();
|
||||||
conversionTarget.addLegalDialect<mlir::scf::SCFDialect>();
|
conversionTarget.addLegalDialect<mlir::scf::SCFDialect>();
|
||||||
|
@ -56,8 +57,8 @@ struct TestChloLegalizeToHloPass
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<FunctionPass> createTestChloLegalizeToHloPass() {
|
std::unique_ptr<FunctionPass> createChloLegalizeToHloPass() {
|
||||||
return std::make_unique<TestChloLegalizeToHloPass>();
|
return std::make_unique<ChloLegalizeToHloPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mhlo
|
} // namespace mhlo
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: mlir-hlo-opt -mhlo-test-chlo-legalize-to-hlo -cse -split-input-file -verify-diagnostics %s -o - | FileCheck %s
|
// RUN: mlir-hlo-opt -chlo-legalize-to-hlo -cse -split-input-file -verify-diagnostics %s -o - | FileCheck %s
|
||||||
|
|
||||||
// Check the non-broadcast case for each registered op, then just check a
|
// Check the non-broadcast case for each registered op, then just check a
|
||||||
// representative op for detailed broadcast semantics.
|
// representative op for detailed broadcast semantics.
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: mlir-hlo-opt --mhlo-test-chlo-legalize-to-hlo --split-input-file %s | FileCheck %s
|
// RUN: mlir-hlo-opt --chlo-legalize-to-hlo --split-input-file %s | FileCheck %s
|
||||||
|
|
||||||
// Lower statically shaped `constant_like` to constant.
|
// Lower statically shaped `constant_like` to constant.
|
||||||
// CHECK-LABEL: @constant_like_static_shape
|
// CHECK-LABEL: @constant_like_static_shape
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: mlir-hlo-opt %s -mhlo-test-chlo-legalize-to-hlo -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement -copy-removal -canonicalize -cse -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops -canonicalize -cse -convert-linalg-to-llvm -test-lhlo-legalize-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s
|
// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement -copy-removal -canonicalize -cse -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops -canonicalize -cse -convert-linalg-to-llvm -test-lhlo-legalize-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s
|
||||||
|
|
||||||
func @main() -> () {
|
func @main() -> () {
|
||||||
call @trivial_broadcast_wrapper() : () -> ()
|
call @trivial_broadcast_wrapper() : () -> ()
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: mlir-hlo-opt %s -mhlo-test-chlo-legalize-to-hlo -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement -copy-removal -canonicalize -cse -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops -convert-scf-to-std -canonicalize -cse -test-lhlo-legalize-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s
|
// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement -copy-removal -canonicalize -cse -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops -convert-scf-to-std -canonicalize -cse -test-lhlo-legalize-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s
|
||||||
|
|
||||||
func @main() -> () {
|
func @main() -> () {
|
||||||
call @reshape_with_static_shape_size_matrix_to_1D() : () -> ()
|
call @reshape_with_static_shape_size_matrix_to_1D() : () -> ()
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: mlir-hlo-opt %s -mhlo-test-chlo-legalize-to-hlo -mhlo-test-lower-complex | FileCheck %s
|
// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo -mhlo-test-lower-complex | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: @add
|
// CHECK-LABEL: @add
|
||||||
func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
|
func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
|
||||||
|
|
Loading…
Reference in New Issue