使用nncase将onnx转kmodel时,面对多个参数输入该如何设置

Viewed 138

问题描述


我想实现基于vit的双模态的情感检测,使用视频和音频模态,在训练后得到pth文件,转为onnx后再转为onnxsim,但是在将onnx转kmodel时报错,下面附上代码。

export_kmodel.py代码

import os
import nncase
import numpy as np
import torch
import onnx
from onnxsim import simplify  # 引入 simplifier
from torch.utils.data import DataLoader
from dataset import RavdessDataset

# ================= 基础配置 =================
RAW_ONNX_FILE = "fusion_model.onnx"  # 原始导出的模型
SIM_ONNX_FILE = "fusion_model_sim.onnx"  # 简化后的模型 (将生成这个文件)
KMODEL_FILE = "fusion_model.kmodel"  # 最终输出
DATASET_DIR = r"D:\Learning\Python_learning\dataset\P-Ravdess"
CALIB_COUNT = 50


# ===========================================

def simplify_model(input_path, output_path):
    """
    使用 onnx-simplifier 优化模型结构
    """
    print(f"[*] 正在运行 ONNX Simplifier...")
    print(f"    输入: {input_path}")

    # 1. 加载模型
    model = onnx.load(input_path)

    # 2. 执行简化
    # check_n=3 表示会随机生成3组数据验证简化前后精度是否一致
    model_sim, check = simplify(model, check_n=3)

    if not check:
        print("❌ 警告: ONNX Sim 校验失败!简化后的模型可能精度受损。")
        # 视情况决定是否继续,这里选择抛出异常终止
        raise RuntimeError("ONNX Simplifier validation failed.")

    # 3. 保存简化后的模型
    onnx.save(model_sim, output_path)
    print(f"✅ 简化成功! 已保存至: {output_path}")
    return True


def get_calib_data():
    print("[*] 读取校准数据...")
    dataset = RavdessDataset(DATASET_DIR)
    loader = DataLoader(dataset, batch_size=1, shuffle=True)

    calib_data = []
    for i, (img, aud, _) in enumerate(loader):
        if i >= CALIB_COUNT: break

        img_np = img.numpy().astype(np.float32)
        aud_np = aud.numpy().astype(np.float32)

        # 默认假设:列表顺序即为输入顺序
        calib_data.append([img_np, aud_np])

    return calib_data


def main():
    print("------------------------------------------------")
    print("  NNCase Standard Converter (with ONNX-Sim)")
    print("------------------------------------------------")

    if not os.path.exists(RAW_ONNX_FILE):
        print(f"❌ 错误: 找不到原始文件 {RAW_ONNX_FILE}")
        return

    # ==========================================
    # 第一步:运行 ONNX Simplifier
    # ==========================================
    try:
        simplify_model(RAW_ONNX_FILE, SIM_ONNX_FILE)
    except Exception as e:
        print(f"❌ 模型简化失败: {e}")
        return

    # ==========================================
    # 第二步:NNCase 编译流程
    # ==========================================

    # 1. 设置编译选项
    compile_options = nncase.CompileOptions()
    compile_options.target = "k230"
    compile_options.input_type = "float32"
    compile_options.input_layout = "NCHW"
    compile_options.dump_ir = False

    compiler = nncase.Compiler(compile_options)

    # 2. 导入 **简化后** 的模型
    print(f"[*] 正在导入简化模型: {SIM_ONNX_FILE}")
    with open(SIM_ONNX_FILE, 'rb') as f:
        compiler.import_onnx(f.read(), nncase.ImportOptions())

    # 3. 配置量化
    print("[*] 配置量化参数 (PTQ)...")
    ptq_options = nncase.PTQTensorOptions()
    ptq_options.w_quant_type = "int8"
    ptq_options.i_quant_type = "int8"
    ptq_options.calibrate_method = "Kld"

    # 4. 加载数据
    calib_data = get_calib_data()
    ptq_options.set_tensor_data(calib_data)
    ptq_options.samples_count = len(calib_data)

    compiler.use_ptq(ptq_options)

    # 5. 编译
    print("[*] 开始编译...")
    try:
        compiler.compile()

        kmodel = compiler.gencode_tobytes()
        with open(KMODEL_FILE, 'wb') as f:
            f.write(kmodel)

        print(f"\n✅ 转换成功! 文件已保存: {KMODEL_FILE}")

    except Exception as e:
        print(f"\n❌ 编译发生错误: {e}")


if __name__ == "__main__":
    main()

我这里输入图片和音频两个参数,参数分别为:

图片:[1,3,224,224]
音频:[1,1,128,128]

但是在运行时报错,说尺寸不合规,终端输出如下:

warn: Nncase.Hosting.PluginLoader[0]
      NNCASE_PLUGIN_PATH is not set.
------------------------------------------------
  NNCase Standard Converter (with ONNX-Sim)
------------------------------------------------
[*] 正在运行 ONNX Simplifier...
    输入: fusion_model.onnx
Checking 0/3...
Checking 1/3...
Checking 2/3...
✅ 简化成功! 已保存至: fusion_model_sim.onnx
[*] 正在导入简化模型: fusion_model_sim.onnx
[*] 配置量化参数 (PTQ)...
[*] 读取校准数据...
[train] 数据集加载完毕: 1440 个样本
[*] 开始编译...

报错位置:
Unhandled exception. System.AggregateException: One or more errors occurred. (Feed Value Is Invalid, need f32[1,1,128,128] but get f32[1,3,224,224]!)
 ---> System.InvalidOperationException: Feed Value Is Invalid, need f32[1,1,128,128] but get f32[1,3,224,224]!
   at Nncase.Quantization.CalibrationEvaluator.<>c__DisplayClass11_0.<Visit>b__0()
   at Nncase.Quantization.CalibrationEvaluator.VisitLeaf(ENode enode, Func`1 valueGetter)
   at Nncase.Quantization.CalibrationEvaluator.Visit(ENode enode, Var var)
   at Nncase.Quantization.CalibrationEvaluator.Visit(ENode enode)
   at Nncase.Quantization.CalibrationEvaluator.Visit(EClass eclass)
   at Nncase.Quantization.CalibrationEvaluator.Visit(ENode enode, Func`2 valueGetter)

dataset.py代码

import os
import random
import glob
from PIL import Image
import torch
from torch.utils.data import Dataset
import torchaudio
import torchvision.transforms as transforms
import numpy as np


class RavdessDataset(Dataset):
    def __init__(self, root_dir, phase='train', target_sample_rate=16000, target_len=128):
        """
        root_dir: 数据集根目录
        phase: 'train' 或 'val'
        target_sample_rate: 目标采样率 (默认 16000)
        target_len: 音频时间轴长度 (默认 128)
        """
        self.root_dir = root_dir
        self.phase = phase
        self.audio_root = os.path.join(root_dir, 'Audio')
        self.visual_root = os.path.join(root_dir, 'Visual')
        self.target_sample_rate = target_sample_rate
        self.target_len = target_len

        # 检查路径是否存在
        if not os.path.exists(self.audio_root) or not os.path.exists(self.visual_root):
            raise ValueError(f"数据集路径错误!请检查 {self.audio_root} 和 {self.visual_root}")

        self.classes = sorted(os.listdir(self.audio_root))
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}

        self.samples = []
        # 遍历收集样本
        for cls_name in self.classes:
            cls_folder = os.path.join(self.audio_root, cls_name)
            if not os.path.isdir(cls_folder): continue

            audio_files = glob.glob(os.path.join(cls_folder, "*.wav"))

            for audio_path in audio_files:
                # 文件名关联逻辑
                filename = os.path.basename(audio_path)
                video_name_no_ext = filename.replace(".wav", "")
                visual_folder = os.path.join(self.visual_root, cls_name, video_name_no_ext)

                # 只有当音频和视频都存在时才加入列表
                if os.path.exists(visual_folder) and len(glob.glob(os.path.join(visual_folder, "*.jpg"))) > 0:
                    self.samples.append((audio_path, visual_folder, self.class_to_idx[cls_name]))

        print(f"[{phase}] 数据集加载完毕: {len(self.samples)} 个样本")

        # === 视觉预处理 ===
        self.visual_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            # ImageNet 标准归一化
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        # === 音频特征提取器 ===
        self.mel_spectrogram = torchaudio.transforms.MelSpectrogram(
            sample_rate=self.target_sample_rate,
            n_mels=128,  # 对应频域高度 (Height)
            n_fft=1024,
            hop_length=512
        )
        self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        audio_path, visual_folder, label = self.samples[idx]

        # -----------------------------------------------------------
        # 1. 视觉处理 (Visual Processing) -> [3, 224, 224]
        # -----------------------------------------------------------
        frames = glob.glob(os.path.join(visual_folder, "*.jpg"))

        # 容错:如果没图片,生成黑图
        if len(frames) == 0:
            image = torch.zeros((3, 224, 224), dtype=torch.float32)
        else:
            # 训练时随机抽帧,测试时固定抽中间帧
            if self.phase == 'train':
                selected_frame = random.choice(frames)
            else:
                selected_frame = frames[len(frames) // 2]

            try:
                img_pil = Image.open(selected_frame).convert('RGB')
                image = self.visual_transform(img_pil)
            except Exception as e:
                print(f"图片读取错误 {selected_frame}: {e}")
                image = torch.zeros((3, 224, 224), dtype=torch.float32)

        # -----------------------------------------------------------
        # 2. 音频处理 (Audio Processing) -> [1, 128, 128]
        # -----------------------------------------------------------
        try:
            waveform, sr = torchaudio.load(audio_path)

            # 🔥 修复 1: 强制重采样 (Resample)
            if sr != self.target_sample_rate:
                resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=self.target_sample_rate)
                waveform = resampler(waveform)

            # 🔥 修复 2: 强制单声道 (Mix to Mono)
            # 如果是立体声 [2, Time],取平均值变成 [1, Time]
            if waveform.shape[0] > 1:
                waveform = torch.mean(waveform, dim=0, keepdim=True)

            # 生成 Mel 谱图 -> [1, 128, Time]
            spec = self.mel_spectrogram(waveform)
            spec = self.amplitude_to_db(spec)

            # 🔥 修复 3: 固定时间长度 (Padding / Truncate)
            current_len = spec.shape[2]
            if current_len < self.target_len:
                # 补零
                padding = self.target_len - current_len
                spec = torch.nn.functional.pad(spec, (0, padding))
            else:
                # 截断
                spec = spec[:, :, :self.target_len]

            # 确保是 3 维 [1, 128, 128]
            # MelSpectrogram 通常保留 Channel 维,但为了保险起见:
            if spec.dim() == 2:
                spec = spec.unsqueeze(0)

        except Exception as e:
            print(f"音频处理错误 {audio_path}: {e}")
            spec = torch.zeros((1, 128, self.target_len), dtype=torch.float32)

        return image, spec, label


# 简单的自测代码
if __name__ == "__main__":
    # 请修改为你的实际路径
    root = r"D:\Learning\Python_learning\dataset\P-Ravdess"
    ds = RavdessDataset(root)
    img, aud, lbl = ds[0]

    print("\n✅ Dataset 自检通过!")
    print(f"   Image Shape: {img.shape}  (Expect: [3, 224, 224])")
    print(f"   Audio Shape: {aud.shape}  (Expect: [1, 128, 128])")

    # 验证是否为单声道
    if aud.shape[0] != 1:
        print("❌ 警告:音频不是单声道!")

后面我尝试将数据顺序颠倒,先输入音频再输入图片,但是还是报错,这次是:

need f32[1,3,224,224] but get f32[1,1,128,128]!

头大了,不知道是哪里的问题,是nncase对多输入模型有严格的规范吗?

硬件板卡


亚博K230视觉识别模块

软件版本


nncase_v2.8.0_onnx_v1.14.0_onnx-simplifier_v0.4.33

其他信息


验证了模型是没有问题的,准确率很高,就是卡在转换为kmodel这里了

硬件板卡


亚博K230视觉识别模块

软件版本


nncase_v2.8.0_onnx_v1.14.0_onnx-simplifier_v0.4.33

其他信息


在之前转换过单模态图像输入模型,当时成功转换了,环境应该没有问题。

2 Answers

你好,可以参考如下代码,主要是设置input_shapes为多个。

import os
import argparse
import numpy as np
from PIL import Image
import onnxsim
import onnx
import nncase


def parse_model_input_output(model_file):
    onnx_model = onnx.load(model_file)
    input_all = [node.name for node in onnx_model.graph.input]
    input_initializer = [node.name for node in onnx_model.graph.initializer]
    input_names = list(set(input_all) - set(input_initializer))
    input_tensors = [
        node for node in onnx_model.graph.input if node.name in input_names]

    # input
    inputs = []
    for _, e in enumerate(input_tensors):
        onnx_type = e.type.tensor_type
        input_dict = {}
        input_dict['name'] = e.name
        input_dict['dtype'] = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[onnx_type.elem_type]
        input_dict['shape'] = [(i.dim_value if i.dim_value != 0 else d) for i, d in zip(
            onnx_type.shape.dim, [1, 3, 256, 256])]
        inputs.append(input_dict)

    return onnx_model, inputs


def onnx_simplify(model_file, dump_dir):
    onnx_model, inputs = parse_model_input_output(model_file)
    onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
    input_shapes = {}
    for input in inputs:
        input_shapes[input['name']] = input['shape']

    onnx_model, check = onnxsim.simplify(onnx_model, input_shapes=input_shapes)
    assert check, "Simplified ONNX model could not be validated"

    model_file = os.path.join(dump_dir, 'simplified.onnx')
    onnx.save_model(onnx_model, model_file)
    return model_file


def read_model_file(model_file):
    with open(model_file, 'rb') as f:
        model_content = f.read()
    return model_content


def generate_data_ramdom(shapes, sample):
    data_all = []
    for shape in shapes:
        data = []
        for i in range(sample):
            data.append(np.random.randint(0, 1, shape).astype(np.float32))
        data_all.append(data)
    print(data_all)
    return data_all


def generate_data(shape, batch, calib_dir):
    img_paths = [os.path.join(calib_dir, p) for p in os.listdir(calib_dir)]
    data = []
    for i in range(batch):
        assert i < len(img_paths), "calibration images not enough."
        img_data = Image.open(img_paths[i]).convert('RGB')
        img_data = img_data.resize((shape[3], shape[2]), Image.BILINEAR)
        img_data = np.asarray(img_data, dtype=np.uint8)
        img_data = np.transpose(img_data, (2, 0, 1))
        data.append([img_data[np.newaxis, ...]])
    return data


def main():
    parser = argparse.ArgumentParser(prog="nncase")
    parser.add_argument("--target", type=str, default='cpu', help='target to run')
    parser.add_argument("--model", type=str,
                        default='./resnet_50_size-256.onnx',
                        help='model file')
    parser.add_argument("--dataset", type=str, default='random')
    parser.add_argument("--num_data", type=int, default=5)
    parser.add_argument("--ptq", type=str, help='ptq method,such as int8,int16,wint16,NoClip_int16,NoClip_wint16',default='int8')
    args = parser.parse_args()

    # input_shape = [1, 3, 256, 256]
    input_shapes = [[1, 48, 8, 8], [1, 48, 16, 16]]

    ptq_method = args.ptq

    dump_dir = 'tmp/nanotracker_head'
    if not os.path.exists(dump_dir):
        os.makedirs(dump_dir)

    # onnx simplify
    model_file = onnx_simplify(args.model, dump_dir)

    # compile_options
    compile_options = nncase.CompileOptions()
    compile_options.target = args.target
    compile_options.preprocess = False
    compile_options.swapRB = True
    # compile_options.input_shape = input_shape
    compile_options.input_type = 'uint8'
    compile_options.input_range = [0, 255]
    compile_options.mean = [0, 0, 0]
    compile_options.std = [1, 1, 1]
    compile_options.input_layout = 'NCHW'
    # compile_options.output_layout = 'NHWC'
    compile_options.dump_ir = True
    compile_options.dump_asm = True
    compile_options.dump_dir = dump_dir

    # compile_options.quant_type = 'uint8'

    # compile_options.quant_type = 'uint8'
    if ptq_method=="int8" or ptq_method == "int16":
        compile_options.quant_type = ptq_method
    elif ptq_method == "wint16":
        compile_options.w_quant_type = 'int16'
    elif ptq_method == "NoClip_int16":
        compile_options.calibrate_method = 'NoClip'
        compile_options.quant_type = 'int16'
    elif ptq_method == "NoClip_wint16":
        compile_options.calibrate_method = 'NoClip'
        compile_options.w_quant_type = 'int16'
    else:
        pass

    compile_options.calibrate_method = 'NoClip'
    compile_options.quant_type = 'uint8'

    # compiler
    compiler = nncase.Compiler(compile_options)

    # import
    model_content = read_model_file(model_file)
    import_options = nncase.ImportOptions()
    compiler.import_onnx(model_content, import_options)

    # ptq_options
    ptq_options = nncase.PTQTensorOptions()
    # ptq_options.samples_count = args.num_data
    if args.dataset == 'random':
        ptq_options.set_tensor_data(generate_data_ramdom(input_shapes, ptq_options.samples_count))
    else:
        # ptq_options.set_tensor_data(generate_data(input_shape, ptq_options.samples_count, args.dataset))
        input_shape = [
            [1, 48, 8, 8],
            [1, 48, 16, 16]
        ]
        
        calib_data = [ [],[] ]
        for idx in range(167+1):
            # in_name = 'input_%03d' % idx
            input_0_path = os.path.join( args.dataset,"crop_output_{}.bin".format( idx ) ) 
            bin_data = np.fromfile(input_0_path, np.float32)
            bin_data = np.reshape(bin_data, input_shape[0])
            calib_data[0].append(bin_data)

            input_1_path = os.path.join( args.dataset,"src_output_{}.bin".format( idx ) ) 
            bin_data = np.fromfile(input_1_path, np.float32)
            bin_data = np.reshape(bin_data, input_shape[1])
            calib_data[1].append(bin_data)

        ptq_options.samples_count = len( calib_data[0] )
        ptq_options.set_tensor_data(calib_data)


    compiler.use_ptq(ptq_options)

    # compile
    compiler.compile()

    # kmodel
    kmodel = compiler.gencode_tobytes()
    with open(os.path.join('./data/models', 'nanotracker_head_calib_{}.kmodel'.format(args.target)), 'wb') as f:
        f.write(kmodel)


if __name__ == '__main__':
    main()

成功跑通了,非常感谢👍👍

我之前一直以为nncase需要的数据格式是: [ [样本1_图片, 样本1_音频], [样本2_图片, 样本2_音频], ... ]

实际它要的结构是: [ [所有图片的列表], [所有音频的列表] ]

总的来说就是按输入通道分组!通道一放所有的图片,通道二放所有的音频。