Updates LLVM usage to match
[96ef4f307df2](https://github.com/llvm/llvm-project/commit/96ef4f307df2)

PiperOrigin-RevId: 352786460
This commit is contained in:
Tres Popp 2021-01-20 07:08:32 -08:00 committed by TensorFlow MLIR Team
parent ec5f5667e1
commit ba0346b071
14 changed files with 49 additions and 45 deletions

1
BUILD
View File

@ -960,6 +960,7 @@ cc_library(
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:Transforms",
],
)

View File

@ -15,9 +15,9 @@
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
LLVM_COMMIT = "8456c3a789285079ad35d146e487436b5a27b027"
LLVM_COMMIT = "96ef4f307df27f4e0946eb344bac2703017ad073"
LLVM_SHA256 = "cc17723a31207ffa9c0636bf83752de0e2a20cf99d9a9955c796d7e109b4c68d"
LLVM_SHA256 = "69b6c722deed4f128318259ab8f3c511c9aea91357e52d9479e23edee78deb1a"
LLVM_BAZEL_TAG = "llvm-project-{commit}".format(commit = LLVM_COMMIT)

View File

@ -1,2 +1,2 @@
8456c3a789285079ad35d146e487436b5a27b027
96ef4f307df27f4e0946eb344bac2703017ad073

View File

@ -58,6 +58,7 @@ limitations under the License.
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/InliningUtils.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h"
namespace mlir {
#include "hlo_patterns.cc.inc"
@ -3064,6 +3065,7 @@ MhloDialect::MhloDialect(MLIRContext* context)
>();
addInterfaces<HLOInlinerInterface>();
addTypes<TokenType>();
context->loadDialect<tensor::TensorDialect>();
}
Type MhloDialect::parseType(DialectAsmParser& parser) const {
@ -3111,7 +3113,7 @@ LogicalResult deriveShapeFromFirstOperand(
}
}
*reifiedReturnShapes = SmallVector<Value, 1>{
builder->create<TensorFromElementsOp>(loc, shape_values)};
builder->create<tensor::FromElementsOp>(loc, shape_values)};
return success();
}

View File

@ -1098,8 +1098,8 @@ class DotOpOnTensorsConversion : public OpConversionPattern<mhlo::DotOp> {
rewriter, loc, adaptor.lhs(), adaptor.rhs(), shaped_type, op_type);
auto zero_attr = rewriter.getZeroAttr(shaped_type.getElementType());
Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
auto init_tensor = rewriter.create<DynamicTensorFromElementsOp>(
loc, result_type, dyn_shape);
auto init_tensor =
rewriter.create<tensor::GenerateOp>(loc, result_type, dyn_shape);
{
OpBuilder::InsertionGuard guard(rewriter);
SmallVector<Type, 4> arg_types(shaped_type.getRank(),
@ -1107,7 +1107,7 @@ class DotOpOnTensorsConversion : public OpConversionPattern<mhlo::DotOp> {
Region& region = init_tensor.body();
Block* block = rewriter.createBlock(&region, region.begin(), arg_types);
rewriter.setInsertionPointToEnd(block);
rewriter.create<YieldOp>(loc, zero);
rewriter.create<tensor::YieldOp>(loc, zero);
}
linalg::LinalgOp linalg_op;
switch (op_type) {
@ -1194,8 +1194,8 @@ class DotGeneralOpOnTensorsConversion
rewriter, loc, adaptor.lhs(), adaptor.rhs(), shaped_type);
auto zero_attr = rewriter.getZeroAttr(shaped_type.getElementType());
Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
auto init_tensor = rewriter.create<DynamicTensorFromElementsOp>(
loc, result_type, dyn_shape);
auto init_tensor =
rewriter.create<tensor::GenerateOp>(loc, result_type, dyn_shape);
{
OpBuilder::InsertionGuard guard(rewriter);
SmallVector<Type, 4> arg_types(shaped_type.getRank(),
@ -1203,7 +1203,7 @@ class DotGeneralOpOnTensorsConversion
Region& region = init_tensor.body();
Block* block = rewriter.createBlock(&region, region.begin(), arg_types);
rewriter.setInsertionPointToEnd(block);
rewriter.create<YieldOp>(loc, zero);
rewriter.create<tensor::YieldOp>(loc, zero);
}
auto linalg_op = rewriter.create<linalg::BatchMatmulOp>(
loc, /*resultTensorTypes=*/TypeRange{result_type},

View File

@ -101,7 +101,7 @@ struct ElementwiseOpConversion : public OpRewritePattern<OpTy> {
Type indexTy = rewriter.getIndexType();
Value numElements =
rewriter.create<shape::NumElementsOp>(loc, indexTy, shape);
Value flatShape = rewriter.create<TensorFromElementsOp>(loc, numElements);
Value flatShape = rewriter.create<tensor::FromElementsOp>(loc, numElements);
// Flatten operands.
SmallVector<Value, 3> flatOperands;
@ -176,7 +176,7 @@ struct ConvertUnrankedScalarDynamicBroadcastBinaryOp
rewriter.create<shape::ShapeOfOp>(loc, lhs_is_scalar ? rhs : lhs);
Value num_elements = rewriter.create<shape::NumElementsOp>(loc, shape);
Value size_tensor =
rewriter.create<TensorFromElementsOp>(loc, num_elements);
rewriter.create<tensor::FromElementsOp>(loc, num_elements);
Value reshaped = rewriter.create<mhlo::DynamicReshapeOp>(
loc, RankedTensorType::get({-1}, scalar_element_type),
lhs_is_scalar ? rhs : lhs, size_tensor);

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "llvm/ADT/SmallVector.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
@ -59,7 +60,7 @@ Value CalculateShapeValue(Location loc, Value operand,
for (int64_t i = 0; i < rank; ++i) {
shape_values.push_back(rewriter.create<mlir::DimOp>(loc, operand, i));
}
return rewriter.create<TensorFromElementsOp>(loc, shape_values);
return rewriter.create<tensor::FromElementsOp>(loc, shape_values);
}
Value MaterializeEpsilon(Operation* op, FloatAttr epsilon_attr,

View File

@ -603,7 +603,7 @@ func @dynamic_reshape_of_dynamic_reshape(%arg0: tensor<?xf16>, %shape: tensor<?x
%0 = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor<?xf16>, tensor<?xindex>) -> tensor<*xf16>
%1 = shape.shape_of %0 : tensor<*xf16> -> tensor<?xindex>
%2 = shape.num_elements %1 : tensor<?xindex> -> index
%3 = tensor_from_elements %2 : tensor<1xindex>
%3 = tensor.from_elements %2 : tensor<1xindex>
%4 = "mhlo.dynamic_reshape"(%0, %3) : (tensor<*xf16>, tensor<1xindex>) -> tensor<?xf16>
return %4 : tensor<?xf16>
}

View File

@ -57,7 +57,7 @@ func @trivial_broadcast_wrapper() {
// Test DynamicBroadcastInDimOp.
%c3 = constant 3 : index
%c4 = constant 4 : index
%shape = tensor_from_elements %c3, %c4 : tensor<2xindex>
%shape = tensor.from_elements %c3, %c4 : tensor<2xindex>
%dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) {
broadcast_dimensions = dense<0> : tensor<1xi64>
} : (tensor<3xf32>, tensor<2xindex>) -> tensor<3x4xf32>
@ -106,7 +106,7 @@ func @broadcast_in_X_dim_wrapper() {
// Test DynamicBroadcastInDimOp.
%c4 = constant 4 : index
%shape = tensor_from_elements %c3, %c4 : tensor<2xindex>
%shape = tensor.from_elements %c3, %c4 : tensor<2xindex>
%dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) {
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
} : (tensor<1x4xf32>, tensor<2xindex>) -> tensor<3x4xf32>
@ -153,7 +153,7 @@ func @broadcast_in_Y_dim_wrapper() {
// Test DynamicBroadcastInDimOp.
%c3 = constant 3 : index
%c4 = constant 4 : index
%shape = tensor_from_elements %c3, %c4 : tensor<2xindex>
%shape = tensor.from_elements %c3, %c4 : tensor<2xindex>
%dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) {
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
} : (tensor<3x1xf32>, tensor<2xindex>) -> tensor<3x4xf32>
@ -202,7 +202,7 @@ func @broadcast_in_X_dim_transpose_wrapper() {
// Test DynamicBroadcastInDimOp.
%c4 = constant 4 : index
%shape = tensor_from_elements %c3, %c4 : tensor<2xindex>
%shape = tensor.from_elements %c3, %c4 : tensor<2xindex>
%dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) {
broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>
} : (tensor<4x1xf32>, tensor<2xindex>) -> tensor<3x4xf32>
@ -249,7 +249,7 @@ func @broadcast_in_Y_dim_transpose_wrapper() {
// Test DynamicBroadcastInDimOp.
%c3 = constant 3 : index
%c4 = constant 4 : index
%shape = tensor_from_elements %c3, %c4 : tensor<2xindex>
%shape = tensor.from_elements %c3, %c4 : tensor<2xindex>
%dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) {
broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>
} : (tensor<1x3xf32>, tensor<2xindex>) -> tensor<3x4xf32>
@ -290,7 +290,7 @@ func @broadcast_scalar_1d_wrapper() {
// Test DynamicBroadcastInDimOp.
%c3 = constant 3 : index
%c4 = constant 4 : index
%shape = tensor_from_elements %c3, %c4 : tensor<2xindex>
%shape = tensor.from_elements %c3, %c4 : tensor<2xindex>
%dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) {
broadcast_dimensions = dense<0> : tensor<1xi64>
} : (tensor<1xf32>, tensor<2xindex>) -> tensor<3x4xf32>
@ -331,7 +331,7 @@ func @broadcast_scalar_2d_wrapper() {
// Test DynamicBroadcastInDimOp.
%c3 = constant 3 : index
%c4 = constant 4 : index
%shape = tensor_from_elements %c3, %c4 : tensor<2xindex>
%shape = tensor.from_elements %c3, %c4 : tensor<2xindex>
%dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) {
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
} : (tensor<1x1xf32>, tensor<2xindex>) -> tensor<3x4xf32>
@ -381,7 +381,7 @@ func @broadcast_to_the_same_shape() {
// CHECK-NEXT: [1, 2, 3]
// Test DynamicBroadcastInDimOp.
%shape = tensor_from_elements %c2, %c3 : tensor<2xindex>
%shape = tensor.from_elements %c2, %c3 : tensor<2xindex>
%dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) {
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
} : (tensor<2x3xf32>, tensor<2xindex>) -> tensor<2x3xf32>
@ -429,7 +429,7 @@ func @broadcast_1d_to_2d() {
// Test DynamicBroadcastInDimOp.
%c3 = constant 3 : index
%c4 = constant 3 : index
%shape = tensor_from_elements %c3, %c4 : tensor<2xindex>
%shape = tensor.from_elements %c3, %c4 : tensor<2xindex>
%dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) {
broadcast_dimensions = dense<0> : tensor<1xi64>
} : (tensor<3xf32>, tensor<2xindex>) -> tensor<3x3xf32>
@ -477,7 +477,7 @@ func @broadcast_1d_to_2d_with_transpose() {
// Test DynamicBroadcastInDimOp.
%c3 = constant 3 : index
%shape = tensor_from_elements %c3, %c3 : tensor<2xindex>
%shape = tensor.from_elements %c3, %c3 : tensor<2xindex>
%dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) {
broadcast_dimensions = dense<1> : tensor<1xi64>
} : (tensor<3xf32>, tensor<2xindex>) -> tensor<3x3xf32>

View File

@ -135,13 +135,13 @@ func @broadcast(%operand: tensor<5xf32>) -> tensor<10x5xf32> {
func @dyn_broadcast(%operand: tensor<?x?xf32>) -> tensor<?x?x?xf32> {
// CHECK-SAME: %[[OPERAND:.*]]: memref<?x?xf32>
%c1 = constant 1 : i64
%shape = tensor_from_elements %c1, %c1, %c1 : tensor<3xi64>
%shape = tensor.from_elements %c1, %c1, %c1 : tensor<3xi64>
%result = "mhlo.dynamic_broadcast_in_dim"(%operand, %shape) {
broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
} : (tensor<?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
return %result : tensor<?x?x?xf32>
}
// CHECK: %[[SHAPE:.*]] = tensor_from_elements
// CHECK: %[[SHAPE:.*]] = tensor.from_elements
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[C1:.*]] = constant 1 : index
@ -463,7 +463,7 @@ func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
// CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
// CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64>
// CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[IC0]], %[[IC1]] : tensor<2xi64>
// CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xi64>
// CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
// CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xi64>
@ -487,7 +487,7 @@ func @tanh_dyn(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
// CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
// CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64>
// CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[IC0]], %[[IC1]] : tensor<2xi64>
// CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xi64>
// CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
// CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xi64>

View File

@ -849,7 +849,7 @@ func @dot_matmul(%arg0: tensor<2x3xf32>,
return %0 : tensor<2x?xf32>
}
// CHECK: func @dot_matmul(%[[ARG0:.*]]: tensor<2x3xf32>, %[[ARG1:.*]]: tensor<3x?xf32>)
// CHECK: %[[INIT:.*]] = dynamic_tensor_from_elements
// CHECK: %[[INIT:.*]] = tensor.generate
// CHECK: linalg.matmul
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xf32>, tensor<3x?xf32>)
// CHECK-SAME: outs(%[[INIT]] : tensor<2x?xf32>)
@ -863,7 +863,7 @@ func @dot_matvec(%arg0: tensor<?x3xf32>,
return %0 : tensor<?xf32>
}
// CHECK: func @dot_matvec(%[[ARG0:.*]]: tensor<?x3xf32>, %[[ARG1:.*]]: tensor<3xf32>)
// CHECK: %[[INIT:.*]] = dynamic_tensor_from_elements
// CHECK: %[[INIT:.*]] = tensor.generate
// CHECK: linalg.matvec
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x3xf32>, tensor<3xf32>)
// CHECK-SAME: outs(%[[INIT]] : tensor<?xf32>)
@ -876,7 +876,7 @@ func @dot_dot(%arg0: tensor<?xf32>,
return %0 : tensor<f32>
}
// CHECK: func @dot_dot(%[[ARG0:.*]]: tensor<?xf32>, %[[ARG1:.*]]: tensor<?xf32>)
// CHECK: %[[INIT:.*]] = dynamic_tensor_from_elements
// CHECK: %[[INIT:.*]] = tensor.generate
// CHECK: linalg.dot
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?xf32>, tensor<?xf32>)
// CHECK-SAME: outs(%[[INIT]] : tensor<f32>)
@ -897,7 +897,7 @@ func @dot_general(%arg0: tensor<?x?x3xf32>,
return %0 : tensor<?x?x?xf32>
}
// CHECK: func @dot_general(%[[ARG0:.*]]: tensor<?x?x3xf32>, %[[ARG1:.*]]: tensor<?x3x?xf32>)
// CHECK: %[[INIT:.*]] = dynamic_tensor_from_elements
// CHECK: %[[INIT:.*]] = tensor.generate
// CHECK: linalg.batch_matmul
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x3xf32>, tensor<?x3x?xf32>)
// CHECK-SAME: outs(%[[INIT]] : tensor<?x?x?xf32>)

View File

@ -7,7 +7,7 @@ func @sqr_transform_result(%a: tensor<*xf32>) -> tensor<*xf32> {
// Flatten operand shape.
%shape = shape.shape_of %a : tensor<*xf32> -> tensor<?xindex>
%num_elements = shape.num_elements %shape : tensor<?xindex> -> index
%flat_shape = tensor_from_elements %num_elements : tensor<1xindex>
%flat_shape = tensor.from_elements %num_elements : tensor<1xindex>
%flat_a = "mhlo.dynamic_reshape"(%a, %flat_shape)
: (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
@ -29,7 +29,7 @@ func @sqr_transform_result(%a: tensor<*xf32>) -> tensor<*xf32> {
func @sqrt(%a: tensor<*xf32>) -> tensor<*xf32> {
// CHECK-NEXT: %[[SHAPE:.*]] = shape.shape_of %[[A]] : tensor<*xf32> -> tensor<?xindex>
// CHECK-NEXT: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]]
// CHECK-NEXT: %[[FLAT_SHAPE:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
// CHECK-NEXT: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
// CHECK-NEXT: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[FLAT_B:.*]] = "mhlo.sqrt"(%[[FLAT_A]]) : (tensor<?xf32>) -> tensor<?xf32>
// CHECK-NEXT: %[[B:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
@ -71,7 +71,7 @@ func @add_unranked(%a : tensor<*xf32>, %b : tensor<*xf32>) -> tensor<*xf32> {
// CHECK: %[[SHAPE_B:.*]] = shape.shape_of %[[B]]
// CHECK: %[[SHAPE:.*]] = shape.any %[[SHAPE_A]], %[[SHAPE_B]]
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]]
// CHECK: %[[FLAT_SHAPE:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
// CHECK: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
// CHECK: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[FLAT_B:.*]] = "mhlo.dynamic_reshape"(%[[B]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[FLAT_RESULT:.*]] = mhlo.add %[[FLAT_A]], %[[FLAT_B]] : tensor<?xf32>
@ -88,7 +88,7 @@ func @add_unranked(%a : tensor<*xf32>, %b : tensor<*xf32>) -> tensor<*xf32> {
func @tan(%a : tensor<*xf32>) -> tensor<*xf32> {
// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[A]] : tensor<*xf32> -> tensor<?xindex>
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]]
// CHECK: %[[FLAT_SHAPE:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
// CHECK: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
// CHECK: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[FLAT_B:.*]] = chlo.tan %[[FLAT_A]] : tensor<?xf32>
// CHECK: %[[B:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
@ -113,7 +113,7 @@ func @addScalarUnranked(%arg0: tensor<f32>, %arg1: tensor<*xf32>) -> tensor<*xf3
// to a 1D tensor.
// CHECK-NEXT: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<*xf32>
// CHECK-NEXT: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_1]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[SIZE_TENSOR:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
// CHECK-NEXT: %[[SIZE_TENSOR:.*]] = tensor.from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_1]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[BROADCASTED_RESULT:.*]] = chlo.broadcast_add %[[ARG_0]], %[[RESHAPED]] : (tensor<f32>, tensor<?xf32>) -> tensor<?xf32>
// As part of the unranked logic, the result is reshaped back
@ -135,7 +135,7 @@ func @addUnrankedScalar(%arg0: tensor<*xf32>, %arg1: tensor<f32>) -> tensor<*xf3
// to a 1D tensor.
// CHECK-NEXT: %[[SHAPE_0:.*]] = shape.shape_of %[[ARG_0]] : tensor<*xf32>
// CHECK-NEXT: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_0]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[SIZE_TENSOR:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
// CHECK-NEXT: %[[SIZE_TENSOR:.*]] = tensor.from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_0]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// The assuming region is part of the second stage of lowering
// with ranked broadcasting logic.
@ -166,7 +166,7 @@ func @addUnrankedUnranked(
// CHECK-NEXT: %[[SCALAR_LHS:.*]] = tensor.cast %[[LHS]] : tensor<*xf32> to tensor<f32>
// CHECK-NEXT: %[[RHS_SHAPE_1:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
// CHECK-NEXT: %[[NUM_RHS:.*]] = shape.num_elements %[[RHS_SHAPE_1]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[NUM_TENS_RHS:.*]] = tensor_from_elements %[[NUM_RHS]] : tensor<1xindex>
// CHECK-NEXT: %[[NUM_TENS_RHS:.*]] = tensor.from_elements %[[NUM_RHS]] : tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[NUM_TENS_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[LHS_SCALAR_RESULT:.*]] = chlo.broadcast_add %[[SCALAR_LHS]], %[[RESHAPED_RHS]] : (tensor<f32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK-NEXT: %[[RESHAPED_LHS_SCALAR_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[LHS_SCALAR_RESULT]], %[[RHS_SHAPE_1]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
@ -179,7 +179,7 @@ func @addUnrankedUnranked(
// CHECK-NEXT: %[[VAL_14:.*]] = scf.if %[[RHS_IS_SCALAR]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[SCALAR_RHS:.*]] = tensor.cast %[[RHS]] : tensor<*xf32> to tensor<f32>
// CHECK-NEXT: %[[NUM_LHS:.*]] = shape.num_elements %[[LHS_SHAPE]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[NUM_TENS_LHS:.*]] = tensor_from_elements %[[NUM_LHS]] : tensor<1xindex>
// CHECK-NEXT: %[[NUM_TENS_LHS:.*]] = tensor.from_elements %[[NUM_LHS]] : tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[NUM_TENS_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[RHS_SCALAR_RESULT:.*]] = chlo.broadcast_add %[[RESHAPED_LHS]], %[[SCALAR_RHS]] : (tensor<?xf32>, tensor<f32>) -> tensor<?xf32>
// CHECK-NEXT: %[[RESHAPED_RHS_SCALAR_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RHS_SCALAR_RESULT:.*]], %[[LHS_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
@ -190,7 +190,7 @@ func @addUnrankedUnranked(
// CHECK-NEXT: %[[VAL_18:.*]] = scf.if %[[SHAPES_EQ]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[ANY_SHAPE:.*]] = shape.any %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[ANY_NUM:.*]] = shape.num_elements %[[ANY_SHAPE]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[ANY_TENSOR:.*]] = tensor_from_elements %[[ANY_NUM]] : tensor<1xindex>
// CHECK-NEXT: %[[ANY_TENSOR:.*]] = tensor.from_elements %[[ANY_NUM]] : tensor<1xindex>
// CHECK-NEXT: %[[FLATTENED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[ANY_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[FLATTENED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[ANY_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[FLATTENED_RESULT:.*]] = mhlo.add %[[FLATTENED_LHS]], %[[FLATTENED_RHS]] : tensor<?xf32>

View File

@ -9,7 +9,7 @@ func @select(%pred : tensor<2x?xi1>, %a : tensor<2x?xf32>, %b : tensor<2x?xf32>)
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[DIM_AS_INDEX:.*]] = dim %[[PRED]], %[[C1]] : tensor<2x?xi1>
// CHECK: %[[DIM:.*]] = index_cast %[[DIM_AS_INDEX]] : index to i64
// CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[C2]], %[[DIM]] : tensor<2xi64>
// CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[C2]], %[[DIM]] : tensor<2xi64>
// CHECK: return %[[SHAPE]] : tensor<2xi64>
%0 = "mhlo.select"(%pred, %a, %b)
: (tensor<2x?xi1>, tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
@ -26,7 +26,7 @@ func @compare(%a : tensor<2x?xf32>, %b : tensor<2x?xf32>) -> tensor<2xi64> {
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[DIM_AS_INDEX:.*]] = dim %[[A]], %[[C1]] : tensor<2x?xf32>
// CHECK: %[[DIM:.*]] = index_cast %[[DIM_AS_INDEX]] : index to i64
// CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[C2]], %[[DIM]] : tensor<2xi64>
// CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[C2]], %[[DIM]] : tensor<2xi64>
// CHECK: return %[[SHAPE]] : tensor<2xi64>
%0 = "mhlo.compare"(%a, %b) {comparison_direction = "NE"}
: (tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xi1>

View File

@ -109,7 +109,7 @@ func @batchNormInference_dynamic_shape(
// CHECK-DAG: %[[C3:.*]] = constant 3 : index
// CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e-03> : tensor<f32>
// CHECK-DAG: %[[DIM:.+]] = dim %[[VARIANCE]], %[[C0]] : tensor<?xf32>
// CHECK-DAG: %[[TO_DIM_TENSOR:.+]] = tensor_from_elements %[[DIM]] : tensor<1xindex>
// CHECK-DAG: %[[TO_DIM_TENSOR:.+]] = tensor.from_elements %[[DIM]] : tensor<1xindex>
// CHECK-DAG: %[[EPS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[TO_DIM_TENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<?xf32>
// CHECK-DAG: %[[STDDEV:.+]] = "mhlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<?xf32>) -> tensor<?xf32>
@ -117,7 +117,7 @@ func @batchNormInference_dynamic_shape(
// CHECK-DAG: %[[INPUT_DIM_1:.+]] = dim %[[X]], %[[C1]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[INPUT_DIM_2:.+]] = dim %[[X]], %[[C2]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[INPUT_DIM_3:.+]] = dim %[[X]], %[[C3]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[TO_INPUT_DIM_TENSOR:.+]] = tensor_from_elements %[[INPUT_DIM_0]], %[[INPUT_DIM_1]], %[[INPUT_DIM_2]], %[[INPUT_DIM_3]] : tensor<4xindex>
// CHECK-DAG: %[[TO_INPUT_DIM_TENSOR:.+]] = tensor.from_elements %[[INPUT_DIM_0]], %[[INPUT_DIM_1]], %[[INPUT_DIM_2]], %[[INPUT_DIM_3]] : tensor<4xindex>
// CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[STDDEV]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK-DAG: %[[OFFSET_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[OFFSET]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>