Sync with latest MLIR. (#26)
This commit is contained in:
parent
f384e3187e
commit
22a6bdc574
src
builder
dialect/onnx
pass
transform
|
@ -680,7 +680,7 @@ private:
|
|||
|
||||
auto tensor_val =
|
||||
frontend_symbols_.GetTensorByOnnxName(output_tensor_legalized_name);
|
||||
ret_types.emplace_back(tensor_val->getType());
|
||||
ret_types.emplace_back(tensor_val.getType());
|
||||
ret_vals.push_back(tensor_val);
|
||||
}
|
||||
|
||||
|
|
|
@ -65,14 +65,14 @@ ONNXEntryPointOp ONNXEntryPointOp::create(mlir::Location location,
|
|||
// Exp
|
||||
/// Infer the output shape of the ONNXExpOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXExpOp::inferShapes() { getResult()->setType(getOperand()->getType()); }
|
||||
void ONNXExpOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Tanh
|
||||
/// Infer the output shape of the ONNXTanhOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXTanhOp::inferShapes() {
|
||||
getResult()->setType(getOperand()->getType());
|
||||
getResult().setType(getOperand().getType());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -80,7 +80,7 @@ void ONNXTanhOp::inferShapes() {
|
|||
/// Infer the output shape of the ONNXSinhOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXSinhOp::inferShapes() {
|
||||
getResult()->setType(getOperand()->getType());
|
||||
getResult().setType(getOperand().getType());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -88,27 +88,27 @@ void ONNXSinhOp::inferShapes() {
|
|||
/// Infer the output shape of the ONNXCoshOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXCoshOp::inferShapes() {
|
||||
getResult()->setType(getOperand()->getType());
|
||||
getResult().setType(getOperand().getType());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Cos
|
||||
/// Infer the output shape of the ONNXCosOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXCosOp::inferShapes() { getResult()->setType(getOperand()->getType()); }
|
||||
void ONNXCosOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Log
|
||||
/// Infer the output shape of the ONNXLogOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXLogOp::inferShapes() { getResult()->setType(getOperand()->getType()); }
|
||||
void ONNXLogOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// HardSigmoid
|
||||
/// Infer the output shape of the ONNXHardSigmoidOp. This method is required by
|
||||
/// the shape inference interface.
|
||||
void ONNXHardSigmoidOp::inferShapes() {
|
||||
getResult()->setType(getOperand()->getType());
|
||||
getResult().setType(getOperand().getType());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -116,21 +116,21 @@ void ONNXHardSigmoidOp::inferShapes() {
|
|||
/// Infer the output shape of the ONNXSigmoidOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXSigmoidOp::inferShapes() {
|
||||
getResult()->setType(getOperand()->getType());
|
||||
getResult().setType(getOperand().getType());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Elu
|
||||
/// Infer the output shape of the ONNXEluOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXEluOp::inferShapes() { getResult()->setType(getOperand()->getType()); }
|
||||
void ONNXEluOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Relu
|
||||
/// Infer the output shape of the ONNXReluOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXReluOp::inferShapes() {
|
||||
getResult()->setType(getOperand()->getType());
|
||||
getResult().setType(getOperand().getType());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -138,7 +138,7 @@ void ONNXReluOp::inferShapes() {
|
|||
/// Infer the output shape of the ONNXLeakyReluOp. This method is required by
|
||||
/// the shape inference interface.
|
||||
void ONNXLeakyReluOp::inferShapes() {
|
||||
getResult()->setType(getOperand()->getType());
|
||||
getResult().setType(getOperand().getType());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -146,7 +146,7 @@ void ONNXLeakyReluOp::inferShapes() {
|
|||
/// Infer the output shape of the ONNXSeluOp. This method is required by
|
||||
/// the shape inference interface.
|
||||
void ONNXSeluOp::inferShapes() {
|
||||
getResult()->setType(getOperand()->getType());
|
||||
getResult().setType(getOperand().getType());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -154,7 +154,7 @@ void ONNXSeluOp::inferShapes() {
|
|||
/// Infer the output shape of the ONNXReciprocalOp. This method is required by
|
||||
/// the shape inference interface.
|
||||
void ONNXReciprocalOp::inferShapes() {
|
||||
getResult()->setType(getOperand()->getType());
|
||||
getResult().setType(getOperand().getType());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -162,12 +162,12 @@ void ONNXReciprocalOp::inferShapes() {
|
|||
/// Infer the output shape of the ONNXAddOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXAddOp::inferShapes() {
|
||||
if (!getOperand(0)->getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1)->getType().isa<RankedTensorType>())
|
||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1).getType().isa<RankedTensorType>())
|
||||
return;
|
||||
auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>();
|
||||
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>();
|
||||
getResult()->setType(getBroadcastedType(lhsTy, rhsTy));
|
||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -175,12 +175,12 @@ void ONNXAddOp::inferShapes() {
|
|||
/// Infer the output shape of the ONNXMulOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXMulOp::inferShapes() {
|
||||
if (!getOperand(0)->getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1)->getType().isa<RankedTensorType>())
|
||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1).getType().isa<RankedTensorType>())
|
||||
return;
|
||||
auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>();
|
||||
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>();
|
||||
getResult()->setType(getBroadcastedType(lhsTy, rhsTy));
|
||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -188,12 +188,12 @@ void ONNXMulOp::inferShapes() {
|
|||
/// Infer the output shape of the ONNXDivOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXDivOp::inferShapes() {
|
||||
if (!getOperand(0)->getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1)->getType().isa<RankedTensorType>())
|
||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1).getType().isa<RankedTensorType>())
|
||||
return;
|
||||
auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>();
|
||||
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>();
|
||||
getResult()->setType(getBroadcastedType(lhsTy, rhsTy));
|
||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -201,12 +201,12 @@ void ONNXDivOp::inferShapes() {
|
|||
/// Infer the output shape of the ONNXSubOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXSubOp::inferShapes() {
|
||||
if (!getOperand(0)->getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1)->getType().isa<RankedTensorType>())
|
||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1).getType().isa<RankedTensorType>())
|
||||
return;
|
||||
auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>();
|
||||
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>();
|
||||
getResult()->setType(getBroadcastedType(lhsTy, rhsTy));
|
||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -214,12 +214,12 @@ void ONNXSubOp::inferShapes() {
|
|||
/// Infer the output shape of the ONNXAndOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXAndOp::inferShapes() {
|
||||
if (!getOperand(0)->getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1)->getType().isa<RankedTensorType>())
|
||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1).getType().isa<RankedTensorType>())
|
||||
return;
|
||||
auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>();
|
||||
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>();
|
||||
getResult()->setType(getBroadcastedType(lhsTy, rhsTy));
|
||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -227,12 +227,12 @@ void ONNXAndOp::inferShapes() {
|
|||
/// Infer the output shape of the ONNXOrOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXOrOp::inferShapes() {
|
||||
if (!getOperand(0)->getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1)->getType().isa<RankedTensorType>())
|
||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1).getType().isa<RankedTensorType>())
|
||||
return;
|
||||
auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>();
|
||||
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>();
|
||||
getResult()->setType(getBroadcastedType(lhsTy, rhsTy));
|
||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -240,12 +240,12 @@ void ONNXOrOp::inferShapes() {
|
|||
/// Infer the output shape of the ONNXXorOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXXorOp::inferShapes() {
|
||||
if (!getOperand(0)->getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1)->getType().isa<RankedTensorType>())
|
||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1).getType().isa<RankedTensorType>())
|
||||
return;
|
||||
auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>();
|
||||
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>();
|
||||
getResult()->setType(getBroadcastedType(lhsTy, rhsTy));
|
||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -256,15 +256,15 @@ void ONNXXorOp::inferShapes() {
|
|||
/// shape inference interface.
|
||||
void ONNXSumOp::inferShapes() {
|
||||
for (int i = 0; i < getNumOperands(); ++i) {
|
||||
if (!getOperand(i)->getType().cast<RankedTensorType>())
|
||||
if (!getOperand(i).getType().cast<RankedTensorType>())
|
||||
return;
|
||||
}
|
||||
Type resultTy = getOperand(0)->getType().cast<RankedTensorType>();
|
||||
Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||
for (int i = 1; i < getNumOperands(); ++i) {
|
||||
Type nextTy = getOperand(i)->getType().cast<RankedTensorType>();
|
||||
Type nextTy = getOperand(i).getType().cast<RankedTensorType>();
|
||||
resultTy = getBroadcastedType(resultTy, nextTy);
|
||||
}
|
||||
getResult()->setType(resultTy);
|
||||
getResult().setType(resultTy);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -273,15 +273,15 @@ void ONNXSumOp::inferShapes() {
|
|||
/// shape inference interface.
|
||||
void ONNXMaxOp::inferShapes() {
|
||||
for (int i = 0; i < getNumOperands(); ++i) {
|
||||
if (!getOperand(i)->getType().cast<RankedTensorType>())
|
||||
if (!getOperand(i).getType().cast<RankedTensorType>())
|
||||
return;
|
||||
}
|
||||
Type resultTy = getOperand(0)->getType().cast<RankedTensorType>();
|
||||
Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||
for (int i = 1; i < getNumOperands(); ++i) {
|
||||
Type nextTy = getOperand(i)->getType().cast<RankedTensorType>();
|
||||
Type nextTy = getOperand(i).getType().cast<RankedTensorType>();
|
||||
resultTy = getBroadcastedType(resultTy, nextTy);
|
||||
}
|
||||
getResult()->setType(resultTy);
|
||||
getResult().setType(resultTy);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -290,15 +290,15 @@ void ONNXMaxOp::inferShapes() {
|
|||
/// shape inference interface.
|
||||
void ONNXMinOp::inferShapes() {
|
||||
for (int i = 0; i < getNumOperands(); ++i) {
|
||||
if (!getOperand(i)->getType().cast<RankedTensorType>())
|
||||
if (!getOperand(i).getType().cast<RankedTensorType>())
|
||||
return;
|
||||
}
|
||||
Type resultTy = getOperand(0)->getType().cast<RankedTensorType>();
|
||||
Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||
for (int i = 1; i < getNumOperands(); ++i) {
|
||||
Type nextTy = getOperand(i)->getType().cast<RankedTensorType>();
|
||||
Type nextTy = getOperand(i).getType().cast<RankedTensorType>();
|
||||
resultTy = getBroadcastedType(resultTy, nextTy);
|
||||
}
|
||||
getResult()->setType(resultTy);
|
||||
getResult().setType(resultTy);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -306,7 +306,7 @@ void ONNXMinOp::inferShapes() {
|
|||
/// Infer the output shape of the ONNXIdentityOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXIdentityOp::inferShapes() {
|
||||
getResult()->setType(getOperand()->getType());
|
||||
getResult().setType(getOperand().getType());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -315,15 +315,15 @@ void ONNXIdentityOp::inferShapes() {
|
|||
|
||||
void ONNXMatMulOp::inferShapes() {
|
||||
// Cannot infer shape if no shape exists.
|
||||
if (!getOperand(0)->getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1)->getType().isa<RankedTensorType>())
|
||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1).getType().isa<RankedTensorType>())
|
||||
return;
|
||||
auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>();
|
||||
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>();
|
||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||
SmallVector<int64_t, 2> dims;
|
||||
dims.emplace_back(lhsTy.getShape()[0]);
|
||||
dims.emplace_back(rhsTy.getShape()[1]);
|
||||
getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
||||
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
||||
}
|
||||
|
||||
// TODO:
|
||||
|
@ -336,30 +336,30 @@ void ONNXMatMulOp::inferShapes() {
|
|||
|
||||
void ONNXGemmOp::inferShapes() {
|
||||
// Cannot infer shape if no shape exists.
|
||||
if (!getOperand(0)->getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1)->getType().isa<RankedTensorType>())
|
||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1).getType().isa<RankedTensorType>())
|
||||
return;
|
||||
auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>();
|
||||
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>();
|
||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||
SmallVector<int64_t, 2> dims;
|
||||
dims.emplace_back(lhsTy.getShape()[0]);
|
||||
dims.emplace_back(rhsTy.getShape()[1]);
|
||||
getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
||||
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
||||
}
|
||||
|
||||
// FullGemm
|
||||
|
||||
void ONNXFullGemmOp::inferShapes() {
|
||||
// Cannot infer shape if no shape exists.
|
||||
if (!getOperand(0)->getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1)->getType().isa<RankedTensorType>())
|
||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1).getType().isa<RankedTensorType>())
|
||||
return;
|
||||
auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>();
|
||||
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>();
|
||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||
SmallVector<int64_t, 2> dims;
|
||||
dims.emplace_back(lhsTy.getShape()[0]);
|
||||
dims.emplace_back(rhsTy.getShape()[1]);
|
||||
getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
||||
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
||||
}
|
||||
|
||||
// TODO:
|
||||
|
@ -372,11 +372,11 @@ void ONNXFullGemmOp::inferShapes() {
|
|||
|
||||
void ONNXReshapeOp::inferShapes() {
|
||||
// Cannot infer shape if no shape tensor is specified.
|
||||
if (!getOperand(1)->getType().isa<RankedTensorType>())
|
||||
if (!getOperand(1).getType().isa<RankedTensorType>())
|
||||
emitError("Shape tensor not ranked.");
|
||||
|
||||
auto inputTensorTy = getOperand(0)->getType().cast<RankedTensorType>();
|
||||
auto shapeTensorTy = getOperand(1)->getType().cast<RankedTensorType>();
|
||||
auto inputTensorTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto shapeTensorTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||
|
||||
// Only rank 1 shape tensors are supported.
|
||||
if (shapeTensorTy.getShape().size() != 1)
|
||||
|
@ -392,7 +392,7 @@ void ONNXReshapeOp::inferShapes() {
|
|||
for (int i = 0; i < outputRank; ++i)
|
||||
dims.emplace_back(-1);
|
||||
|
||||
getResult()->setType(
|
||||
getResult().setType(
|
||||
RankedTensorType::get(dims, inputTensorTy.getElementType()));
|
||||
}
|
||||
|
||||
|
@ -402,16 +402,16 @@ void ONNXReshapeOp::inferShapes() {
|
|||
|
||||
void ONNXTransposeOp::inferShapes() {
|
||||
// Cannot infer shape if no shape exists.
|
||||
if (!getOperand()->getType().isa<RankedTensorType>())
|
||||
if (!getOperand().getType().isa<RankedTensorType>())
|
||||
emitError("Shape tensor not ranked.");
|
||||
|
||||
// Naive transposition which handles the default case of
|
||||
// reversing the shape of the tensor (similar to numpy.transpose).
|
||||
// TODO: Once attributes are supported we can handle the case where the
|
||||
// transposition uses a permutation vector to interchange the axes.
|
||||
auto arrayTy = getOperand()->getType().cast<RankedTensorType>();
|
||||
auto arrayTy = getOperand().getType().cast<RankedTensorType>();
|
||||
SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
|
||||
getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
||||
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -61,7 +61,7 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc,
|
|||
Value maxDim = nullptr;
|
||||
for (int i = 0; i < operands.size(); i++) {
|
||||
auto operandShape =
|
||||
operands[i]->getType().cast<MemRefType>().getShape();
|
||||
operands[i].getType().cast<MemRefType>().getShape();
|
||||
int operandDimIdx = operandShape.size() - 1 - reversedIdx;
|
||||
|
||||
if (operandDimIdx < 0)
|
||||
|
@ -162,7 +162,7 @@ getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
|
|||
int dimIdx = rank - 1 - reversedIdx;
|
||||
sharedDimCount[dimIdx] = 0;
|
||||
for (int i = 0; i < operands.size(); ++i) {
|
||||
auto shape = operands[i]->getType().cast<MemRefType>().getShape();
|
||||
auto shape = operands[i].getType().cast<MemRefType>().getShape();
|
||||
if (reversedIdx <= shape.size() - 1)
|
||||
sharedDimCount[dimIdx]++;
|
||||
}
|
||||
|
@ -174,7 +174,7 @@ getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
|
|||
// more than one, since they are potentially broadcasted dimensions.
|
||||
for (int i = 0; i < operands.size(); ++i) {
|
||||
std::map<int, Value> broadcastedDims;
|
||||
auto shape = operands[i]->getType().cast<MemRefType>().getShape();
|
||||
auto shape = operands[i].getType().cast<MemRefType>().getShape();
|
||||
int size = shape.size();
|
||||
for (int j = 0; j < shape.size(); ++j) {
|
||||
if (shape[j] < 0 and sharedDimCount[rank - size + j] > 1) {
|
||||
|
@ -198,7 +198,7 @@ getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
|
|||
std::map<int, Value> broadcastedDims) {
|
||||
// `operand` must has a ranked type. This should have been checked by the
|
||||
// shape inference pass.
|
||||
auto operandShape = operand->getType().cast<MemRefType>().getShape();
|
||||
auto operandShape = operand.getType().cast<MemRefType>().getShape();
|
||||
auto rank = operandShape.size();
|
||||
auto loopCount = loopIVs.size();
|
||||
|
||||
|
@ -319,7 +319,7 @@ Value mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types,
|
|||
/* Lower UnaryOp to Ops in the Standard dialect.
|
||||
*/
|
||||
auto loc = op->getLoc();
|
||||
Type element_type = operands.front()->getType();
|
||||
Type element_type = operands.front().getType();
|
||||
if (element_type.isa<IntegerType>()) {
|
||||
return rewriter.create<ScalarIOp<UnaryOp>>(loc, result_types, operands,
|
||||
mlir::None);
|
||||
|
|
|
@ -24,7 +24,7 @@ include "dialect/onnx/onnx.td"
|
|||
/// dag benefitsAdded = (addBenefit 0)
|
||||
/// >;
|
||||
|
||||
def HasOneUse : Constraint<CPred<"$0->hasOneUse()">>;
|
||||
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pattern-Match and Rewrite
|
||||
|
|
|
@ -55,7 +55,7 @@ struct KrnlIterateOpLowering : public OpRewritePattern<KrnlIterateOp> {
|
|||
for (size_t i = 0; i < nestedForOps.size() - 1; i++) {
|
||||
auto iterateIV = iterateOp.bodyRegion().front().getArgument(0);
|
||||
auto forIV = nestedForOps[i].getBody()->getArgument(0);
|
||||
iterateIV->replaceAllUsesWith(forIV);
|
||||
iterateIV.replaceAllUsesWith(forIV);
|
||||
iterateOp.bodyRegion().front().eraseArgument(0);
|
||||
}
|
||||
|
||||
|
|
|
@ -65,7 +65,7 @@ public:
|
|||
|
||||
// First operand.
|
||||
Type dstType =
|
||||
operands[0]->getType().cast<LLVM::LLVMType>().getStructElementType(1);
|
||||
operands[0].getType().cast<LLVM::LLVMType>().getStructElementType(1);
|
||||
Value alignedDstMemory = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, dstType, operands[0], rewriter.getI64ArrayAttr(1));
|
||||
Value alignedInt8PtrDstMemory = rewriter.create<LLVM::BitcastOp>(
|
||||
|
@ -73,7 +73,7 @@ public:
|
|||
|
||||
// Second operand.
|
||||
Type srcType =
|
||||
operands[1]->getType().cast<LLVM::LLVMType>().getStructElementType(1);
|
||||
operands[1].getType().cast<LLVM::LLVMType>().getStructElementType(1);
|
||||
Value alignedSrcMemory = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, srcType, operands[1], rewriter.getI64ArrayAttr(1));
|
||||
Value alignedInt8PtrSrcMemory = rewriter.create<LLVM::BitcastOp>(
|
||||
|
@ -253,7 +253,7 @@ public:
|
|||
// Get the first memref returned, convert to a dynamic memref and store
|
||||
// it in the wrapped Output.
|
||||
auto outMemRef = outputMemRefs.getResult(0);
|
||||
auto outMemRefTy = outMemRef->getType().dyn_cast<LLVMType>();
|
||||
auto outMemRefTy = outMemRef.getType().dyn_cast<LLVMType>();
|
||||
auto outMemRefRank =
|
||||
outMemRefTy.getStructElementType(3).getArrayNumElements();
|
||||
auto outMemRefRankVal = rewriter.create<LLVM::ConstantOp>(
|
||||
|
|
Loading…
Reference in New Issue