From 22a6bdc5746b978c4edb2e5d040d6a6157caff62 Mon Sep 17 00:00:00 2001 From: Tian Jin Date: Mon, 13 Jan 2020 12:21:29 -0500 Subject: [PATCH] Sync with latest MLIR. (#26) --- src/builder/frontend_dialect_transformer.cpp | 2 +- src/dialect/onnx/onnx_ops.cpp | 166 +++++++++---------- src/pass/lower_frontend_to_krnl.cpp | 10 +- src/pass/onnx_combine.td | 2 +- src/transform/lower_krnl.cpp | 2 +- src/transform/lower_to_llvm.cpp | 6 +- 6 files changed, 94 insertions(+), 94 deletions(-) diff --git a/src/builder/frontend_dialect_transformer.cpp b/src/builder/frontend_dialect_transformer.cpp index c091903..43f8aa9 100644 --- a/src/builder/frontend_dialect_transformer.cpp +++ b/src/builder/frontend_dialect_transformer.cpp @@ -680,7 +680,7 @@ private: auto tensor_val = frontend_symbols_.GetTensorByOnnxName(output_tensor_legalized_name); - ret_types.emplace_back(tensor_val->getType()); + ret_types.emplace_back(tensor_val.getType()); ret_vals.push_back(tensor_val); } diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index 6ff0fad..4b68fe7 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -65,14 +65,14 @@ ONNXEntryPointOp ONNXEntryPointOp::create(mlir::Location location, // Exp /// Infer the output shape of the ONNXExpOp. This method is required by the /// shape inference interface. -void ONNXExpOp::inferShapes() { getResult()->setType(getOperand()->getType()); } +void ONNXExpOp::inferShapes() { getResult().setType(getOperand().getType()); } //===----------------------------------------------------------------------===// // Tanh /// Infer the output shape of the ONNXTanhOp. This method is required by the /// shape inference interface. void ONNXTanhOp::inferShapes() { - getResult()->setType(getOperand()->getType()); + getResult().setType(getOperand().getType()); } //===----------------------------------------------------------------------===// @@ -80,7 +80,7 @@ void ONNXTanhOp::inferShapes() { /// Infer the output shape of the ONNXSinhOp. This method is required by the /// shape inference interface. void ONNXSinhOp::inferShapes() { - getResult()->setType(getOperand()->getType()); + getResult().setType(getOperand().getType()); } //===----------------------------------------------------------------------===// @@ -88,27 +88,27 @@ void ONNXSinhOp::inferShapes() { /// Infer the output shape of the ONNXCoshOp. This method is required by the /// shape inference interface. void ONNXCoshOp::inferShapes() { - getResult()->setType(getOperand()->getType()); + getResult().setType(getOperand().getType()); } //===----------------------------------------------------------------------===// // Cos /// Infer the output shape of the ONNXCosOp. This method is required by the /// shape inference interface. -void ONNXCosOp::inferShapes() { getResult()->setType(getOperand()->getType()); } +void ONNXCosOp::inferShapes() { getResult().setType(getOperand().getType()); } //===----------------------------------------------------------------------===// // Log /// Infer the output shape of the ONNXLogOp. This method is required by the /// shape inference interface. -void ONNXLogOp::inferShapes() { getResult()->setType(getOperand()->getType()); } +void ONNXLogOp::inferShapes() { getResult().setType(getOperand().getType()); } //===----------------------------------------------------------------------===// // HardSigmoid /// Infer the output shape of the ONNXHardSigmoidOp. This method is required by /// the shape inference interface. void ONNXHardSigmoidOp::inferShapes() { - getResult()->setType(getOperand()->getType()); + getResult().setType(getOperand().getType()); } //===----------------------------------------------------------------------===// @@ -116,21 +116,21 @@ void ONNXHardSigmoidOp::inferShapes() { /// Infer the output shape of the ONNXSigmoidOp. This method is required by the /// shape inference interface. void ONNXSigmoidOp::inferShapes() { - getResult()->setType(getOperand()->getType()); + getResult().setType(getOperand().getType()); } //===----------------------------------------------------------------------===// // Elu /// Infer the output shape of the ONNXEluOp. This method is required by the /// shape inference interface. -void ONNXEluOp::inferShapes() { getResult()->setType(getOperand()->getType()); } +void ONNXEluOp::inferShapes() { getResult().setType(getOperand().getType()); } //===----------------------------------------------------------------------===// // Relu /// Infer the output shape of the ONNXReluOp. This method is required by the /// shape inference interface. void ONNXReluOp::inferShapes() { - getResult()->setType(getOperand()->getType()); + getResult().setType(getOperand().getType()); } //===----------------------------------------------------------------------===// @@ -138,7 +138,7 @@ void ONNXReluOp::inferShapes() { /// Infer the output shape of the ONNXLeakyReluOp. This method is required by /// the shape inference interface. void ONNXLeakyReluOp::inferShapes() { - getResult()->setType(getOperand()->getType()); + getResult().setType(getOperand().getType()); } //===----------------------------------------------------------------------===// @@ -146,7 +146,7 @@ void ONNXLeakyReluOp::inferShapes() { /// Infer the output shape of the ONNXSeluOp. This method is required by /// the shape inference interface. void ONNXSeluOp::inferShapes() { - getResult()->setType(getOperand()->getType()); + getResult().setType(getOperand().getType()); } //===----------------------------------------------------------------------===// @@ -154,7 +154,7 @@ void ONNXSeluOp::inferShapes() { /// Infer the output shape of the ONNXReciprocalOp. This method is required by /// the shape inference interface. void ONNXReciprocalOp::inferShapes() { - getResult()->setType(getOperand()->getType()); + getResult().setType(getOperand().getType()); } //===----------------------------------------------------------------------===// @@ -162,12 +162,12 @@ void ONNXReciprocalOp::inferShapes() { /// Infer the output shape of the ONNXAddOp. This method is required by the /// shape inference interface. void ONNXAddOp::inferShapes() { - if (!getOperand(0)->getType().isa() || - !getOperand(1)->getType().isa()) + if (!getOperand(0).getType().isa() || + !getOperand(1).getType().isa()) return; - auto lhsTy = getOperand(0)->getType().cast(); - auto rhsTy = getOperand(1)->getType().cast(); - getResult()->setType(getBroadcastedType(lhsTy, rhsTy)); + auto lhsTy = getOperand(0).getType().cast(); + auto rhsTy = getOperand(1).getType().cast(); + getResult().setType(getBroadcastedType(lhsTy, rhsTy)); } //===----------------------------------------------------------------------===// @@ -175,12 +175,12 @@ void ONNXAddOp::inferShapes() { /// Infer the output shape of the ONNXMulOp. This method is required by the /// shape inference interface. void ONNXMulOp::inferShapes() { - if (!getOperand(0)->getType().isa() || - !getOperand(1)->getType().isa()) + if (!getOperand(0).getType().isa() || + !getOperand(1).getType().isa()) return; - auto lhsTy = getOperand(0)->getType().cast(); - auto rhsTy = getOperand(1)->getType().cast(); - getResult()->setType(getBroadcastedType(lhsTy, rhsTy)); + auto lhsTy = getOperand(0).getType().cast(); + auto rhsTy = getOperand(1).getType().cast(); + getResult().setType(getBroadcastedType(lhsTy, rhsTy)); } //===----------------------------------------------------------------------===// @@ -188,12 +188,12 @@ void ONNXMulOp::inferShapes() { /// Infer the output shape of the ONNXDivOp. This method is required by the /// shape inference interface. void ONNXDivOp::inferShapes() { - if (!getOperand(0)->getType().isa() || - !getOperand(1)->getType().isa()) + if (!getOperand(0).getType().isa() || + !getOperand(1).getType().isa()) return; - auto lhsTy = getOperand(0)->getType().cast(); - auto rhsTy = getOperand(1)->getType().cast(); - getResult()->setType(getBroadcastedType(lhsTy, rhsTy)); + auto lhsTy = getOperand(0).getType().cast(); + auto rhsTy = getOperand(1).getType().cast(); + getResult().setType(getBroadcastedType(lhsTy, rhsTy)); } //===----------------------------------------------------------------------===// @@ -201,12 +201,12 @@ void ONNXDivOp::inferShapes() { /// Infer the output shape of the ONNXSubOp. This method is required by the /// shape inference interface. void ONNXSubOp::inferShapes() { - if (!getOperand(0)->getType().isa() || - !getOperand(1)->getType().isa()) + if (!getOperand(0).getType().isa() || + !getOperand(1).getType().isa()) return; - auto lhsTy = getOperand(0)->getType().cast(); - auto rhsTy = getOperand(1)->getType().cast(); - getResult()->setType(getBroadcastedType(lhsTy, rhsTy)); + auto lhsTy = getOperand(0).getType().cast(); + auto rhsTy = getOperand(1).getType().cast(); + getResult().setType(getBroadcastedType(lhsTy, rhsTy)); } //===----------------------------------------------------------------------===// @@ -214,12 +214,12 @@ void ONNXSubOp::inferShapes() { /// Infer the output shape of the ONNXAndOp. This method is required by the /// shape inference interface. void ONNXAndOp::inferShapes() { - if (!getOperand(0)->getType().isa() || - !getOperand(1)->getType().isa()) + if (!getOperand(0).getType().isa() || + !getOperand(1).getType().isa()) return; - auto lhsTy = getOperand(0)->getType().cast(); - auto rhsTy = getOperand(1)->getType().cast(); - getResult()->setType(getBroadcastedType(lhsTy, rhsTy)); + auto lhsTy = getOperand(0).getType().cast(); + auto rhsTy = getOperand(1).getType().cast(); + getResult().setType(getBroadcastedType(lhsTy, rhsTy)); } //===----------------------------------------------------------------------===// @@ -227,12 +227,12 @@ void ONNXAndOp::inferShapes() { /// Infer the output shape of the ONNXOrOp. This method is required by the /// shape inference interface. void ONNXOrOp::inferShapes() { - if (!getOperand(0)->getType().isa() || - !getOperand(1)->getType().isa()) + if (!getOperand(0).getType().isa() || + !getOperand(1).getType().isa()) return; - auto lhsTy = getOperand(0)->getType().cast(); - auto rhsTy = getOperand(1)->getType().cast(); - getResult()->setType(getBroadcastedType(lhsTy, rhsTy)); + auto lhsTy = getOperand(0).getType().cast(); + auto rhsTy = getOperand(1).getType().cast(); + getResult().setType(getBroadcastedType(lhsTy, rhsTy)); } //===----------------------------------------------------------------------===// @@ -240,12 +240,12 @@ void ONNXOrOp::inferShapes() { /// Infer the output shape of the ONNXXorOp. This method is required by the /// shape inference interface. void ONNXXorOp::inferShapes() { - if (!getOperand(0)->getType().isa() || - !getOperand(1)->getType().isa()) + if (!getOperand(0).getType().isa() || + !getOperand(1).getType().isa()) return; - auto lhsTy = getOperand(0)->getType().cast(); - auto rhsTy = getOperand(1)->getType().cast(); - getResult()->setType(getBroadcastedType(lhsTy, rhsTy)); + auto lhsTy = getOperand(0).getType().cast(); + auto rhsTy = getOperand(1).getType().cast(); + getResult().setType(getBroadcastedType(lhsTy, rhsTy)); } //===----------------------------------------------------------------------===// @@ -256,15 +256,15 @@ void ONNXXorOp::inferShapes() { /// shape inference interface. void ONNXSumOp::inferShapes() { for (int i = 0; i < getNumOperands(); ++i) { - if (!getOperand(i)->getType().cast()) + if (!getOperand(i).getType().cast()) return; } - Type resultTy = getOperand(0)->getType().cast(); + Type resultTy = getOperand(0).getType().cast(); for (int i = 1; i < getNumOperands(); ++i) { - Type nextTy = getOperand(i)->getType().cast(); + Type nextTy = getOperand(i).getType().cast(); resultTy = getBroadcastedType(resultTy, nextTy); } - getResult()->setType(resultTy); + getResult().setType(resultTy); } //===----------------------------------------------------------------------===// @@ -273,15 +273,15 @@ void ONNXSumOp::inferShapes() { /// shape inference interface. void ONNXMaxOp::inferShapes() { for (int i = 0; i < getNumOperands(); ++i) { - if (!getOperand(i)->getType().cast()) + if (!getOperand(i).getType().cast()) return; } - Type resultTy = getOperand(0)->getType().cast(); + Type resultTy = getOperand(0).getType().cast(); for (int i = 1; i < getNumOperands(); ++i) { - Type nextTy = getOperand(i)->getType().cast(); + Type nextTy = getOperand(i).getType().cast(); resultTy = getBroadcastedType(resultTy, nextTy); } - getResult()->setType(resultTy); + getResult().setType(resultTy); } //===----------------------------------------------------------------------===// @@ -290,15 +290,15 @@ void ONNXMaxOp::inferShapes() { /// shape inference interface. void ONNXMinOp::inferShapes() { for (int i = 0; i < getNumOperands(); ++i) { - if (!getOperand(i)->getType().cast()) + if (!getOperand(i).getType().cast()) return; } - Type resultTy = getOperand(0)->getType().cast(); + Type resultTy = getOperand(0).getType().cast(); for (int i = 1; i < getNumOperands(); ++i) { - Type nextTy = getOperand(i)->getType().cast(); + Type nextTy = getOperand(i).getType().cast(); resultTy = getBroadcastedType(resultTy, nextTy); } - getResult()->setType(resultTy); + getResult().setType(resultTy); } //===----------------------------------------------------------------------===// @@ -306,7 +306,7 @@ void ONNXMinOp::inferShapes() { /// Infer the output shape of the ONNXIdentityOp. This method is required by the /// shape inference interface. void ONNXIdentityOp::inferShapes() { - getResult()->setType(getOperand()->getType()); + getResult().setType(getOperand().getType()); } //===----------------------------------------------------------------------===// @@ -315,15 +315,15 @@ void ONNXIdentityOp::inferShapes() { void ONNXMatMulOp::inferShapes() { // Cannot infer shape if no shape exists. - if (!getOperand(0)->getType().isa() || - !getOperand(1)->getType().isa()) + if (!getOperand(0).getType().isa() || + !getOperand(1).getType().isa()) return; - auto lhsTy = getOperand(0)->getType().cast(); - auto rhsTy = getOperand(1)->getType().cast(); + auto lhsTy = getOperand(0).getType().cast(); + auto rhsTy = getOperand(1).getType().cast(); SmallVector dims; dims.emplace_back(lhsTy.getShape()[0]); dims.emplace_back(rhsTy.getShape()[1]); - getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType())); + getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType())); } // TODO: @@ -336,30 +336,30 @@ void ONNXMatMulOp::inferShapes() { void ONNXGemmOp::inferShapes() { // Cannot infer shape if no shape exists. - if (!getOperand(0)->getType().isa() || - !getOperand(1)->getType().isa()) + if (!getOperand(0).getType().isa() || + !getOperand(1).getType().isa()) return; - auto lhsTy = getOperand(0)->getType().cast(); - auto rhsTy = getOperand(1)->getType().cast(); + auto lhsTy = getOperand(0).getType().cast(); + auto rhsTy = getOperand(1).getType().cast(); SmallVector dims; dims.emplace_back(lhsTy.getShape()[0]); dims.emplace_back(rhsTy.getShape()[1]); - getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType())); + getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType())); } // FullGemm void ONNXFullGemmOp::inferShapes() { // Cannot infer shape if no shape exists. - if (!getOperand(0)->getType().isa() || - !getOperand(1)->getType().isa()) + if (!getOperand(0).getType().isa() || + !getOperand(1).getType().isa()) return; - auto lhsTy = getOperand(0)->getType().cast(); - auto rhsTy = getOperand(1)->getType().cast(); + auto lhsTy = getOperand(0).getType().cast(); + auto rhsTy = getOperand(1).getType().cast(); SmallVector dims; dims.emplace_back(lhsTy.getShape()[0]); dims.emplace_back(rhsTy.getShape()[1]); - getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType())); + getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType())); } // TODO: @@ -372,11 +372,11 @@ void ONNXFullGemmOp::inferShapes() { void ONNXReshapeOp::inferShapes() { // Cannot infer shape if no shape tensor is specified. - if (!getOperand(1)->getType().isa()) + if (!getOperand(1).getType().isa()) emitError("Shape tensor not ranked."); - auto inputTensorTy = getOperand(0)->getType().cast(); - auto shapeTensorTy = getOperand(1)->getType().cast(); + auto inputTensorTy = getOperand(0).getType().cast(); + auto shapeTensorTy = getOperand(1).getType().cast(); // Only rank 1 shape tensors are supported. if (shapeTensorTy.getShape().size() != 1) @@ -392,7 +392,7 @@ void ONNXReshapeOp::inferShapes() { for (int i = 0; i < outputRank; ++i) dims.emplace_back(-1); - getResult()->setType( + getResult().setType( RankedTensorType::get(dims, inputTensorTy.getElementType())); } @@ -402,16 +402,16 @@ void ONNXReshapeOp::inferShapes() { void ONNXTransposeOp::inferShapes() { // Cannot infer shape if no shape exists. - if (!getOperand()->getType().isa()) + if (!getOperand().getType().isa()) emitError("Shape tensor not ranked."); // Naive transposition which handles the default case of // reversing the shape of the tensor (similar to numpy.transpose). // TODO: Once attributes are supported we can handle the case where the // transposition uses a permutation vector to interchange the axes. - auto arrayTy = getOperand()->getType().cast(); + auto arrayTy = getOperand().getType().cast(); SmallVector dims(llvm::reverse(arrayTy.getShape())); - getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType())); + getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType())); } //===----------------------------------------------------------------------===// diff --git a/src/pass/lower_frontend_to_krnl.cpp b/src/pass/lower_frontend_to_krnl.cpp index a17484f..a578479 100644 --- a/src/pass/lower_frontend_to_krnl.cpp +++ b/src/pass/lower_frontend_to_krnl.cpp @@ -61,7 +61,7 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc, Value maxDim = nullptr; for (int i = 0; i < operands.size(); i++) { auto operandShape = - operands[i]->getType().cast().getShape(); + operands[i].getType().cast().getShape(); int operandDimIdx = operandShape.size() - 1 - reversedIdx; if (operandDimIdx < 0) @@ -162,7 +162,7 @@ getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter, int dimIdx = rank - 1 - reversedIdx; sharedDimCount[dimIdx] = 0; for (int i = 0; i < operands.size(); ++i) { - auto shape = operands[i]->getType().cast().getShape(); + auto shape = operands[i].getType().cast().getShape(); if (reversedIdx <= shape.size() - 1) sharedDimCount[dimIdx]++; } @@ -174,7 +174,7 @@ getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter, // more than one, since they are potentially broadcasted dimensions. for (int i = 0; i < operands.size(); ++i) { std::map broadcastedDims; - auto shape = operands[i]->getType().cast().getShape(); + auto shape = operands[i].getType().cast().getShape(); int size = shape.size(); for (int j = 0; j < shape.size(); ++j) { if (shape[j] < 0 and sharedDimCount[rank - size + j] > 1) { @@ -198,7 +198,7 @@ getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter, std::map broadcastedDims) { // `operand` must has a ranked type. This should have been checked by the // shape inference pass. - auto operandShape = operand->getType().cast().getShape(); + auto operandShape = operand.getType().cast().getShape(); auto rank = operandShape.size(); auto loopCount = loopIVs.size(); @@ -319,7 +319,7 @@ Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, /* Lower UnaryOp to Ops in the Standard dialect. */ auto loc = op->getLoc(); - Type element_type = operands.front()->getType(); + Type element_type = operands.front().getType(); if (element_type.isa()) { return rewriter.create>(loc, result_types, operands, mlir::None); diff --git a/src/pass/onnx_combine.td b/src/pass/onnx_combine.td index 199e27a..8a40928 100644 --- a/src/pass/onnx_combine.td +++ b/src/pass/onnx_combine.td @@ -24,7 +24,7 @@ include "dialect/onnx/onnx.td" /// dag benefitsAdded = (addBenefit 0) /// >; -def HasOneUse : ConstrainthasOneUse()">>; +def HasOneUse : Constraint>; //===----------------------------------------------------------------------===// // Pattern-Match and Rewrite diff --git a/src/transform/lower_krnl.cpp b/src/transform/lower_krnl.cpp index 5578862..da16f81 100644 --- a/src/transform/lower_krnl.cpp +++ b/src/transform/lower_krnl.cpp @@ -55,7 +55,7 @@ struct KrnlIterateOpLowering : public OpRewritePattern { for (size_t i = 0; i < nestedForOps.size() - 1; i++) { auto iterateIV = iterateOp.bodyRegion().front().getArgument(0); auto forIV = nestedForOps[i].getBody()->getArgument(0); - iterateIV->replaceAllUsesWith(forIV); + iterateIV.replaceAllUsesWith(forIV); iterateOp.bodyRegion().front().eraseArgument(0); } diff --git a/src/transform/lower_to_llvm.cpp b/src/transform/lower_to_llvm.cpp index b61b848..be6d291 100644 --- a/src/transform/lower_to_llvm.cpp +++ b/src/transform/lower_to_llvm.cpp @@ -65,7 +65,7 @@ public: // First operand. Type dstType = - operands[0]->getType().cast().getStructElementType(1); + operands[0].getType().cast().getStructElementType(1); Value alignedDstMemory = rewriter.create( loc, dstType, operands[0], rewriter.getI64ArrayAttr(1)); Value alignedInt8PtrDstMemory = rewriter.create( @@ -73,7 +73,7 @@ public: // Second operand. Type srcType = - operands[1]->getType().cast().getStructElementType(1); + operands[1].getType().cast().getStructElementType(1); Value alignedSrcMemory = rewriter.create( loc, srcType, operands[1], rewriter.getI64ArrayAttr(1)); Value alignedInt8PtrSrcMemory = rewriter.create( @@ -253,7 +253,7 @@ public: // Get the first memref returned, convert to a dynamic memref and store // it in the wrapped Output. auto outMemRef = outputMemRefs.getResult(0); - auto outMemRefTy = outMemRef->getType().dyn_cast(); + auto outMemRefTy = outMemRef.getType().dyn_cast(); auto outMemRefRank = outMemRefTy.getStructElementType(3).getArrayNumElements(); auto outMemRefRankVal = rewriter.create(