PR #50211: [MLIR][DISC] Bufferize RealDynamicSliceOp and ReduceOp
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/50211 support hlo-to-lhlo conversion for RealDynamicSliceOp and ReduceOp Copybara import of the project: -- c417b336670a1fc256f7026dfe8080e46d13d79a by Wenyi Zhao <reyizero@gmail.com>: [MLIR][DISC] Bufferize RealDynamicSliceOp and ReduceOp PiperOrigin-RevId: 378972113
This commit is contained in:
parent
95ba03534f
commit
8388303fd2
|
@ -967,7 +967,7 @@ def HLO_AllToAllOp : HLO_Op<"all_to_all",
|
|||
let results = (outs HLO_Tensor);
|
||||
}
|
||||
|
||||
def HLO_ReduceOp: HLO_Op<"reduce", [
|
||||
def HLO_ReduceOp: HLO_ShapedInterfaceOp<"reduce", [
|
||||
RecursiveSideEffects,
|
||||
SameVariadicOperandSize,
|
||||
SingleBlockImplicitTerminator<"ReturnOp">,
|
||||
|
@ -2199,7 +2199,7 @@ def HLO_ReducePrecisionOp :
|
|||
let results = (outs HLO_FpTensor:$output);
|
||||
}
|
||||
|
||||
def HLO_RealDynamicSliceOp: HLO_Op<
|
||||
def HLO_RealDynamicSliceOp: HLO_ShapedInterfaceOp<
|
||||
"real_dynamic_slice",
|
||||
[NoSideEffect, AllElementTypesMatch<["operand", "result"]>,
|
||||
AllTypesMatch<["start_indices", "limit_indices", "strides"]>]> {
|
||||
|
|
|
@ -75,6 +75,7 @@ MAP_HLO_TO_LHLO(NegOp);
|
|||
MAP_HLO_TO_LHLO(NotOp);
|
||||
MAP_HLO_TO_LHLO(OrOp);
|
||||
MAP_HLO_TO_LHLO(PowOp);
|
||||
MAP_HLO_TO_LHLO(RealDynamicSliceOp);
|
||||
MAP_HLO_TO_LHLO(RealOp);
|
||||
MAP_HLO_TO_LHLO(ReduceOp);
|
||||
MAP_HLO_TO_LHLO(ReshapeOp);
|
||||
|
|
|
@ -1626,6 +1626,52 @@ static LogicalResult Verify(RealDynamicSliceOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult RealDynamicSliceOp::reifyReturnTypeShapes(
|
||||
OpBuilder& builder, ValueRange operands,
|
||||
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||
RealDynamicSliceOp::Adaptor adaptor(operands);
|
||||
Value operand = adaptor.operand();
|
||||
Value start_indices = adaptor.start_indices();
|
||||
Value limit_indices = adaptor.limit_indices();
|
||||
Value strides = adaptor.strides();
|
||||
|
||||
auto operand_type = operand.getType().dyn_cast<RankedTensorType>();
|
||||
// Not support unranked type a.t.m.
|
||||
if (!operand_type) return failure();
|
||||
|
||||
Location loc = this->getLoc();
|
||||
SmallVector<Value, 4> shape_values;
|
||||
shape_values.reserve(operand_type.getRank());
|
||||
Type shape_scalar_type =
|
||||
start_indices.getType().cast<ShapedType>().getElementType();
|
||||
Value one = builder.create<ConstantIndexOp>(loc, 1);
|
||||
one = MaybeCastTo(builder, loc, one, shape_scalar_type);
|
||||
for (const auto& element : llvm::enumerate(operand_type.getShape())) {
|
||||
Value offset = builder.create<ConstantIndexOp>(loc, element.index());
|
||||
Value value_start =
|
||||
builder.create<tensor::ExtractOp>(loc, start_indices, offset);
|
||||
Value value_limit =
|
||||
builder.create<tensor::ExtractOp>(loc, limit_indices, offset);
|
||||
Value value_stride =
|
||||
builder.create<tensor::ExtractOp>(loc, strides, offset);
|
||||
// size = (limit - start + stride - 1) / stride
|
||||
shape_values.push_back(builder.create<SignedDivIOp>(
|
||||
loc,
|
||||
builder.create<SubIOp>(
|
||||
loc,
|
||||
builder.create<AddIOp>(
|
||||
loc, value_stride,
|
||||
builder.create<SubIOp>(loc, value_limit, value_start)),
|
||||
one),
|
||||
value_stride));
|
||||
}
|
||||
|
||||
reifiedReturnShapes.push_back(builder.create<tensor::FromElementsOp>(
|
||||
loc, shape_scalar_type, shape_values));
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// InfeedOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -2072,6 +2118,46 @@ void ReduceOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
|
|||
MLIRContext* context) {
|
||||
results.insert<LowerBoolSplatConstantsIntoRegion>(context);
|
||||
}
|
||||
|
||||
LogicalResult ReduceOp::reifyReturnTypeShapes(
|
||||
OpBuilder& builder, ValueRange operands,
|
||||
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||
ReduceOp::Adaptor adaptor(operands);
|
||||
auto inputs = adaptor.inputs();
|
||||
|
||||
auto operand_type = inputs[0].getType().dyn_cast<RankedTensorType>();
|
||||
// Not support unranked type a.t.m.
|
||||
if (!operand_type) return failure();
|
||||
|
||||
Location loc = this->getLoc();
|
||||
SmallVector<Value, 4> shape_values;
|
||||
SmallVector<int64_t, 4> dimensions(this->dimensions().getValues<int64_t>());
|
||||
shape_values.reserve(operand_type.getRank());
|
||||
Type shape_scalar_type = builder.getIndexType();
|
||||
auto to_shape_scalar_type = [&](Value v) {
|
||||
return MaybeCastTo(builder, loc, v, shape_scalar_type);
|
||||
};
|
||||
|
||||
for (const auto& element : llvm::enumerate(operand_type.getShape())) {
|
||||
int64_t idx = element.index();
|
||||
auto it = std::find(dimensions.begin(), dimensions.end(), idx);
|
||||
if (it != dimensions.end()) {
|
||||
continue;
|
||||
}
|
||||
Value value_dim = to_shape_scalar_type(
|
||||
builder.create<memref::DimOp>(loc, inputs[0], element.index()));
|
||||
shape_values.push_back(value_dim);
|
||||
}
|
||||
|
||||
Value output_shape = builder.create<tensor::FromElementsOp>(
|
||||
loc, shape_scalar_type, shape_values);
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
reifiedReturnShapes.push_back(output_shape);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RngNormalOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -434,11 +434,8 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
|
|||
<< "tensor to buffer conversion expects a single block "
|
||||
"in the region containing the operation";
|
||||
}
|
||||
const auto& original_results = op.getResults();
|
||||
SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
|
||||
for (auto result : original_results) {
|
||||
buffer_args.push_back(InsertAlloc(loc, result, &rewriter));
|
||||
}
|
||||
if (failed(ConvertResults(op, buffer_args, rewriter))) return failure();
|
||||
auto new_op = rewriter.create<lmhlo::ReduceOp>(loc, llvm::None, buffer_args,
|
||||
op->getAttrs());
|
||||
|
||||
|
@ -671,7 +668,8 @@ void populateDynamicHLOToLHLOOnlyConversionPattern(
|
|||
patterns->insert<HloToLhloOpConverter<mhlo::DynamicBroadcastInDimOp>,
|
||||
HloToLhloOpConverter<mhlo::DynamicIotaOp>,
|
||||
HloToLhloOpConverter<mhlo::DynamicPadOp>,
|
||||
HloToLhloOpConverter<mhlo::DynamicReshapeOp>
|
||||
HloToLhloOpConverter<mhlo::DynamicReshapeOp>,
|
||||
HloToLhloOpConverter<mhlo::RealDynamicSliceOp>
|
||||
>(*converter, context);
|
||||
// clang-format on
|
||||
}
|
||||
|
|
|
@ -81,3 +81,72 @@ func @dynamic_pad(%arg0: tensor<?x?xf32>, %arg1: tensor<f32>, %arg2: tensor<2xin
|
|||
%0 = "mhlo.dynamic_pad"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<?x?xf32>, tensor<f32>, tensor<2xindex>, tensor<2xindex>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
return %0: tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @real_dynamic_slice
|
||||
// CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32>,
|
||||
// CHECK-SAME: %[[START:.*]]: memref<2xi32>, %[[LIMIT:.*]]: memref<2xi32>, %[[STRIDE:.*]]: memref<2xi32>) -> memref<?x?xf32>
|
||||
func @real_dynamic_slice(%arg0: tensor<?x?xf32>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>, %arg3: tensor<2xi32>) -> tensor<?x?xf32> {
|
||||
// CHECK-NOT: tensor_load
|
||||
// CHECK: %[[T0:.*]] = memref.load %[[START]][%c0] : memref<2xi32>
|
||||
// CHECK: %[[T1:.*]] = memref.load %[[LIMIT]][%c0] : memref<2xi32>
|
||||
// CHECK: %[[T2:.*]] = memref.load %[[STRIDE]][%c0] : memref<2xi32>
|
||||
// CHECK: %[[T3:.*]] = subi %[[T1]], %[[T0]] : i32
|
||||
// CHECK: %[[T4:.*]] = addi %[[T2]], %[[T3]] : i32
|
||||
// CHECK: %[[T5:.*]] = subi %[[T4]], %c1_i32 : i32
|
||||
// CHECK: %[[T6:.*]] = divi_signed %[[T5]], %[[T2]] : i32
|
||||
// CHECK: %[[T7:.*]] = memref.load %[[START]][%c1] : memref<2xi32>
|
||||
// CHECK: %[[T8:.*]] = memref.load %[[LIMIT]][%c1] : memref<2xi32>
|
||||
// CHECK: %[[T9:.*]] = memref.load %[[STRIDE]][%c1] : memref<2xi32>
|
||||
// CHECK: %[[T10:.*]] = subi %[[T8]], %[[T7]] : i32
|
||||
// CHECK: %[[T11:.*]] = addi %[[T9]], %[[T10]] : i32
|
||||
// CHECK: %[[T12:.*]] = subi %[[T11]], %c1_i32 : i32
|
||||
// CHECK: %[[T13:.*]] = divi_signed %[[T12]], %[[T9]] : i32
|
||||
// CHECK: %[[T14:.*]] = index_cast %[[T6]] : i32 to index
|
||||
// CHECK: %[[T15:.*]] = index_cast %[[T13]] : i32 to index
|
||||
// CHECK: %[[T16:.*]] = memref.alloc(%[[T14]], %[[T15]]) : memref<?x?xf32>
|
||||
// CHECK: "lmhlo.real_dynamic_slice"(%[[ARG]], %[[START]], %[[LIMIT]], %[[STRIDE]], %[[T16]])
|
||||
%0 = "mhlo.real_dynamic_slice"(%arg0, %arg1, %arg2, %arg3) : (tensor<?x?xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<?x?xf32>
|
||||
return %0: tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @row_reduce
|
||||
// CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32>, %[[VAL:.*]]: memref<f32>) -> memref<?xf32>
|
||||
func @row_reduce(%arg0: tensor<?x?xf32>, %arg1: tensor<f32>) -> tensor<?xf32> {
|
||||
// CHECK-NOT: tensor_load
|
||||
// CHECK: %[[DIM0:.*]] = memref.dim %[[ARG]], %c0 : memref<?x?xf32>
|
||||
// CHECK: %[[OUT:.*]] = memref.alloc(%[[DIM0]]) : memref<?xf32>
|
||||
// CHECK: lmhlo.reduce
|
||||
// CHECK-SAME: %[[ARG]], %[[VAL]], %[[OUT]]
|
||||
// CHECK: return %[[OUT]] : memref<?xf32>
|
||||
%0 = "mhlo.reduce"(%arg0, %arg1) ( {
|
||||
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>): // no predecessors
|
||||
%1 = mhlo.add %arg2, %arg3 : tensor<f32>
|
||||
"mhlo.return"(%1) : (tensor<f32>) -> ()
|
||||
}) {dimensions = dense<1> : tensor<1xi64>}
|
||||
: (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32>
|
||||
return %0: tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @column_reduce
|
||||
// CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32>, %[[VAL:.*]]: memref<f32>) -> memref<?xf32>
|
||||
func @column_reduce(%arg0: tensor<?x?xf32>, %arg1: tensor<f32>) -> tensor<?xf32> {
|
||||
// CHECK-NOT: tensor_load
|
||||
// CHECK: %[[DIM1:.*]] = memref.dim %[[ARG]], %c1 : memref<?x?xf32>
|
||||
// CHECK: %[[OUT:.*]] = memref.alloc(%[[DIM1]]) : memref<?xf32>
|
||||
// CHECK: lmhlo.reduce
|
||||
// CHECK-SAME: %[[ARG]], %[[VAL]], %[[OUT]]
|
||||
// CHECK: return %[[OUT]] : memref<?xf32>
|
||||
%0 = "mhlo.reduce"(%arg0, %arg1) ( {
|
||||
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>): // no predecessors
|
||||
%1 = mhlo.add %arg2, %arg3 : tensor<f32>
|
||||
"mhlo.return"(%1) : (tensor<f32>) -> ()
|
||||
}) {dimensions = dense<0> : tensor<1xi64>}
|
||||
: (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32>
|
||||
return %0: tensor<?xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue