Linux SDK下 OCR识别模型问题

Viewed 170

问题描述


image.png
我自己训练了一个OCR数字识别模型,能够直接替换掉demo中的模型呢。
直接替换后,是下面效果
image.png

硬件板卡


K230 01studio

软件版本


自己编译的linuxSDK,git commi版本commit 13b1e97db90a521e5d8ba23ce70e9f028b56182e

需要调整哪些参数呢?

4 Answers

用什么训练的?在线训练平台还是AICube?

在线平台训练的

那这个可能没办法调,在线训练平台为了简化参数都是设置为固定值,你这个训练的效果也不是很好。如果自己有计算资源,可以尝试安装AICube训练调试。

就是说我基于官方提供的方式训练出来直接替换是没问题的,只是训练的效果不理想,导致最终识别效果不好。是这样吗

best_OCR_RLNet_can2_5.0s_20260202163233.kmodel 为什么使用的模型是RLNet不是CRNN呢?

在线训练平台为了使得用户操作更简便,所以设置了默认的参数,不让用户配置大量复杂的超参,在线训练平台也只适用于简单的项目,如果你想完成复杂的可以调节参数配置多次训练的应用,要么使用AICube自己训练(需要自己有GPU),要么自己找开源的代码训练,然后转换为kmodel,自己写前后处理的代码。

Unhandled exception. System.AggregateException: One or more errors occurred. (assert(_value == 0) error!
File "/home/gitlab-runner/builds/zaC7hZ1H/1/maix2-ai-sw/k510-gnne-compiler/modules/Nncase.Modules.K230/Transform/Rules/Tile/TileHelper/SpaceSearcher.cs", line 1240 .)
---> System.InvalidOperationException: assert(_value == 0) error!
File "/home/gitlab-runner/builds/zaC7hZ1H/1/maix2-ai-sw/k510-gnne-compiler/modules/Nncase.Modules.K230/Transform/Rules/Tile/TileHelper/SpaceSearcher.cs", line 1240 .
at Nncase.Passes.Rules.K230.TileUtilities.Assert(Boolean v, String vStr, String path, Int32 line)
at Nncase.Passes.Rules.K230.Ccr.set_Value(Int32 value)
at Nncase.Passes.Rules.K230.CcrHandler.SetItems(List1 ccrsToSet) at Nncase.Passes.Rules.K230.GnneActionUpdater.UpdateCcr(List1 ccrsToSet, List1 ccrsToClr) at Nncase.Passes.Rules.K230.GnneActionUpdater.UpdateMfuAct1(SegmentND ifmap1, SegmentND ifmap2, SegmentND ofmap, ACT1_SOURCE_TYPE sourceType1, ACT1_SOURCE_TYPE sourceType2, DataType quantType1, DataType quantType2, DataType quantTypeD, DeQuantizeParam deqParams1, DeQuantizeParam deqParams2, Int32 rshiftBits1, Int32 rshiftBits2, Int32 rshiftBitsD, Boolean is16Segments, Int32 iPp, List1 ccrsToSet, List1 ccrsToClr, Int32 offsetS1, Int32 offsetS2, Int32 offsetD, Int32 offsetAct1, ItemName src2ItemName, ItemName src1ItemName, ItemName act1Name, MFU_ACT1_FUNCTION act1Mode, ItemName dstItemName, List1 src1Stride, List1 src2Stride, List1 dstStride, Boolean isByChannel)
at Nncase.Passes.Rules.K230.TileLayerGroup.BuildScheduleAct1(GnneActionUpdater actionUpdater, TiledGlb glb, Buffer ddrAct, Buffer ddrIf2, Boolean firstSlice)
at Nncase.Passes.Rules.K230.TileLayerGroup.BuildAct1(TiledGlb glb, List1 ifBuffers, Boolean firstSlice, Boolean isFirstSlice, List1 ifBuffersCopy)
at Nncase.Passes.Rules.K230.TileLayerGroup.BuildSchedule(FusionInfo fusionInfo)
at Nncase.Passes.Rules.Tile.FusionConvertVisitor.BuildSchedule(FusionType fusionType, FusionInfo fusionInfo)
at Nncase.Passes.Rules.Tile.MultiFusionChecker.Convert()
at Nncase.Passes.Rules.Tile.CheckedConvertMutator.RewriteLeafFusion(Fusion expr)
at Nncase.IR.ExprVisitor3.DispatchVisit(Expr expr, TContext context) at Nncase.IR.ExprRewriter1.VisitOperands(Expr expr, TContext context)
at Nncase.IR.ExprVisitor3.VisitCall(Call expr, TContext context) at Nncase.IR.ExprVisitor3.DispatchVisit(Expr expr, TContext context)
at Nncase.IR.ExprRewriter1.VisitOperands(Expr expr, TContext context) at Nncase.IR.ExprVisitor3.VisitTuple(Tuple expr, TContext context)
at Nncase.IR.ExprVisitor3.DispatchVisit(Expr expr, TContext context) at Nncase.IR.ExprRewriter1.VisitOperands(Expr expr, TContext context)
at Nncase.IR.ExprVisitor3.VisitCall(Call expr, TContext context) at Nncase.IR.ExprVisitor3.DispatchVisit(Expr expr, TContext context)
at Nncase.IR.ExprRewriter1.VisitOperands(Expr expr, TContext context) at Nncase.IR.ExprVisitor3.VisitCall(Call expr, TContext context)
at Nncase.IR.ExprVisitor3.DispatchVisit(Expr expr, TContext context) at Nncase.IR.ExprRewriter1.VisitOperands(Expr expr, TContext context)
at Nncase.IR.ExprVisitor3.VisitCall(Call expr, TContext context) at Nncase.IR.ExprVisitor3.DispatchVisit(Expr expr, TContext context)
at Nncase.IR.ExprRewriter1.VisitOperands(Expr expr, TContext context) at Nncase.IR.ExprVisitor3.VisitCall(Call expr, TContext context)
at Nncase.IR.ExprVisitor3.DispatchVisit(Expr expr, TContext context) at Nncase.IR.ExprRewriter1.VisitOperands(Expr expr, TContext context)
at Nncase.IR.ExprVisitor3.VisitCall(Call expr, TContext context) at Nncase.IR.ExprVisitor3.DispatchVisit(Expr expr, TContext context)
at Nncase.IR.ExprRewriter1.VisitOperands(Expr expr, TContext context) at Nncase.IR.ExprVisitor3.VisitCall(Call expr, TContext context)
at Nncase.IR.ExprVisitor3.DispatchVisit(Expr expr, TContext context) at Nncase.IR.ExprRewriter1.VisitOperands(Expr expr, TContext context)
at Nncase.IR.ExprVisitor3.VisitFunction(Function expr, TContext context) at Nncase.IR.ExprVisitor3.DispatchVisit(Expr expr, TContext context)
at Nncase.IR.ExprRewriter1.Rewrite(Expr expr, TContext context) at Nncase.IR.ExprRewriter.Rewrite(Expr expr) at Nncase.Passes.Rules.Tile.K230FusionToTirPass.RunCoreAsync(IRModule module, RunPassContext options) at Nncase.Passes.Pass2.RunAsync(TInput input, RunPassContext context)
at Nncase.Passes.PassManager.ModulePassGroup.RunAsync(IRModule module)
at Nncase.Passes.PassManager.RunAsync(IRModule module)
at Nncase.Compiler.Compiler.RunPassAsync(Action1 register, String name, IProgress1 progress, CancellationToken token)
at Nncase.Compiler.Compiler.CompileAsync(IProgress`1 progress, CancellationToken token)
--- End of inner exception stack trace ---
at System.Threading.Tasks.Task.ThrowIfExceptional(Boolean includeTaskCanceledExceptions)
at System.Threading.Tasks.Task.Wait(Int32 millisecondsTimeout, CancellationToken cancellationToken)
at Nncase.Compiler.Interop.CApi.CompilerCompile(IntPtr compilerHandle)
Aborted (core dumped)

用的什么转换脚本?

我把代码贴到下面

import os
import glob
import cv2
import numpy as np
ONNX_PATH = "./ocr/PaddleOCR/inference_model/rec_student_export/rec_fixed.onnx"
OUT_KMODEL = "./ocr/PaddleOCR/inference_model/rec_student_export/rec_int8.kmodel"
CALIB_DIR = "./ocr/PaddleOCR/images_numeric"

REC_IMAGE_SHAPE = (3, 48, 320)  # C,H,W


def set_plugin_path():
    if os.environ.get("NNCASE_PLUGIN_PATH"):
        return
    import site

    candidates = []
    try:
        candidates.extend(site.getsitepackages())
    except Exception:
        pass
    candidates.append(site.getusersitepackages())

    for base in candidates:
        mod_dir = os.path.join(base, "nncase", "modules")
        if os.path.isdir(mod_dir):
            os.environ["NNCASE_PLUGIN_PATH"] = mod_dir
            return


def resize_to_nchw_uint8(img, rec_image_shape):
    imgC, imgH, imgW = rec_image_shape
    assert imgC == img.shape[2]
    max_wh_ratio = imgW / float(imgH)
    imgW = int(imgH * max_wh_ratio)
    h, w = img.shape[:2]
    ratio = w / float(h)
    if int(np.ceil(imgH * ratio)) > imgW:
        resized_w = imgW
    else:
        resized_w = int(np.ceil(imgH * ratio))
    resized_image = cv2.resize(img, (resized_w, imgH))
    resized_image = resized_image.astype("uint8")
    resized_image = resized_image.transpose((2, 0, 1))
    padding_im = np.zeros((imgC, imgH, imgW), dtype=np.uint8)
    padding_im[:, :, 0:resized_w] = resized_image
    return padding_im


def load_calib_data(calib_dir, max_samples=50):
    exts = ("*.jpg", "*.jpeg", "*.png", "*.bmp")
    img_paths = []
    for e in exts:
        img_paths.extend(glob.glob(os.path.join(calib_dir, e)))
    img_paths = sorted(img_paths)[:max_samples]
    if not img_paths:
        raise RuntimeError(f"No calibration images found in {calib_dir}")

    data_list = []
    for p in img_paths:
        img = cv2.imread(p)
        if img is None:
            continue
        nchw = resize_to_nchw_uint8(img, REC_IMAGE_SHAPE)
        nchw = np.expand_dims(nchw, axis=0)  # NCHW
        data_list.append([nchw])

    if not data_list:
        raise RuntimeError("No valid calibration data could be loaded.")
    return data_list


def main():
    set_plugin_path()
    import nncase
    assert nncase.check_target("k230"), "nncase target k230 not available"

    with open(ONNX_PATH, "rb") as f:
        onnx_content = f.read()

    compile_options = nncase.CompileOptions()
    compile_options.target = "k230"
    compile_options.input_shape = list((1,) + REC_IMAGE_SHAPE)
    compile_options.preprocess = True
    compile_options.input_type = "uint8"
    compile_options.input_layout = "NCHW"
    compile_options.output_layout = "NCHW"
    compile_options.input_range = [0, 255]
    compile_options.mean = [127.5, 127.5, 127.5]
    compile_options.std = [127.5, 127.5, 127.5]

    compiler = nncase.Compiler(compile_options)

    import_options = nncase.ImportOptions()
    compiler.import_onnx(onnx_content, import_options)

    ptq = nncase.PTQTensorOptions()
    # keep default calibrate_method (Kld) unless plugin disallows
    ptq.samples_count = 50
    ptq.calibrate_method = "Kld"
    ptq.quant_type = "uint8"
    ptq.w_quant_type = "int8"
    ptq.use_mix_quant = True
    ptq.set_tensor_data(load_calib_data(CALIB_DIR, max_samples=ptq.samples_count))

    compiler.use_ptq(ptq)
    compiler.compile()

    os.makedirs(os.path.dirname(OUT_KMODEL), exist_ok=True)
    with open(OUT_KMODEL, "wb") as f:
        compiler.gencode(f)

    print(f"kmodel saved: {OUT_KMODEL}")


if __name__ == "__main__":
    main()

你自己写的代码?

GPT写的...

那应该是转换代码有问题