PR #50073: [MLIR] Add GatherOp lowering from lmhlo to Affine.

Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/50073

-- Lowering of `GatherOp` is added from lmhlo to Affine. The lowering
   has been added as a part of `lhlo-legalize-to-affine` pass.

Signed-off-by: Abhishek Varma <abhishek.varma@polymagelabs.com>
Copybara import of the project:

--
5b3dcd4ab31a69f305cd079b869ee35ba6dc8bf5 by Abhishek Varma <abhishek.varma@polymagelabs.com>:

[MLIR] Add GatherOp lowering from lmhlo to Affine.

-- Lowering of `GatherOp` is added from lmhlo to Affine. The lowering
   has been added as a part of `lhlo-legalize-to-affine` pass.

Signed-off-by: Abhishek Varma <abhishek.varma@polymagelabs.com>
PiperOrigin-RevId: 380052157
This commit is contained in:
Abhishek Varma 2021-06-17 14:54:39 -07:00 committed by TensorFlow MLIR Team
parent 2ab16024cf
commit da6593e960
2 changed files with 503 additions and 1 deletions

View File

@ -177,6 +177,306 @@ struct ConcatOpConverter : public OpRewritePattern<ConcatenateOp> {
}
};
/// Returns a zero value of type `type`. `type` is expected to be either
/// int or float.
static Value getZeroValue(Type type, Location loc, PatternRewriter& rewriter) {
assert(type.isIntOrFloat() && "Expected int or float");
if (IntegerType intType = type.dyn_cast<IntegerType>())
return rewriter.create<mlir::ConstantIntOp>(loc, 0, intType.getWidth());
FloatType floatType = type.cast<FloatType>();
return rewriter.create<mlir::ConstantFloatOp>(
loc, APFloat::getZero(floatType.getFloatSemantics()), floatType);
}
/// Emits a nest to fill the given `buffer` of memref type with `fillValue`.
static void fillBuffer(Location loc, Value buffer, Value fillValue,
PatternRewriter& builder) {
OpBuilder::InsertionGuard guard(builder);
MemRefType bufType = buffer.getType().cast<MemRefType>();
unsigned rank = bufType.getRank();
SmallVector<Value, 4> dimSizes;
dimSizes.reserve(rank);
for (unsigned i = 0; i < rank; ++i)
dimSizes.push_back(builder.create<mlir::memref::DimOp>(loc, buffer, i));
AffineMap idSymMap = builder.getSymbolIdentityMap();
AffineMap lbMap = builder.getConstantAffineMap(0);
SmallVector<Value, 4> ivs(rank);
AffineForOp forOp;
for (unsigned i = 0; i < rank; ++i) {
forOp = builder.create<AffineForOp>(loc, llvm::None, lbMap, dimSizes[i],
idSymMap);
builder.setInsertionPointToStart(forOp.getBody());
ivs[i] = forOp.getInductionVar();
}
Type fillValueType = fillValue.getType();
auto fillMemRefType = fillValueType.dyn_cast<MemRefType>();
assert(((fillMemRefType && fillMemRefType.getRank() == 0) ||
fillValueType.isIntOrFloat()) &&
"init value has to be a 0-d memref or int or fp");
Value initVal = fillMemRefType ? builder.create<AffineLoadOp>(
loc, fillValue, /*indices=*/llvm::None)
: fillValue;
builder.create<AffineStoreOp>(loc, initVal, buffer, ivs);
}
/// Converts GatherOp to Affine nest form.
/// Pseudocode:
/// 1. Fill a temporary output tensor with 0.
/// 2. Repeat the following for each batch dimension :-
/// 1. For each indices in 'operand' :-
/// 1. Get hold of start indices from 'start_indices'.
/// 2. Add offset to the start indices to get the final indices.
/// 3. Load value from 'operand' tensor : 'operand_val'.
/// 4. Load value from temporary output : 'prev_val'.
/// 5. If the final indices match current indices of 'operand' :
/// 'prev_val' = 'prev_val' + 'operand_val'
/// 6. Store 'prev_val' back to the temporary output.
class GatherOpConverter : public OpRewritePattern<GatherOp> {
public:
using OpRewritePattern<GatherOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GatherOp op,
PatternRewriter& rewriter) const final {
Location loc = op.getLoc();
// Operand array.
Value operand = op.operand();
MemRefType operand_type = operand.getType().cast<MemRefType>();
unsigned operand_rank = operand_type.getRank();
ArrayRef<int64_t> operand_shape = operand_type.getShape();
// Start_indices array.
Value start_indices = op.start_indices();
MemRefType start_indices_type = start_indices.getType().cast<MemRefType>();
unsigned start_indices_rank = start_indices_type.getRank();
ArrayRef<int64_t> start_indices_shape = start_indices_type.getShape();
// Output array.
Value output = op.output();
MemRefType output_type = output.getType().cast<MemRefType>();
ArrayRef<int64_t> output_shape = output_type.getShape();
if (!operand_type.hasStaticShape() ||
!start_indices_type.hasStaticShape() || !output_type.hasStaticShape())
return rewriter.notifyMatchFailure(op, "only static shaped type allowed");
mhlo::GatherDimensionNumbers gather_dim = op.dimension_numbersAttr();
// Collapsed_slice_dim.
DenseIntElementsAttr collapsed_slice_dims_attr =
gather_dim.collapsed_slice_dims();
SmallVector<int64_t, 4> collapsed_slice_dims;
for (const APInt& dim : collapsed_slice_dims_attr.getIntValues())
collapsed_slice_dims.push_back(dim.getSExtValue());
// Offset_dim.
DenseIntElementsAttr offset_dims_attr = gather_dim.offset_dims();
SmallVector<int64_t, 4> offset_dims;
for (const APInt& dim : offset_dims_attr.getIntValues())
offset_dims.push_back(dim.getSExtValue());
// Start_index_map.
DenseIntElementsAttr start_index_map_attr = gather_dim.start_index_map();
SmallVector<int64_t, 4> start_index_map;
for (const APInt& dim : start_index_map_attr.getIntValues())
start_index_map.push_back(dim.getSExtValue());
// Index_vector_dim.
IntegerAttr index_vector_dim_attr = gather_dim.index_vector_dim();
int64_t index_vector_dim = index_vector_dim_attr.getValue().getSExtValue();
// Slice_sizes.
DenseIntElementsAttr slice_sizes_attr = op.slice_sizesAttr();
SmallVector<int64_t, 4> slice_sizes;
for (const APInt& dim : slice_sizes_attr.getIntValues())
slice_sizes.push_back(dim.getSExtValue());
// Creating constants with 0 value. We need the Integer type constant value
// because the indices type will be Integer.
Value zero_int_val = rewriter.create<mlir::ConstantIntOp>(
loc, 0, start_indices_type.getElementType());
Type element_type = output_type.getElementType();
Value zero_load_value = getZeroValue(element_type, loc, rewriter);
// Initializing the output buffer with 0.
fillBuffer(loc, output, zero_load_value, rewriter);
// We fetch the shape of start_indices at index_vector_dim. In case
// index_vector_dim is equal to the rank of start_indices, we implicitly
// consider start_indices to have a trailing 1 dimension.
unsigned start_indices_numbers =
(index_vector_dim == start_indices_rank)
? 1
: start_indices_shape[index_vector_dim];
// We create integer constants till start_incides_index which help us to
// fetch start_indices in affine transformation.
SmallVector<Value, 4> start_indices_index;
for (unsigned i = 0; i < start_indices_numbers; i++) {
Value i_val = rewriter.create<mlir::ConstantIntOp>(
loc, i, start_indices_type.getElementType());
i_val = rewriter.create<IndexCastOp>(loc, i_val, rewriter.getIndexType());
start_indices_index.push_back(i_val);
}
// S_in contains the multiple indices that form a starting index used in the
// input/operand tensor. O_in contains the multiple offsets of corresponding
// starting index used in the input/operand tensor. We initialize both of
// them with 0.
SmallVector<Value, 4> S_in;
SmallVector<Value, 4> O_in;
Value zero_index_val = rewriter.create<IndexCastOp>(
loc, zero_int_val, rewriter.getIndexType());
for (unsigned i = 0; i < operand_rank; i++) {
S_in.push_back(zero_index_val);
O_in.push_back(zero_index_val);
}
// batch_induction_vars stores the loop induction variables pertaining to
// the batches of start indices.
SmallVector<Value, 4> batch_induction_vars;
// output_induction_vars stores the loop induction variables pertaining to
// both batches and offsets within the output tensor.
SmallVector<Value, 4> output_induction_vars;
// Create loops to iterate over each batch of starting index.
for (unsigned i = 0; i < start_indices_rank; i++) {
// ith dimension of start_indices doesn't form a batch if it is equal to
// index_vector_dim.
if (i == index_vector_dim) continue;
AffineForOp for_op =
rewriter.create<AffineForOp>(loc, 0, start_indices_shape[i]);
batch_induction_vars.push_back(for_op.getInductionVar());
output_induction_vars.push_back(for_op.getInductionVar());
rewriter.setInsertionPointToStart(for_op.getBody());
}
// Create loops to iterate over each offset dimension within the output
// tensor.
for (unsigned i = 0, k = 0, e = offset_dims.size(); i < e; i++) {
AffineForOp for_op =
rewriter.create<AffineForOp>(loc, 0, output_shape[offset_dims[i]]);
rewriter.setInsertionPointToStart(for_op.getBody());
// We try to fetch the first non-collapsed dimension.
while (k < collapsed_slice_dims.size() && collapsed_slice_dims[k] == i)
k++;
// Remapping the offset_dim[i] to the non-collapsed dimension.
O_in[k++] = for_op.getInductionVar();
// We assume offset_dims to be sorted. So when inserted to
// output_induction_vars the loop induction variable gets inserted at the
// correct position.
output_induction_vars.insert(
output_induction_vars.begin() + offset_dims[i],
for_op.getInductionVar());
}
// Create loops to iterate over all dimensions within the operand tensor.
SmallVector<Value, 4> operand_index;
for (unsigned i = 0, k = 0; i < operand_rank; i++) {
// We assume start_index_map to have sorted dimensions. We only include
// those dimensions of operand tensor which are present in
// start_index_map.
if (k < start_index_map.size() && i == start_index_map[k++]) {
AffineForOp for_op =
rewriter.create<AffineForOp>(loc, 0, operand_shape[i]);
operand_index.push_back(for_op.getInductionVar());
rewriter.setInsertionPointToStart(for_op.getBody());
} else {
operand_index.push_back(O_in[i]);
}
}
// In case index_vector_dim is not equal to start_indices shape then we
// create another loop to iterate over starting index and update
// batch_induction_vars.
if (index_vector_dim != start_indices_rank) {
for (unsigned i = 0; i < start_indices_numbers; i++) {
batch_induction_vars.insert(
batch_induction_vars.begin() + index_vector_dim,
start_indices_index[i]);
Value start_index = rewriter.create<AffineLoadOp>(loc, start_indices,
batch_induction_vars);
start_index = rewriter.create<IndexCastOp>(loc, start_index,
rewriter.getIndexType());
S_in[start_index_map[i]] = start_index;
batch_induction_vars.erase(batch_induction_vars.begin() +
index_vector_dim);
}
} else {
// Since index_vector_dim is equal to start_indicesRank we can directly
// fetch the start_index from batch_induction_vars.
Value start_index = rewriter.create<AffineLoadOp>(loc, start_indices,
batch_induction_vars);
start_index = rewriter.create<IndexCastOp>(loc, start_index,
rewriter.getIndexType());
S_in[0] = start_index;
}
// We load value at a particular operand index and populate the output
// tensor if the index constraints match.
Value load_value =
rewriter.create<AffineLoadOp>(loc, operand, operand_index);
SmallVector<Value, 4> predicates;
// Adding offsets to the corresponding starting index and comparing it with
// the corresponding operand index.
for (unsigned k = 0, i = 0; k < start_index_map.size(); k++) {
i = start_index_map[k];
Value add_start_index_offset = rewriter.create<mlir::AddIOp>(
loc, rewriter.getIndexType(), S_in[i], O_in[i]);
Value predicate = rewriter.create<mlir::CmpIOp>(
loc, CmpIPredicate::eq, add_start_index_offset, operand_index[i]);
predicates.push_back(predicate);
}
// Since the no. of predicates is equal to start_index_map.size() we
// iterate over pairs of predicates and join them with AndOp.
unsigned num_equality_checks = start_index_map.size() / 2;
// We store the final predicate formed by joining other predicates with
// AndOp in result_predicate.
Value result_predicate = nullptr;
for (unsigned i = 0; i < num_equality_checks; i += 2) {
Value predicateA = predicates[i];
Value predicateB = predicates[i + 1];
Value and_predicate =
rewriter.create<mlir::AndOp>(loc, predicateA, predicateB);
result_predicate = (i == 0) ? and_predicate
: rewriter.create<mlir::AndOp>(
loc, result_predicate, and_predicate);
}
// We fetch the last predicate value. In case this is the only predicate
// we let result_predicate be equal to this predicate value. Else if there
// are odd number of predicates we join it to other predicates using AndOp.
Value predicate = predicates.back();
if (!result_predicate) result_predicate = predicate;
// In case there are odd number of predicates we join the last predicate
// to the result_predicate using AndOp.
else if (start_index_map.size() % 2 == 1)
result_predicate =
rewriter.create<mlir::AndOp>(loc, result_predicate, predicate);
// We use the loaded value if the index computed by adding offsets to
// starting index is equal to the current operand index. We use 0 as a value
// otherwise.
Value select_load = rewriter.create<mlir::SelectOp>(
loc, result_predicate, load_value, zero_load_value);
// We load value at output array.
Value output_value =
rewriter.create<AffineLoadOp>(loc, output, output_induction_vars);
// The selected value is added to the previous value stored in output array.
if (element_type.isa<FloatType>())
output_value =
rewriter.create<AddFOp>(loc, element_type, select_load, output_value);
else
output_value =
rewriter.create<AddIOp>(loc, element_type, select_load, output_value);
rewriter.create<AffineStoreOp>(loc, output_value, output,
output_induction_vars);
rewriter.eraseOp(op);
return success();
}
};
template <typename LhloOpTy>
struct BinaryOpConverter : public OpRewritePattern<LhloOpTy> {
using OpRewritePattern<LhloOpTy>::OpRewritePattern;
@ -225,7 +525,8 @@ void populateLHLOToAffineConversionPattern(MLIRContext* context,
BinaryOpConverter<lmhlo::MulOp>,
BinaryOpConverter<lmhlo::SubOp>,
ConcatOpConverter,
DotOpConverter>(context);
DotOpConverter,
GatherOpConverter>(context);
// clang-format on
}

View File

@ -238,3 +238,204 @@ func @concatenate_dynamic(%arg0: memref<1x?xf32>, %arg1: memref<1x?xf32>, %arg2:
"lmhlo.copy"(%0, %arg2) : (memref<1x?xf32>, memref<1x?xf32>) -> ()
"lmhlo.terminator"() : () -> ()
}
// Gather op.
// Test case 1: A general GatherOp test case.
// CHECK-LABEL: func @gather_1
// CHECK-SAME: (%[[OPERAND:.*]]: memref<28996x512xf32>, %[[START_INDICES:.*]]: memref<1x128xi32>, %[[OUTPUT:.*]]: memref<1x128x512xf32>)
func @gather_1(%arg0: memref<28996x512xf32>, %arg1: memref<1x128xi32>, %arg2: memref<1x128x512xf32>) {
%0 = memref.alloc() : memref<1x128x512xf32>
"lmhlo.gather"(%arg0, %arg1, %0) {dimension_numbers = { collapsed_slice_dims = dense<0> : tensor<1xi64>,
index_vector_dim = 2 : i64,
offset_dims = dense<2> : tensor<1xi64>,
start_index_map = dense<0> : tensor<1xi64>},
indices_are_sorted = false, name = "gather.381", slice_sizes = dense<[1, 512]> : tensor<2xi64>} :
(memref<28996x512xf32>, memref<1x128xi32>, memref<1x128x512xf32>) -> ()
"lmhlo.copy"(%0, %arg2) : (memref<1x128x512xf32>, memref<1x128x512xf32>) -> ()
"lmhlo.terminator"() : () -> ()
}
// CHECK-NEXT: %[[zero:.*]] = constant 0.000000e+00 : f32
// CHECK-NEXT: %[[temp_output:.*]] = memref.alloc() : memref<1x128x512xf32>
// CHECK-NEXT: affine.for %{{.*}} = 0 to 1 {
// CHECK-NEXT: affine.for %{{.*}} = 0 to 128 {
// CHECK-NEXT: affine.for %{{.*}} = 0 to 512 {
// CHECK-NEXT: affine.store %[[zero]], %[[temp_output]][%{{.*}}, %{{.*}}, %{{.*}}] : memref<1x128x512xf32>
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: affine.for %[[batch0:.*]] = 0 to 1 {
// CHECK-NEXT: affine.for %[[batch1:.*]] = 0 to 128 {
// CHECK-NEXT: affine.for %[[offset0:.*]] = 0 to 512 {
// CHECK-NEXT: affine.for %[[iv0:.*]] = 0 to 28996 {
// CHECK-NEXT: %[[a:.*]] = affine.load %[[START_INDICES]][%[[batch0]], %[[batch1]]] : memref<1x128xi32>
// CHECK-NEXT: %[[S_in0:.*]] = index_cast %[[a]] : i32 to index
// CHECK-NEXT: %[[operand_val:.*]] = affine.load %[[OPERAND]][%[[iv0]], %[[offset0]]] : memref<28996x512xf32>
// CHECK-NEXT: %[[pred:.*]] = cmpi eq, %[[S_in0]], %[[iv0]] : index
// CHECK-NEXT: %[[selected_value:.*]] = select %[[pred]], %[[operand_val]], %[[zero]] : f32
// CHECK-NEXT: %[[prev_value:.*]] = affine.load %[[temp_output]][%[[batch0]], %[[batch1]], %[[offset0]]] : memref<1x128x512xf32>
// CHECK-NEXT: %[[final_value:.*]] = addf %[[selected_value]], %[[prev_value]] : f32
// CHECK-NEXT: affine.store %[[final_value]], %[[temp_output]][%[[batch0]], %[[batch1]], %[[offset0]]] : memref<1x128x512xf32>
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// Test case 2: Checks for multi-dimensional starting indices.
// CHECK-LABEL: func @gather_2
// CHECK-SAME: (%[[OPERAND:.*]]: memref<16x11xf32>, %[[START_INDICES:.*]]: memref<5x2xi32>, %[[OUTPUT:.*]]: memref<5x8x6xf32>)
func @gather_2(%arg0: memref<16x11xf32>, %arg1: memref<5x2xi32>, %arg2: memref<5x8x6xf32>) {
%0 = memref.alloc() : memref<5x8x6xf32>
"lmhlo.gather"(%arg0, %arg1, %0) {dimension_numbers = { collapsed_slice_dims = dense<-1> : tensor<1xi64>,
index_vector_dim = 1 : i64,
offset_dims = dense<[1,2]> : tensor<2xi64>,
start_index_map = dense<[0,1]> : tensor<2xi64>},
indices_are_sorted = false, name = "gather.381", slice_sizes = dense<[8, 6]> : tensor<2xi64>} :
(memref<16x11xf32>, memref<5x2xi32>, memref<5x8x6xf32>) -> ()
"lmhlo.copy"(%0, %arg2) : (memref<5x8x6xf32>, memref<5x8x6xf32>) -> ()
"lmhlo.terminator"() : () -> ()
}
// CHECK-NEXT: %[[zero:.*]] = constant 0.000000e+00 : f32
// CHECK-NEXT: %c0 = constant 0 : index
// CHECK-NEXT: %c1 = constant 1 : index
// CHECK-NEXT: %[[temp_output:.*]] = memref.alloc() : memref<5x8x6xf32>
// CHECK-NEXT: affine.for %{{.*}} = 0 to 5 {
// CHECK-NEXT: affine.for %{{.*}} = 0 to 8 {
// CHECK-NEXT: affine.for %{{.*}} = 0 to 6 {
// CHECK-NEXT: affine.store %[[zero]], %[[temp_output]][%{{.*}}, %{{.*}}, %{{.*}}] : memref<5x8x6xf32>
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: affine.for %[[batch0:.*]] = 0 to 5 {
// CHECK-NEXT: affine.for %[[offset0:.*]] = 0 to 8 {
// CHECK-NEXT: affine.for %[[offset1:.*]] = 0 to 6 {
// CHECK-NEXT: affine.for %[[iv0:.*]] = 0 to 16 {
// CHECK-NEXT: affine.for %[[iv1:.*]] = 0 to 11 {
// CHECK-NEXT: %[[a:.*]] = affine.load %[[START_INDICES]][%[[batch0]], %c0] : memref<5x2xi32>
// CHECK-NEXT: %[[S_in0:.*]] = index_cast %[[a]] : i32 to index
// CHECK-NEXT: %[[b:.*]] = affine.load %[[START_INDICES]][%[[batch0]], %c1] : memref<5x2xi32>
// CHECK-NEXT: %[[S_in1:.*]] = index_cast %[[b]] : i32 to index
// CHECK-NEXT: %[[operand_val:.*]] = affine.load %[[OPERAND]][%[[iv0]], %[[iv1]]] : memref<16x11xf32>
// CHECK-NEXT: %[[In0:.*]] = addi %[[S_in0]], %[[offset0]] : index
// CHECK-NEXT: %[[pred1:.*]] = cmpi eq, %[[In0]], %[[iv0]] : index
// CHECK-NEXT: %[[In1:.*]] = addi %[[S_in1]], %[[offset1]] : index
// CHECK-NEXT: %[[pred2:.*]] = cmpi eq, %[[In1]], %[[iv1]] : index
// CHECK-NEXT: %[[and1:.*]] = and %[[pred1]], %[[pred2]] : i1
// CHECK-NEXT: %[[selected_value:.*]] = select %[[and1]], %[[operand_val]], %[[zero]] : f32
// CHECK-NEXT: %[[prev_value:.*]] = affine.load %[[temp_output]][%[[batch0]], %[[offset0]], %[[offset1]]] : memref<5x8x6xf32>
// CHECK-NEXT: %[[final_value:.*]] = addf %[[selected_value]], %[[prev_value]] : f32
// CHECK-NEXT: affine.store %[[final_value]], %[[temp_output]][%[[batch0]], %[[offset0]], %[[offset1]]] : memref<5x8x6xf32>
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// Test case 3: Checks for multi-dimensional start_indices with multi-dimensional batch size. This also tests for f16 type.
// CHECK-LABEL: func @gather_3
// CHECK-SAME: (%[[OPERAND:.*]]: memref<16x11xf16>, %[[START_INDICES:.*]]: memref<4x2x5xi32>, %[[OUTPUT:.*]]: memref<4x5x8x6xf16>)
func @gather_3(%arg0: memref<16x11xf16>, %arg1: memref<4x2x5xi32>, %arg2: memref<4x5x8x6xf16>) {
%0 = memref.alloc() : memref<4x5x8x6xf16>
"lmhlo.gather"(%arg0, %arg1, %0) {dimension_numbers = { collapsed_slice_dims = dense<-1> : tensor<1xi64>,
index_vector_dim = 1 : i64,
offset_dims = dense<[2,3]> : tensor<2xi64>,
start_index_map = dense<[0,1]> : tensor<2xi64>},
indices_are_sorted = false, name = "gather.381", slice_sizes = dense<[8, 6]> : tensor<2xi64>} :
(memref<16x11xf16>, memref<4x2x5xi32>, memref<4x5x8x6xf16>) -> ()
"lmhlo.copy"(%0, %arg2) : (memref<4x5x8x6xf16>, memref<4x5x8x6xf16>) -> ()
"lmhlo.terminator"() : () -> ()
}
// CHECK-NEXT: %[[zero:.*]] = constant 0.000000e+00 : f16
// CHECK-NEXT: %c0 = constant 0 : index
// CHECK-NEXT: %c1 = constant 1 : index
// CHECK-NEXT: %[[temp_output:.*]] = memref.alloc() : memref<4x5x8x6xf16>
// CHECK-NEXT: affine.for %{{.*}} = 0 to 4 {
// CHECK-NEXT: affine.for %{{.*}} = 0 to 5 {
// CHECK-NEXT: affine.for %{{.*}} = 0 to 8 {
// CHECK-NEXT: affine.for %{{.*}} = 0 to 6 {
// CHECK-NEXT: affine.store %[[zero]], %[[temp_output]][%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : memref<4x5x8x6xf16>
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: affine.for %[[batch0:.*]] = 0 to 4 {
// CHECK-NEXT: affine.for %[[batch1:.*]] = 0 to 5 {
// CHECK-NEXT: affine.for %[[offset0:.*]] = 0 to 8 {
// CHECK-NEXT: affine.for %[[offset1:.*]] = 0 to 6 {
// CHECK-NEXT: affine.for %[[iv0:.*]] = 0 to 16 {
// CHECK-NEXT: affine.for %[[iv1:.*]] = 0 to 11 {
// CHECK-NEXT: %[[a:.*]] = affine.load %[[START_INDICES]][%[[batch0]], %c0, %[[batch1]]] : memref<4x2x5xi32>
// CHECK-NEXT: %[[S_in0:.*]] = index_cast %[[a]] : i32 to index
// CHECK-NEXT: %[[b:.*]] = affine.load %[[START_INDICES]][%[[batch0]], %c1, %[[batch1]]] : memref<4x2x5xi32>
// CHECK-NEXT: %[[S_in1:.*]] = index_cast %[[b]] : i32 to index
// CHECK-NEXT: %[[operand_val:.*]] = affine.load %[[OPERAND]][%[[iv0]], %[[iv1]]] : memref<16x11xf16>
// CHECK-NEXT: %[[In0:.*]] = addi %[[S_in0]], %[[offset0]] : index
// CHECK-NEXT: %[[pred1:.*]] = cmpi eq, %[[In0]], %[[iv0]] : index
// CHECK-NEXT: %[[In1:.*]] = addi %[[S_in1]], %[[offset1]] : index
// CHECK-NEXT: %[[pred2:.*]] = cmpi eq, %[[In1]], %[[iv1]] : index
// CHECK-NEXT: %[[and1:.*]] = and %[[pred1]], %[[pred2]] : i1
// CHECK-NEXT: %[[selected_value:.*]] = select %[[and1]], %[[operand_val]], %[[zero]] : f16
// CHECK-NEXT: %[[prev_value:.*]] = affine.load %[[temp_output]][%[[batch0]], %[[batch1]], %[[offset0]], %[[offset1]]] : memref<4x5x8x6xf16>
// CHECK-NEXT: %[[final_value:.*]] = addf %[[selected_value]], %[[prev_value]] : f16
// CHECK-NEXT: affine.store %[[final_value]], %[[temp_output]][%[[batch0]], %[[batch1]], %[[offset0]], %[[offset1]]] : memref<4x5x8x6xf16>
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// Test case 4: Changing starting_index_map : X -> [0,X]
// CHECK-LABEL: func @gather_4
// CHECK-SAME: (%[[OPERAND:.*]]: memref<16x11xf32>, %[[START_INDICES:.*]]: memref<5x4xi32>, %[[OUTPUT:.*]]: memref<4x5x6xf32>)
func @gather_4(%arg0: memref<16x11xf32>, %arg1: memref<5x4xi32>, %arg2: memref<4x5x6xf32>) {
%0 = memref.alloc() : memref<4x5x6xf32>
"lmhlo.gather"(%arg0, %arg1, %0) {dimension_numbers = { collapsed_slice_dims = dense<0> : tensor<1xi64>,
index_vector_dim = 2 : i64,
offset_dims = dense<2> : tensor<1xi64>,
start_index_map = dense<0> : tensor<1xi64>},
indices_are_sorted = false, name = "gather.381", slice_sizes = dense<[1, 6]> : tensor<2xi64>} :
(memref<16x11xf32>, memref<5x4xi32>, memref<4x5x6xf32>) -> ()
"lmhlo.copy"(%0, %arg2) : (memref<4x5x6xf32>, memref<4x5x6xf32>) -> ()
"lmhlo.terminator"() : () -> ()
}
// CHECK-NEXT: %[[zero:.*]] = constant 0.000000e+00 : f32
// CHECK-NEXT: %[[temp_output:.*]] = memref.alloc() : memref<4x5x6xf32>
// CHECK-NEXT: affine.for %{{.*}} = 0 to 4 {
// CHECK-NEXT: affine.for %{{.*}} = 0 to 5 {
// CHECK-NEXT: affine.for %{{.*}} = 0 to 6 {
// CHECK-NEXT: affine.store %[[zero]], %[[temp_output]][%{{.*}}, %{{.*}}, %{{.*}}] : memref<4x5x6xf32>
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: affine.for %[[batch0:.*]] = 0 to 5 {
// CHECK-NEXT: affine.for %[[batch1:.*]] = 0 to 4 {
// CHECK-NEXT: affine.for %[[offset0:.*]] = 0 to 6 {
// CHECK-NEXT: affine.for %[[iv0:.*]] = 0 to 16 {
// CHECK-NEXT: %[[a:.*]] = affine.load %[[START_INDICES]][%[[batch0]], %[[batch1]]] : memref<5x4xi32>
// CHECK-NEXT: %[[S_in0:.*]] = index_cast %[[a]] : i32 to index
// CHECK-NEXT: %[[operand_val:.*]] = affine.load %[[OPERAND]][%[[iv0]], %[[offset0]]] : memref<16x11xf32>
// CHECK-NEXT: %[[pred:.*]] = cmpi eq, %[[S_in0]], %[[iv0]] : index
// CHECK-NEXT: %[[selected_value:.*]] = select %[[pred]], %[[operand_val]], %[[zero]] : f32
// CHECK-NEXT: %[[prev_value:.*]] = affine.load %[[temp_output]][%[[batch0]], %[[batch1]], %[[offset0]]] : memref<4x5x6xf32>
// CHECK-NEXT: %[[final_value:.*]] = addf %[[selected_value]], %[[prev_value]] : f32
// CHECK-NEXT: affine.store %[[final_value]], %[[temp_output]][%[[batch0]], %[[batch1]], %[[offset0]]] : memref<4x5x6xf32>
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// Test case 5: Testing for more than two equality checks.
// CHECK-LABEL: func @gather_5
func @gather_5(%arg0: memref<28996x512x256xf32>, %arg1: memref<10x3xi32>, %arg2: memref<10x20x10x5xf32>) {
%0 = memref.alloc() : memref<10x20x10x5xf32>
"lmhlo.gather"(%arg0, %arg1, %0) {dimension_numbers = { collapsed_slice_dims = dense<-1> : tensor<1xi64>,
index_vector_dim = 1 : i64,
offset_dims = dense<[1,2,3]> : tensor<3xi64>,
start_index_map = dense<[0,1,2]> : tensor<3xi64>},
indices_are_sorted = false, name = "gather.381", slice_sizes = dense<[20, 10, 5]> : tensor<3xi64>} :
(memref<28996x512x256xf32>, memref<10x3xi32>, memref<10x20x10x5xf32>) -> ()
"lmhlo.copy"(%0, %arg2) : (memref<10x20x10x5xf32>, memref<10x20x10x5xf32>) -> ()
"lmhlo.terminator"() : () -> ()
}
// CHECK: %[[and1:.*]] = and %{{.*}}, %{{.*}} : i1
// CHECK-NEXT: and %[[and1]], %{{.*}} : i1