onnx转换kmodel效果不好

Viewed 79

问题描述


原onnx模型和转换出来的kmodel余弦相似度只有0.3几正常吗?更换校正集几乎没什么作用,无论是真实图片还是随机生成的数据作校正集,转换出来的kmodel余弦相似度都只有0.3几,没有什么变化。不量化的话转换kmodel余弦相似度有0.9几。使用的是paddle平台转换出来的一个CRNN的文字识别onnx模型。

复现步骤


python3 script.py --target k230 --model ./inference.onnx --dataset_path ./rec_imgs4 --input_width 100 --input_height 32 --ptq_option 0

转换脚本代码:

import os
import argparse
import numpy as np
from PIL import Image
import onnxsim
import onnx
import nncase
import shutil
import math

def parse_model_input_output(model_file,input_shape):
    onnx_model = onnx.load(model_file)
    input_all = [node.name for node in onnx_model.graph.input]
    input_initializer = [node.name for node in onnx_model.graph.initializer]
    input_names = list(set(input_all) - set(input_initializer))
    input_tensors = [
        node for node in onnx_model.graph.input if node.name in input_names]

    # input
    inputs = []
    for _, e in enumerate(input_tensors):
        onnx_type = e.type.tensor_type
        input_dict = {}
        input_dict['name'] = e.name
        input_dict['dtype'] = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[onnx_type.elem_type]
        input_dict['shape'] = [(i.dim_value if i.dim_value != 0 else d) for i, d in zip(
            onnx_type.shape.dim, input_shape)]
        inputs.append(input_dict)

    return onnx_model, inputs


def onnx_simplify(model_file, dump_dir,input_shape):
    onnx_model, inputs = parse_model_input_output(model_file,input_shape)
    onnx_model = onnx.shape_inference.infer_shapes(onnx_model)

    # 安装的 onnx/onnxsim 版本太老,校验器(checker)看不懂模型的 IR 版本
    # ↓↓↓ 尝试:下调 IR 版本到当前 onnx 可识别的版本 ↓↓↓
    try:
        if onnx_model.ir_version > onnx.IR_VERSION:
            print(f"[warn] downgrade ir_version {onnx_model.ir_version} -> {onnx.IR_VERSION}")
            onnx_model.ir_version = onnx.IR_VERSION
    except Exception as e:
        print("[warn] fail to adjust ir_version:", e)
    # ↑↑↑

    input_shapes = {}
    for input in inputs:
        input_shapes[input['name']] = input['shape']

    onnx_model, check = onnxsim.simplify(onnx_model, input_shapes=input_shapes)
    assert check, "Simplified ONNX model could not be validated"

    model_file = os.path.join(dump_dir, 'simplified.onnx')
    onnx.save_model(onnx_model, model_file)
    return model_file


def read_model_file(model_file):
    with open(model_file, 'rb') as f:
        model_content = f.read()
    return model_content


def generate_data(shape, batch, calib_dir):
    img_paths = [os.path.join(calib_dir, p) for p in os.listdir(calib_dir)]
    data = []
    for i in range(batch):
        assert i < len(img_paths), "calibration images not enough."
        img_data = Image.open(img_paths[i]).convert('RGB')
        img_data = img_data.resize((shape[3], shape[2]), Image.BILINEAR)
        img_data = np.asarray(img_data, dtype=np.uint8)
        img_data = np.transpose(img_data, (2, 0, 1))
        data.append([img_data[np.newaxis, ...]])
    return np.array(data)


def main():
    parser = argparse.ArgumentParser(prog="nncase")
    parser.add_argument("--target", default="k230",type=str, help='target to run,k230/cpu')
    parser.add_argument("--model",type=str, help='model file')
    parser.add_argument("--dataset_path", type=str, help='calibration_dataset')
    parser.add_argument("--input_width", type=int, default=320, help='model input_width')
    parser.add_argument("--input_height", type=int, default=320, help='model input_height')
    parser.add_argument("--ptq_option", type=int, default=0, help='ptq_option:0,1,2,3,4,5')

    args = parser.parse_args()

    # # 更新参数为32倍数
    # input_width = int(math.ceil(args.input_width / 32.0)) * 32
    # input_height = int(math.ceil(args.input_height / 32.0)) * 32

    # CRNN:高度固定 32,宽度对齐到横向 stride 的倍数 
    # CRNN 识别模型在横向大概只下采样了 2 次(stride=2, stride=2),所以总 stride=4;输入宽度最好是 4 的倍数。
    stride_x = 4   # 你的模型最后输出 25 步,对应 W=100/4 → stride_x=4
    input_height = 32
    input_width  = int(math.ceil(args.input_width / float(stride_x))) * stride_x

    # 模型的输入shape,维度要跟input_layout一致
    input_shape=[1,3,input_height,input_width]

    dump_dir = 'tmp'
    if not os.path.exists(dump_dir):
        os.makedirs(dump_dir)

    # onnx simplify
    model_file = onnx_simplify(args.model, dump_dir,input_shape)

    # 设置CompileOptions
    compile_options = nncase.CompileOptions()
    compile_options.target = args.target

    # 是否采用kmodel模型做预处理
    compile_options.preprocess = True
    # onnx模型需要RGB的,k230上的摄像头给出的数据也是RGB格式的,因此不需要开启交换RB
    compile_options.swapRB = False
    # 输入图像的shape
    compile_options.input_shape = input_shape
    # 模型输入格式‘uint8’或者‘float32’
    compile_options.input_type = 'uint8'

    # 如果输入是‘uint8’格式,输入反量化之后的范围
    compile_options.input_range = [0, 1]
    # # 预处理的mean/std值,每个channel一个,该数据由YOLOv8源码获取
    # compile_options.mean = [0, 0, 0] 
    # compile_options.std = [1, 1, 1]

    # 预处理的mean/std值,每个channel一个
    compile_options.mean = [0.5, 0.5, 0.5] # 尝试
    compile_options.std = [0.5, 0.5, 0.5] # 尝试

    # 设置输入的layout,onnx默认‘NCHW’即可
    compile_options.input_layout = "NCHW"

    # 创建Compiler实例
    compiler = nncase.Compiler(compile_options)

    # 导入onnx模型
    model_content = read_model_file(model_file)
    import_options = nncase.ImportOptions()
    compiler.import_onnx(model_content, import_options)

    # 配置量化方式
    ptq_options = nncase.PTQTensorOptions()
    ptq_options.samples_count = 10

    if args.ptq_option == 0:
        ptq_options.calibrate_method = 'NoClip'
        ptq_options.quant_type = 'uint8'
        ptq_options.w_quant_type = 'uint8'
    elif args.ptq_option == 1:
        ptq_options.calibrate_method = 'NoClip'
        ptq_options.quant_type = 'uint8'
        ptq_options.w_quant_type = 'int16'
    elif args.ptq_option == 2:
        ptq_options.calibrate_method = 'NoClip'
        ptq_options.quant_type = 'int16'
        ptq_options.w_quant_type = 'uint8'
    elif args.ptq_option == 3:
        ptq_options.calibrate_method = 'Kld'
        ptq_options.quant_type = 'uint8'
        ptq_options.w_quant_type = 'uint8'
    elif args.ptq_option == 4:
        ptq_options.calibrate_method = 'Kld'
        ptq_options.quant_type = 'uint8'
        ptq_options.w_quant_type = 'int16'
    elif args.ptq_option == 5:
        ptq_options.calibrate_method = 'Kld'
        ptq_options.quant_type = 'int16'
        ptq_options.w_quant_type = 'uint8'
    else:
        pass

    # 设置校正数据
    ptq_options.set_tensor_data(generate_data(input_shape, ptq_options.samples_count, args.dataset_path))
    compiler.use_ptq(ptq_options)

    # 启动编译
    compiler.compile()

    # 写入kmodel文件
    kmodel = compiler.gencode_tobytes()
    base,ext=os.path.splitext(args.model)
    kmodel_name=base+".kmodel"
    with open(kmodel_name, 'wb') as f:
        f.write(kmodel)


if __name__ == '__main__':
    main()

如果尝试用int16量化,--ptq_option 1,会报错:

warn: Nncase.Hosting.PluginLoader[0]
      NNCASE_PLUGIN_PATH is not set.
[warn] downgrade ir_version 10 -> 7
Unhandled exception. System.AggregateException: One or more errors occurred. (assert(allocation.IsOk) error!
 File "/home/gitlab-runner/builds/zaC7hZ1H/1/maix2-ai-sw/k510-gnne-compiler/modules/Nncase.Modules.K230/Transform/Rules/Tile/TileLSTM.cs", line 201 .)
 ---> System.InvalidOperationException: assert(allocation.IsOk) error!
 File "/home/gitlab-runner/builds/zaC7hZ1H/1/maix2-ai-sw/k510-gnne-compiler/modules/Nncase.Modules.K230/Transform/Rules/Tile/TileLSTM.cs", line 201 .
   at Nncase.Passes.Rules.K230.TileUtilities.Assert(Boolean v, String vStr, String path, Int32 line)
   at Nncase.Passes.Rules.K230.TileLSTM.SearchGlbParameters()
   at Nncase.Passes.Rules.K230.TileLSTM.GetReplace(Expr output, Call midCall, IReadOnlyList`1 midCallParams)
   at Nncase.Passes.Rules.K230.TileLSTM.GetReplace(IMatchResult __result, RunPassContext __context)
   at Nncase.Passes.Rules.Tile.K230FusionConvertVisitor.Process(Fusion fusion)
   at Nncase.Passes.Rules.Tile.K230FusionConvertVisitor.RewriteLeafFusion(Fusion expr)
   at Nncase.IR.ExprVisitor`3.DispatchVisit(Expr expr, TContext context)
   at Nncase.IR.ExprRewriter`1.Rewrite(Expr expr, TContext context)
   at Nncase.IR.ExprRewriter.Rewrite(Expr expr)
   at Nncase.Passes.Rules.Tile.CheckedConvertMutator.RewriteLeafFusion(Fusion expr)
   at Nncase.IR.ExprVisitor`3.DispatchVisit(Expr expr, TContext context)
   at Nncase.IR.ExprRewriter`1.VisitOperands(Expr expr, TContext context)
   at Nncase.IR.ExprVisitor`3.VisitCall(Call expr, TContext context)
   at Nncase.IR.ExprVisitor`3.DispatchVisit(Expr expr, TContext context)
   at Nncase.IR.ExprRewriter`1.VisitOperands(Expr expr, TContext context)
   at Nncase.IR.ExprVisitor`3.VisitCall(Call expr, TContext context)
   at Nncase.IR.ExprVisitor`3.DispatchVisit(Expr expr, TContext context)
   at Nncase.IR.ExprRewriter`1.VisitOperands(Expr expr, TContext context)
   at Nncase.IR.ExprVisitor`3.VisitCall(Call expr, TContext context)
   at Nncase.IR.ExprVisitor`3.DispatchVisit(Expr expr, TContext context)
   at Nncase.IR.ExprRewriter`1.VisitOperands(Expr expr, TContext context)
   at Nncase.IR.ExprVisitor`3.VisitTuple(Tuple expr, TContext context)
   at Nncase.IR.ExprVisitor`3.DispatchVisit(Expr expr, TContext context)
   at Nncase.IR.ExprRewriter`1.VisitOperands(Expr expr, TContext context)
   at Nncase.IR.ExprVisitor`3.VisitCall(Call expr, TContext context)
   at Nncase.IR.ExprVisitor`3.DispatchVisit(Expr expr, TContext context)
   at Nncase.IR.ExprRewriter`1.VisitOperands(Expr expr, TContext context)
   at Nncase.IR.ExprVisitor`3.VisitCall(Call expr, TContext context)
   at Nncase.IR.ExprVisitor`3.DispatchVisit(Expr expr, TContext context)
   at Nncase.IR.ExprRewriter`1.VisitOperands(Expr expr, TContext context)
   at Nncase.IR.ExprVisitor`3.VisitCall(Call expr, TContext context)
   at Nncase.IR.ExprVisitor`3.DispatchVisit(Expr expr, TContext context)
   at Nncase.IR.ExprRewriter`1.VisitOperands(Expr expr, TContext context)
   at Nncase.IR.ExprVisitor`3.VisitCall(Call expr, TContext context)
   at Nncase.IR.ExprVisitor`3.DispatchVisit(Expr expr, TContext context)
   at Nncase.IR.ExprRewriter`1.VisitOperands(Expr expr, TContext context)
   at Nncase.IR.ExprVisitor`3.VisitCall(Call expr, TContext context)
   at Nncase.IR.ExprVisitor`3.DispatchVisit(Expr expr, TContext context)
   at Nncase.IR.ExprRewriter`1.VisitOperands(Expr expr, TContext context)
   at Nncase.IR.ExprVisitor`3.VisitCall(Call expr, TContext context)
   at Nncase.IR.ExprVisitor`3.DispatchVisit(Expr expr, TContext context)
   at Nncase.IR.ExprRewriter`1.VisitOperands(Expr expr, TContext context)
   at Nncase.IR.ExprVisitor`3.VisitFunction(Function expr, TContext context)
   at Nncase.IR.ExprVisitor`3.DispatchVisit(Expr expr, TContext context)
   at Nncase.IR.ExprRewriter`1.Rewrite(Expr expr, TContext context)
   at Nncase.IR.ExprRewriter.Rewrite(Expr expr)
   at Nncase.Passes.Rules.Tile.K230FusionToTirPass.RunCoreAsync(IRModule module, RunPassContext options)
   at Nncase.Passes.Pass`2.RunAsync(TInput input, RunPassContext context)
   at Nncase.Passes.PassManager.ModulePassGroup.RunAsync(IRModule module)
   at Nncase.Passes.PassManager.RunAsync(IRModule module)
   at Nncase.Compiler.Compiler.RunPassAsync(Action`1 register, String name, IProgress`1 progress, CancellationToken token)
   at Nncase.Compiler.Compiler.CompileAsync(IProgress`1 progress, CancellationToken token)
   --- End of inner exception stack trace ---
   at System.Threading.Tasks.Task.ThrowIfExceptional(Boolean includeTaskCanceledExceptions)
   at System.Threading.Tasks.Task.Wait(Int32 millisecondsTimeout, CancellationToken cancellationToken)
   at Nncase.Compiler.Interop.CApi.CompilerCompile(IntPtr compilerHandle)

硬件板卡


创乐博k230

软件版本


CanMV-K230-V3_sdcard__nncase_v2.9.0.img.gz

1 Answers

0.3不正常,可能是预处理步骤不一样?

预处理步骤弄成一模一样了,还是不行,不量化的话转换kmodel识别没有问题,请问有没有什么办法可以排查问题在哪里?

不量化是怎么设置的?

就是把这一行注释掉:
compiler.use_ptq(ptq_options)