PR #50073: [MLIR] Add GatherOp lowering from lmhlo to Affine.
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/50073 -- Lowering of `GatherOp` is added from lmhlo to Affine. The lowering has been added as a part of `lhlo-legalize-to-affine` pass. Signed-off-by: Abhishek Varma <abhishek.varma@polymagelabs.com> Copybara import of the project: -- 5b3dcd4ab31a69f305cd079b869ee35ba6dc8bf5 by Abhishek Varma <abhishek.varma@polymagelabs.com>: [MLIR] Add GatherOp lowering from lmhlo to Affine. -- Lowering of `GatherOp` is added from lmhlo to Affine. The lowering has been added as a part of `lhlo-legalize-to-affine` pass. Signed-off-by: Abhishek Varma <abhishek.varma@polymagelabs.com> PiperOrigin-RevId: 380052157
This commit is contained in:
parent
2ab16024cf
commit
da6593e960
|
@ -177,6 +177,306 @@ struct ConcatOpConverter : public OpRewritePattern<ConcatenateOp> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Returns a zero value of type `type`. `type` is expected to be either
|
||||||
|
/// int or float.
|
||||||
|
static Value getZeroValue(Type type, Location loc, PatternRewriter& rewriter) {
|
||||||
|
assert(type.isIntOrFloat() && "Expected int or float");
|
||||||
|
|
||||||
|
if (IntegerType intType = type.dyn_cast<IntegerType>())
|
||||||
|
return rewriter.create<mlir::ConstantIntOp>(loc, 0, intType.getWidth());
|
||||||
|
|
||||||
|
FloatType floatType = type.cast<FloatType>();
|
||||||
|
return rewriter.create<mlir::ConstantFloatOp>(
|
||||||
|
loc, APFloat::getZero(floatType.getFloatSemantics()), floatType);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Emits a nest to fill the given `buffer` of memref type with `fillValue`.
|
||||||
|
static void fillBuffer(Location loc, Value buffer, Value fillValue,
|
||||||
|
PatternRewriter& builder) {
|
||||||
|
OpBuilder::InsertionGuard guard(builder);
|
||||||
|
MemRefType bufType = buffer.getType().cast<MemRefType>();
|
||||||
|
unsigned rank = bufType.getRank();
|
||||||
|
SmallVector<Value, 4> dimSizes;
|
||||||
|
dimSizes.reserve(rank);
|
||||||
|
for (unsigned i = 0; i < rank; ++i)
|
||||||
|
dimSizes.push_back(builder.create<mlir::memref::DimOp>(loc, buffer, i));
|
||||||
|
|
||||||
|
AffineMap idSymMap = builder.getSymbolIdentityMap();
|
||||||
|
AffineMap lbMap = builder.getConstantAffineMap(0);
|
||||||
|
SmallVector<Value, 4> ivs(rank);
|
||||||
|
AffineForOp forOp;
|
||||||
|
for (unsigned i = 0; i < rank; ++i) {
|
||||||
|
forOp = builder.create<AffineForOp>(loc, llvm::None, lbMap, dimSizes[i],
|
||||||
|
idSymMap);
|
||||||
|
builder.setInsertionPointToStart(forOp.getBody());
|
||||||
|
ivs[i] = forOp.getInductionVar();
|
||||||
|
}
|
||||||
|
Type fillValueType = fillValue.getType();
|
||||||
|
auto fillMemRefType = fillValueType.dyn_cast<MemRefType>();
|
||||||
|
assert(((fillMemRefType && fillMemRefType.getRank() == 0) ||
|
||||||
|
fillValueType.isIntOrFloat()) &&
|
||||||
|
"init value has to be a 0-d memref or int or fp");
|
||||||
|
Value initVal = fillMemRefType ? builder.create<AffineLoadOp>(
|
||||||
|
loc, fillValue, /*indices=*/llvm::None)
|
||||||
|
: fillValue;
|
||||||
|
builder.create<AffineStoreOp>(loc, initVal, buffer, ivs);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Converts GatherOp to Affine nest form.
|
||||||
|
/// Pseudocode:
|
||||||
|
/// 1. Fill a temporary output tensor with 0.
|
||||||
|
/// 2. Repeat the following for each batch dimension :-
|
||||||
|
/// 1. For each indices in 'operand' :-
|
||||||
|
/// 1. Get hold of start indices from 'start_indices'.
|
||||||
|
/// 2. Add offset to the start indices to get the final indices.
|
||||||
|
/// 3. Load value from 'operand' tensor : 'operand_val'.
|
||||||
|
/// 4. Load value from temporary output : 'prev_val'.
|
||||||
|
/// 5. If the final indices match current indices of 'operand' :
|
||||||
|
/// 'prev_val' = 'prev_val' + 'operand_val'
|
||||||
|
/// 6. Store 'prev_val' back to the temporary output.
|
||||||
|
class GatherOpConverter : public OpRewritePattern<GatherOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern<GatherOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(GatherOp op,
|
||||||
|
PatternRewriter& rewriter) const final {
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
|
||||||
|
// Operand array.
|
||||||
|
Value operand = op.operand();
|
||||||
|
MemRefType operand_type = operand.getType().cast<MemRefType>();
|
||||||
|
unsigned operand_rank = operand_type.getRank();
|
||||||
|
ArrayRef<int64_t> operand_shape = operand_type.getShape();
|
||||||
|
|
||||||
|
// Start_indices array.
|
||||||
|
Value start_indices = op.start_indices();
|
||||||
|
MemRefType start_indices_type = start_indices.getType().cast<MemRefType>();
|
||||||
|
unsigned start_indices_rank = start_indices_type.getRank();
|
||||||
|
ArrayRef<int64_t> start_indices_shape = start_indices_type.getShape();
|
||||||
|
|
||||||
|
// Output array.
|
||||||
|
Value output = op.output();
|
||||||
|
MemRefType output_type = output.getType().cast<MemRefType>();
|
||||||
|
ArrayRef<int64_t> output_shape = output_type.getShape();
|
||||||
|
|
||||||
|
if (!operand_type.hasStaticShape() ||
|
||||||
|
!start_indices_type.hasStaticShape() || !output_type.hasStaticShape())
|
||||||
|
return rewriter.notifyMatchFailure(op, "only static shaped type allowed");
|
||||||
|
|
||||||
|
mhlo::GatherDimensionNumbers gather_dim = op.dimension_numbersAttr();
|
||||||
|
|
||||||
|
// Collapsed_slice_dim.
|
||||||
|
DenseIntElementsAttr collapsed_slice_dims_attr =
|
||||||
|
gather_dim.collapsed_slice_dims();
|
||||||
|
SmallVector<int64_t, 4> collapsed_slice_dims;
|
||||||
|
for (const APInt& dim : collapsed_slice_dims_attr.getIntValues())
|
||||||
|
collapsed_slice_dims.push_back(dim.getSExtValue());
|
||||||
|
|
||||||
|
// Offset_dim.
|
||||||
|
DenseIntElementsAttr offset_dims_attr = gather_dim.offset_dims();
|
||||||
|
SmallVector<int64_t, 4> offset_dims;
|
||||||
|
for (const APInt& dim : offset_dims_attr.getIntValues())
|
||||||
|
offset_dims.push_back(dim.getSExtValue());
|
||||||
|
|
||||||
|
// Start_index_map.
|
||||||
|
DenseIntElementsAttr start_index_map_attr = gather_dim.start_index_map();
|
||||||
|
SmallVector<int64_t, 4> start_index_map;
|
||||||
|
for (const APInt& dim : start_index_map_attr.getIntValues())
|
||||||
|
start_index_map.push_back(dim.getSExtValue());
|
||||||
|
|
||||||
|
// Index_vector_dim.
|
||||||
|
IntegerAttr index_vector_dim_attr = gather_dim.index_vector_dim();
|
||||||
|
int64_t index_vector_dim = index_vector_dim_attr.getValue().getSExtValue();
|
||||||
|
|
||||||
|
// Slice_sizes.
|
||||||
|
DenseIntElementsAttr slice_sizes_attr = op.slice_sizesAttr();
|
||||||
|
SmallVector<int64_t, 4> slice_sizes;
|
||||||
|
for (const APInt& dim : slice_sizes_attr.getIntValues())
|
||||||
|
slice_sizes.push_back(dim.getSExtValue());
|
||||||
|
|
||||||
|
// Creating constants with 0 value. We need the Integer type constant value
|
||||||
|
// because the indices type will be Integer.
|
||||||
|
Value zero_int_val = rewriter.create<mlir::ConstantIntOp>(
|
||||||
|
loc, 0, start_indices_type.getElementType());
|
||||||
|
Type element_type = output_type.getElementType();
|
||||||
|
Value zero_load_value = getZeroValue(element_type, loc, rewriter);
|
||||||
|
// Initializing the output buffer with 0.
|
||||||
|
fillBuffer(loc, output, zero_load_value, rewriter);
|
||||||
|
|
||||||
|
// We fetch the shape of start_indices at index_vector_dim. In case
|
||||||
|
// index_vector_dim is equal to the rank of start_indices, we implicitly
|
||||||
|
// consider start_indices to have a trailing 1 dimension.
|
||||||
|
unsigned start_indices_numbers =
|
||||||
|
(index_vector_dim == start_indices_rank)
|
||||||
|
? 1
|
||||||
|
: start_indices_shape[index_vector_dim];
|
||||||
|
// We create integer constants till start_incides_index which help us to
|
||||||
|
// fetch start_indices in affine transformation.
|
||||||
|
SmallVector<Value, 4> start_indices_index;
|
||||||
|
for (unsigned i = 0; i < start_indices_numbers; i++) {
|
||||||
|
Value i_val = rewriter.create<mlir::ConstantIntOp>(
|
||||||
|
loc, i, start_indices_type.getElementType());
|
||||||
|
i_val = rewriter.create<IndexCastOp>(loc, i_val, rewriter.getIndexType());
|
||||||
|
start_indices_index.push_back(i_val);
|
||||||
|
}
|
||||||
|
|
||||||
|
// S_in contains the multiple indices that form a starting index used in the
|
||||||
|
// input/operand tensor. O_in contains the multiple offsets of corresponding
|
||||||
|
// starting index used in the input/operand tensor. We initialize both of
|
||||||
|
// them with 0.
|
||||||
|
SmallVector<Value, 4> S_in;
|
||||||
|
SmallVector<Value, 4> O_in;
|
||||||
|
Value zero_index_val = rewriter.create<IndexCastOp>(
|
||||||
|
loc, zero_int_val, rewriter.getIndexType());
|
||||||
|
for (unsigned i = 0; i < operand_rank; i++) {
|
||||||
|
S_in.push_back(zero_index_val);
|
||||||
|
O_in.push_back(zero_index_val);
|
||||||
|
}
|
||||||
|
|
||||||
|
// batch_induction_vars stores the loop induction variables pertaining to
|
||||||
|
// the batches of start indices.
|
||||||
|
SmallVector<Value, 4> batch_induction_vars;
|
||||||
|
// output_induction_vars stores the loop induction variables pertaining to
|
||||||
|
// both batches and offsets within the output tensor.
|
||||||
|
SmallVector<Value, 4> output_induction_vars;
|
||||||
|
// Create loops to iterate over each batch of starting index.
|
||||||
|
for (unsigned i = 0; i < start_indices_rank; i++) {
|
||||||
|
// ith dimension of start_indices doesn't form a batch if it is equal to
|
||||||
|
// index_vector_dim.
|
||||||
|
if (i == index_vector_dim) continue;
|
||||||
|
AffineForOp for_op =
|
||||||
|
rewriter.create<AffineForOp>(loc, 0, start_indices_shape[i]);
|
||||||
|
batch_induction_vars.push_back(for_op.getInductionVar());
|
||||||
|
output_induction_vars.push_back(for_op.getInductionVar());
|
||||||
|
rewriter.setInsertionPointToStart(for_op.getBody());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create loops to iterate over each offset dimension within the output
|
||||||
|
// tensor.
|
||||||
|
for (unsigned i = 0, k = 0, e = offset_dims.size(); i < e; i++) {
|
||||||
|
AffineForOp for_op =
|
||||||
|
rewriter.create<AffineForOp>(loc, 0, output_shape[offset_dims[i]]);
|
||||||
|
rewriter.setInsertionPointToStart(for_op.getBody());
|
||||||
|
// We try to fetch the first non-collapsed dimension.
|
||||||
|
while (k < collapsed_slice_dims.size() && collapsed_slice_dims[k] == i)
|
||||||
|
k++;
|
||||||
|
// Remapping the offset_dim[i] to the non-collapsed dimension.
|
||||||
|
O_in[k++] = for_op.getInductionVar();
|
||||||
|
// We assume offset_dims to be sorted. So when inserted to
|
||||||
|
// output_induction_vars the loop induction variable gets inserted at the
|
||||||
|
// correct position.
|
||||||
|
output_induction_vars.insert(
|
||||||
|
output_induction_vars.begin() + offset_dims[i],
|
||||||
|
for_op.getInductionVar());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create loops to iterate over all dimensions within the operand tensor.
|
||||||
|
SmallVector<Value, 4> operand_index;
|
||||||
|
for (unsigned i = 0, k = 0; i < operand_rank; i++) {
|
||||||
|
// We assume start_index_map to have sorted dimensions. We only include
|
||||||
|
// those dimensions of operand tensor which are present in
|
||||||
|
// start_index_map.
|
||||||
|
if (k < start_index_map.size() && i == start_index_map[k++]) {
|
||||||
|
AffineForOp for_op =
|
||||||
|
rewriter.create<AffineForOp>(loc, 0, operand_shape[i]);
|
||||||
|
operand_index.push_back(for_op.getInductionVar());
|
||||||
|
rewriter.setInsertionPointToStart(for_op.getBody());
|
||||||
|
} else {
|
||||||
|
operand_index.push_back(O_in[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// In case index_vector_dim is not equal to start_indices shape then we
|
||||||
|
// create another loop to iterate over starting index and update
|
||||||
|
// batch_induction_vars.
|
||||||
|
if (index_vector_dim != start_indices_rank) {
|
||||||
|
for (unsigned i = 0; i < start_indices_numbers; i++) {
|
||||||
|
batch_induction_vars.insert(
|
||||||
|
batch_induction_vars.begin() + index_vector_dim,
|
||||||
|
start_indices_index[i]);
|
||||||
|
Value start_index = rewriter.create<AffineLoadOp>(loc, start_indices,
|
||||||
|
batch_induction_vars);
|
||||||
|
start_index = rewriter.create<IndexCastOp>(loc, start_index,
|
||||||
|
rewriter.getIndexType());
|
||||||
|
S_in[start_index_map[i]] = start_index;
|
||||||
|
batch_induction_vars.erase(batch_induction_vars.begin() +
|
||||||
|
index_vector_dim);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Since index_vector_dim is equal to start_indicesRank we can directly
|
||||||
|
// fetch the start_index from batch_induction_vars.
|
||||||
|
Value start_index = rewriter.create<AffineLoadOp>(loc, start_indices,
|
||||||
|
batch_induction_vars);
|
||||||
|
start_index = rewriter.create<IndexCastOp>(loc, start_index,
|
||||||
|
rewriter.getIndexType());
|
||||||
|
S_in[0] = start_index;
|
||||||
|
}
|
||||||
|
|
||||||
|
// We load value at a particular operand index and populate the output
|
||||||
|
// tensor if the index constraints match.
|
||||||
|
Value load_value =
|
||||||
|
rewriter.create<AffineLoadOp>(loc, operand, operand_index);
|
||||||
|
SmallVector<Value, 4> predicates;
|
||||||
|
// Adding offsets to the corresponding starting index and comparing it with
|
||||||
|
// the corresponding operand index.
|
||||||
|
for (unsigned k = 0, i = 0; k < start_index_map.size(); k++) {
|
||||||
|
i = start_index_map[k];
|
||||||
|
Value add_start_index_offset = rewriter.create<mlir::AddIOp>(
|
||||||
|
loc, rewriter.getIndexType(), S_in[i], O_in[i]);
|
||||||
|
Value predicate = rewriter.create<mlir::CmpIOp>(
|
||||||
|
loc, CmpIPredicate::eq, add_start_index_offset, operand_index[i]);
|
||||||
|
predicates.push_back(predicate);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Since the no. of predicates is equal to start_index_map.size() we
|
||||||
|
// iterate over pairs of predicates and join them with AndOp.
|
||||||
|
unsigned num_equality_checks = start_index_map.size() / 2;
|
||||||
|
// We store the final predicate formed by joining other predicates with
|
||||||
|
// AndOp in result_predicate.
|
||||||
|
Value result_predicate = nullptr;
|
||||||
|
for (unsigned i = 0; i < num_equality_checks; i += 2) {
|
||||||
|
Value predicateA = predicates[i];
|
||||||
|
Value predicateB = predicates[i + 1];
|
||||||
|
Value and_predicate =
|
||||||
|
rewriter.create<mlir::AndOp>(loc, predicateA, predicateB);
|
||||||
|
result_predicate = (i == 0) ? and_predicate
|
||||||
|
: rewriter.create<mlir::AndOp>(
|
||||||
|
loc, result_predicate, and_predicate);
|
||||||
|
}
|
||||||
|
// We fetch the last predicate value. In case this is the only predicate
|
||||||
|
// we let result_predicate be equal to this predicate value. Else if there
|
||||||
|
// are odd number of predicates we join it to other predicates using AndOp.
|
||||||
|
Value predicate = predicates.back();
|
||||||
|
if (!result_predicate) result_predicate = predicate;
|
||||||
|
// In case there are odd number of predicates we join the last predicate
|
||||||
|
// to the result_predicate using AndOp.
|
||||||
|
else if (start_index_map.size() % 2 == 1)
|
||||||
|
result_predicate =
|
||||||
|
rewriter.create<mlir::AndOp>(loc, result_predicate, predicate);
|
||||||
|
|
||||||
|
// We use the loaded value if the index computed by adding offsets to
|
||||||
|
// starting index is equal to the current operand index. We use 0 as a value
|
||||||
|
// otherwise.
|
||||||
|
Value select_load = rewriter.create<mlir::SelectOp>(
|
||||||
|
loc, result_predicate, load_value, zero_load_value);
|
||||||
|
// We load value at output array.
|
||||||
|
Value output_value =
|
||||||
|
rewriter.create<AffineLoadOp>(loc, output, output_induction_vars);
|
||||||
|
|
||||||
|
// The selected value is added to the previous value stored in output array.
|
||||||
|
if (element_type.isa<FloatType>())
|
||||||
|
output_value =
|
||||||
|
rewriter.create<AddFOp>(loc, element_type, select_load, output_value);
|
||||||
|
else
|
||||||
|
output_value =
|
||||||
|
rewriter.create<AddIOp>(loc, element_type, select_load, output_value);
|
||||||
|
rewriter.create<AffineStoreOp>(loc, output_value, output,
|
||||||
|
output_induction_vars);
|
||||||
|
rewriter.eraseOp(op);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename LhloOpTy>
|
template <typename LhloOpTy>
|
||||||
struct BinaryOpConverter : public OpRewritePattern<LhloOpTy> {
|
struct BinaryOpConverter : public OpRewritePattern<LhloOpTy> {
|
||||||
using OpRewritePattern<LhloOpTy>::OpRewritePattern;
|
using OpRewritePattern<LhloOpTy>::OpRewritePattern;
|
||||||
|
@ -225,7 +525,8 @@ void populateLHLOToAffineConversionPattern(MLIRContext* context,
|
||||||
BinaryOpConverter<lmhlo::MulOp>,
|
BinaryOpConverter<lmhlo::MulOp>,
|
||||||
BinaryOpConverter<lmhlo::SubOp>,
|
BinaryOpConverter<lmhlo::SubOp>,
|
||||||
ConcatOpConverter,
|
ConcatOpConverter,
|
||||||
DotOpConverter>(context);
|
DotOpConverter,
|
||||||
|
GatherOpConverter>(context);
|
||||||
// clang-format on
|
// clang-format on
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -238,3 +238,204 @@ func @concatenate_dynamic(%arg0: memref<1x?xf32>, %arg1: memref<1x?xf32>, %arg2:
|
||||||
"lmhlo.copy"(%0, %arg2) : (memref<1x?xf32>, memref<1x?xf32>) -> ()
|
"lmhlo.copy"(%0, %arg2) : (memref<1x?xf32>, memref<1x?xf32>) -> ()
|
||||||
"lmhlo.terminator"() : () -> ()
|
"lmhlo.terminator"() : () -> ()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Gather op.
|
||||||
|
// Test case 1: A general GatherOp test case.
|
||||||
|
// CHECK-LABEL: func @gather_1
|
||||||
|
// CHECK-SAME: (%[[OPERAND:.*]]: memref<28996x512xf32>, %[[START_INDICES:.*]]: memref<1x128xi32>, %[[OUTPUT:.*]]: memref<1x128x512xf32>)
|
||||||
|
func @gather_1(%arg0: memref<28996x512xf32>, %arg1: memref<1x128xi32>, %arg2: memref<1x128x512xf32>) {
|
||||||
|
%0 = memref.alloc() : memref<1x128x512xf32>
|
||||||
|
"lmhlo.gather"(%arg0, %arg1, %0) {dimension_numbers = { collapsed_slice_dims = dense<0> : tensor<1xi64>,
|
||||||
|
index_vector_dim = 2 : i64,
|
||||||
|
offset_dims = dense<2> : tensor<1xi64>,
|
||||||
|
start_index_map = dense<0> : tensor<1xi64>},
|
||||||
|
indices_are_sorted = false, name = "gather.381", slice_sizes = dense<[1, 512]> : tensor<2xi64>} :
|
||||||
|
(memref<28996x512xf32>, memref<1x128xi32>, memref<1x128x512xf32>) -> ()
|
||||||
|
"lmhlo.copy"(%0, %arg2) : (memref<1x128x512xf32>, memref<1x128x512xf32>) -> ()
|
||||||
|
"lmhlo.terminator"() : () -> ()
|
||||||
|
}
|
||||||
|
// CHECK-NEXT: %[[zero:.*]] = constant 0.000000e+00 : f32
|
||||||
|
// CHECK-NEXT: %[[temp_output:.*]] = memref.alloc() : memref<1x128x512xf32>
|
||||||
|
// CHECK-NEXT: affine.for %{{.*}} = 0 to 1 {
|
||||||
|
// CHECK-NEXT: affine.for %{{.*}} = 0 to 128 {
|
||||||
|
// CHECK-NEXT: affine.for %{{.*}} = 0 to 512 {
|
||||||
|
// CHECK-NEXT: affine.store %[[zero]], %[[temp_output]][%{{.*}}, %{{.*}}, %{{.*}}] : memref<1x128x512xf32>
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: affine.for %[[batch0:.*]] = 0 to 1 {
|
||||||
|
// CHECK-NEXT: affine.for %[[batch1:.*]] = 0 to 128 {
|
||||||
|
// CHECK-NEXT: affine.for %[[offset0:.*]] = 0 to 512 {
|
||||||
|
// CHECK-NEXT: affine.for %[[iv0:.*]] = 0 to 28996 {
|
||||||
|
// CHECK-NEXT: %[[a:.*]] = affine.load %[[START_INDICES]][%[[batch0]], %[[batch1]]] : memref<1x128xi32>
|
||||||
|
// CHECK-NEXT: %[[S_in0:.*]] = index_cast %[[a]] : i32 to index
|
||||||
|
// CHECK-NEXT: %[[operand_val:.*]] = affine.load %[[OPERAND]][%[[iv0]], %[[offset0]]] : memref<28996x512xf32>
|
||||||
|
// CHECK-NEXT: %[[pred:.*]] = cmpi eq, %[[S_in0]], %[[iv0]] : index
|
||||||
|
// CHECK-NEXT: %[[selected_value:.*]] = select %[[pred]], %[[operand_val]], %[[zero]] : f32
|
||||||
|
// CHECK-NEXT: %[[prev_value:.*]] = affine.load %[[temp_output]][%[[batch0]], %[[batch1]], %[[offset0]]] : memref<1x128x512xf32>
|
||||||
|
// CHECK-NEXT: %[[final_value:.*]] = addf %[[selected_value]], %[[prev_value]] : f32
|
||||||
|
// CHECK-NEXT: affine.store %[[final_value]], %[[temp_output]][%[[batch0]], %[[batch1]], %[[offset0]]] : memref<1x128x512xf32>
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
|
||||||
|
// Test case 2: Checks for multi-dimensional starting indices.
|
||||||
|
// CHECK-LABEL: func @gather_2
|
||||||
|
// CHECK-SAME: (%[[OPERAND:.*]]: memref<16x11xf32>, %[[START_INDICES:.*]]: memref<5x2xi32>, %[[OUTPUT:.*]]: memref<5x8x6xf32>)
|
||||||
|
func @gather_2(%arg0: memref<16x11xf32>, %arg1: memref<5x2xi32>, %arg2: memref<5x8x6xf32>) {
|
||||||
|
%0 = memref.alloc() : memref<5x8x6xf32>
|
||||||
|
"lmhlo.gather"(%arg0, %arg1, %0) {dimension_numbers = { collapsed_slice_dims = dense<-1> : tensor<1xi64>,
|
||||||
|
index_vector_dim = 1 : i64,
|
||||||
|
offset_dims = dense<[1,2]> : tensor<2xi64>,
|
||||||
|
start_index_map = dense<[0,1]> : tensor<2xi64>},
|
||||||
|
indices_are_sorted = false, name = "gather.381", slice_sizes = dense<[8, 6]> : tensor<2xi64>} :
|
||||||
|
(memref<16x11xf32>, memref<5x2xi32>, memref<5x8x6xf32>) -> ()
|
||||||
|
"lmhlo.copy"(%0, %arg2) : (memref<5x8x6xf32>, memref<5x8x6xf32>) -> ()
|
||||||
|
"lmhlo.terminator"() : () -> ()
|
||||||
|
}
|
||||||
|
// CHECK-NEXT: %[[zero:.*]] = constant 0.000000e+00 : f32
|
||||||
|
// CHECK-NEXT: %c0 = constant 0 : index
|
||||||
|
// CHECK-NEXT: %c1 = constant 1 : index
|
||||||
|
// CHECK-NEXT: %[[temp_output:.*]] = memref.alloc() : memref<5x8x6xf32>
|
||||||
|
// CHECK-NEXT: affine.for %{{.*}} = 0 to 5 {
|
||||||
|
// CHECK-NEXT: affine.for %{{.*}} = 0 to 8 {
|
||||||
|
// CHECK-NEXT: affine.for %{{.*}} = 0 to 6 {
|
||||||
|
// CHECK-NEXT: affine.store %[[zero]], %[[temp_output]][%{{.*}}, %{{.*}}, %{{.*}}] : memref<5x8x6xf32>
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: affine.for %[[batch0:.*]] = 0 to 5 {
|
||||||
|
// CHECK-NEXT: affine.for %[[offset0:.*]] = 0 to 8 {
|
||||||
|
// CHECK-NEXT: affine.for %[[offset1:.*]] = 0 to 6 {
|
||||||
|
// CHECK-NEXT: affine.for %[[iv0:.*]] = 0 to 16 {
|
||||||
|
// CHECK-NEXT: affine.for %[[iv1:.*]] = 0 to 11 {
|
||||||
|
// CHECK-NEXT: %[[a:.*]] = affine.load %[[START_INDICES]][%[[batch0]], %c0] : memref<5x2xi32>
|
||||||
|
// CHECK-NEXT: %[[S_in0:.*]] = index_cast %[[a]] : i32 to index
|
||||||
|
// CHECK-NEXT: %[[b:.*]] = affine.load %[[START_INDICES]][%[[batch0]], %c1] : memref<5x2xi32>
|
||||||
|
// CHECK-NEXT: %[[S_in1:.*]] = index_cast %[[b]] : i32 to index
|
||||||
|
// CHECK-NEXT: %[[operand_val:.*]] = affine.load %[[OPERAND]][%[[iv0]], %[[iv1]]] : memref<16x11xf32>
|
||||||
|
// CHECK-NEXT: %[[In0:.*]] = addi %[[S_in0]], %[[offset0]] : index
|
||||||
|
// CHECK-NEXT: %[[pred1:.*]] = cmpi eq, %[[In0]], %[[iv0]] : index
|
||||||
|
// CHECK-NEXT: %[[In1:.*]] = addi %[[S_in1]], %[[offset1]] : index
|
||||||
|
// CHECK-NEXT: %[[pred2:.*]] = cmpi eq, %[[In1]], %[[iv1]] : index
|
||||||
|
// CHECK-NEXT: %[[and1:.*]] = and %[[pred1]], %[[pred2]] : i1
|
||||||
|
// CHECK-NEXT: %[[selected_value:.*]] = select %[[and1]], %[[operand_val]], %[[zero]] : f32
|
||||||
|
// CHECK-NEXT: %[[prev_value:.*]] = affine.load %[[temp_output]][%[[batch0]], %[[offset0]], %[[offset1]]] : memref<5x8x6xf32>
|
||||||
|
// CHECK-NEXT: %[[final_value:.*]] = addf %[[selected_value]], %[[prev_value]] : f32
|
||||||
|
// CHECK-NEXT: affine.store %[[final_value]], %[[temp_output]][%[[batch0]], %[[offset0]], %[[offset1]]] : memref<5x8x6xf32>
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
|
||||||
|
// Test case 3: Checks for multi-dimensional start_indices with multi-dimensional batch size. This also tests for f16 type.
|
||||||
|
// CHECK-LABEL: func @gather_3
|
||||||
|
// CHECK-SAME: (%[[OPERAND:.*]]: memref<16x11xf16>, %[[START_INDICES:.*]]: memref<4x2x5xi32>, %[[OUTPUT:.*]]: memref<4x5x8x6xf16>)
|
||||||
|
func @gather_3(%arg0: memref<16x11xf16>, %arg1: memref<4x2x5xi32>, %arg2: memref<4x5x8x6xf16>) {
|
||||||
|
%0 = memref.alloc() : memref<4x5x8x6xf16>
|
||||||
|
"lmhlo.gather"(%arg0, %arg1, %0) {dimension_numbers = { collapsed_slice_dims = dense<-1> : tensor<1xi64>,
|
||||||
|
index_vector_dim = 1 : i64,
|
||||||
|
offset_dims = dense<[2,3]> : tensor<2xi64>,
|
||||||
|
start_index_map = dense<[0,1]> : tensor<2xi64>},
|
||||||
|
indices_are_sorted = false, name = "gather.381", slice_sizes = dense<[8, 6]> : tensor<2xi64>} :
|
||||||
|
(memref<16x11xf16>, memref<4x2x5xi32>, memref<4x5x8x6xf16>) -> ()
|
||||||
|
"lmhlo.copy"(%0, %arg2) : (memref<4x5x8x6xf16>, memref<4x5x8x6xf16>) -> ()
|
||||||
|
"lmhlo.terminator"() : () -> ()
|
||||||
|
}
|
||||||
|
// CHECK-NEXT: %[[zero:.*]] = constant 0.000000e+00 : f16
|
||||||
|
// CHECK-NEXT: %c0 = constant 0 : index
|
||||||
|
// CHECK-NEXT: %c1 = constant 1 : index
|
||||||
|
// CHECK-NEXT: %[[temp_output:.*]] = memref.alloc() : memref<4x5x8x6xf16>
|
||||||
|
// CHECK-NEXT: affine.for %{{.*}} = 0 to 4 {
|
||||||
|
// CHECK-NEXT: affine.for %{{.*}} = 0 to 5 {
|
||||||
|
// CHECK-NEXT: affine.for %{{.*}} = 0 to 8 {
|
||||||
|
// CHECK-NEXT: affine.for %{{.*}} = 0 to 6 {
|
||||||
|
// CHECK-NEXT: affine.store %[[zero]], %[[temp_output]][%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : memref<4x5x8x6xf16>
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: affine.for %[[batch0:.*]] = 0 to 4 {
|
||||||
|
// CHECK-NEXT: affine.for %[[batch1:.*]] = 0 to 5 {
|
||||||
|
// CHECK-NEXT: affine.for %[[offset0:.*]] = 0 to 8 {
|
||||||
|
// CHECK-NEXT: affine.for %[[offset1:.*]] = 0 to 6 {
|
||||||
|
// CHECK-NEXT: affine.for %[[iv0:.*]] = 0 to 16 {
|
||||||
|
// CHECK-NEXT: affine.for %[[iv1:.*]] = 0 to 11 {
|
||||||
|
// CHECK-NEXT: %[[a:.*]] = affine.load %[[START_INDICES]][%[[batch0]], %c0, %[[batch1]]] : memref<4x2x5xi32>
|
||||||
|
// CHECK-NEXT: %[[S_in0:.*]] = index_cast %[[a]] : i32 to index
|
||||||
|
// CHECK-NEXT: %[[b:.*]] = affine.load %[[START_INDICES]][%[[batch0]], %c1, %[[batch1]]] : memref<4x2x5xi32>
|
||||||
|
// CHECK-NEXT: %[[S_in1:.*]] = index_cast %[[b]] : i32 to index
|
||||||
|
// CHECK-NEXT: %[[operand_val:.*]] = affine.load %[[OPERAND]][%[[iv0]], %[[iv1]]] : memref<16x11xf16>
|
||||||
|
// CHECK-NEXT: %[[In0:.*]] = addi %[[S_in0]], %[[offset0]] : index
|
||||||
|
// CHECK-NEXT: %[[pred1:.*]] = cmpi eq, %[[In0]], %[[iv0]] : index
|
||||||
|
// CHECK-NEXT: %[[In1:.*]] = addi %[[S_in1]], %[[offset1]] : index
|
||||||
|
// CHECK-NEXT: %[[pred2:.*]] = cmpi eq, %[[In1]], %[[iv1]] : index
|
||||||
|
// CHECK-NEXT: %[[and1:.*]] = and %[[pred1]], %[[pred2]] : i1
|
||||||
|
// CHECK-NEXT: %[[selected_value:.*]] = select %[[and1]], %[[operand_val]], %[[zero]] : f16
|
||||||
|
// CHECK-NEXT: %[[prev_value:.*]] = affine.load %[[temp_output]][%[[batch0]], %[[batch1]], %[[offset0]], %[[offset1]]] : memref<4x5x8x6xf16>
|
||||||
|
// CHECK-NEXT: %[[final_value:.*]] = addf %[[selected_value]], %[[prev_value]] : f16
|
||||||
|
// CHECK-NEXT: affine.store %[[final_value]], %[[temp_output]][%[[batch0]], %[[batch1]], %[[offset0]], %[[offset1]]] : memref<4x5x8x6xf16>
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
|
||||||
|
// Test case 4: Changing starting_index_map : X -> [0,X]
|
||||||
|
// CHECK-LABEL: func @gather_4
|
||||||
|
// CHECK-SAME: (%[[OPERAND:.*]]: memref<16x11xf32>, %[[START_INDICES:.*]]: memref<5x4xi32>, %[[OUTPUT:.*]]: memref<4x5x6xf32>)
|
||||||
|
func @gather_4(%arg0: memref<16x11xf32>, %arg1: memref<5x4xi32>, %arg2: memref<4x5x6xf32>) {
|
||||||
|
%0 = memref.alloc() : memref<4x5x6xf32>
|
||||||
|
"lmhlo.gather"(%arg0, %arg1, %0) {dimension_numbers = { collapsed_slice_dims = dense<0> : tensor<1xi64>,
|
||||||
|
index_vector_dim = 2 : i64,
|
||||||
|
offset_dims = dense<2> : tensor<1xi64>,
|
||||||
|
start_index_map = dense<0> : tensor<1xi64>},
|
||||||
|
indices_are_sorted = false, name = "gather.381", slice_sizes = dense<[1, 6]> : tensor<2xi64>} :
|
||||||
|
(memref<16x11xf32>, memref<5x4xi32>, memref<4x5x6xf32>) -> ()
|
||||||
|
"lmhlo.copy"(%0, %arg2) : (memref<4x5x6xf32>, memref<4x5x6xf32>) -> ()
|
||||||
|
"lmhlo.terminator"() : () -> ()
|
||||||
|
}
|
||||||
|
// CHECK-NEXT: %[[zero:.*]] = constant 0.000000e+00 : f32
|
||||||
|
// CHECK-NEXT: %[[temp_output:.*]] = memref.alloc() : memref<4x5x6xf32>
|
||||||
|
// CHECK-NEXT: affine.for %{{.*}} = 0 to 4 {
|
||||||
|
// CHECK-NEXT: affine.for %{{.*}} = 0 to 5 {
|
||||||
|
// CHECK-NEXT: affine.for %{{.*}} = 0 to 6 {
|
||||||
|
// CHECK-NEXT: affine.store %[[zero]], %[[temp_output]][%{{.*}}, %{{.*}}, %{{.*}}] : memref<4x5x6xf32>
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: affine.for %[[batch0:.*]] = 0 to 5 {
|
||||||
|
// CHECK-NEXT: affine.for %[[batch1:.*]] = 0 to 4 {
|
||||||
|
// CHECK-NEXT: affine.for %[[offset0:.*]] = 0 to 6 {
|
||||||
|
// CHECK-NEXT: affine.for %[[iv0:.*]] = 0 to 16 {
|
||||||
|
// CHECK-NEXT: %[[a:.*]] = affine.load %[[START_INDICES]][%[[batch0]], %[[batch1]]] : memref<5x4xi32>
|
||||||
|
// CHECK-NEXT: %[[S_in0:.*]] = index_cast %[[a]] : i32 to index
|
||||||
|
// CHECK-NEXT: %[[operand_val:.*]] = affine.load %[[OPERAND]][%[[iv0]], %[[offset0]]] : memref<16x11xf32>
|
||||||
|
// CHECK-NEXT: %[[pred:.*]] = cmpi eq, %[[S_in0]], %[[iv0]] : index
|
||||||
|
// CHECK-NEXT: %[[selected_value:.*]] = select %[[pred]], %[[operand_val]], %[[zero]] : f32
|
||||||
|
// CHECK-NEXT: %[[prev_value:.*]] = affine.load %[[temp_output]][%[[batch0]], %[[batch1]], %[[offset0]]] : memref<4x5x6xf32>
|
||||||
|
// CHECK-NEXT: %[[final_value:.*]] = addf %[[selected_value]], %[[prev_value]] : f32
|
||||||
|
// CHECK-NEXT: affine.store %[[final_value]], %[[temp_output]][%[[batch0]], %[[batch1]], %[[offset0]]] : memref<4x5x6xf32>
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
|
||||||
|
// Test case 5: Testing for more than two equality checks.
|
||||||
|
// CHECK-LABEL: func @gather_5
|
||||||
|
func @gather_5(%arg0: memref<28996x512x256xf32>, %arg1: memref<10x3xi32>, %arg2: memref<10x20x10x5xf32>) {
|
||||||
|
%0 = memref.alloc() : memref<10x20x10x5xf32>
|
||||||
|
"lmhlo.gather"(%arg0, %arg1, %0) {dimension_numbers = { collapsed_slice_dims = dense<-1> : tensor<1xi64>,
|
||||||
|
index_vector_dim = 1 : i64,
|
||||||
|
offset_dims = dense<[1,2,3]> : tensor<3xi64>,
|
||||||
|
start_index_map = dense<[0,1,2]> : tensor<3xi64>},
|
||||||
|
indices_are_sorted = false, name = "gather.381", slice_sizes = dense<[20, 10, 5]> : tensor<3xi64>} :
|
||||||
|
(memref<28996x512x256xf32>, memref<10x3xi32>, memref<10x20x10x5xf32>) -> ()
|
||||||
|
"lmhlo.copy"(%0, %arg2) : (memref<10x20x10x5xf32>, memref<10x20x10x5xf32>) -> ()
|
||||||
|
"lmhlo.terminator"() : () -> ()
|
||||||
|
}
|
||||||
|
// CHECK: %[[and1:.*]] = and %{{.*}}, %{{.*}} : i1
|
||||||
|
// CHECK-NEXT: and %[[and1]], %{{.*}} : i1
|
||||||
|
|
Loading…
Reference in New Issue