[MLIR:LHLO_GPU] Add additional constraints for batchnorm
- Constrain batchnorm inputs and outputs to be fp memrefs. PiperOrigin-RevId: 348665747
This commit is contained in:
parent
ccdd07f8e4
commit
bc367971ec
|
@ -47,14 +47,14 @@ def I32Buffer : MemRefOf<[I32]>;
|
|||
def LHLOGPU_BatchNormGradOp : LHLOGPU_Op<"batch_norm_grad">,
|
||||
BASE_HLO_BatchNormGradOp {
|
||||
let arguments = (ins
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$scale,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$mean,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$stddev,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$grad_output,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$grad_operand, // gradient of $operand.
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$grad_scale,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$grad_offset,
|
||||
Arg<LHLO_FpBuffer, "", [MemRead]>:$operand,
|
||||
Arg<LHLO_FpBuffer, "", [MemRead]>:$scale,
|
||||
Arg<LHLO_FpBuffer, "", [MemRead]>:$mean,
|
||||
Arg<LHLO_FpBuffer, "", [MemRead]>:$stddev,
|
||||
Arg<LHLO_FpBuffer, "", [MemRead]>:$grad_output,
|
||||
Arg<LHLO_FpBuffer, "", [MemWrite]>:$grad_operand, // gradient of $operand.
|
||||
Arg<LHLO_FpBuffer, "", [MemWrite]>:$grad_scale,
|
||||
Arg<LHLO_FpBuffer, "", [MemWrite]>:$grad_offset,
|
||||
F32Attr:$epsilon,
|
||||
I64Attr:$feature_index
|
||||
);
|
||||
|
@ -63,12 +63,12 @@ def LHLOGPU_BatchNormGradOp : LHLOGPU_Op<"batch_norm_grad">,
|
|||
def LHLOGPU_BatchNormInferenceOp : LHLOGPU_Op<"batch_norm_inference">,
|
||||
BASE_HLO_BatchNormInferenceOp {
|
||||
let arguments = (ins
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$scale,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$offset,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$mean,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$stddev,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||
Arg<LHLO_FpBuffer, "", [MemRead]>:$operand,
|
||||
Arg<LHLO_FpBuffer, "", [MemRead]>:$scale,
|
||||
Arg<LHLO_FpBuffer, "", [MemRead]>:$offset,
|
||||
Arg<LHLO_FpBuffer, "", [MemRead]>:$mean,
|
||||
Arg<LHLO_FpBuffer, "", [MemRead]>:$stddev,
|
||||
Arg<LHLO_FpBuffer, "", [MemWrite]>:$output,
|
||||
F32Attr:$epsilon,
|
||||
I64Attr:$feature_index);
|
||||
}
|
||||
|
@ -77,12 +77,12 @@ def LHLOGPU_BatchNormTrainingOp : LHLOGPU_Op<"batch_norm_training">,
|
|||
BASE_HLO_BatchNormTrainingOp {
|
||||
|
||||
let arguments = (ins
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$scale,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$offset,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$batch_mean,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$batch_stddev,
|
||||
Arg<LHLO_FpBuffer, "", [MemRead]>:$operand,
|
||||
Arg<LHLO_FpBuffer, "", [MemRead]>:$scale,
|
||||
Arg<LHLO_FpBuffer, "", [MemRead]>:$offset,
|
||||
Arg<LHLO_FpBuffer, "", [MemWrite]>:$output,
|
||||
Arg<LHLO_FpBuffer, "", [MemWrite]>:$batch_mean,
|
||||
Arg<LHLO_FpBuffer, "", [MemWrite]>:$batch_stddev,
|
||||
F32Attr:$epsilon,
|
||||
I64Attr:$feature_index
|
||||
);
|
||||
|
|
Loading…
Reference in New Issue