From ba0346b071f4212008943c50f47766e677992b56 Mon Sep 17 00:00:00 2001 From: Tres Popp Date: Wed, 20 Jan 2021 07:08:32 -0800 Subject: [PATCH] Integrate LLVM at llvm/llvm-project@96ef4f307df2 Updates LLVM usage to match [96ef4f307df2](https://github.com/llvm/llvm-project/commit/96ef4f307df2) PiperOrigin-RevId: 352786460 --- BUILD | 1 + WORKSPACE | 4 ++-- build_tools/llvm_version.txt | 2 +- lib/Dialect/mhlo/IR/hlo_ops.cc | 4 +++- .../mhlo/transforms/legalize_to_linalg.cc | 12 +++++------ .../mhlo/transforms/transform_unranked_hlo.cc | 4 ++-- .../mhlo/transforms/unfuse_batch_norm.cc | 3 ++- tests/canonicalize.mlir | 2 +- tests/end2end/broadcast.mlir | 20 +++++++++---------- tests/hlo-legalize-to-lhlo.mlir | 8 ++++---- tests/hlo-legalize-to-linalg.mlir | 8 ++++---- tests/hlo-transform-unranked.mlir | 18 ++++++++--------- tests/mhlo_infer_shape_type_methods.mlir | 4 ++-- tests/unfuse_batch_norm.mlir | 4 ++-- 14 files changed, 49 insertions(+), 45 deletions(-) diff --git a/BUILD b/BUILD index 466c509..9a22ef8 100644 --- a/BUILD +++ b/BUILD @@ -960,6 +960,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:Transforms", ], ) diff --git a/WORKSPACE b/WORKSPACE index 556d62d..59bda96 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -15,9 +15,9 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") -LLVM_COMMIT = "8456c3a789285079ad35d146e487436b5a27b027" +LLVM_COMMIT = "96ef4f307df27f4e0946eb344bac2703017ad073" -LLVM_SHA256 = "cc17723a31207ffa9c0636bf83752de0e2a20cf99d9a9955c796d7e109b4c68d" +LLVM_SHA256 = "69b6c722deed4f128318259ab8f3c511c9aea91357e52d9479e23edee78deb1a" LLVM_BAZEL_TAG = "llvm-project-{commit}".format(commit = LLVM_COMMIT) diff --git a/build_tools/llvm_version.txt b/build_tools/llvm_version.txt index 4e26d6c..59e4b85 100644 --- a/build_tools/llvm_version.txt +++ b/build_tools/llvm_version.txt @@ -1,2 +1,2 @@ -8456c3a789285079ad35d146e487436b5a27b027 +96ef4f307df27f4e0946eb344bac2703017ad073 diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index b00973b..663e6e0 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -58,6 +58,7 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/InliningUtils.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" namespace mlir { #include "hlo_patterns.cc.inc" @@ -3064,6 +3065,7 @@ MhloDialect::MhloDialect(MLIRContext* context) >(); addInterfaces(); addTypes(); + context->loadDialect(); } Type MhloDialect::parseType(DialectAsmParser& parser) const { @@ -3111,7 +3113,7 @@ LogicalResult deriveShapeFromFirstOperand( } } *reifiedReturnShapes = SmallVector{ - builder->create(loc, shape_values)}; + builder->create(loc, shape_values)}; return success(); } diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 9c83ce7..2cee0b4 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -1098,8 +1098,8 @@ class DotOpOnTensorsConversion : public OpConversionPattern { rewriter, loc, adaptor.lhs(), adaptor.rhs(), shaped_type, op_type); auto zero_attr = rewriter.getZeroAttr(shaped_type.getElementType()); Value zero = rewriter.create(loc, zero_attr); - auto init_tensor = rewriter.create( - loc, result_type, dyn_shape); + auto init_tensor = + rewriter.create(loc, result_type, dyn_shape); { OpBuilder::InsertionGuard guard(rewriter); SmallVector arg_types(shaped_type.getRank(), @@ -1107,7 +1107,7 @@ class DotOpOnTensorsConversion : public OpConversionPattern { Region& region = init_tensor.body(); Block* block = rewriter.createBlock(®ion, region.begin(), arg_types); rewriter.setInsertionPointToEnd(block); - rewriter.create(loc, zero); + rewriter.create(loc, zero); } linalg::LinalgOp linalg_op; switch (op_type) { @@ -1194,8 +1194,8 @@ class DotGeneralOpOnTensorsConversion rewriter, loc, adaptor.lhs(), adaptor.rhs(), shaped_type); auto zero_attr = rewriter.getZeroAttr(shaped_type.getElementType()); Value zero = rewriter.create(loc, zero_attr); - auto init_tensor = rewriter.create( - loc, result_type, dyn_shape); + auto init_tensor = + rewriter.create(loc, result_type, dyn_shape); { OpBuilder::InsertionGuard guard(rewriter); SmallVector arg_types(shaped_type.getRank(), @@ -1203,7 +1203,7 @@ class DotGeneralOpOnTensorsConversion Region& region = init_tensor.body(); Block* block = rewriter.createBlock(®ion, region.begin(), arg_types); rewriter.setInsertionPointToEnd(block); - rewriter.create(loc, zero); + rewriter.create(loc, zero); } auto linalg_op = rewriter.create( loc, /*resultTensorTypes=*/TypeRange{result_type}, diff --git a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc index 70d5d38..b359217 100644 --- a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc +++ b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc @@ -101,7 +101,7 @@ struct ElementwiseOpConversion : public OpRewritePattern { Type indexTy = rewriter.getIndexType(); Value numElements = rewriter.create(loc, indexTy, shape); - Value flatShape = rewriter.create(loc, numElements); + Value flatShape = rewriter.create(loc, numElements); // Flatten operands. SmallVector flatOperands; @@ -176,7 +176,7 @@ struct ConvertUnrankedScalarDynamicBroadcastBinaryOp rewriter.create(loc, lhs_is_scalar ? rhs : lhs); Value num_elements = rewriter.create(loc, shape); Value size_tensor = - rewriter.create(loc, num_elements); + rewriter.create(loc, num_elements); Value reshaped = rewriter.create( loc, RankedTensorType::get({-1}, scalar_element_type), lhs_is_scalar ? rhs : lhs, size_tensor); diff --git a/lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc b/lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc index 0639589..962d4ca 100644 --- a/lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc +++ b/lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc @@ -16,6 +16,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" @@ -59,7 +60,7 @@ Value CalculateShapeValue(Location loc, Value operand, for (int64_t i = 0; i < rank; ++i) { shape_values.push_back(rewriter.create(loc, operand, i)); } - return rewriter.create(loc, shape_values); + return rewriter.create(loc, shape_values); } Value MaterializeEpsilon(Operation* op, FloatAttr epsilon_attr, diff --git a/tests/canonicalize.mlir b/tests/canonicalize.mlir index 1127d4c..3b6cf16 100644 --- a/tests/canonicalize.mlir +++ b/tests/canonicalize.mlir @@ -603,7 +603,7 @@ func @dynamic_reshape_of_dynamic_reshape(%arg0: tensor, %shape: tensor, tensor) -> tensor<*xf16> %1 = shape.shape_of %0 : tensor<*xf16> -> tensor %2 = shape.num_elements %1 : tensor -> index - %3 = tensor_from_elements %2 : tensor<1xindex> + %3 = tensor.from_elements %2 : tensor<1xindex> %4 = "mhlo.dynamic_reshape"(%0, %3) : (tensor<*xf16>, tensor<1xindex>) -> tensor return %4 : tensor } diff --git a/tests/end2end/broadcast.mlir b/tests/end2end/broadcast.mlir index 932035e..1129e79 100644 --- a/tests/end2end/broadcast.mlir +++ b/tests/end2end/broadcast.mlir @@ -57,7 +57,7 @@ func @trivial_broadcast_wrapper() { // Test DynamicBroadcastInDimOp. %c3 = constant 3 : index %c4 = constant 4 : index - %shape = tensor_from_elements %c3, %c4 : tensor<2xindex> + %shape = tensor.from_elements %c3, %c4 : tensor<2xindex> %dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) { broadcast_dimensions = dense<0> : tensor<1xi64> } : (tensor<3xf32>, tensor<2xindex>) -> tensor<3x4xf32> @@ -106,7 +106,7 @@ func @broadcast_in_X_dim_wrapper() { // Test DynamicBroadcastInDimOp. %c4 = constant 4 : index - %shape = tensor_from_elements %c3, %c4 : tensor<2xindex> + %shape = tensor.from_elements %c3, %c4 : tensor<2xindex> %dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) { broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> } : (tensor<1x4xf32>, tensor<2xindex>) -> tensor<3x4xf32> @@ -153,7 +153,7 @@ func @broadcast_in_Y_dim_wrapper() { // Test DynamicBroadcastInDimOp. %c3 = constant 3 : index %c4 = constant 4 : index - %shape = tensor_from_elements %c3, %c4 : tensor<2xindex> + %shape = tensor.from_elements %c3, %c4 : tensor<2xindex> %dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) { broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> } : (tensor<3x1xf32>, tensor<2xindex>) -> tensor<3x4xf32> @@ -202,7 +202,7 @@ func @broadcast_in_X_dim_transpose_wrapper() { // Test DynamicBroadcastInDimOp. %c4 = constant 4 : index - %shape = tensor_from_elements %c3, %c4 : tensor<2xindex> + %shape = tensor.from_elements %c3, %c4 : tensor<2xindex> %dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) { broadcast_dimensions = dense<[1, 0]> : tensor<2xi64> } : (tensor<4x1xf32>, tensor<2xindex>) -> tensor<3x4xf32> @@ -249,7 +249,7 @@ func @broadcast_in_Y_dim_transpose_wrapper() { // Test DynamicBroadcastInDimOp. %c3 = constant 3 : index %c4 = constant 4 : index - %shape = tensor_from_elements %c3, %c4 : tensor<2xindex> + %shape = tensor.from_elements %c3, %c4 : tensor<2xindex> %dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) { broadcast_dimensions = dense<[1, 0]> : tensor<2xi64> } : (tensor<1x3xf32>, tensor<2xindex>) -> tensor<3x4xf32> @@ -290,7 +290,7 @@ func @broadcast_scalar_1d_wrapper() { // Test DynamicBroadcastInDimOp. %c3 = constant 3 : index %c4 = constant 4 : index - %shape = tensor_from_elements %c3, %c4 : tensor<2xindex> + %shape = tensor.from_elements %c3, %c4 : tensor<2xindex> %dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) { broadcast_dimensions = dense<0> : tensor<1xi64> } : (tensor<1xf32>, tensor<2xindex>) -> tensor<3x4xf32> @@ -331,7 +331,7 @@ func @broadcast_scalar_2d_wrapper() { // Test DynamicBroadcastInDimOp. %c3 = constant 3 : index %c4 = constant 4 : index - %shape = tensor_from_elements %c3, %c4 : tensor<2xindex> + %shape = tensor.from_elements %c3, %c4 : tensor<2xindex> %dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) { broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> } : (tensor<1x1xf32>, tensor<2xindex>) -> tensor<3x4xf32> @@ -381,7 +381,7 @@ func @broadcast_to_the_same_shape() { // CHECK-NEXT: [1, 2, 3] // Test DynamicBroadcastInDimOp. - %shape = tensor_from_elements %c2, %c3 : tensor<2xindex> + %shape = tensor.from_elements %c2, %c3 : tensor<2xindex> %dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) { broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> } : (tensor<2x3xf32>, tensor<2xindex>) -> tensor<2x3xf32> @@ -429,7 +429,7 @@ func @broadcast_1d_to_2d() { // Test DynamicBroadcastInDimOp. %c3 = constant 3 : index %c4 = constant 3 : index - %shape = tensor_from_elements %c3, %c4 : tensor<2xindex> + %shape = tensor.from_elements %c3, %c4 : tensor<2xindex> %dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) { broadcast_dimensions = dense<0> : tensor<1xi64> } : (tensor<3xf32>, tensor<2xindex>) -> tensor<3x3xf32> @@ -477,7 +477,7 @@ func @broadcast_1d_to_2d_with_transpose() { // Test DynamicBroadcastInDimOp. %c3 = constant 3 : index - %shape = tensor_from_elements %c3, %c3 : tensor<2xindex> + %shape = tensor.from_elements %c3, %c3 : tensor<2xindex> %dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) { broadcast_dimensions = dense<1> : tensor<1xi64> } : (tensor<3xf32>, tensor<2xindex>) -> tensor<3x3xf32> diff --git a/tests/hlo-legalize-to-lhlo.mlir b/tests/hlo-legalize-to-lhlo.mlir index 11497ca..47224bc 100644 --- a/tests/hlo-legalize-to-lhlo.mlir +++ b/tests/hlo-legalize-to-lhlo.mlir @@ -135,13 +135,13 @@ func @broadcast(%operand: tensor<5xf32>) -> tensor<10x5xf32> { func @dyn_broadcast(%operand: tensor) -> tensor { // CHECK-SAME: %[[OPERAND:.*]]: memref %c1 = constant 1 : i64 - %shape = tensor_from_elements %c1, %c1, %c1 : tensor<3xi64> + %shape = tensor.from_elements %c1, %c1, %c1 : tensor<3xi64> %result = "mhlo.dynamic_broadcast_in_dim"(%operand, %shape) { broadcast_dimensions = dense<[1, 2]> : tensor<2xi64> } : (tensor, tensor<3xi64>) -> tensor return %result : tensor } -// CHECK: %[[SHAPE:.*]] = tensor_from_elements +// CHECK: %[[SHAPE:.*]] = tensor.from_elements // CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[C1:.*]] = constant 1 : index @@ -463,7 +463,7 @@ func @add_dyn(%lhs: tensor, %rhs: tensor) -> tensor { // CHECK: %[[C1:.*]] = constant 1 : index // CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref // CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64 - // CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64> + // CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[IC0]], %[[IC1]] : tensor<2xi64> // CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xi64> // CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index // CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xi64> @@ -487,7 +487,7 @@ func @tanh_dyn(%arg0: tensor) -> tensor { // CHECK: %[[C1:.*]] = constant 1 : index // CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref // CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64 - // CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64> + // CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[IC0]], %[[IC1]] : tensor<2xi64> // CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xi64> // CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index // CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xi64> diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index 298ce7b..159a616 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -849,7 +849,7 @@ func @dot_matmul(%arg0: tensor<2x3xf32>, return %0 : tensor<2x?xf32> } // CHECK: func @dot_matmul(%[[ARG0:.*]]: tensor<2x3xf32>, %[[ARG1:.*]]: tensor<3x?xf32>) -// CHECK: %[[INIT:.*]] = dynamic_tensor_from_elements +// CHECK: %[[INIT:.*]] = tensor.generate // CHECK: linalg.matmul // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xf32>, tensor<3x?xf32>) // CHECK-SAME: outs(%[[INIT]] : tensor<2x?xf32>) @@ -863,7 +863,7 @@ func @dot_matvec(%arg0: tensor, return %0 : tensor } // CHECK: func @dot_matvec(%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor<3xf32>) -// CHECK: %[[INIT:.*]] = dynamic_tensor_from_elements +// CHECK: %[[INIT:.*]] = tensor.generate // CHECK: linalg.matvec // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor, tensor<3xf32>) // CHECK-SAME: outs(%[[INIT]] : tensor) @@ -876,7 +876,7 @@ func @dot_dot(%arg0: tensor, return %0 : tensor } // CHECK: func @dot_dot(%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) -// CHECK: %[[INIT:.*]] = dynamic_tensor_from_elements +// CHECK: %[[INIT:.*]] = tensor.generate // CHECK: linalg.dot // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor, tensor) // CHECK-SAME: outs(%[[INIT]] : tensor) @@ -897,7 +897,7 @@ func @dot_general(%arg0: tensor, return %0 : tensor } // CHECK: func @dot_general(%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) -// CHECK: %[[INIT:.*]] = dynamic_tensor_from_elements +// CHECK: %[[INIT:.*]] = tensor.generate // CHECK: linalg.batch_matmul // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor, tensor) // CHECK-SAME: outs(%[[INIT]] : tensor) diff --git a/tests/hlo-transform-unranked.mlir b/tests/hlo-transform-unranked.mlir index cc6d725..f7ef720 100644 --- a/tests/hlo-transform-unranked.mlir +++ b/tests/hlo-transform-unranked.mlir @@ -7,7 +7,7 @@ func @sqr_transform_result(%a: tensor<*xf32>) -> tensor<*xf32> { // Flatten operand shape. %shape = shape.shape_of %a : tensor<*xf32> -> tensor %num_elements = shape.num_elements %shape : tensor -> index - %flat_shape = tensor_from_elements %num_elements : tensor<1xindex> + %flat_shape = tensor.from_elements %num_elements : tensor<1xindex> %flat_a = "mhlo.dynamic_reshape"(%a, %flat_shape) : (tensor<*xf32>, tensor<1xindex>) -> tensor @@ -29,7 +29,7 @@ func @sqr_transform_result(%a: tensor<*xf32>) -> tensor<*xf32> { func @sqrt(%a: tensor<*xf32>) -> tensor<*xf32> { // CHECK-NEXT: %[[SHAPE:.*]] = shape.shape_of %[[A]] : tensor<*xf32> -> tensor // CHECK-NEXT: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]] - // CHECK-NEXT: %[[FLAT_SHAPE:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex> + // CHECK-NEXT: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[NUM_ELEMENTS]] : tensor<1xindex> // CHECK-NEXT: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor // CHECK-NEXT: %[[FLAT_B:.*]] = "mhlo.sqrt"(%[[FLAT_A]]) : (tensor) -> tensor // CHECK-NEXT: %[[B:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE]]) : (tensor, tensor) -> tensor<*xf32> @@ -71,7 +71,7 @@ func @add_unranked(%a : tensor<*xf32>, %b : tensor<*xf32>) -> tensor<*xf32> { // CHECK: %[[SHAPE_B:.*]] = shape.shape_of %[[B]] // CHECK: %[[SHAPE:.*]] = shape.any %[[SHAPE_A]], %[[SHAPE_B]] // CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]] - // CHECK: %[[FLAT_SHAPE:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex> + // CHECK: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[NUM_ELEMENTS]] : tensor<1xindex> // CHECK: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor // CHECK: %[[FLAT_B:.*]] = "mhlo.dynamic_reshape"(%[[B]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor // CHECK: %[[FLAT_RESULT:.*]] = mhlo.add %[[FLAT_A]], %[[FLAT_B]] : tensor @@ -88,7 +88,7 @@ func @add_unranked(%a : tensor<*xf32>, %b : tensor<*xf32>) -> tensor<*xf32> { func @tan(%a : tensor<*xf32>) -> tensor<*xf32> { // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[A]] : tensor<*xf32> -> tensor // CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]] - // CHECK: %[[FLAT_SHAPE:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex> + // CHECK: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[NUM_ELEMENTS]] : tensor<1xindex> // CHECK: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor // CHECK: %[[FLAT_B:.*]] = chlo.tan %[[FLAT_A]] : tensor // CHECK: %[[B:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE]]) : (tensor, tensor) -> tensor<*xf32> @@ -113,7 +113,7 @@ func @addScalarUnranked(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<*xf3 // to a 1D tensor. // CHECK-NEXT: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<*xf32> // CHECK-NEXT: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_1]] : tensor -> index -// CHECK-NEXT: %[[SIZE_TENSOR:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex> +// CHECK-NEXT: %[[SIZE_TENSOR:.*]] = tensor.from_elements %[[NUM_ELEMENTS]] : tensor<1xindex> // CHECK-NEXT: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_1]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor // CHECK-NEXT: %[[BROADCASTED_RESULT:.*]] = chlo.broadcast_add %[[ARG_0]], %[[RESHAPED]] : (tensor, tensor) -> tensor // As part of the unranked logic, the result is reshaped back @@ -135,7 +135,7 @@ func @addUnrankedScalar(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor<*xf3 // to a 1D tensor. // CHECK-NEXT: %[[SHAPE_0:.*]] = shape.shape_of %[[ARG_0]] : tensor<*xf32> // CHECK-NEXT: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_0]] : tensor -> index -// CHECK-NEXT: %[[SIZE_TENSOR:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex> +// CHECK-NEXT: %[[SIZE_TENSOR:.*]] = tensor.from_elements %[[NUM_ELEMENTS]] : tensor<1xindex> // CHECK-NEXT: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_0]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor // The assuming region is part of the second stage of lowering // with ranked broadcasting logic. @@ -166,7 +166,7 @@ func @addUnrankedUnranked( // CHECK-NEXT: %[[SCALAR_LHS:.*]] = tensor.cast %[[LHS]] : tensor<*xf32> to tensor // CHECK-NEXT: %[[RHS_SHAPE_1:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor // CHECK-NEXT: %[[NUM_RHS:.*]] = shape.num_elements %[[RHS_SHAPE_1]] : tensor -> index -// CHECK-NEXT: %[[NUM_TENS_RHS:.*]] = tensor_from_elements %[[NUM_RHS]] : tensor<1xindex> +// CHECK-NEXT: %[[NUM_TENS_RHS:.*]] = tensor.from_elements %[[NUM_RHS]] : tensor<1xindex> // CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[NUM_TENS_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor // CHECK-NEXT: %[[LHS_SCALAR_RESULT:.*]] = chlo.broadcast_add %[[SCALAR_LHS]], %[[RESHAPED_RHS]] : (tensor, tensor) -> tensor // CHECK-NEXT: %[[RESHAPED_LHS_SCALAR_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[LHS_SCALAR_RESULT]], %[[RHS_SHAPE_1]]) : (tensor, tensor) -> tensor<*xf32> @@ -179,7 +179,7 @@ func @addUnrankedUnranked( // CHECK-NEXT: %[[VAL_14:.*]] = scf.if %[[RHS_IS_SCALAR]] -> (tensor<*xf32>) { // CHECK-NEXT: %[[SCALAR_RHS:.*]] = tensor.cast %[[RHS]] : tensor<*xf32> to tensor // CHECK-NEXT: %[[NUM_LHS:.*]] = shape.num_elements %[[LHS_SHAPE]] : tensor -> index -// CHECK-NEXT: %[[NUM_TENS_LHS:.*]] = tensor_from_elements %[[NUM_LHS]] : tensor<1xindex> +// CHECK-NEXT: %[[NUM_TENS_LHS:.*]] = tensor.from_elements %[[NUM_LHS]] : tensor<1xindex> // CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[NUM_TENS_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor // CHECK-NEXT: %[[RHS_SCALAR_RESULT:.*]] = chlo.broadcast_add %[[RESHAPED_LHS]], %[[SCALAR_RHS]] : (tensor, tensor) -> tensor // CHECK-NEXT: %[[RESHAPED_RHS_SCALAR_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RHS_SCALAR_RESULT:.*]], %[[LHS_SHAPE]]) : (tensor, tensor) -> tensor<*xf32> @@ -190,7 +190,7 @@ func @addUnrankedUnranked( // CHECK-NEXT: %[[VAL_18:.*]] = scf.if %[[SHAPES_EQ]] -> (tensor<*xf32>) { // CHECK-NEXT: %[[ANY_SHAPE:.*]] = shape.any %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor, tensor -> tensor // CHECK-NEXT: %[[ANY_NUM:.*]] = shape.num_elements %[[ANY_SHAPE]] : tensor -> index -// CHECK-NEXT: %[[ANY_TENSOR:.*]] = tensor_from_elements %[[ANY_NUM]] : tensor<1xindex> +// CHECK-NEXT: %[[ANY_TENSOR:.*]] = tensor.from_elements %[[ANY_NUM]] : tensor<1xindex> // CHECK-NEXT: %[[FLATTENED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[ANY_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor // CHECK-NEXT: %[[FLATTENED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[ANY_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor // CHECK-NEXT: %[[FLATTENED_RESULT:.*]] = mhlo.add %[[FLATTENED_LHS]], %[[FLATTENED_RHS]] : tensor diff --git a/tests/mhlo_infer_shape_type_methods.mlir b/tests/mhlo_infer_shape_type_methods.mlir index 8829e4c..c40eb3e 100644 --- a/tests/mhlo_infer_shape_type_methods.mlir +++ b/tests/mhlo_infer_shape_type_methods.mlir @@ -9,7 +9,7 @@ func @select(%pred : tensor<2x?xi1>, %a : tensor<2x?xf32>, %b : tensor<2x?xf32>) // CHECK: %[[C1:.*]] = constant 1 : index // CHECK: %[[DIM_AS_INDEX:.*]] = dim %[[PRED]], %[[C1]] : tensor<2x?xi1> // CHECK: %[[DIM:.*]] = index_cast %[[DIM_AS_INDEX]] : index to i64 - // CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[C2]], %[[DIM]] : tensor<2xi64> + // CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[C2]], %[[DIM]] : tensor<2xi64> // CHECK: return %[[SHAPE]] : tensor<2xi64> %0 = "mhlo.select"(%pred, %a, %b) : (tensor<2x?xi1>, tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xf32> @@ -26,7 +26,7 @@ func @compare(%a : tensor<2x?xf32>, %b : tensor<2x?xf32>) -> tensor<2xi64> { // CHECK: %[[C1:.*]] = constant 1 : index // CHECK: %[[DIM_AS_INDEX:.*]] = dim %[[A]], %[[C1]] : tensor<2x?xf32> // CHECK: %[[DIM:.*]] = index_cast %[[DIM_AS_INDEX]] : index to i64 - // CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[C2]], %[[DIM]] : tensor<2xi64> + // CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[C2]], %[[DIM]] : tensor<2xi64> // CHECK: return %[[SHAPE]] : tensor<2xi64> %0 = "mhlo.compare"(%a, %b) {comparison_direction = "NE"} : (tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xi1> diff --git a/tests/unfuse_batch_norm.mlir b/tests/unfuse_batch_norm.mlir index 53ee94f..f8fea55 100644 --- a/tests/unfuse_batch_norm.mlir +++ b/tests/unfuse_batch_norm.mlir @@ -109,7 +109,7 @@ func @batchNormInference_dynamic_shape( // CHECK-DAG: %[[C3:.*]] = constant 3 : index // CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e-03> : tensor // CHECK-DAG: %[[DIM:.+]] = dim %[[VARIANCE]], %[[C0]] : tensor - // CHECK-DAG: %[[TO_DIM_TENSOR:.+]] = tensor_from_elements %[[DIM]] : tensor<1xindex> + // CHECK-DAG: %[[TO_DIM_TENSOR:.+]] = tensor.from_elements %[[DIM]] : tensor<1xindex> // CHECK-DAG: %[[EPS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[TO_DIM_TENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor // CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor // CHECK-DAG: %[[STDDEV:.+]] = "mhlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor) -> tensor @@ -117,7 +117,7 @@ func @batchNormInference_dynamic_shape( // CHECK-DAG: %[[INPUT_DIM_1:.+]] = dim %[[X]], %[[C1]] : tensor // CHECK-DAG: %[[INPUT_DIM_2:.+]] = dim %[[X]], %[[C2]] : tensor // CHECK-DAG: %[[INPUT_DIM_3:.+]] = dim %[[X]], %[[C3]] : tensor - // CHECK-DAG: %[[TO_INPUT_DIM_TENSOR:.+]] = tensor_from_elements %[[INPUT_DIM_0]], %[[INPUT_DIM_1]], %[[INPUT_DIM_2]], %[[INPUT_DIM_3]] : tensor<4xindex> + // CHECK-DAG: %[[TO_INPUT_DIM_TENSOR:.+]] = tensor.from_elements %[[INPUT_DIM_0]], %[[INPUT_DIM_1]], %[[INPUT_DIM_2]], %[[INPUT_DIM_3]] : tensor<4xindex> // CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[STDDEV]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor // CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor // CHECK-DAG: %[[OFFSET_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[OFFSET]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor