PR #50236: [MLIR][DISC] Bufferize TransposeOp and ConcatenateOp
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/50236 support hlo-to-lhlo conversion for TransposeOp and ConcatenateOp Copybara import of the project: -- 62860e717f2a14fbd3ddfb634aa6ff132d245a72 by Wenyi Zhao <reyizero@gmail.com>: [MLIR][DISC] Bufferize TransposeOp and ConcatenateOp -- ce2ff57c1edee1172cd2f36346cc0b34ec1c7467 by Wenyi Zhao <reyizero@gmail.com>: fix PiperOrigin-RevId: 379330954
This commit is contained in:
parent
23ebbb28d1
commit
7f94bd923b
|
@ -1382,7 +1382,7 @@ def HLO_ClampOp : HLO_Op<"clamp",
|
|||
let results = (outs HLO_Tensor);
|
||||
}
|
||||
|
||||
def HLO_ConcatenateOp : HLO_Op<"concatenate",
|
||||
def HLO_ConcatenateOp : HLO_ShapedInterfaceOp<"concatenate",
|
||||
[NoSideEffect, SameOperandsAndResultElementType,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
let summary = "XLA's concatenate op";
|
||||
|
@ -1901,7 +1901,7 @@ def HLO_TraceOp: HLO_Op<"trace", []> {
|
|||
let hasCustomHLOConverter = 1;
|
||||
}
|
||||
|
||||
def HLO_TransposeOp: HLO_Op<"transpose",
|
||||
def HLO_TransposeOp: HLO_ShapedInterfaceOp<"transpose",
|
||||
[NoSideEffect, SameOperandsAndResultElementType]> {
|
||||
let summary = "Transpose operator";
|
||||
let description = [{
|
||||
|
|
|
@ -47,6 +47,7 @@ MAP_HLO_TO_LHLO(ClampOp);
|
|||
MAP_HLO_TO_LHLO(ConstOp);
|
||||
MAP_HLO_TO_LHLO(CompareOp);
|
||||
MAP_HLO_TO_LHLO(ComplexOp);
|
||||
MAP_HLO_TO_LHLO(ConcatenateOp);
|
||||
MAP_HLO_TO_LHLO(ConvOp);
|
||||
MAP_HLO_TO_LHLO(ConvertOp);
|
||||
MAP_HLO_TO_LHLO(CopyOp);
|
||||
|
|
|
@ -1417,6 +1417,57 @@ static LogicalResult Verify(ConcatenateOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ConcatenateOp::reifyReturnTypeShapes(
|
||||
OpBuilder& builder, ValueRange operands,
|
||||
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||
ConcatenateOp::Adaptor adaptor(operands);
|
||||
auto inputs = adaptor.val();
|
||||
|
||||
auto operand_type = inputs[0].getType().dyn_cast<RankedTensorType>();
|
||||
// Not support unranked type a.t.m.
|
||||
if (!operand_type) return failure();
|
||||
|
||||
Location loc = this->getLoc();
|
||||
Type shape_scalar_type = builder.getIndexType();
|
||||
auto to_shape_scalar_type = [&](Value v) {
|
||||
return MaybeCastTo(builder, loc, v, shape_scalar_type);
|
||||
};
|
||||
|
||||
SmallVector<SmallVector<Value, 4>, 4> all_shape_values;
|
||||
for (size_t input_id = 0; input_id < inputs.size(); ++input_id) {
|
||||
Value operand = inputs[input_id];
|
||||
auto operand_type = operand.getType().dyn_cast<RankedTensorType>();
|
||||
if (!operand_type) return failure();
|
||||
|
||||
SmallVector<Value, 4> shape_vals;
|
||||
for (const auto& element : llvm::enumerate(operand_type.getShape())) {
|
||||
Value value_dim = to_shape_scalar_type(
|
||||
builder.create<memref::DimOp>(loc, operand, element.index()));
|
||||
shape_vals.push_back(value_dim);
|
||||
}
|
||||
all_shape_values.emplace_back(std::move(shape_vals));
|
||||
}
|
||||
|
||||
int axis = this->dimension();
|
||||
auto& shape_values = all_shape_values[0];
|
||||
for (size_t vec_id = 1; vec_id < all_shape_values.size(); ++vec_id) {
|
||||
auto& other_shape_values = all_shape_values[vec_id];
|
||||
if (other_shape_values.size() != shape_values.size()) {
|
||||
this->emitOpError()
|
||||
<< "Concatenate expects all operands must be of the same rank";
|
||||
return failure();
|
||||
}
|
||||
shape_values[axis] = builder.create<AddIOp>(loc, shape_values[axis],
|
||||
other_shape_values[axis]);
|
||||
}
|
||||
|
||||
Value output_shape = builder.create<tensor::FromElementsOp>(
|
||||
loc, shape_scalar_type, shape_values);
|
||||
reifiedReturnShapes.push_back(output_shape);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DynamicReshapeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -3363,6 +3414,40 @@ static LogicalResult Verify(TransposeOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult TransposeOp::reifyReturnTypeShapes(
|
||||
OpBuilder& builder, ValueRange operands,
|
||||
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||
TransposeOp::Adaptor adaptor(operands);
|
||||
Value operand = adaptor.operand();
|
||||
|
||||
auto operand_type = operand.getType().dyn_cast<RankedTensorType>();
|
||||
// Not support unranked type a.t.m.
|
||||
if (!operand_type) return failure();
|
||||
|
||||
Location loc = this->getLoc();
|
||||
SmallVector<int64_t, 4> permutation(this->permutation().getValues<int64_t>());
|
||||
SmallVector<Value, 4> shape_values(permutation.size());
|
||||
|
||||
Type shape_scalar_type = builder.getIndexType();
|
||||
auto to_shape_scalar_type = [&](Value v) {
|
||||
return MaybeCastTo(builder, loc, v, shape_scalar_type);
|
||||
};
|
||||
|
||||
for (const auto& element : llvm::enumerate(operand_type.getShape())) {
|
||||
int64_t idx = element.index();
|
||||
auto it = std::find(permutation.begin(), permutation.end(), idx);
|
||||
Value value_dim = to_shape_scalar_type(
|
||||
builder.create<memref::DimOp>(loc, operand, element.index()));
|
||||
shape_values[std::distance(permutation.begin(), it)] = value_dim;
|
||||
}
|
||||
|
||||
Value output_shape = builder.create<tensor::FromElementsOp>(
|
||||
loc, shape_scalar_type, shape_values);
|
||||
reifiedReturnShapes.push_back(output_shape);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TriangularSolveOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -696,6 +696,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
|
|||
HloToLhloOpConverter<mhlo::CeilOp>,
|
||||
HloToLhloOpConverter<mhlo::CompareOp>,
|
||||
HloToLhloOpConverter<mhlo::ComplexOp>,
|
||||
HloToLhloOpConverter<mhlo::ConcatenateOp>,
|
||||
HloToLhloOpConverter<mhlo::ConstOp>,
|
||||
HloToLhloOpConverter<mhlo::ConvOp>,
|
||||
HloToLhloOpConverter<mhlo::ConvertOp>,
|
||||
|
|
|
@ -150,3 +150,37 @@ func @column_reduce(%arg0: tensor<?x?xf32>, %arg1: tensor<f32>) -> tensor<?xf32>
|
|||
: (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32>
|
||||
return %0: tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @transpose
|
||||
// CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32>) -> memref<?x?xf32>
|
||||
func @transpose(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
// CHECK-NOT: tensor_load
|
||||
// CHECK: %[[DIM0:.*]] = memref.dim %[[ARG]], %c0 : memref<?x?xf32>
|
||||
// CHECK: %[[DIM1:.*]] = memref.dim %[[ARG]], %c1 : memref<?x?xf32>
|
||||
// CHECK: %[[OUT:.*]] = memref.alloc(%[[DIM1]], %[[DIM0]]) : memref<?x?xf32>
|
||||
// CHECK: "lmhlo.transpose"(%[[ARG]], %[[OUT]])
|
||||
%0 = "mhlo.transpose"(%arg0) {permutation = dense<[1,0]> : tensor<2xi64>} : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
return %0: tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @concatenate
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xi32>, %[[ARG1:.*]]: memref<?x?xi32>, %[[ARG2:.*]]: memref<?x?xi32>) -> memref<?x?xi32>
|
||||
func @concatenate(%a: tensor<?x?xi32>, %b: tensor<?x?xi32>, %c: tensor<?x?xi32>) -> tensor<?x?xi32> {
|
||||
// CHECK-NOT: tensor_load
|
||||
// CHECK: %[[ARG0_DIM0:.*]] = memref.dim %[[ARG0]], %c0 : memref<?x?xi32>
|
||||
// CHECK: %[[ARG0_DIM1:.*]] = memref.dim %[[ARG0]], %c1 : memref<?x?xi32>
|
||||
// CHECK: %[[ARG1_DIM1:.*]] = memref.dim %[[ARG1]], %c1 : memref<?x?xi32>
|
||||
// CHECK: %[[ARG2_DIM1:.*]] = memref.dim %[[ARG2]], %c1 : memref<?x?xi32>
|
||||
// CHECK: %[[TMP:.*]] = addi %[[ARG0_DIM1]], %[[ARG1_DIM1]] : index
|
||||
// CHECK: %[[OUT_DIM1:.*]] = addi %[[TMP]], %[[ARG2_DIM1]] : index
|
||||
// CHECK: %[[OUT:.*]] = memref.alloc(%[[ARG0_DIM0]], %[[OUT_DIM1]]) : memref<?x?xi32>
|
||||
// CHECK: "lmhlo.concatenate"(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[OUT]])
|
||||
%concat = "mhlo.concatenate"(%a, %b, %c) {
|
||||
dimension = 1
|
||||
} : (tensor<?x?xi32>, tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
|
||||
return %concat : tensor<?x?xi32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue