diff --git a/src/Conversion/ONNXToKrnl/Math/Reduction.cpp b/src/Conversion/ONNXToKrnl/Math/Reduction.cpp index 2a66c39..a21e579 100644 --- a/src/Conversion/ONNXToKrnl/Math/Reduction.cpp +++ b/src/Conversion/ONNXToKrnl/Math/Reduction.cpp @@ -37,6 +37,12 @@ Value getIdentityValue( return emitConstantOp(rewriter, loc, type, 0); } +template <> +Value getIdentityValue( + ConversionPatternRewriter &rewriter, Location loc, Type type) { + return emitConstantOp(rewriter, loc, type, 0); +} + // Scalar ops template <> struct ScalarOp { @@ -50,6 +56,50 @@ struct ScalarOp { using IOp = AddIOp; }; +template <> +struct ScalarOp { + using FOp = AddFOp; + using IOp = AddIOp; +}; + +/// Helper function to get the size of a MemRef in a given type. +Value getSizeInType(ConversionPatternRewriter &rewriter, Location loc, + Value memRef, Type elementType) { + auto shape = memRef.getType().cast().getShape(); + + // We accumulate static dimensions first and then unknown dimensions. + int64_t staticNumElement = 1; + bool allStaticDimensions = true; + + // 1. Static dimensions. + for (unsigned i = 0; i < shape.size(); i++) { + if (shape[i] != -1) + staticNumElement *= shape[i]; + else + allStaticDimensions = false; + } + // 2. Unknown dimensions. + Value sizeVal = emitConstantOp(rewriter, loc, elementType, staticNumElement); + if (!allStaticDimensions) { + for (unsigned i = 0; i < shape.size(); i++) { + if (shape[i] == -1) { + Value index = rewriter.create(loc, memRef, i); + if (elementType.isa()) { + Value dim = + rewriter.create(loc, index, rewriter.getI64Type()); + dim = rewriter.create(loc, dim, elementType); + sizeVal = rewriter.create(loc, sizeVal, dim); + } else if (elementType.isa()) { + Value dim = rewriter.create(loc, index, elementType); + sizeVal = rewriter.create(loc, sizeVal, dim); + } else + llvm_unreachable("unsupported element type"); + } + } + } + return sizeVal; +} + //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXReduceMaxOp //===----------------------------------------------------------------------===// @@ -97,8 +147,12 @@ Value emitScalarOpFor(ConversionPatternRewriter &rewriter, template struct ONNXReductionOpLowering : public ConversionPattern { - ONNXReductionOpLowering(MLIRContext *ctx) - : ConversionPattern(ONNXReductionOp::getOperationName(), 1, ctx) {} + bool computeMean = false; + + ONNXReductionOpLowering(MLIRContext *ctx, bool computeMean = false) + : ConversionPattern(ONNXReductionOp::getOperationName(), 1, ctx) { + this->computeMean = computeMean; + } LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { @@ -123,7 +177,8 @@ struct ONNXReductionOpLowering : public ConversionPattern { * */ auto loc = op->getLoc(); - auto memRefInType = operands[0].getType().cast(); + auto input = operands[0]; + auto memRefInType = input.getType().cast(); auto memRefInShape = memRefInType.getShape(); auto memRefOutType = convertToMemRefType(*op->result_type_begin()); int64_t inRank = memRefInType.getRank(); @@ -165,7 +220,7 @@ struct ONNXReductionOpLowering : public ConversionPattern { SmallVector allocOperands; for (decltype(outRank) i = 0; i < outRank; ++i) { if (memRefOutShape[i] < 0) { - auto dim = rewriter.create(loc, operands[0], outInDimMap[i]); + auto dim = rewriter.create(loc, input, outInDimMap[i]); allocOperands.push_back(dim); } } @@ -177,11 +232,12 @@ struct ONNXReductionOpLowering : public ConversionPattern { } } - // There are two Krnl loops: - // - One to initialize the result memref, and - // - One to do reduction + // There are two required and one optional Krnl loops: + // - One to initialize the result memref, + // - One to do reduction, and + // - One to compute mean (optional). - // Define loops to initialize the result. + // 1. Define loops to initialize the result. std::vector originalLoopsInit; defineLoops(rewriter, loc, originalLoopsInit, outRank); @@ -208,14 +264,15 @@ struct ONNXReductionOpLowering : public ConversionPattern { getIdentityValue(rewriter, loc, elementOutType); rewriter.create(loc, identity, alloc, loopIVs); - // Define an Krnl loop to do reduction. + // 2. Define an Krnl loop to do reduction. rewriter.setInsertionPointAfter(iterateOpInit); + auto ipMainRegion = rewriter.saveInsertionPoint(); std::vector originalLoops; defineLoops(rewriter, loc, originalLoops, inRank); // Iteration information KrnlIterateOperandPack pack(rewriter, originalLoops); for (decltype(inRank) i = 0; i < inRank; ++i) { - addDimensionToPack(rewriter, loc, pack, operands[0], i); + addDimensionToPack(rewriter, loc, pack, input, i); } auto iterateOp = rewriter.create(loc, pack); Block &iterationBlock = iterateOp.bodyRegion().front(); @@ -245,12 +302,44 @@ struct ONNXReductionOpLowering : public ConversionPattern { } Value next, accumulated; - next = rewriter.create(loc, operands[0], inLoopIVs); + next = rewriter.create(loc, input, inLoopIVs); accumulated = rewriter.create(loc, alloc, outLoopIVs); accumulated = emitScalarOpFor( rewriter, loc, op, memRefOutType.getElementType(), {accumulated, next}); rewriter.create(loc, accumulated, alloc, outLoopIVs); + // 3. Define an Krnl loop to compute mean (optional). + rewriter.restoreInsertionPoint(ipMainRegion); + if (computeMean) { + Type elementType = memRefOutType.getElementType(); + // Compute the divisor that is the number of elements participated in + // reduction, i.e., 'divisor = size of input / size of output' + Value inputSize = getSizeInType(rewriter, loc, input, elementType); + Value outputSize = getSizeInType(rewriter, loc, alloc, elementType); + Value divisor; + if (elementType.isa()) + divisor = rewriter.create(loc, inputSize, outputSize); + else if (elementType.isa()) + divisor = rewriter.create(loc, inputSize, outputSize); + else + llvm_unreachable("unsupported element type"); + + // Compute mean + BuildKrnlLoop meanLoops(rewriter, loc, outRank); + meanLoops.createDefineAndIterateOp(alloc); + rewriter.setInsertionPointToStart(meanLoops.getIterateBlock()); + auto meanIVs = meanLoops.getAllInductionVar(); + auto loadData = rewriter.create(loc, alloc, meanIVs); + Value meanVal; + if (elementType.isa()) + meanVal = rewriter.create(loc, loadData, divisor); + else if (elementType.isa()) + meanVal = rewriter.create(loc, loadData, divisor); + else + llvm_unreachable("unsupported element type"); + rewriter.create(loc, meanVal, alloc, meanIVs); + } + rewriter.replaceOp(op, alloc); return success(); } @@ -262,4 +351,6 @@ void populateLoweringONNXReductionOpPattern( ONNXReductionOpLowering, ONNXReductionOpLowering, ONNXReductionOpLowering>(ctx); + patterns.insert>( + ctx, /*computeMean=*/true); } diff --git a/src/Conversion/ONNXToKrnl/Tensor/Split.cpp b/src/Conversion/ONNXToKrnl/Tensor/Split.cpp index f647102..fbd12ed 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Split.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Split.cpp @@ -71,7 +71,6 @@ struct ONNXSplitOpLowering : public ConversionPattern { // Create loop. BuildKrnlLoop outputLoops(rewriter, loc, rank); outputLoops.createDefineAndIterateOp(allocs[i]); - outputLoops.createIterateOp(); rewriter.setInsertionPointToStart(outputLoops.getIterateBlock()); // Indices for the read and write. SmallVector readIndices; diff --git a/src/Dialect/Krnl/KrnlHelper.cpp b/src/Dialect/Krnl/KrnlHelper.cpp index b683346..a211e74 100644 --- a/src/Dialect/Krnl/KrnlHelper.cpp +++ b/src/Dialect/Krnl/KrnlHelper.cpp @@ -271,4 +271,9 @@ BlockArgument &BuildKrnlLoop::getInductionVar(int originalLoopIndex) { return iterBlock->getArguments()[originalLoopIndex]; } +ArrayRef BuildKrnlLoop::getAllInductionVar() { + return ArrayRef( + iterBlock->getArguments().begin(), iterBlock->getArguments().end()); +} + } // namespace mlir diff --git a/src/Dialect/Krnl/KrnlHelper.hpp b/src/Dialect/Krnl/KrnlHelper.hpp index a16ccd9..6d37c4b 100644 --- a/src/Dialect/Krnl/KrnlHelper.hpp +++ b/src/Dialect/Krnl/KrnlHelper.hpp @@ -186,6 +186,9 @@ public: // index. Use the index returned when pushing the bounds. BlockArgument &getInductionVar(int originalLoopIndex); + // Get all of the (original loop) induction variables. + ArrayRef getAllInductionVar(); + // Get a reference to the code region of the optimization operation. // This allows us to set the insertion point to the inner block of the // loop nest optimization operation. diff --git a/test/backend/test.py b/test/backend/test.py index f6649fc..7f59e1a 100644 --- a/test/backend/test.py +++ b/test/backend/test.py @@ -290,6 +290,16 @@ test_to_enable = [ "test_reduce_sum_square_negative_axes_keepdims_example_cpu", "test_reduce_sum_square_negative_axes_keepdims_random_cpu", + # ReduceMean + "test_reduce_mean_default_axes_keepdims_example_cpu", + "test_reduce_mean_default_axes_keepdims_random_cpu", + "test_reduce_mean_do_not_keepdims_example_cpu", + "test_reduce_mean_do_not_keepdims_random_cpu", + "test_reduce_mean_keepdims_example_cpu", + "test_reduce_mean_keepdims_random_cpu", + "test_reduce_mean_negative_axes_keepdims_example_cpu", + "test_reduce_mean_negative_axes_keepdims_random_cpu", + # Selu Op: "test_selu_cpu", "test_selu_default_cpu", diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index 6b437a5..45af9bb 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -709,7 +709,7 @@ func @test_reducemax(%arg0 : tensor<3x2x2xf32>) -> tensor<*xf32> { // CHECK: [[DEF_LOOPS2:%.+]]:3 = krnl.define_loops 3 // CHECK: krnl.iterate([[DEF_LOOPS2]]#0, [[DEF_LOOPS2]]#1, [[DEF_LOOPS2]]#2) with ([[DEF_LOOPS2]]#0 -> %arg1 = 0 to 3, [[DEF_LOOPS2]]#1 -> %arg2 = 0 to 2, [[DEF_LOOPS2]]#2 -> %arg3 = 0 to 2) { // CHECK: [[LOAD1:%.+]] = affine.load %arg0[%arg1, %arg2, %arg3] : memref<3x2x2xf32> - // CHECK: [[LOAD2:%.+]] = affine.load %0[%arg1, %arg3] : memref<3x2xf32> + // CHECK: [[LOAD2:%.+]] = affine.load [[RES]][%arg1, %arg3] : memref<3x2xf32> // CHECK: [[CMP:%.+]] = cmpf "ogt", [[LOAD2]], [[LOAD1]] : f32 // CHECK: [[SELECT:%.+]] = select [[CMP]], [[LOAD2]], [[LOAD1]] : f32 // CHECK: store [[SELECT]], [[RES]][%arg1, %arg3] : memref<3x2xf32> @@ -733,7 +733,7 @@ func @test_reducemin(%arg0 : tensor<3x2x2xf32>) -> tensor<*xf32> { // CHECK: [[DEF_LOOPS2:%.+]]:3 = krnl.define_loops 3 // CHECK: krnl.iterate([[DEF_LOOPS2]]#0, [[DEF_LOOPS2]]#1, [[DEF_LOOPS2]]#2) with ([[DEF_LOOPS2]]#0 -> %arg1 = 0 to 3, [[DEF_LOOPS2]]#1 -> %arg2 = 0 to 2, [[DEF_LOOPS2]]#2 -> %arg3 = 0 to 2) { // CHECK: [[LOAD1:%.+]] = affine.load %arg0[%arg1, %arg2, %arg3] : memref<3x2x2xf32> - // CHECK: [[LOAD2:%.+]] = affine.load %0[%arg1, %arg3] : memref<3x2xf32> + // CHECK: [[LOAD2:%.+]] = affine.load [[RES]][%arg1, %arg3] : memref<3x2xf32> // CHECK: [[CMP:%.+]] = cmpf "olt", [[LOAD2]], [[LOAD1]] : f32 // CHECK: [[SELECT:%.+]] = select [[CMP]], [[LOAD2]], [[LOAD1]] : f32 // CHECK: affine.store [[SELECT]], [[RES]][%arg1, %arg3] : memref<3x2xf32> @@ -757,7 +757,7 @@ func @test_reduceprod(%arg0 : tensor<3x2x2xf32>) -> tensor<*xf32> { // CHECK: [[DEF_LOOPS2:%.+]]:3 = krnl.define_loops 3 // CHECK: krnl.iterate([[DEF_LOOPS2]]#0, [[DEF_LOOPS2]]#1, [[DEF_LOOPS2]]#2) with ([[DEF_LOOPS2]]#0 -> %arg1 = 0 to 3, [[DEF_LOOPS2]]#1 -> %arg2 = 0 to 2, [[DEF_LOOPS2]]#2 -> %arg3 = 0 to 2) { // CHECK: [[LOAD1:%.+]] = affine.load %arg0[%arg1, %arg2, %arg3] : memref<3x2x2xf32> - // CHECK: [[LOAD2:%.+]] = affine.load %0[%arg1, %arg3] : memref<3x2xf32> + // CHECK: [[LOAD2:%.+]] = affine.load [[RES]][%arg1, %arg3] : memref<3x2xf32> // CHECK: [[REDUCE:%.+]] = mulf [[LOAD2]], [[LOAD1]] : f32 // CHECK: affine.store [[REDUCE]], [[RES]][%arg1, %arg3] : memref<3x2xf32> // CHECK: } @@ -780,13 +780,116 @@ func @test_reducesum(%arg0 : tensor<3x2x2xf32>) -> tensor<*xf32> { // CHECK: [[DEF_LOOPS2:%.+]]:3 = krnl.define_loops 3 // CHECK: krnl.iterate([[DEF_LOOPS2]]#0, [[DEF_LOOPS2]]#1, [[DEF_LOOPS2]]#2) with ([[DEF_LOOPS2]]#0 -> %arg1 = 0 to 3, [[DEF_LOOPS2]]#1 -> %arg2 = 0 to 2, [[DEF_LOOPS2]]#2 -> %arg3 = 0 to 2) { // CHECK: [[LOAD1:%.+]] = affine.load %arg0[%arg1, %arg2, %arg3] : memref<3x2x2xf32> - // CHECK: [[LOAD2:%.+]] = affine.load %0[%arg1, %arg3] : memref<3x2xf32> + // CHECK: [[LOAD2:%.+]] = affine.load [[RES]][%arg1, %arg3] : memref<3x2xf32> // CHECK: [[REDUCE:%.+]] = addf [[LOAD2]], [[LOAD1]] : f32 // CHECK: affine.store [[REDUCE]], [[RES]][%arg1, %arg3] : memref<3x2xf32> // CHECK: } // CHECK: return [[RES]] : memref<3x2xf32> } - + +// ----- + +/// Check ReduceMean with f32. +func @test_reducemean_f32(%arg0 : tensor<3x2x2xf32>) -> tensor<*xf32> { + %0 ="onnx.ReduceMean"(%arg0) {axes=[1], keepdims = 0 : si64} : (tensor<3x2x2xf32>)-> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_reducemean_f32 + // CHECK: [[RES:%.+]] = alloc() : memref<3x2xf32> + // CHECK: [[DEF_LOOPS1:%.+]]:2 = krnl.define_loops 2 + // CHECK: krnl.iterate([[DEF_LOOPS1]]#0, [[DEF_LOOPS1]]#1) with ([[DEF_LOOPS1]]#0 -> %arg1 = 0 to 3, [[DEF_LOOPS1]]#1 -> %arg2 = 0 to 2) { + // CHECK: [[IDENTITY:%.+]] = constant 0.000000e+00 : f32 + // CHECK: affine.store [[IDENTITY]], [[RES]][%arg1, %arg2] : memref<3x2xf32> + + // CHECK: [[DEF_LOOPS2:%.+]]:3 = krnl.define_loops 3 + // CHECK: krnl.iterate([[DEF_LOOPS2]]#0, [[DEF_LOOPS2]]#1, [[DEF_LOOPS2]]#2) with ([[DEF_LOOPS2]]#0 -> %arg1 = 0 to 3, [[DEF_LOOPS2]]#1 -> %arg2 = 0 to 2, [[DEF_LOOPS2]]#2 -> %arg3 = 0 to 2) { + // CHECK: [[LOAD1:%.+]] = affine.load %arg0[%arg1, %arg2, %arg3] : memref<3x2x2xf32> + // CHECK: [[LOAD2:%.+]] = affine.load [[RES]][%arg1, %arg3] : memref<3x2xf32> + // CHECK: [[REDUCE:%.+]] = addf [[LOAD2]], [[LOAD1]] : f32 + // CHECK: affine.store [[REDUCE]], [[RES]][%arg1, %arg3] : memref<3x2xf32> + // CHECK: } + + // CHECK: [[INPUT_SIZE:%.+]] = constant 1.200000e+01 : f32 + // CHECK: [[OUTPUT_SIZE:%.+]] = constant 6.000000e+00 : f32 + // CHECK: [[DIVISOR:%.+]] = divf [[INPUT_SIZE]], [[OUTPUT_SIZE]] : f32 + // CHECK: [[DEF_MEAN_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: krnl.iterate([[DEF_MEAN_LOOPS]]#0, [[DEF_MEAN_LOOPS]]#1) with ([[DEF_MEAN_LOOPS]]#0 -> %arg1 = 0 to 3, [[DEF_MEAN_LOOPS]]#1 -> %arg2 = 0 to 2) { + // CHECK: [[LOAD3:%.+]] = affine.load [[RES]][%arg1, %arg2] : memref<3x2xf32> + // CHECK: [[MEAN:%.+]] = divf [[LOAD3]], [[DIVISOR]] : f32 + // CHECK: affine.store [[MEAN]], [[RES]][%arg1, %arg2] : memref<3x2xf32> + // CHECK: } + // CHECK: return [[RES]] : memref<3x2xf32> +} + +// ----- + +/// Check ReduceMean with i32. +func @test_reducemean_i32(%arg0 : tensor<3x2x2xi32>) -> tensor<*xi32> { + %0 ="onnx.ReduceMean"(%arg0) {axes=[1], keepdims = 0 : si64} : (tensor<3x2x2xi32>)-> tensor<*xi32> + "std.return"(%0) : (tensor<*xi32>) -> () + + // CHECK-LABEL: test_reducemean_i32 + // CHECK: [[RES:%.+]] = alloc() : memref<3x2xi32> + // CHECK: [[DEF_LOOPS1:%.+]]:2 = krnl.define_loops 2 + // CHECK: krnl.iterate([[DEF_LOOPS1]]#0, [[DEF_LOOPS1]]#1) with ([[DEF_LOOPS1]]#0 -> %arg1 = 0 to 3, [[DEF_LOOPS1]]#1 -> %arg2 = 0 to 2) { + // CHECK: [[IDENTITY:%.+]] = constant 0 : i32 + // CHECK: affine.store [[IDENTITY]], [[RES]][%arg1, %arg2] : memref<3x2xi32> + + // CHECK: [[DEF_LOOPS2:%.+]]:3 = krnl.define_loops 3 + // CHECK: krnl.iterate([[DEF_LOOPS2]]#0, [[DEF_LOOPS2]]#1, [[DEF_LOOPS2]]#2) with ([[DEF_LOOPS2]]#0 -> %arg1 = 0 to 3, [[DEF_LOOPS2]]#1 -> %arg2 = 0 to 2, [[DEF_LOOPS2]]#2 -> %arg3 = 0 to 2) { + // CHECK: [[LOAD1:%.+]] = affine.load %arg0[%arg1, %arg2, %arg3] : memref<3x2x2xi32> + // CHECK: [[LOAD2:%.+]] = affine.load [[RES]][%arg1, %arg3] : memref<3x2xi32> + // CHECK: [[REDUCE:%.+]] = addi [[LOAD2]], [[LOAD1]] : i32 + // CHECK: affine.store [[REDUCE]], [[RES]][%arg1, %arg3] : memref<3x2xi32> + // CHECK: } + + // CHECK: [[INPUT_SIZE:%.+]] = constant 12 : i32 + // CHECK: [[OUTPUT_SIZE:%.+]] = constant 6 : i32 + // CHECK: [[DIVISOR:%.+]] = divi_signed [[INPUT_SIZE]], [[OUTPUT_SIZE]] : i32 + // CHECK: [[DEF_MEAN_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: krnl.iterate([[DEF_MEAN_LOOPS]]#0, [[DEF_MEAN_LOOPS]]#1) with ([[DEF_MEAN_LOOPS]]#0 -> %arg1 = 0 to 3, [[DEF_MEAN_LOOPS]]#1 -> %arg2 = 0 to 2) { + // CHECK: [[LOAD3:%.+]] = affine.load [[RES]][%arg1, %arg2] : memref<3x2xi32> + // CHECK: [[MEAN:%.+]] = divi_signed [[LOAD3]], [[DIVISOR]] : i32 + // CHECK: affine.store [[MEAN]], [[RES]][%arg1, %arg2] : memref<3x2xi32> + // CHECK: } + // CHECK: return [[RES]] : memref<3x2xi32> +} + +// ----- + +/// Check computing the divisor in ReduceMean +/// when the input has unknown dimensions and is of i32. +func @test_reducemean_i32_unknown_dims(%arg0 : tensor<3x?x2xi32>) -> tensor<*xi32> { + %0 ="onnx.ReduceMean"(%arg0) {axes=[1], keepdims = 0 : si64} : (tensor<3x?x2xi32>)-> tensor<*xi32> + "std.return"(%0) : (tensor<*xi32>) -> () + // CHECK-LABEL: test_reducemean_i32_unknown_dims + // CHECK: [[INPUT_SIZE_CONSTANT:%.+]] = constant 6 : i32 + // CHECK: [[ONE:%.+]] = constant 1 : index + // CHECK: [[DIM:%.+]] = dim %arg0, [[ONE]] : memref<3x?x2xi32> + // CHECK: [[UNKNOWN_DIM:%.+]] = index_cast [[DIM]] : index to i32 + // CHECK: [[INPUT_SIZE:%.+]] = muli [[INPUT_SIZE_CONSTANT]], [[UNKNOWN_DIM]] : i32 + // CHECK: [[OUTPUT_SIZE:%.+]] = constant 6 : i32 + // CHECK: [[DIVISOR:%.+]] = divi_signed [[INPUT_SIZE]], [[OUTPUT_SIZE]] : i32 +} + +// ----- + +/// Check computing the divisor in ReduceMean +/// when the input has unknown dimensions and is of f32. +func @test_reducemean_f32_unknown_dims(%arg0 : tensor<3x?x2xf32>) -> tensor<*xf32> { + %0 ="onnx.ReduceMean"(%arg0) {axes=[1], keepdims = 0 : si64} : (tensor<3x?x2xf32>)-> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + // CHECK-LABEL: test_reducemean_f32_unknown_dims + // CHECK: [[INPUT_SIZE_CONSTANT:%.+]] = constant 6.000000e+00 : f32 + // CHECK: [[ONE:%.+]] = constant 1 : index + // CHECK: [[DIM:%.+]] = dim %arg0, [[ONE]] : memref<3x?x2xf32> + // CHECK: [[UNKNOWN_DIM_i64:%.+]] = index_cast [[DIM]] : index to i64 + // CHECK: [[UNKNOWN_DIM:%.+]] = uitofp [[UNKNOWN_DIM_i64]] : i64 to f32 + // CHECK: [[INPUT_SIZE:%.+]] = mulf [[INPUT_SIZE_CONSTANT]], [[UNKNOWN_DIM]] : f32 + // CHECK: [[OUTPUT_SIZE:%.+]] = constant 6.000000e+00 : f32 + // CHECK: [[DIVISOR:%.+]] = divf [[INPUT_SIZE]], [[OUTPUT_SIZE]] : f32 +} + // ----- func @test_softmax(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> {