PR #50236: [MLIR][DISC] Bufferize TransposeOp and ConcatenateOp
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/50236 support hlo-to-lhlo conversion for TransposeOp and ConcatenateOp Copybara import of the project: -- 62860e717f2a14fbd3ddfb634aa6ff132d245a72 by Wenyi Zhao <reyizero@gmail.com>: [MLIR][DISC] Bufferize TransposeOp and ConcatenateOp -- ce2ff57c1edee1172cd2f36346cc0b34ec1c7467 by Wenyi Zhao <reyizero@gmail.com>: fix PiperOrigin-RevId: 379330954
This commit is contained in:
parent
23ebbb28d1
commit
7f94bd923b
|
@ -1382,7 +1382,7 @@ def HLO_ClampOp : HLO_Op<"clamp",
|
||||||
let results = (outs HLO_Tensor);
|
let results = (outs HLO_Tensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLO_ConcatenateOp : HLO_Op<"concatenate",
|
def HLO_ConcatenateOp : HLO_ShapedInterfaceOp<"concatenate",
|
||||||
[NoSideEffect, SameOperandsAndResultElementType,
|
[NoSideEffect, SameOperandsAndResultElementType,
|
||||||
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||||
let summary = "XLA's concatenate op";
|
let summary = "XLA's concatenate op";
|
||||||
|
@ -1901,7 +1901,7 @@ def HLO_TraceOp: HLO_Op<"trace", []> {
|
||||||
let hasCustomHLOConverter = 1;
|
let hasCustomHLOConverter = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLO_TransposeOp: HLO_Op<"transpose",
|
def HLO_TransposeOp: HLO_ShapedInterfaceOp<"transpose",
|
||||||
[NoSideEffect, SameOperandsAndResultElementType]> {
|
[NoSideEffect, SameOperandsAndResultElementType]> {
|
||||||
let summary = "Transpose operator";
|
let summary = "Transpose operator";
|
||||||
let description = [{
|
let description = [{
|
||||||
|
|
|
@ -47,6 +47,7 @@ MAP_HLO_TO_LHLO(ClampOp);
|
||||||
MAP_HLO_TO_LHLO(ConstOp);
|
MAP_HLO_TO_LHLO(ConstOp);
|
||||||
MAP_HLO_TO_LHLO(CompareOp);
|
MAP_HLO_TO_LHLO(CompareOp);
|
||||||
MAP_HLO_TO_LHLO(ComplexOp);
|
MAP_HLO_TO_LHLO(ComplexOp);
|
||||||
|
MAP_HLO_TO_LHLO(ConcatenateOp);
|
||||||
MAP_HLO_TO_LHLO(ConvOp);
|
MAP_HLO_TO_LHLO(ConvOp);
|
||||||
MAP_HLO_TO_LHLO(ConvertOp);
|
MAP_HLO_TO_LHLO(ConvertOp);
|
||||||
MAP_HLO_TO_LHLO(CopyOp);
|
MAP_HLO_TO_LHLO(CopyOp);
|
||||||
|
|
|
@ -1417,6 +1417,57 @@ static LogicalResult Verify(ConcatenateOp op) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult ConcatenateOp::reifyReturnTypeShapes(
|
||||||
|
OpBuilder& builder, ValueRange operands,
|
||||||
|
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||||
|
ConcatenateOp::Adaptor adaptor(operands);
|
||||||
|
auto inputs = adaptor.val();
|
||||||
|
|
||||||
|
auto operand_type = inputs[0].getType().dyn_cast<RankedTensorType>();
|
||||||
|
// Not support unranked type a.t.m.
|
||||||
|
if (!operand_type) return failure();
|
||||||
|
|
||||||
|
Location loc = this->getLoc();
|
||||||
|
Type shape_scalar_type = builder.getIndexType();
|
||||||
|
auto to_shape_scalar_type = [&](Value v) {
|
||||||
|
return MaybeCastTo(builder, loc, v, shape_scalar_type);
|
||||||
|
};
|
||||||
|
|
||||||
|
SmallVector<SmallVector<Value, 4>, 4> all_shape_values;
|
||||||
|
for (size_t input_id = 0; input_id < inputs.size(); ++input_id) {
|
||||||
|
Value operand = inputs[input_id];
|
||||||
|
auto operand_type = operand.getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!operand_type) return failure();
|
||||||
|
|
||||||
|
SmallVector<Value, 4> shape_vals;
|
||||||
|
for (const auto& element : llvm::enumerate(operand_type.getShape())) {
|
||||||
|
Value value_dim = to_shape_scalar_type(
|
||||||
|
builder.create<memref::DimOp>(loc, operand, element.index()));
|
||||||
|
shape_vals.push_back(value_dim);
|
||||||
|
}
|
||||||
|
all_shape_values.emplace_back(std::move(shape_vals));
|
||||||
|
}
|
||||||
|
|
||||||
|
int axis = this->dimension();
|
||||||
|
auto& shape_values = all_shape_values[0];
|
||||||
|
for (size_t vec_id = 1; vec_id < all_shape_values.size(); ++vec_id) {
|
||||||
|
auto& other_shape_values = all_shape_values[vec_id];
|
||||||
|
if (other_shape_values.size() != shape_values.size()) {
|
||||||
|
this->emitOpError()
|
||||||
|
<< "Concatenate expects all operands must be of the same rank";
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
shape_values[axis] = builder.create<AddIOp>(loc, shape_values[axis],
|
||||||
|
other_shape_values[axis]);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value output_shape = builder.create<tensor::FromElementsOp>(
|
||||||
|
loc, shape_scalar_type, shape_values);
|
||||||
|
reifiedReturnShapes.push_back(output_shape);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// DynamicReshapeOp
|
// DynamicReshapeOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -3363,6 +3414,40 @@ static LogicalResult Verify(TransposeOp op) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult TransposeOp::reifyReturnTypeShapes(
|
||||||
|
OpBuilder& builder, ValueRange operands,
|
||||||
|
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||||
|
TransposeOp::Adaptor adaptor(operands);
|
||||||
|
Value operand = adaptor.operand();
|
||||||
|
|
||||||
|
auto operand_type = operand.getType().dyn_cast<RankedTensorType>();
|
||||||
|
// Not support unranked type a.t.m.
|
||||||
|
if (!operand_type) return failure();
|
||||||
|
|
||||||
|
Location loc = this->getLoc();
|
||||||
|
SmallVector<int64_t, 4> permutation(this->permutation().getValues<int64_t>());
|
||||||
|
SmallVector<Value, 4> shape_values(permutation.size());
|
||||||
|
|
||||||
|
Type shape_scalar_type = builder.getIndexType();
|
||||||
|
auto to_shape_scalar_type = [&](Value v) {
|
||||||
|
return MaybeCastTo(builder, loc, v, shape_scalar_type);
|
||||||
|
};
|
||||||
|
|
||||||
|
for (const auto& element : llvm::enumerate(operand_type.getShape())) {
|
||||||
|
int64_t idx = element.index();
|
||||||
|
auto it = std::find(permutation.begin(), permutation.end(), idx);
|
||||||
|
Value value_dim = to_shape_scalar_type(
|
||||||
|
builder.create<memref::DimOp>(loc, operand, element.index()));
|
||||||
|
shape_values[std::distance(permutation.begin(), it)] = value_dim;
|
||||||
|
}
|
||||||
|
|
||||||
|
Value output_shape = builder.create<tensor::FromElementsOp>(
|
||||||
|
loc, shape_scalar_type, shape_values);
|
||||||
|
reifiedReturnShapes.push_back(output_shape);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// TriangularSolveOp
|
// TriangularSolveOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -696,6 +696,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
|
||||||
HloToLhloOpConverter<mhlo::CeilOp>,
|
HloToLhloOpConverter<mhlo::CeilOp>,
|
||||||
HloToLhloOpConverter<mhlo::CompareOp>,
|
HloToLhloOpConverter<mhlo::CompareOp>,
|
||||||
HloToLhloOpConverter<mhlo::ComplexOp>,
|
HloToLhloOpConverter<mhlo::ComplexOp>,
|
||||||
|
HloToLhloOpConverter<mhlo::ConcatenateOp>,
|
||||||
HloToLhloOpConverter<mhlo::ConstOp>,
|
HloToLhloOpConverter<mhlo::ConstOp>,
|
||||||
HloToLhloOpConverter<mhlo::ConvOp>,
|
HloToLhloOpConverter<mhlo::ConvOp>,
|
||||||
HloToLhloOpConverter<mhlo::ConvertOp>,
|
HloToLhloOpConverter<mhlo::ConvertOp>,
|
||||||
|
|
|
@ -150,3 +150,37 @@ func @column_reduce(%arg0: tensor<?x?xf32>, %arg1: tensor<f32>) -> tensor<?xf32>
|
||||||
: (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32>
|
: (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32>
|
||||||
return %0: tensor<?xf32>
|
return %0: tensor<?xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @transpose
|
||||||
|
// CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32>) -> memref<?x?xf32>
|
||||||
|
func @transpose(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||||
|
// CHECK-NOT: tensor_load
|
||||||
|
// CHECK: %[[DIM0:.*]] = memref.dim %[[ARG]], %c0 : memref<?x?xf32>
|
||||||
|
// CHECK: %[[DIM1:.*]] = memref.dim %[[ARG]], %c1 : memref<?x?xf32>
|
||||||
|
// CHECK: %[[OUT:.*]] = memref.alloc(%[[DIM1]], %[[DIM0]]) : memref<?x?xf32>
|
||||||
|
// CHECK: "lmhlo.transpose"(%[[ARG]], %[[OUT]])
|
||||||
|
%0 = "mhlo.transpose"(%arg0) {permutation = dense<[1,0]> : tensor<2xi64>} : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
|
return %0: tensor<?x?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @concatenate
|
||||||
|
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xi32>, %[[ARG1:.*]]: memref<?x?xi32>, %[[ARG2:.*]]: memref<?x?xi32>) -> memref<?x?xi32>
|
||||||
|
func @concatenate(%a: tensor<?x?xi32>, %b: tensor<?x?xi32>, %c: tensor<?x?xi32>) -> tensor<?x?xi32> {
|
||||||
|
// CHECK-NOT: tensor_load
|
||||||
|
// CHECK: %[[ARG0_DIM0:.*]] = memref.dim %[[ARG0]], %c0 : memref<?x?xi32>
|
||||||
|
// CHECK: %[[ARG0_DIM1:.*]] = memref.dim %[[ARG0]], %c1 : memref<?x?xi32>
|
||||||
|
// CHECK: %[[ARG1_DIM1:.*]] = memref.dim %[[ARG1]], %c1 : memref<?x?xi32>
|
||||||
|
// CHECK: %[[ARG2_DIM1:.*]] = memref.dim %[[ARG2]], %c1 : memref<?x?xi32>
|
||||||
|
// CHECK: %[[TMP:.*]] = addi %[[ARG0_DIM1]], %[[ARG1_DIM1]] : index
|
||||||
|
// CHECK: %[[OUT_DIM1:.*]] = addi %[[TMP]], %[[ARG2_DIM1]] : index
|
||||||
|
// CHECK: %[[OUT:.*]] = memref.alloc(%[[ARG0_DIM0]], %[[OUT_DIM1]]) : memref<?x?xi32>
|
||||||
|
// CHECK: "lmhlo.concatenate"(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[OUT]])
|
||||||
|
%concat = "mhlo.concatenate"(%a, %b, %c) {
|
||||||
|
dimension = 1
|
||||||
|
} : (tensor<?x?xi32>, tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
|
||||||
|
return %concat : tensor<?x?xi32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue