[MLIR][KernelGen] Fix unranked codegeneration in kernel generator

PiperOrigin-RevId: 335847086
This commit is contained in:
A. Unique TensorFlower 2020-10-07 05:39:04 -07:00 committed by TensorFlow MLIR Team
parent 1f14527609
commit 3736c5542f
9 changed files with 18 additions and 15 deletions

View File

@ -15,9 +15,9 @@ limitations under the License.
include "mlir/Pass/PassBase.td" include "mlir/Pass/PassBase.td"
def TestChloLegalizeToHloPass : Pass<"mhlo-test-chlo-legalize-to-hlo", "FuncOp"> { def ChloLegalizeToHloPass : Pass<"chlo-legalize-to-hlo", "FuncOp"> {
let summary = "Test pass for applying chlo -> hlo legalization patterns."; let summary = "Legalize CHLO to HLO.";
let constructor = "createTestChloLegalizeToHloPass()"; let constructor = "createChloLegalizeToHloPass()";
} }
def HloLegalizeToLhloPass : Pass<"hlo-legalize-to-lhlo", "ModuleOp"> { def HloLegalizeToLhloPass : Pass<"hlo-legalize-to-lhlo", "ModuleOp"> {

View File

@ -44,6 +44,9 @@ std::unique_ptr<OperationPass<FuncOp>> createControlFlowToScfPass();
/// Lowers from HLO dialect to Standard dialect. /// Lowers from HLO dialect to Standard dialect.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToStdPass(); std::unique_ptr<OperationPass<FuncOp>> createLegalizeToStdPass();
/// Lowers from the CHLO dialect to the HLO dialect.
std::unique_ptr<FunctionPass> createChloLegalizeToHloPass();
/// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary /// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
/// buffers if necessary. If `results_escape_functions` is set to true, /// buffers if necessary. If `results_escape_functions` is set to true,
/// allocated buffers for function results will be returned and escape the /// allocated buffers for function results will be returned and escape the

View File

@ -22,7 +22,6 @@ limitations under the License.
namespace mlir { namespace mlir {
namespace mhlo { namespace mhlo {
std::unique_ptr<Pass> createTestChloLegalizeToHloPass();
std::unique_ptr<FunctionPass> createTestInferShapedTypeMethodsPass(); std::unique_ptr<FunctionPass> createTestInferShapedTypeMethodsPass();
std::unique_ptr<Pass> createTestMaterializeBroadcastsPass(); std::unique_ptr<Pass> createTestMaterializeBroadcastsPass();
std::unique_ptr<Pass> createTestUnfuseBatchNormPass(); std::unique_ptr<Pass> createTestUnfuseBatchNormPass();

View File

@ -27,8 +27,8 @@ namespace mhlo {
namespace { namespace {
struct TestChloLegalizeToHloPass struct ChloLegalizeToHloPass
: public PassWrapper<TestChloLegalizeToHloPass, FunctionPass> { : public PassWrapper<ChloLegalizeToHloPass, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override { void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<mhlo::MhloDialect, shape::ShapeDialect, scf::SCFDialect>(); registry.insert<mhlo::MhloDialect, shape::ShapeDialect, scf::SCFDialect>();
} }
@ -36,11 +36,12 @@ struct TestChloLegalizeToHloPass
void runOnFunction() override { void runOnFunction() override {
ConversionTarget conversionTarget(getContext()); ConversionTarget conversionTarget(getContext());
OwningRewritePatternList conversionPatterns; OwningRewritePatternList conversionPatterns;
conversionTarget.addIllegalDialect<chlo::HloClientDialect>(); conversionTarget.addIllegalDialect<chlo::HloClientDialect>();
// Consider the mhlo dialect legal for tests. // Consider the mhlo dialect legal for tests.
conversionTarget.addLegalDialect<mhlo::MhloDialect>(); conversionTarget.addLegalDialect<mhlo::MhloDialect>();
// The conversion uses helpers from the Standard dialect.
// The conversion uses helpers from the standard dialect.
conversionTarget.addLegalDialect<mlir::StandardOpsDialect>(); conversionTarget.addLegalDialect<mlir::StandardOpsDialect>();
conversionTarget.addLegalDialect<mlir::shape::ShapeDialect>(); conversionTarget.addLegalDialect<mlir::shape::ShapeDialect>();
conversionTarget.addLegalDialect<mlir::scf::SCFDialect>(); conversionTarget.addLegalDialect<mlir::scf::SCFDialect>();
@ -56,8 +57,8 @@ struct TestChloLegalizeToHloPass
} // namespace } // namespace
std::unique_ptr<FunctionPass> createTestChloLegalizeToHloPass() { std::unique_ptr<FunctionPass> createChloLegalizeToHloPass() {
return std::make_unique<TestChloLegalizeToHloPass>(); return std::make_unique<ChloLegalizeToHloPass>();
} }
} // namespace mhlo } // namespace mhlo

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt -mhlo-test-chlo-legalize-to-hlo -cse -split-input-file -verify-diagnostics %s -o - | FileCheck %s // RUN: mlir-hlo-opt -chlo-legalize-to-hlo -cse -split-input-file -verify-diagnostics %s -o - | FileCheck %s
// Check the non-broadcast case for each registered op, then just check a // Check the non-broadcast case for each registered op, then just check a
// representative op for detailed broadcast semantics. // representative op for detailed broadcast semantics.

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt --mhlo-test-chlo-legalize-to-hlo --split-input-file %s | FileCheck %s // RUN: mlir-hlo-opt --chlo-legalize-to-hlo --split-input-file %s | FileCheck %s
// Lower statically shaped `constant_like` to constant. // Lower statically shaped `constant_like` to constant.
// CHECK-LABEL: @constant_like_static_shape // CHECK-LABEL: @constant_like_static_shape

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt %s -mhlo-test-chlo-legalize-to-hlo -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement -copy-removal -canonicalize -cse -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops -canonicalize -cse -convert-linalg-to-llvm -test-lhlo-legalize-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s // RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement -copy-removal -canonicalize -cse -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops -canonicalize -cse -convert-linalg-to-llvm -test-lhlo-legalize-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s
func @main() -> () { func @main() -> () {
call @trivial_broadcast_wrapper() : () -> () call @trivial_broadcast_wrapper() : () -> ()

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt %s -mhlo-test-chlo-legalize-to-hlo -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement -copy-removal -canonicalize -cse -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops -convert-scf-to-std -canonicalize -cse -test-lhlo-legalize-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s // RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement -copy-removal -canonicalize -cse -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops -convert-scf-to-std -canonicalize -cse -test-lhlo-legalize-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s
func @main() -> () { func @main() -> () {
call @reshape_with_static_shape_size_matrix_to_1D() : () -> () call @reshape_with_static_shape_size_matrix_to_1D() : () -> ()

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt %s -mhlo-test-chlo-legalize-to-hlo -mhlo-test-lower-complex | FileCheck %s // RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo -mhlo-test-lower-complex | FileCheck %s
// CHECK-LABEL: @add // CHECK-LABEL: @add
func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {