diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td index b5411e3..896fe0f 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td @@ -99,6 +99,17 @@ def HLO_IntFpOrComplexTensor : TensorOf<[HLO_Int, AnyFloat, HLO_Complex]>; // Any pred, int or floating-point tensor types def HLO_PredIntOrFpTensor : TensorOf<[HLO_Pred, HLO_Int, AnyFloat]>; +// A layout attribute (1D tensor of index type) +def HLO_LayoutAttr : Attr< + And<[IndexElementsAttr.predicate, + CPred<[{$_self.cast<::mlir::DenseIntElementsAttr>().getType().getRank() + == 1}]>]>, + "A 1D tensor of index type (layout)"> { + let storageType = IndexElementsAttr.storageType; + let returnType = IndexElementsAttr.returnType; + let convertFromStorage = IndexElementsAttr.convertFromStorage; +} + //===----------------------------------------------------------------------===// // MHLO nullary op definitions. //===----------------------------------------------------------------------===// diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td index d1bdd49..fb1e176 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -641,7 +641,10 @@ def LHLO_TriangularSolveOp: LHLO_Op<"triangular_solve", [SameOperandsElementType BoolAttr:$left_side, BoolAttr:$lower, BoolAttr:$unit_diagonal, - HLO_TransposeAttr:$transpose_a + HLO_TransposeAttr:$transpose_a, + HLO_LayoutAttr:$layout_a, + HLO_LayoutAttr:$layout_b, + HLO_LayoutAttr:$layout_output ); } diff --git a/tests/lhlo_ops.mlir b/tests/lhlo_ops.mlir index 76be69f..b01beb7 100644 --- a/tests/lhlo_ops.mlir +++ b/tests/lhlo_ops.mlir @@ -865,7 +865,12 @@ func @replica_id_memrefs(%arg_out: memref) -> () { // CHECK-LABEL: func @triangular_solve_memrefs func @triangular_solve_memrefs(%arg0: memref<4x4xf32>, %arg1: memref<3x4xf32>, %arg_out: memref<3x4xf32>) -> () { - "lmhlo.triangular_solve"(%arg0, %arg1, %arg_out) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} + "lmhlo.triangular_solve"(%arg0, %arg1, %arg_out) + {layout_a = dense<[1, 0]> : tensor<2xindex>, + layout_b = dense<[1, 0]> : tensor<2xindex>, + layout_output = dense<[1, 0]> : tensor<2xindex>, + left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", + unit_diagonal = true} : (memref<4x4xf32>, memref<3x4xf32>, memref<3x4xf32>) -> () return }