diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 9beda13..56a1ea1 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -49,17 +49,17 @@ SmallVector GetNParallelLoopsAttrs(unsigned nParallelLoops) { } template -Value getResultValue(Operation* op) { +Value GetResultValue(Operation* op) { return isLHLO ? op->getOperand(op->getNumOperands() - 1) : op->getResult(0); } template -ShapedType getHloOpResultType(Operation* op) { - return getResultValue(op).getType().template cast(); +ShapedType GetHloOpResultType(Operation* op) { + return GetResultValue(op).getType().template cast(); } template -bool verifyHloOpBufferOrTensorSemantics(Operation* op) { +bool VerifyHloOpBufferOrTensorSemantics(Operation* op) { auto verify_type = [&](Value val) -> bool { return (isLHLO && val.getType().isa()) || (!isLHLO && val.getType().isa()); @@ -293,8 +293,8 @@ class DataMovementOpConverter : public OpConversionPattern { LogicalResult matchAndRewrite( OpTy op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { - if (!verifyHloOpBufferOrTensorSemantics(op)) return failure(); - auto result_type = getHloOpResultType(op); + if (!VerifyHloOpBufferOrTensorSemantics(op)) return failure(); + auto result_type = GetHloOpResultType(op); SmallVector indexing_maps = Derived::getIndexingMaps(op, &rewriter); @@ -331,7 +331,7 @@ class BroadcastConverter ShapedType input_type = broadcast_op.operand().getType().template cast(); unsigned input_rank = input_type.getRank(); - unsigned nloops = getHloOpResultType(broadcast_op).getRank(); + unsigned nloops = GetHloOpResultType(broadcast_op).getRank(); // BroadcastOp prepends the dimensions in the `broadcast_sizes` attribute to // the input's dimensions. @@ -365,7 +365,7 @@ class HloBroadcastInDimConverter static SmallVector getIndexingMaps( mhlo::BroadcastInDimOp broadcast_op, Builder* b) { - auto result_type = getHloOpResultType(broadcast_op); + auto result_type = GetHloOpResultType(broadcast_op); auto operand_type = broadcast_op.operand().getType().template cast(); unsigned nloops = result_type.getRank(); @@ -563,7 +563,7 @@ class TransposeConverter isLHLO>::DataMovementOpConverter; static SmallVector getIndexingMaps(OpTy op, Builder* b) { auto result_type = - getHloOpResultType(op).template cast(); + GetHloOpResultType(op).template cast(); auto nloops = result_type.getRank(); SmallVector input_exprs; input_exprs.resize(result_type.getRank()); @@ -587,11 +587,11 @@ class ReshapeOpConverter : public OpConversionPattern { LogicalResult matchAndRewrite( OpTy reshape_op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { - if (!verifyHloOpBufferOrTensorSemantics(reshape_op)) + if (!VerifyHloOpBufferOrTensorSemantics(reshape_op)) return failure(); ShapedType operand_type = reshape_op.operand().getType().template cast(); - ShapedType result_type = getHloOpResultType(reshape_op); + ShapedType result_type = GetHloOpResultType(reshape_op); if (!operand_type.hasStaticShape() || !result_type.hasStaticShape()) return failure(); @@ -696,7 +696,7 @@ class IotaConverter : public OpConversionPattern { LogicalResult matchAndRewrite( OpTy iota_op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { - ShapedType result_shaped_type = getHloOpResultType(iota_op); + ShapedType result_shaped_type = GetHloOpResultType(iota_op); if (!result_shaped_type) return failure(); auto result_element_type = result_shaped_type.getElementType(); @@ -867,7 +867,7 @@ class ReverseConverter isLHLO>::DataMovementOpConverter; static SmallVector getIndexingMaps(OpTy op, Builder* b) { auto result_type = - getHloOpResultType(op).template cast(); + GetHloOpResultType(op).template cast(); auto nloops = result_type.getRank(); SmallVector input_exprs; input_exprs.reserve(nloops);