[MLIR][HLO] Support all smaller ranks in rank specialization cases
Rank specialization cases can be applied to all argument tensors of smaller ranks than the expected maximum rank. This is crucial if all operands are effectively scalars and the maximum reduced rank is 0. PiperOrigin-RevId: 375712020
This commit is contained in:
		
							parent
							
								
									c5af02fd8d
								
							
						
					
					
						commit
						cb46298a07
					
				|  | @ -423,14 +423,14 @@ Value RecusivelyMaterializeTargetRankSpecializationCases( | |||
|     OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op, | ||||
|     const SmallVector<Value, 8> &shapes, Value max_rank, | ||||
|     int64_t min_target_rank, int64_t max_target_rank) { | ||||
|   Value min_target_rank_predicate = | ||||
|       b.create<CmpIOp>(loc, CmpIPredicate::eq, max_rank, | ||||
|   Value condition = | ||||
|       b.create<CmpIOp>(loc, CmpIPredicate::ule, max_rank, | ||||
|                        b.create<ConstantIndexOp>(loc, min_target_rank)); | ||||
| 
 | ||||
|   // If only a unique target rank is left, we can lower to an assert instead
 | ||||
|   // of the usual if operation.
 | ||||
|   if (min_target_rank == max_target_rank) { | ||||
|     b.create<AssertOp>(loc, min_target_rank_predicate, | ||||
|     b.create<AssertOp>(loc, condition, | ||||
|                        "Input for dynamic binary or n-ary op lowering was of " | ||||
|                        "a rank greater than " + | ||||
|                            std::to_string(max_target_rank)); | ||||
|  | @ -439,9 +439,8 @@ Value RecusivelyMaterializeTargetRankSpecializationCases( | |||
|   } | ||||
| 
 | ||||
|   // Materialize IR for the smallest considered target rank.
 | ||||
|   auto if_op = | ||||
|       b.create<scf::IfOp>(loc, op->getResultTypes(), min_target_rank_predicate, | ||||
|                           /*withElseRegion=*/true); | ||||
|   auto if_op = b.create<scf::IfOp>(loc, op->getResultTypes(), condition, | ||||
|                                    /*withElseRegion=*/true); | ||||
|   auto then_builder = if_op.getThenBodyBuilder(); | ||||
|   then_builder.create<scf::YieldOp>( | ||||
|       loc, MaterializeTargetRankSpecializationCase(then_builder, loc, op, | ||||
|  |  | |||
|  | @ -67,8 +67,8 @@ func @add_mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, | |||
| // CHECK-SCF-DAG:    %[[R20_GT_R1:.*]] = cmpi sgt, %[[R20]], %[[REDUCED_RANK1]] | ||||
| // CHECK-SCF-DAG:    %[[MAX_RED_RANK:.*]] = select %[[R20_GT_R1]], %[[R20]], %[[REDUCED_RANK1]] | ||||
| //                 Generic case 1: | ||||
| // CHECK-SCF:        %[[MAX_RED_RANK_EQ_1:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C1]] | ||||
| // CHECK-SCF:        %[[UNSHAPED_RES_1:.*]] = scf.if %[[MAX_RED_RANK_EQ_1]] | ||||
| // CHECK-SCF:        %[[MAX_RED_RANK_LE_1:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C1]] | ||||
| // CHECK-SCF:        %[[UNSHAPED_RES_1:.*]] = scf.if %[[MAX_RED_RANK_LE_1]] | ||||
| // CHECK-SCF-DAG:      %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_1]] | ||||
| // CHECK-SCF-DAG:      %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_1]] | ||||
| // CHECK-SCF-DAG:      %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_1]] | ||||
|  | @ -84,8 +84,8 @@ func @add_mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, | |||
| // CHECK-SCF:          scf.yield %[[INNER_RES_]] | ||||
| // CHECK-SCF:        else | ||||
| //                 Generic case 2: | ||||
| // CHECK-SCF:          %[[MAX_RED_RANK_EQ_2:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C2]] | ||||
| // CHECK-SCF:          %[[UNSHAPED_RES_2:.*]] = scf.if %[[MAX_RED_RANK_EQ_2]] | ||||
| // CHECK-SCF:          %[[MAX_RED_RANK_LE_2:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C2]] | ||||
| // CHECK-SCF:          %[[UNSHAPED_RES_2:.*]] = scf.if %[[MAX_RED_RANK_LE_2]] | ||||
| // CHECK-SCF-DAG:        %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_2]] | ||||
| // CHECK-SCF-DAG:        %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_2]] | ||||
| // CHECK-SCF-DAG:        %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_2]] | ||||
|  | @ -101,8 +101,8 @@ func @add_mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, | |||
| // CHECK-SCF:            scf.yield %[[INNER_RES_]] | ||||
| // CHECK-SCF:          else | ||||
| //                 Generic case 3: | ||||
| // CHECK-SCF:            %[[MAX_RED_RANK_EQ_3:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C3]] | ||||
| // CHECK-SCF:            %[[UNSHAPED_RES_3:.*]] = scf.if %[[MAX_RED_RANK_EQ_3]] | ||||
| // CHECK-SCF:            %[[MAX_RED_RANK_LE_3:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C3]] | ||||
| // CHECK-SCF:            %[[UNSHAPED_RES_3:.*]] = scf.if %[[MAX_RED_RANK_LE_3]] | ||||
| // CHECK-SCF-DAG:          %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_3]] | ||||
| // CHECK-SCF-DAG:          %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_3]] | ||||
| // CHECK-SCF-DAG:          %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_3]] | ||||
|  | @ -118,8 +118,8 @@ func @add_mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, | |||
| // CHECK-SCF:              scf.yield %[[INNER_RES_]] | ||||
| // CHECK-SCF:            else | ||||
| //                 Generic case 4: | ||||
| // CHECK-SCF:              %[[MAX_RED_RANK_EQ_4:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C4]] | ||||
| // CHECK-SCF:              %[[UNSHAPED_RES_4:.*]] = scf.if %[[MAX_RED_RANK_EQ_4]] | ||||
| // CHECK-SCF:              %[[MAX_RED_RANK_LE_4:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C4]] | ||||
| // CHECK-SCF:              %[[UNSHAPED_RES_4:.*]] = scf.if %[[MAX_RED_RANK_LE_4]] | ||||
| // CHECK-SCF-DAG:            %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_4]] | ||||
| // CHECK-SCF-DAG:            %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_4]] | ||||
| // CHECK-SCF-DAG:            %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_4]] | ||||
|  | @ -135,8 +135,8 @@ func @add_mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, | |||
| // CHECK-SCF:                scf.yield %[[INNER_RES_]] | ||||
| // CHECK-SCF:              else | ||||
| //                 Generic case 5: | ||||
| // CHECK-SCF:                %[[MAX_RED_RANK_EQ_5:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C5]] | ||||
| // CHECK-SCF:                %[[UNSHAPED_RES_5:.*]] = scf.if %[[MAX_RED_RANK_EQ_5]] | ||||
| // CHECK-SCF:                %[[MAX_RED_RANK_LE_5:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C5]] | ||||
| // CHECK-SCF:                %[[UNSHAPED_RES_5:.*]] = scf.if %[[MAX_RED_RANK_LE_5]] | ||||
| // CHECK-SCF-DAG:              %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_5]] | ||||
| // CHECK-SCF-DAG:              %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_5]] | ||||
| // CHECK-SCF-DAG:              %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_5]] | ||||
|  | @ -152,8 +152,8 @@ func @add_mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, | |||
| // CHECK-SCF:                  scf.yield %[[INNER_RES_]] | ||||
| // CHECK-SCF:                else | ||||
| //                 Generic case 6: | ||||
| // CHECK-SCF:                  %[[MAX_RED_RANK_EQ_6:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C6]] | ||||
| // CHECK-SCF:                  %[[UNSHAPED_RES_6:.*]] = scf.if %[[MAX_RED_RANK_EQ_6]] | ||||
| // CHECK-SCF:                  %[[MAX_RED_RANK_LE_6:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C6]] | ||||
| // CHECK-SCF:                  %[[UNSHAPED_RES_6:.*]] = scf.if %[[MAX_RED_RANK_LE_6]] | ||||
| // CHECK-SCF-DAG:                %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_6]] | ||||
| // CHECK-SCF-DAG:                %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_6]] | ||||
| // CHECK-SCF-DAG:                %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_6]] | ||||
|  | @ -169,8 +169,8 @@ func @add_mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, | |||
| // CHECK-SCF:                    scf.yield %[[INNER_RES_]] | ||||
| // CHECK-SCF:                  else | ||||
| //                 Generic case 7: | ||||
| // CHECK-SCF:                    %[[MAX_RED_RANK_EQ_7:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C7]] | ||||
| // CHECK-SCF:                    %[[UNSHAPED_RES_7:.*]] = scf.if %[[MAX_RED_RANK_EQ_7]] | ||||
| // CHECK-SCF:                    %[[MAX_RED_RANK_LE_7:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C7]] | ||||
| // CHECK-SCF:                    %[[UNSHAPED_RES_7:.*]] = scf.if %[[MAX_RED_RANK_LE_7]] | ||||
| // CHECK-SCF-DAG:                  %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_7]] | ||||
| // CHECK-SCF-DAG:                  %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_7]] | ||||
| // CHECK-SCF-DAG:                  %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_7]] | ||||
|  | @ -186,8 +186,8 @@ func @add_mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, | |||
| // CHECK-SCF:                      scf.yield %[[INNER_RES_]] | ||||
| // CHECK-SCF:                    else | ||||
| //                 Generic case 8: | ||||
| // CHECK-SCF:                      %[[MAX_RED_RANK_EQ_8:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C8]] | ||||
| // CHECK-SCF:                      assert %[[MAX_RED_RANK_EQ_8]], "Input for dynamic binary or n-ary op lowering was of a rank greater than 8" | ||||
| // CHECK-SCF:                      %[[MAX_RED_RANK_LE_8:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C8]] | ||||
| // CHECK-SCF:                      assert %[[MAX_RED_RANK_LE_8]], "Input for dynamic binary or n-ary op lowering was of a rank greater than 8" | ||||
| // CHECK-SCF-DAG:                  %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_8]] | ||||
| // CHECK-SCF-DAG:                  %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_8]] | ||||
| // CHECK-SCF-DAG:                  %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_8]] | ||||
|  | @ -524,8 +524,8 @@ func @mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> { | |||
| // CHECK-SCF-DAG:        %[[R0_GT_R1:.*]] = cmpi sgt, %[[REDUCED_RANK0]], %[[REDUCED_RANK1]] | ||||
| // CHECK-SCF-DAG:        %[[MAX_RED_RANK:.*]] = select %[[R0_GT_R1]], %[[REDUCED_RANK0]], %[[REDUCED_RANK1]] | ||||
| //                 Generic case 1: | ||||
| // CHECK-SCF:            %[[MAX_RED_RANK_EQ_1:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C1]] | ||||
| // CHECK-SCF:            %[[UNSHAPED_RES_1:.*]] = scf.if %[[MAX_RED_RANK_EQ_1]] | ||||
| // CHECK-SCF:            %[[MAX_RED_RANK_LE_1:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C1]] | ||||
| // CHECK-SCF:            %[[UNSHAPED_RES_1:.*]] = scf.if %[[MAX_RED_RANK_LE_1]] | ||||
| // CHECK-SCF-DAG:          %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_1]] | ||||
| // CHECK-SCF-DAG:          %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_1]] | ||||
| // CHECK-SCF-DAG:          %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] | ||||
|  | @ -537,8 +537,8 @@ func @mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> { | |||
| // CHECK-SCF:              scf.yield %[[INNER_RES_]] | ||||
| // CHECK-SCF:            else | ||||
| //                 Generic case 2: | ||||
| // CHECK-SCF:              %[[MAX_RED_RANK_EQ_2:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C2]] | ||||
| // CHECK-SCF:              %[[UNSHAPED_RES_2:.*]] = scf.if %[[MAX_RED_RANK_EQ_2]] | ||||
| // CHECK-SCF:              %[[MAX_RED_RANK_LE_2:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C2]] | ||||
| // CHECK-SCF:              %[[UNSHAPED_RES_2:.*]] = scf.if %[[MAX_RED_RANK_LE_2]] | ||||
| // CHECK-SCF-DAG:            %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_2]] | ||||
| // CHECK-SCF-DAG:            %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_2]] | ||||
| // CHECK-SCF-DAG:            %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] | ||||
|  | @ -550,8 +550,8 @@ func @mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> { | |||
| // CHECK-SCF:                scf.yield %[[INNER_RES_]] | ||||
| // CHECK-SCF:              else | ||||
| //                 Generic case 3: | ||||
| // CHECK-SCF:                %[[MAX_RED_RANK_EQ_3:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C3]] | ||||
| // CHECK-SCF:                %[[UNSHAPED_RES_3:.*]] = scf.if %[[MAX_RED_RANK_EQ_3]] | ||||
| // CHECK-SCF:                %[[MAX_RED_RANK_LE_3:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C3]] | ||||
| // CHECK-SCF:                %[[UNSHAPED_RES_3:.*]] = scf.if %[[MAX_RED_RANK_LE_3]] | ||||
| // CHECK-SCF-DAG:              %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_3]] | ||||
| // CHECK-SCF-DAG:              %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_3]] | ||||
| // CHECK-SCF-DAG:              %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] | ||||
|  | @ -563,8 +563,8 @@ func @mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> { | |||
| // CHECK-SCF:                  scf.yield %[[INNER_RES_]] | ||||
| // CHECK-SCF:                else | ||||
| //                 Generic case 4: | ||||
| // CHECK-SCF:                  %[[MAX_RED_RANK_EQ_4:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C4]] | ||||
| // CHECK-SCF:                  %[[UNSHAPED_RES_4:.*]] = scf.if %[[MAX_RED_RANK_EQ_4]] | ||||
| // CHECK-SCF:                  %[[MAX_RED_RANK_LE_4:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C4]] | ||||
| // CHECK-SCF:                  %[[UNSHAPED_RES_4:.*]] = scf.if %[[MAX_RED_RANK_LE_4]] | ||||
| // CHECK-SCF-DAG:                %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_4]] | ||||
| // CHECK-SCF-DAG:                %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_4]] | ||||
| // CHECK-SCF-DAG:                %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] | ||||
|  | @ -576,8 +576,8 @@ func @mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> { | |||
| // CHECK-SCF:                    scf.yield %[[INNER_RES_]] | ||||
| // CHECK-SCF:                  else | ||||
| //                 Generic case 5: | ||||
| // CHECK-SCF:                    %[[MAX_RED_RANK_EQ_5:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C5]] | ||||
| // CHECK-SCF:                    %[[UNSHAPED_RES_5:.*]] = scf.if %[[MAX_RED_RANK_EQ_5]] | ||||
| // CHECK-SCF:                    %[[MAX_RED_RANK_LE_5:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C5]] | ||||
| // CHECK-SCF:                    %[[UNSHAPED_RES_5:.*]] = scf.if %[[MAX_RED_RANK_LE_5]] | ||||
| // CHECK-SCF-DAG:                  %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_5]] | ||||
| // CHECK-SCF-DAG:                  %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_5]] | ||||
| // CHECK-SCF-DAG:                  %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] | ||||
|  | @ -589,8 +589,8 @@ func @mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> { | |||
| // CHECK-SCF:                      scf.yield %[[INNER_RES_]] | ||||
| // CHECK-SCF:                    else | ||||
| //                 Generic case 6: | ||||
| // CHECK-SCF:                      %[[MAX_RED_RANK_EQ_6:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C6]] | ||||
| // CHECK-SCF:                      %[[UNSHAPED_RES_6:.*]] = scf.if %[[MAX_RED_RANK_EQ_6]] | ||||
| // CHECK-SCF:                      %[[MAX_RED_RANK_LE_6:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C6]] | ||||
| // CHECK-SCF:                      %[[UNSHAPED_RES_6:.*]] = scf.if %[[MAX_RED_RANK_LE_6]] | ||||
| // CHECK-SCF-DAG:                    %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_6]] | ||||
| // CHECK-SCF-DAG:                    %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_6]] | ||||
| // CHECK-SCF-DAG:                    %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] | ||||
|  | @ -602,8 +602,8 @@ func @mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> { | |||
| // CHECK-SCF:                        scf.yield %[[INNER_RES_]] | ||||
| // CHECK-SCF:                      else | ||||
| //                 Generic case 7: | ||||
| // CHECK-SCF:                        %[[MAX_RED_RANK_EQ_7:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C7]] | ||||
| // CHECK-SCF:                        %[[UNSHAPED_RES_7:.*]] = scf.if %[[MAX_RED_RANK_EQ_7]] | ||||
| // CHECK-SCF:                        %[[MAX_RED_RANK_LE_7:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C7]] | ||||
| // CHECK-SCF:                        %[[UNSHAPED_RES_7:.*]] = scf.if %[[MAX_RED_RANK_LE_7]] | ||||
| // CHECK-SCF-DAG:                      %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_7]] | ||||
| // CHECK-SCF-DAG:                      %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_7]] | ||||
| // CHECK-SCF-DAG:                      %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] | ||||
|  | @ -615,8 +615,8 @@ func @mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> { | |||
| // CHECK-SCF:                          scf.yield %[[INNER_RES_]] | ||||
| // CHECK-SCF:                        else | ||||
| //                 Generic case 8: | ||||
| // CHECK-SCF:                          %[[MAX_RED_RANK_EQ_8:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C8]] | ||||
| // CHECK-SCF:                          assert %[[MAX_RED_RANK_EQ_8]], "Input for dynamic binary or n-ary op lowering was of a rank greater than 8" | ||||
| // CHECK-SCF:                          %[[MAX_RED_RANK_LE_8:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C8]] | ||||
| // CHECK-SCF:                          assert %[[MAX_RED_RANK_LE_8]], "Input for dynamic binary or n-ary op lowering was of a rank greater than 8" | ||||
| // CHECK-SCF-DAG:                      %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_8]] | ||||
| // CHECK-SCF-DAG:                      %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_8]] | ||||
| // CHECK-SCF-DAG:                      %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue