[XLA:GPU] Add conversion from HLO -> MLIR LMHLO for TriangularSolve
- Also add layout attributes for inputs and output for error checking. PiperOrigin-RevId: 355863625
This commit is contained in:
parent
1c4521cc42
commit
b251712b1d
|
@ -99,6 +99,17 @@ def HLO_IntFpOrComplexTensor : TensorOf<[HLO_Int, AnyFloat, HLO_Complex]>;
|
|||
// Any pred, int or floating-point tensor types
|
||||
def HLO_PredIntOrFpTensor : TensorOf<[HLO_Pred, HLO_Int, AnyFloat]>;
|
||||
|
||||
// A layout attribute (1D tensor of index type)
|
||||
def HLO_LayoutAttr : Attr<
|
||||
And<[IndexElementsAttr.predicate,
|
||||
CPred<[{$_self.cast<::mlir::DenseIntElementsAttr>().getType().getRank()
|
||||
== 1}]>]>,
|
||||
"A 1D tensor of index type (layout)"> {
|
||||
let storageType = IndexElementsAttr.storageType;
|
||||
let returnType = IndexElementsAttr.returnType;
|
||||
let convertFromStorage = IndexElementsAttr.convertFromStorage;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MHLO nullary op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -641,7 +641,10 @@ def LHLO_TriangularSolveOp: LHLO_Op<"triangular_solve", [SameOperandsElementType
|
|||
BoolAttr:$left_side,
|
||||
BoolAttr:$lower,
|
||||
BoolAttr:$unit_diagonal,
|
||||
HLO_TransposeAttr:$transpose_a
|
||||
HLO_TransposeAttr:$transpose_a,
|
||||
HLO_LayoutAttr:$layout_a,
|
||||
HLO_LayoutAttr:$layout_b,
|
||||
HLO_LayoutAttr:$layout_output
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
@ -865,7 +865,12 @@ func @replica_id_memrefs(%arg_out: memref<ui32>) -> () {
|
|||
|
||||
// CHECK-LABEL: func @triangular_solve_memrefs
|
||||
func @triangular_solve_memrefs(%arg0: memref<4x4xf32>, %arg1: memref<3x4xf32>, %arg_out: memref<3x4xf32>) -> () {
|
||||
"lmhlo.triangular_solve"(%arg0, %arg1, %arg_out) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true}
|
||||
"lmhlo.triangular_solve"(%arg0, %arg1, %arg_out)
|
||||
{layout_a = dense<[1, 0]> : tensor<2xindex>,
|
||||
layout_b = dense<[1, 0]> : tensor<2xindex>,
|
||||
layout_output = dense<[1, 0]> : tensor<2xindex>,
|
||||
left_side = true, lower = true, transpose_a = "NO_TRANSPOSE",
|
||||
unit_diagonal = true}
|
||||
: (memref<4x4xf32>, memref<3x4xf32>, memref<3x4xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue