[MLIR:LHLO_GPU] Add additional constraints for batchnorm

- Constrain batchnorm inputs and outputs to be fp memrefs.

PiperOrigin-RevId: 348665747
This commit is contained in:
Rahul Joshi 2020-12-22 11:29:31 -08:00 committed by TensorFlow MLIR Team
parent ccdd07f8e4
commit bc367971ec
1 changed files with 20 additions and 20 deletions

View File

@ -47,14 +47,14 @@ def I32Buffer : MemRefOf<[I32]>;
def LHLOGPU_BatchNormGradOp : LHLOGPU_Op<"batch_norm_grad">, def LHLOGPU_BatchNormGradOp : LHLOGPU_Op<"batch_norm_grad">,
BASE_HLO_BatchNormGradOp { BASE_HLO_BatchNormGradOp {
let arguments = (ins let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand, Arg<LHLO_FpBuffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$scale, Arg<LHLO_FpBuffer, "", [MemRead]>:$scale,
Arg<LHLO_Buffer, "", [MemRead]>:$mean, Arg<LHLO_FpBuffer, "", [MemRead]>:$mean,
Arg<LHLO_Buffer, "", [MemRead]>:$stddev, Arg<LHLO_FpBuffer, "", [MemRead]>:$stddev,
Arg<LHLO_Buffer, "", [MemRead]>:$grad_output, Arg<LHLO_FpBuffer, "", [MemRead]>:$grad_output,
Arg<LHLO_Buffer, "", [MemWrite]>:$grad_operand, // gradient of $operand. Arg<LHLO_FpBuffer, "", [MemWrite]>:$grad_operand, // gradient of $operand.
Arg<LHLO_Buffer, "", [MemWrite]>:$grad_scale, Arg<LHLO_FpBuffer, "", [MemWrite]>:$grad_scale,
Arg<LHLO_Buffer, "", [MemWrite]>:$grad_offset, Arg<LHLO_FpBuffer, "", [MemWrite]>:$grad_offset,
F32Attr:$epsilon, F32Attr:$epsilon,
I64Attr:$feature_index I64Attr:$feature_index
); );
@ -63,12 +63,12 @@ def LHLOGPU_BatchNormGradOp : LHLOGPU_Op<"batch_norm_grad">,
def LHLOGPU_BatchNormInferenceOp : LHLOGPU_Op<"batch_norm_inference">, def LHLOGPU_BatchNormInferenceOp : LHLOGPU_Op<"batch_norm_inference">,
BASE_HLO_BatchNormInferenceOp { BASE_HLO_BatchNormInferenceOp {
let arguments = (ins let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand, Arg<LHLO_FpBuffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$scale, Arg<LHLO_FpBuffer, "", [MemRead]>:$scale,
Arg<LHLO_Buffer, "", [MemRead]>:$offset, Arg<LHLO_FpBuffer, "", [MemRead]>:$offset,
Arg<LHLO_Buffer, "", [MemRead]>:$mean, Arg<LHLO_FpBuffer, "", [MemRead]>:$mean,
Arg<LHLO_Buffer, "", [MemRead]>:$stddev, Arg<LHLO_FpBuffer, "", [MemRead]>:$stddev,
Arg<LHLO_Buffer, "", [MemWrite]>:$output, Arg<LHLO_FpBuffer, "", [MemWrite]>:$output,
F32Attr:$epsilon, F32Attr:$epsilon,
I64Attr:$feature_index); I64Attr:$feature_index);
} }
@ -77,12 +77,12 @@ def LHLOGPU_BatchNormTrainingOp : LHLOGPU_Op<"batch_norm_training">,
BASE_HLO_BatchNormTrainingOp { BASE_HLO_BatchNormTrainingOp {
let arguments = (ins let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand, Arg<LHLO_FpBuffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$scale, Arg<LHLO_FpBuffer, "", [MemRead]>:$scale,
Arg<LHLO_Buffer, "", [MemRead]>:$offset, Arg<LHLO_FpBuffer, "", [MemRead]>:$offset,
Arg<LHLO_Buffer, "", [MemWrite]>:$output, Arg<LHLO_FpBuffer, "", [MemWrite]>:$output,
Arg<LHLO_Buffer, "", [MemWrite]>:$batch_mean, Arg<LHLO_FpBuffer, "", [MemWrite]>:$batch_mean,
Arg<LHLO_Buffer, "", [MemWrite]>:$batch_stddev, Arg<LHLO_FpBuffer, "", [MemWrite]>:$batch_stddev,
F32Attr:$epsilon, F32Attr:$epsilon,
I64Attr:$feature_index I64Attr:$feature_index
); );