[MLIR][KernelGen] Fix unranked codegeneration in kernel generator

PiperOrigin-RevId: 335847086
This commit is contained in:
A. Unique TensorFlower 2020-10-07 05:39:04 -07:00 committed by TensorFlow MLIR Team
parent 1f14527609
commit 3736c5542f
9 changed files with 18 additions and 15 deletions

View File

@ -15,9 +15,9 @@ limitations under the License.
include "mlir/Pass/PassBase.td"
def TestChloLegalizeToHloPass : Pass<"mhlo-test-chlo-legalize-to-hlo", "FuncOp"> {
let summary = "Test pass for applying chlo -> hlo legalization patterns.";
let constructor = "createTestChloLegalizeToHloPass()";
def ChloLegalizeToHloPass : Pass<"chlo-legalize-to-hlo", "FuncOp"> {
let summary = "Legalize CHLO to HLO.";
let constructor = "createChloLegalizeToHloPass()";
}
def HloLegalizeToLhloPass : Pass<"hlo-legalize-to-lhlo", "ModuleOp"> {

View File

@ -44,6 +44,9 @@ std::unique_ptr<OperationPass<FuncOp>> createControlFlowToScfPass();
/// Lowers from HLO dialect to Standard dialect.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToStdPass();
/// Lowers from the CHLO dialect to the HLO dialect.
std::unique_ptr<FunctionPass> createChloLegalizeToHloPass();
/// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
/// buffers if necessary. If `results_escape_functions` is set to true,
/// allocated buffers for function results will be returned and escape the

View File

@ -22,7 +22,6 @@ limitations under the License.
namespace mlir {
namespace mhlo {
std::unique_ptr<Pass> createTestChloLegalizeToHloPass();
std::unique_ptr<FunctionPass> createTestInferShapedTypeMethodsPass();
std::unique_ptr<Pass> createTestMaterializeBroadcastsPass();
std::unique_ptr<Pass> createTestUnfuseBatchNormPass();

View File

@ -27,8 +27,8 @@ namespace mhlo {
namespace {
struct TestChloLegalizeToHloPass
: public PassWrapper<TestChloLegalizeToHloPass, FunctionPass> {
struct ChloLegalizeToHloPass
: public PassWrapper<ChloLegalizeToHloPass, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<mhlo::MhloDialect, shape::ShapeDialect, scf::SCFDialect>();
}
@ -36,11 +36,12 @@ struct TestChloLegalizeToHloPass
void runOnFunction() override {
ConversionTarget conversionTarget(getContext());
OwningRewritePatternList conversionPatterns;
conversionTarget.addIllegalDialect<chlo::HloClientDialect>();
// Consider the mhlo dialect legal for tests.
conversionTarget.addLegalDialect<mhlo::MhloDialect>();
// The conversion uses helpers from the Standard dialect.
// The conversion uses helpers from the standard dialect.
conversionTarget.addLegalDialect<mlir::StandardOpsDialect>();
conversionTarget.addLegalDialect<mlir::shape::ShapeDialect>();
conversionTarget.addLegalDialect<mlir::scf::SCFDialect>();
@ -56,8 +57,8 @@ struct TestChloLegalizeToHloPass
} // namespace
std::unique_ptr<FunctionPass> createTestChloLegalizeToHloPass() {
return std::make_unique<TestChloLegalizeToHloPass>();
std::unique_ptr<FunctionPass> createChloLegalizeToHloPass() {
return std::make_unique<ChloLegalizeToHloPass>();
}
} // namespace mhlo

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt -mhlo-test-chlo-legalize-to-hlo -cse -split-input-file -verify-diagnostics %s -o - | FileCheck %s
// RUN: mlir-hlo-opt -chlo-legalize-to-hlo -cse -split-input-file -verify-diagnostics %s -o - | FileCheck %s
// Check the non-broadcast case for each registered op, then just check a
// representative op for detailed broadcast semantics.

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt --mhlo-test-chlo-legalize-to-hlo --split-input-file %s | FileCheck %s
// RUN: mlir-hlo-opt --chlo-legalize-to-hlo --split-input-file %s | FileCheck %s
// Lower statically shaped `constant_like` to constant.
// CHECK-LABEL: @constant_like_static_shape

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt %s -mhlo-test-chlo-legalize-to-hlo -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement -copy-removal -canonicalize -cse -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops -canonicalize -cse -convert-linalg-to-llvm -test-lhlo-legalize-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s
// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement -copy-removal -canonicalize -cse -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops -canonicalize -cse -convert-linalg-to-llvm -test-lhlo-legalize-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s
func @main() -> () {
call @trivial_broadcast_wrapper() : () -> ()

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt %s -mhlo-test-chlo-legalize-to-hlo -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement -copy-removal -canonicalize -cse -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops -convert-scf-to-std -canonicalize -cse -test-lhlo-legalize-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s
// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement -copy-removal -canonicalize -cse -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops -convert-scf-to-std -canonicalize -cse -test-lhlo-legalize-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s
func @main() -> () {
call @reshape_with_static_shape_size_matrix_to_1D() : () -> ()

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt %s -mhlo-test-chlo-legalize-to-hlo -mhlo-test-lower-complex | FileCheck %s
// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo -mhlo-test-lower-complex | FileCheck %s
// CHECK-LABEL: @add
func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {