[MLIR] Add support for representing variadic reduce-window in HLO/LMHLO dialect.
- Fixed a subset of transformations to handle variadic reduce-window. PiperOrigin-RevId: 366278650
This commit is contained in:
parent
d1f697e618
commit
ff2cbfa2ec
|
@ -1257,6 +1257,7 @@ def HLO_TriangularSolveOp: HLO_Op<"triangular_solve",
|
||||||
|
|
||||||
def HLO_ReduceWindowOp: HLO_Op<"reduce_window", [
|
def HLO_ReduceWindowOp: HLO_Op<"reduce_window", [
|
||||||
RecursiveSideEffects,
|
RecursiveSideEffects,
|
||||||
|
SameVariadicOperandSize,
|
||||||
SingleBlockImplicitTerminator<"ReturnOp">
|
SingleBlockImplicitTerminator<"ReturnOp">
|
||||||
]>, BASE_HLO_ReduceWindowOp {
|
]>, BASE_HLO_ReduceWindowOp {
|
||||||
|
|
||||||
|
@ -1264,8 +1265,8 @@ def HLO_ReduceWindowOp: HLO_Op<"reduce_window", [
|
||||||
// attributes are 1-d. Attributes' leading dimension should match rank of the
|
// attributes are 1-d. Attributes' leading dimension should match rank of the
|
||||||
// inputs.
|
// inputs.
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
HLO_Tensor:$operand,
|
Variadic<HLO_Tensor>:$inputs,
|
||||||
HLO_Tensor:$init_value,
|
Variadic<HLO_Tensor>:$init_values,
|
||||||
I64ElementsAttr:$window_dimensions,
|
I64ElementsAttr:$window_dimensions,
|
||||||
// If strides or dilations attributes are missing then the default value is
|
// If strides or dilations attributes are missing then the default value is
|
||||||
// one for each of the input dimensions. Similarly, padding values are zero
|
// one for each of the input dimensions. Similarly, padding values are zero
|
||||||
|
@ -1276,15 +1277,36 @@ def HLO_ReduceWindowOp: HLO_Op<"reduce_window", [
|
||||||
OptionalAttr<I64ElementsAttr>:$padding
|
OptionalAttr<I64ElementsAttr>:$padding
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs HLO_Tensor);
|
let results = (outs Variadic<HLO_Tensor>);
|
||||||
|
|
||||||
// TODO(hinsu): Verify that the attached body arguments and results are
|
// TODO(hinsu): Verify that the attached body arguments and results are
|
||||||
// compatible with reduce op's operands.
|
// compatible with reduce op's operands.
|
||||||
let regions = (region SizedRegion<1>:$body);
|
let regions = (region SizedRegion<1>:$body);
|
||||||
|
|
||||||
let hasCustomHLOConverter = 1;
|
// Builder for non-variadic version of the operation.
|
||||||
|
let builders = [
|
||||||
|
OpBuilder<(ins "Type":$result_type, "Value":$operand,
|
||||||
|
"Value":$init_value,
|
||||||
|
"DenseIntElementsAttr":$window_dimensions,
|
||||||
|
"DenseIntElementsAttr":$window_strides,
|
||||||
|
"DenseIntElementsAttr":$base_dilations,
|
||||||
|
"DenseIntElementsAttr":$window_dilations,
|
||||||
|
"DenseIntElementsAttr":$padding),
|
||||||
|
[{
|
||||||
|
build($_builder, $_state, TypeRange(result_type), ValueRange(operand),
|
||||||
|
ValueRange(init_value), window_dimensions, window_strides,
|
||||||
|
base_dilations, window_dilations, padding);
|
||||||
|
}]>
|
||||||
|
];
|
||||||
|
|
||||||
|
let hasCustomHLOConverter = 1;
|
||||||
|
let verifier = [{ return Verify(*this); }];
|
||||||
// TODO(hinsu): Implement custom printer and parser.
|
// TODO(hinsu): Implement custom printer and parser.
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
// Get the operation used for reduction applied to `result_index`th result.
|
||||||
|
Operation *getReductionOp(int result_index);
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLO_ReturnOp : HLO_Op<"return", [NoSideEffect, Terminator]> {
|
def HLO_ReturnOp : HLO_Op<"return", [NoSideEffect, Terminator]> {
|
||||||
|
|
|
@ -216,12 +216,12 @@ def LHLO_ReduceOp: LHLO_Op<"reduce", [SameVariadicOperandSize]>, BASE_HLO_Reduce
|
||||||
let hasCanonicalizer = 1;
|
let hasCanonicalizer = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", []>, BASE_HLO_ReduceWindowOp {
|
def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", [SameVariadicOperandSize]>,
|
||||||
|
BASE_HLO_ReduceWindowOp {
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$inputs,
|
||||||
Arg<LHLO_Buffer, "", [MemRead]>:$init_value,
|
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$init_values,
|
||||||
Arg<LHLO_Buffer, "", [MemWrite]>:$out,
|
Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$out,
|
||||||
I64ElementsAttr:$window_dimensions,
|
I64ElementsAttr:$window_dimensions,
|
||||||
// If strides or dilations attributes are missing then the default value is
|
// If strides or dilations attributes are missing then the default value is
|
||||||
// one for each of the input dimensions. Similarly, padding values are zero
|
// one for each of the input dimensions. Similarly, padding values are zero
|
||||||
|
@ -233,6 +233,7 @@ def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", []>, BASE_HLO_ReduceWindowOp {
|
||||||
);
|
);
|
||||||
|
|
||||||
let regions = (region SizedRegion<1>:$body);
|
let regions = (region SizedRegion<1>:$body);
|
||||||
|
let verifier = [{ return Verify(*this); }];
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(timshen): Add a custom syntax for this.
|
// TODO(timshen): Add a custom syntax for this.
|
||||||
|
|
|
@ -1728,6 +1728,38 @@ static LogicalResult Verify(RecvOp op) {
|
||||||
|
|
||||||
OpFoldResult CopyOp::fold(ArrayRef<Attribute> operands) { return getOperand(); }
|
OpFoldResult CopyOp::fold(ArrayRef<Attribute> operands) { return getOperand(); }
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ReduceWindowOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// For reduce-window, all `inputs` need to have compatible shapes.
|
||||||
|
static LogicalResult Verify(ReduceWindowOp op) {
|
||||||
|
if (failed(verifyCompatibleShapes(op.inputs().getTypes())))
|
||||||
|
return op.emitOpError() << "requires same shape for all inputs";
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the operation used for reduction applied to `result_index`th result. Its
|
||||||
|
// expected to be a binary operation that consumes `result_index`th and
|
||||||
|
// `result_index + operands().size`th arguments of the body.
|
||||||
|
Operation* ReduceWindowOp::getReductionOp(int result_index) {
|
||||||
|
auto return_op = cast<ReturnOp>(body().front().getTerminator());
|
||||||
|
Operation* compute_op = return_op.results()[result_index].getDefiningOp();
|
||||||
|
if (compute_op->getNumOperands() != 2) return nullptr;
|
||||||
|
auto arg0 = compute_op->getOperand(0).dyn_cast<BlockArgument>();
|
||||||
|
auto arg1 = compute_op->getOperand(1).dyn_cast<BlockArgument>();
|
||||||
|
if (!arg0 || !arg1) return nullptr;
|
||||||
|
int arg0_num = arg0.getArgNumber();
|
||||||
|
int arg1_num = arg1.getArgNumber();
|
||||||
|
int other_arg_index = result_index + inputs().size();
|
||||||
|
if (arg0_num == result_index && arg1_num == other_arg_index)
|
||||||
|
return compute_op;
|
||||||
|
if (arg0_num == other_arg_index && arg1_num == result_index &&
|
||||||
|
compute_op->hasTrait<OpTrait::IsCommutative>())
|
||||||
|
return compute_op;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ReverseOp
|
// ReverseOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -281,6 +281,17 @@ void ReduceOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
|
||||||
results.insert<RemoveCopyInReduceBody>(context);
|
results.insert<RemoveCopyInReduceBody>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ReduceWindowOp.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// For reduce-window, all `inputs` need to have compatible shapes.
|
||||||
|
static LogicalResult Verify(ReduceWindowOp op) {
|
||||||
|
if (failed(verifyCompatibleShapes(op.inputs().getTypes())))
|
||||||
|
return op.emitOpError() << "requires same shape for all operands";
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace lmhlo
|
} // namespace lmhlo
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
|
|
@ -1687,40 +1687,34 @@ struct ReduceWindowOpOnTensorsConversion
|
||||||
/// the pooling is determined based on the body of the reduce window
|
/// the pooling is determined based on the body of the reduce window
|
||||||
/// operation. This class enumerates the different variants.
|
/// operation. This class enumerates the different variants.
|
||||||
enum class PoolingType {
|
enum class PoolingType {
|
||||||
|
kInvalid,
|
||||||
kMin,
|
kMin,
|
||||||
kMax,
|
kMax,
|
||||||
kAdd,
|
kAdd,
|
||||||
};
|
};
|
||||||
|
|
||||||
static PoolingType getPoolingType(Region& region) {
|
static PoolingType getPoolingType(mhlo::ReduceWindowOp reduce_op,
|
||||||
assert(region.getBlocks().size() == 1 &&
|
int result_index) {
|
||||||
"expected the region has exactlly one block");
|
if (Operation* op = reduce_op.getReductionOp(result_index)) {
|
||||||
Block& block = region.front();
|
if (isa<mhlo::MinOp>(*op)) return PoolingType::kMin;
|
||||||
assert(block.getOperations().size() == 2 &&
|
if (isa<mhlo::MaxOp>(*op)) return PoolingType::kMax;
|
||||||
"expected the block has exactlly two operations");
|
if (isa<mhlo::AddOp>(*op)) return PoolingType::kAdd;
|
||||||
auto op = block.begin();
|
}
|
||||||
if (isa<mhlo::MinOp>(op)) return PoolingType::kMin;
|
return PoolingType::kInvalid;
|
||||||
if (isa<mhlo::MaxOp>(op)) return PoolingType::kMax;
|
|
||||||
if (isa<mhlo::AddOp>(op)) return PoolingType::kAdd;
|
|
||||||
|
|
||||||
llvm_unreachable("unknown pooling type");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(
|
LogicalResult matchAndRewrite(
|
||||||
mhlo::ReduceWindowOp op, ArrayRef<Value> args,
|
mhlo::ReduceWindowOp op, ArrayRef<Value> args,
|
||||||
ConversionPatternRewriter& rewriter) const override {
|
ConversionPatternRewriter& rewriter) const override {
|
||||||
auto loc = op.getLoc();
|
auto loc = op.getLoc();
|
||||||
auto result_type = op.getResult().getType().cast<ShapedType>();
|
int rank = op.getResultTypes()[0].cast<ShapedType>().getRank();
|
||||||
if (result_type.getRank() != 4) {
|
if (rank != 4) {
|
||||||
return rewriter.notifyMatchFailure(op, "expected NHWC pooling-based op");
|
return rewriter.notifyMatchFailure(op, "expected NHWC pooling-based op");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a fake window dimension.
|
SmallVector<int64_t, 2> shapes;
|
||||||
SmallVector<int64_t, 4> shapes;
|
|
||||||
shapes.push_back(op.window_dimensions().getValue<int64_t>(1));
|
shapes.push_back(op.window_dimensions().getValue<int64_t>(1));
|
||||||
shapes.push_back(op.window_dimensions().getValue<int64_t>(2));
|
shapes.push_back(op.window_dimensions().getValue<int64_t>(2));
|
||||||
auto fake_window_dims = rewriter.create<linalg::InitTensorOp>(
|
|
||||||
loc, shapes, result_type.getElementType());
|
|
||||||
|
|
||||||
if (op.window_strides() &&
|
if (op.window_strides() &&
|
||||||
(op.window_strides().getValue().getValue<int64_t>(0) != 1 ||
|
(op.window_strides().getValue().getValue<int64_t>(0) != 1 ||
|
||||||
|
@ -1735,10 +1729,6 @@ struct ReduceWindowOpOnTensorsConversion
|
||||||
op, "expected window_dimensions to be [1,x,y,1]");
|
op, "expected window_dimensions to be [1,x,y,1]");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!args[0].getType().cast<ShapedType>().getElementType().isF32()) {
|
|
||||||
return rewriter.notifyMatchFailure(op, "expected element type to be f32");
|
|
||||||
}
|
|
||||||
|
|
||||||
Attribute strides;
|
Attribute strides;
|
||||||
if (op.window_stridesAttr()) {
|
if (op.window_stridesAttr()) {
|
||||||
strides = rewriter.getI64VectorAttr(
|
strides = rewriter.getI64VectorAttr(
|
||||||
|
@ -1756,9 +1746,25 @@ struct ReduceWindowOpOnTensorsConversion
|
||||||
dilations = rewriter.getI64VectorAttr({1, 1});
|
dilations = rewriter.getI64VectorAttr({1, 1});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SmallVector<Value> pooling_ops;
|
||||||
|
|
||||||
|
ArrayRef<Value> inputs = args.take_front(op.inputs().size());
|
||||||
|
ArrayRef<Value> init_values = args.drop_front(op.inputs().size());
|
||||||
|
for (auto it : llvm::zip(op.getResults(), inputs, init_values)) {
|
||||||
|
OpResult result = std::get<0>(it);
|
||||||
|
Value input = std::get<1>(it);
|
||||||
|
Value init_value = std::get<2>(it);
|
||||||
|
auto result_type = result.getType().cast<ShapedType>();
|
||||||
|
if (!input.getType().cast<ShapedType>().getElementType().isF32()) {
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"expected element type to be f32");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a fake window dimension.
|
||||||
|
auto fake_window_dims = rewriter.create<linalg::InitTensorOp>(
|
||||||
|
loc, shapes, result_type.getElementType());
|
||||||
Value init_tensor = rewriter.create<linalg::InitTensorOp>(
|
Value init_tensor = rewriter.create<linalg::InitTensorOp>(
|
||||||
loc, result_type.getShape(), result_type.getElementType());
|
loc, result_type.getShape(), result_type.getElementType());
|
||||||
Value init_value = args[1];
|
|
||||||
init_value = rewriter.create<tensor::ExtractOp>(loc, init_value);
|
init_value = rewriter.create<tensor::ExtractOp>(loc, init_value);
|
||||||
Value filled_init_tensor =
|
Value filled_init_tensor =
|
||||||
rewriter.create<linalg::FillOp>(loc, init_tensor, init_value)
|
rewriter.create<linalg::FillOp>(loc, init_tensor, init_value)
|
||||||
|
@ -1773,22 +1779,29 @@ struct ReduceWindowOpOnTensorsConversion
|
||||||
.getOperation());
|
.getOperation());
|
||||||
};
|
};
|
||||||
linalg::LinalgOp pooling_op;
|
linalg::LinalgOp pooling_op;
|
||||||
PoolingType pooling_type = getPoolingType(op.body());
|
PoolingType pooling_type = getPoolingType(op, result.getResultNumber());
|
||||||
switch (pooling_type) {
|
switch (pooling_type) {
|
||||||
case PoolingType::kMin: {
|
case PoolingType::kMin: {
|
||||||
pooling_op = create_op(static_cast<linalg::PoolingNHWCMinOp*>(nullptr));
|
pooling_op =
|
||||||
|
create_op(static_cast<linalg::PoolingNHWCMinOp*>(nullptr));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case PoolingType::kMax: {
|
case PoolingType::kMax: {
|
||||||
pooling_op = create_op(static_cast<linalg::PoolingNHWCMaxOp*>(nullptr));
|
pooling_op =
|
||||||
|
create_op(static_cast<linalg::PoolingNHWCMaxOp*>(nullptr));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case PoolingType::kAdd: {
|
case PoolingType::kAdd: {
|
||||||
pooling_op = create_op(static_cast<linalg::PoolingNHWCSumOp*>(nullptr));
|
pooling_op =
|
||||||
|
create_op(static_cast<linalg::PoolingNHWCSumOp*>(nullptr));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case PoolingType::kInvalid:
|
||||||
|
return rewriter.notifyMatchFailure(op, "unknown reduction operation");
|
||||||
}
|
}
|
||||||
rewriter.replaceOp(op, pooling_op->getResult(0));
|
pooling_ops.push_back(pooling_op->getResult(0));
|
||||||
|
}
|
||||||
|
rewriter.replaceOp(op, pooling_ops);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -93,8 +93,8 @@ struct MappedIvs {
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename OpTy>
|
template <typename OpTy>
|
||||||
MappedIvs MapWindowIvsToInput(OpTy op, ValueRange ivs, ValueRange window_ivs,
|
MappedIvs MapWindowIvsToInput(OpTy op, Value operand, ValueRange ivs,
|
||||||
OpBuilder* b) {
|
ValueRange window_ivs, OpBuilder* b) {
|
||||||
MappedIvs mapped_ivs;
|
MappedIvs mapped_ivs;
|
||||||
|
|
||||||
if (!op.window_strides().hasValue()) {
|
if (!op.window_strides().hasValue()) {
|
||||||
|
@ -108,7 +108,6 @@ MappedIvs MapWindowIvsToInput(OpTy op, ValueRange ivs, ValueRange window_ivs,
|
||||||
auto padding = op.padding().getValue();
|
auto padding = op.padding().getValue();
|
||||||
|
|
||||||
auto loc = op.getLoc();
|
auto loc = op.getLoc();
|
||||||
auto operand = op.operand();
|
|
||||||
auto operand_shape = operand.getType().template cast<MemRefType>().getShape();
|
auto operand_shape = operand.getType().template cast<MemRefType>().getShape();
|
||||||
|
|
||||||
// `in_bounds` is false when the mapped indices are in the padding area.
|
// `in_bounds` is false when the mapped indices are in the padding area.
|
||||||
|
@ -196,7 +195,7 @@ class ReduceOpConverter : public OpConversionPattern<lmhlo::ReduceOp> {
|
||||||
LogicalResult matchAndRewrite(
|
LogicalResult matchAndRewrite(
|
||||||
lmhlo::ReduceOp reduce_op, ArrayRef<Value> /*args*/,
|
lmhlo::ReduceOp reduce_op, ArrayRef<Value> /*args*/,
|
||||||
ConversionPatternRewriter& rewriter) const final {
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
// TODO(b/137624192) Implement variadic reduce.
|
// TODO(b/183977252) : Handle variadic ReduceOp/ReduceWindowOp
|
||||||
if (reduce_op.out().size() != 1) return failure();
|
if (reduce_op.out().size() != 1) return failure();
|
||||||
|
|
||||||
scf::ReduceOp scf_reduce_op =
|
scf::ReduceOp scf_reduce_op =
|
||||||
|
@ -312,7 +311,7 @@ class ReduceOpConverter : public OpConversionPattern<lmhlo::ReduceOp> {
|
||||||
// value = input[I]
|
// value = input[I]
|
||||||
// else
|
// else
|
||||||
// value = neutral_value
|
// value = neutral_value
|
||||||
// accumulator = reduction_operator(output[O], value)
|
// accumulator = reduction_operator(accumulator, value)
|
||||||
// output[O] = accumulator
|
// output[O] = accumulator
|
||||||
//
|
//
|
||||||
// Converts `lmhlo.ReduceWindowOp` into two scf::ParallelOp and a
|
// Converts `lmhlo.ReduceWindowOp` into two scf::ParallelOp and a
|
||||||
|
@ -367,6 +366,9 @@ class ReduceWindowOpConverter
|
||||||
LogicalResult matchAndRewrite(
|
LogicalResult matchAndRewrite(
|
||||||
lmhlo::ReduceWindowOp reduce_window_op, ArrayRef<Value> /*args*/,
|
lmhlo::ReduceWindowOp reduce_window_op, ArrayRef<Value> /*args*/,
|
||||||
ConversionPatternRewriter& rewriter) const final {
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
|
// TODO(b/183977252) : Handle variadic ReduceOp/ReduceWindowOp
|
||||||
|
if (reduce_window_op.out().size() != 1) return failure();
|
||||||
|
|
||||||
scf::ParallelOp output_loop, window_loop;
|
scf::ParallelOp output_loop, window_loop;
|
||||||
std::tie(output_loop, window_loop) =
|
std::tie(output_loop, window_loop) =
|
||||||
CreateParallelLoopsToTraverseOutputAndWindow(reduce_window_op,
|
CreateParallelLoopsToTraverseOutputAndWindow(reduce_window_op,
|
||||||
|
@ -387,14 +389,14 @@ class ReduceWindowOpConverter
|
||||||
lmhlo::ReduceWindowOp reduce_window_op,
|
lmhlo::ReduceWindowOp reduce_window_op,
|
||||||
ConversionPatternRewriter* rewriter) const {
|
ConversionPatternRewriter* rewriter) const {
|
||||||
auto loc = reduce_window_op.getLoc();
|
auto loc = reduce_window_op.getLoc();
|
||||||
Value init_value =
|
Value init_value = rewriter->create<memref::LoadOp>(
|
||||||
rewriter->create<memref::LoadOp>(loc, reduce_window_op.init_value());
|
loc, reduce_window_op.init_values()[0]);
|
||||||
|
|
||||||
Value zero = rewriter->create<ConstantIndexOp>(loc, 0);
|
Value zero = rewriter->create<ConstantIndexOp>(loc, 0);
|
||||||
Value one = rewriter->create<ConstantIndexOp>(loc, 1);
|
Value one = rewriter->create<ConstantIndexOp>(loc, 1);
|
||||||
|
|
||||||
// Create an outer parallel loop that spans the output of ReduceWindowOp.
|
// Create an outer parallel loop that spans the output of ReduceWindowOp.
|
||||||
Value output = reduce_window_op.out();
|
Value output = reduce_window_op.out()[0];
|
||||||
auto output_loop = MakeLoopOverShape(loc, output, rewriter);
|
auto output_loop = MakeLoopOverShape(loc, output, rewriter);
|
||||||
|
|
||||||
// Create a nested loop that traverses the window.
|
// Create a nested loop that traverses the window.
|
||||||
|
@ -429,22 +431,22 @@ class ReduceWindowOpConverter
|
||||||
"`window_dilations` attributes yet. The attributes will be ignored.");
|
"`window_dilations` attributes yet. The attributes will be ignored.");
|
||||||
}
|
}
|
||||||
|
|
||||||
Value operand = reduce_window_op.operand();
|
Value input = reduce_window_op.inputs()[0];
|
||||||
auto operand_type = operand.getType().cast<MemRefType>();
|
auto input_type = input.getType().cast<MemRefType>();
|
||||||
|
|
||||||
// Compute ivs in 'arg' buffer and whether these ivs are in pad area or not.
|
// Compute ivs in 'arg' buffer and whether these ivs are in pad area or not.
|
||||||
MappedIvs mapped_ivs =
|
MappedIvs mapped_ivs = MapWindowIvsToInput(
|
||||||
MapWindowIvsToInput(reduce_window_op, output_loop.getInductionVars(),
|
reduce_window_op, input, output_loop.getInductionVars(),
|
||||||
window_loop.getInductionVars(), rewriter);
|
window_loop.getInductionVars(), rewriter);
|
||||||
|
|
||||||
auto elem_or_init = rewriter->create<scf::IfOp>(
|
auto elem_or_init = rewriter->create<scf::IfOp>(
|
||||||
loc, operand_type.getElementType(), mapped_ivs.in_bounds,
|
loc, input_type.getElementType(), mapped_ivs.in_bounds,
|
||||||
/*withElseRegion=*/true);
|
/*withElseRegion=*/true);
|
||||||
|
|
||||||
OpBuilder then_builder =
|
OpBuilder then_builder =
|
||||||
elem_or_init.getThenBodyBuilder(rewriter->getListener());
|
elem_or_init.getThenBodyBuilder(rewriter->getListener());
|
||||||
Value elem = then_builder.create<mlir::memref::LoadOp>(
|
Value elem =
|
||||||
loc, reduce_window_op.operand(), mapped_ivs.ivs);
|
then_builder.create<mlir::memref::LoadOp>(loc, input, mapped_ivs.ivs);
|
||||||
then_builder.create<scf::YieldOp>(loc, elem);
|
then_builder.create<scf::YieldOp>(loc, elem);
|
||||||
|
|
||||||
OpBuilder else_builder =
|
OpBuilder else_builder =
|
||||||
|
@ -611,8 +613,8 @@ class SelectAndScatterOpConverter
|
||||||
OpBuilder::atBlockEnd(window_loops.inner_loop.getBody());
|
OpBuilder::atBlockEnd(window_loops.inner_loop.getBody());
|
||||||
|
|
||||||
// Compute ivs in 'arg' buffer and whether these ivs are in the pad area.
|
// Compute ivs in 'arg' buffer and whether these ivs are in the pad area.
|
||||||
MappedIvs mapped_ivs =
|
MappedIvs mapped_ivs = MapWindowIvsToInput(
|
||||||
MapWindowIvsToInput(s_and_s_op, loop_over_src.getInductionVars(),
|
s_and_s_op, s_and_s_op.operand(), loop_over_src.getInductionVars(),
|
||||||
window_loops.window_ivs, &inner_loop_b);
|
window_loops.window_ivs, &inner_loop_b);
|
||||||
|
|
||||||
IterArgs ivs_val_flag(window_loops.inner_loop.getRegionIterArgs());
|
IterArgs ivs_val_flag(window_loops.inner_loop.getRegionIterArgs());
|
||||||
|
|
|
@ -1863,6 +1863,43 @@ func @reduce_window_max_nhwc_with_cst(%arg0: tensor<1x18x18x64xf32>) -> tensor<1
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
func @reduce_window_sum_max_nhwc(%arg0: tensor<1x18x18x64xf32>,
|
||||||
|
%arg1: tensor<f32>) -> (tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>) {
|
||||||
|
%0:2 = "mhlo.reduce_window"(%arg0, %arg0, %arg1, %arg1) ( {
|
||||||
|
^bb0(%arg2: tensor<f32>, %arg3 : tensor<f32>, %arg4: tensor<f32>, %arg5 : tensor<f32>):
|
||||||
|
%1 = mhlo.add %arg2, %arg4 : tensor<f32>
|
||||||
|
%2 = mhlo.maximum %arg3, %arg5 : tensor<f32>
|
||||||
|
"mhlo.return"(%1, %2) : (tensor<f32>, tensor<f32>) -> ()
|
||||||
|
}) {window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
|
||||||
|
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<1x18x18x64xf32>, tensor<1x18x18x64xf32>, tensor<f32>, tensor<f32>) -> (tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>)
|
||||||
|
return %0#0, %0#1 : tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @reduce_window_sum_max_nhwc
|
||||||
|
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
|
||||||
|
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
|
||||||
|
// CHECK: %[[WINDOW0:.+]] = linalg.init_tensor [3, 3] : tensor<3x3xf32>
|
||||||
|
// CHECK: %[[INIT0:.+]] = linalg.init_tensor [1, 8, 8, 64] : tensor<1x8x8x64xf32>
|
||||||
|
// CHECK: %[[INIT_VAL0:.+]] = tensor.extract %[[ARG1]][] : tensor<f32>
|
||||||
|
// CHECK: %[[FILL0:.+]] = linalg.fill(%[[INIT]], %[[INIT_VAL]]) : tensor<1x8x8x64xf32>, f32 -> tensor<1x8x8x64xf32>
|
||||||
|
// CHECK: %[[RES0:.+]] = linalg.pooling_nhwc_sum
|
||||||
|
// CHECK-SAME: {dilations = dense<1> : vector<2xi64>
|
||||||
|
// CHECK-SAME: strides = dense<2> : vector<2xi64>}
|
||||||
|
// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW0]] : tensor<1x18x18x64xf32>, tensor<3x3xf32>)
|
||||||
|
// CHECK-SAME: outs(%[[FILL0]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
|
||||||
|
// CHECK: %[[WINDOW1:.+]] = linalg.init_tensor [3, 3] : tensor<3x3xf32>
|
||||||
|
// CHECK: %[[INIT1:.+]] = linalg.init_tensor [1, 8, 8, 64] : tensor<1x8x8x64xf32>
|
||||||
|
// CHECK: %[[INIT_VAL1:.+]] = tensor.extract %[[ARG1]][] : tensor<f32>
|
||||||
|
// CHECK: %[[FILL1:.+]] = linalg.fill(%[[INIT1]], %[[INIT_VAL1]]) : tensor<1x8x8x64xf32>, f32 -> tensor<1x8x8x64xf32>
|
||||||
|
// CHECK: %[[RES1:.+]] = linalg.pooling_nhwc_max
|
||||||
|
// CHECK-SAME: {dilations = dense<1> : vector<2xi64>
|
||||||
|
// CHECK-SAME: strides = dense<2> : vector<2xi64>}
|
||||||
|
// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW1]] : tensor<1x18x18x64xf32>, tensor<3x3xf32>)
|
||||||
|
// CHECK-SAME: outs(%[[FILL1]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
|
||||||
|
// CHECK: return %[[RES0]], %[[RES1]]
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
func @torch_select_index(%arg0: tensor<5x1x5xi32>,
|
func @torch_select_index(%arg0: tensor<5x1x5xi32>,
|
||||||
%arg1: tensor<2xi32>) -> tensor<2x1x5xi32> {
|
%arg1: tensor<2xi32>) -> tensor<2x1x5xi32> {
|
||||||
%0 = "mhlo.torch_index_select"(%arg0, %arg1) {
|
%0 = "mhlo.torch_index_select"(%arg0, %arg1) {
|
||||||
|
|
|
@ -1389,3 +1389,37 @@ func @custom_call_multiple_outputs(%x: tensor<2xf32>) -> tensor<2xf32> {
|
||||||
%1 = "mhlo.add"(%0#0, %0#1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
|
%1 = "mhlo.add"(%0#0, %0#1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
|
||||||
return %1 : tensor<2xf32>
|
return %1 : tensor<2xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK: func @reduce_window
|
||||||
|
func @reduce_window(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi32>, %init0: tensor<f32>, %init1: tensor<i32>) -> (tensor<2x2xf32>, tensor<2x2xi32>) {
|
||||||
|
%0:2 = "mhlo.reduce_window"(%arg0, %arg1, %init0, %init1) ({
|
||||||
|
^bb0(%a0: tensor<f32>, %a1: tensor<i32>, %b0: tensor<f32>, %b1: tensor<i32>): // no predecessors
|
||||||
|
%2 = mhlo.add %a0, %b0 : tensor<f32>
|
||||||
|
%3 = mhlo.add %a1, %b1 : tensor<i32>
|
||||||
|
%4 = "mhlo.tuple"(%2, %3) : (tensor<f32>, tensor<i32>) -> tuple<tensor<f32>, tensor<i32>>
|
||||||
|
"mhlo.return"(%4) : (tuple<tensor<f32>, tensor<i32>>) -> ()
|
||||||
|
})
|
||||||
|
{ padding = dense<[[2, 2], [0, 0]]> : tensor<2x2xi64>,
|
||||||
|
window_dimensions = dense<[5, 1]> : tensor<2xi64>,
|
||||||
|
window_strides = dense<[3, 1]> : tensor<2xi64> } : (tensor<4x2xf32>, tensor<4x2xi32>, tensor<f32>, tensor<i32>) -> (tensor<2x2xf32>, tensor<2x2xi32>)
|
||||||
|
return %0#0, %0#1 : tensor<2x2xf32>, tensor<2x2xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @reduce_window_invalid(%arg0: tensor<4x2xf32>, %arg1: tensor<4x3xi32>, %init0: tensor<f32>, %init1: tensor<i32>) -> (tensor<2x2xf32>, tensor<2x2xi32>) {
|
||||||
|
// expected-error @+1 {{requires same shape for all inputs}}
|
||||||
|
%0:2 = "mhlo.reduce_window"(%arg0, %arg1, %init0, %init1) ({
|
||||||
|
^bb0(%a0: tensor<f32>, %a1: tensor<i32>, %b0: tensor<f32>, %b1: tensor<i32>): // no predecessors
|
||||||
|
%2 = mhlo.add %a0, %b0 : tensor<f32>
|
||||||
|
%3 = mhlo.add %a1, %b1 : tensor<i32>
|
||||||
|
%4 = "mhlo.tuple"(%2, %3) : (tensor<f32>, tensor<i32>) -> tuple<tensor<f32>, tensor<i32>>
|
||||||
|
"mhlo.return"(%4) : (tuple<tensor<f32>, tensor<i32>>) -> ()
|
||||||
|
})
|
||||||
|
{ padding = dense<[[2, 2], [0, 0]]> : tensor<2x2xi64>,
|
||||||
|
window_dimensions = dense<[5, 1]> : tensor<2xi64>,
|
||||||
|
window_strides = dense<[3, 1]> : tensor<2xi64> } : (tensor<4x2xf32>, tensor<4x3xi32>, tensor<f32>, tensor<i32>) -> (tensor<2x2xf32>, tensor<2x2xi32>)
|
||||||
|
return %0#0, %0#1 : tensor<2x2xf32>, tensor<2x2xi32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue