Pass result element type to XlaBuilder for `mhlo.dot_general` and `mhlo.convolution` ops.

`mhlo.dot_general` and `mhlo.convolution` result element type might be different from operand element type. See `preferred_element_type` attribute that allows i8xi8 to i32 dot computation. `mhlo` to HLO exporter should pass the result element type to Xla builder to override the shape inference of XLA.

PiperOrigin-RevId: 358580718
This commit is contained in:
Prakalp Srivastava 2021-02-20 07:06:24 -08:00 committed by TensorFlow MLIR Team
parent f63c93399a
commit 909574e393
1 changed files with 3 additions and 0 deletions

View File

@ -960,6 +960,9 @@ def HLO_DotGeneralOp: HLO_Op<"dot_general", [NoSideEffect]>,
let results = (outs HLO_Tensor);
let verifier = [{ return Verify(*this); }];
// DotGeneral op required custom exporter to pass the preferred element type
// to Xla builder.
let hasCustomHLOConverter = 1;
}
// Define Base Einsum op within the HLO dialect as these are client ops and