[MLIR:LHLO_GPU] Add additional constraints for batchnorm

- Constrain batchnorm inputs and outputs to be fp memrefs.

PiperOrigin-RevId: 348665747
This commit is contained in:
Rahul Joshi 2020-12-22 11:29:31 -08:00 committed by TensorFlow MLIR Team
parent ccdd07f8e4
commit bc367971ec
1 changed files with 20 additions and 20 deletions

View File

@ -47,14 +47,14 @@ def I32Buffer : MemRefOf<[I32]>;
def LHLOGPU_BatchNormGradOp : LHLOGPU_Op<"batch_norm_grad">,
BASE_HLO_BatchNormGradOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$scale,
Arg<LHLO_Buffer, "", [MemRead]>:$mean,
Arg<LHLO_Buffer, "", [MemRead]>:$stddev,
Arg<LHLO_Buffer, "", [MemRead]>:$grad_output,
Arg<LHLO_Buffer, "", [MemWrite]>:$grad_operand, // gradient of $operand.
Arg<LHLO_Buffer, "", [MemWrite]>:$grad_scale,
Arg<LHLO_Buffer, "", [MemWrite]>:$grad_offset,
Arg<LHLO_FpBuffer, "", [MemRead]>:$operand,
Arg<LHLO_FpBuffer, "", [MemRead]>:$scale,
Arg<LHLO_FpBuffer, "", [MemRead]>:$mean,
Arg<LHLO_FpBuffer, "", [MemRead]>:$stddev,
Arg<LHLO_FpBuffer, "", [MemRead]>:$grad_output,
Arg<LHLO_FpBuffer, "", [MemWrite]>:$grad_operand, // gradient of $operand.
Arg<LHLO_FpBuffer, "", [MemWrite]>:$grad_scale,
Arg<LHLO_FpBuffer, "", [MemWrite]>:$grad_offset,
F32Attr:$epsilon,
I64Attr:$feature_index
);
@ -63,12 +63,12 @@ def LHLOGPU_BatchNormGradOp : LHLOGPU_Op<"batch_norm_grad">,
def LHLOGPU_BatchNormInferenceOp : LHLOGPU_Op<"batch_norm_inference">,
BASE_HLO_BatchNormInferenceOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$scale,
Arg<LHLO_Buffer, "", [MemRead]>:$offset,
Arg<LHLO_Buffer, "", [MemRead]>:$mean,
Arg<LHLO_Buffer, "", [MemRead]>:$stddev,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
Arg<LHLO_FpBuffer, "", [MemRead]>:$operand,
Arg<LHLO_FpBuffer, "", [MemRead]>:$scale,
Arg<LHLO_FpBuffer, "", [MemRead]>:$offset,
Arg<LHLO_FpBuffer, "", [MemRead]>:$mean,
Arg<LHLO_FpBuffer, "", [MemRead]>:$stddev,
Arg<LHLO_FpBuffer, "", [MemWrite]>:$output,
F32Attr:$epsilon,
I64Attr:$feature_index);
}
@ -77,12 +77,12 @@ def LHLOGPU_BatchNormTrainingOp : LHLOGPU_Op<"batch_norm_training">,
BASE_HLO_BatchNormTrainingOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$scale,
Arg<LHLO_Buffer, "", [MemRead]>:$offset,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
Arg<LHLO_Buffer, "", [MemWrite]>:$batch_mean,
Arg<LHLO_Buffer, "", [MemWrite]>:$batch_stddev,
Arg<LHLO_FpBuffer, "", [MemRead]>:$operand,
Arg<LHLO_FpBuffer, "", [MemRead]>:$scale,
Arg<LHLO_FpBuffer, "", [MemRead]>:$offset,
Arg<LHLO_FpBuffer, "", [MemWrite]>:$output,
Arg<LHLO_FpBuffer, "", [MemWrite]>:$batch_mean,
Arg<LHLO_FpBuffer, "", [MemWrite]>:$batch_stddev,
F32Attr:$epsilon,
I64Attr:$feature_index
);