avatar

Text03's Blog

如果生活把你推到了板边, 记得升龙 →↓↘+👊

  • 首页
  • 链接
  • 关于
主页 trtyolo C++ API 文档 AI整理
文章

trtyolo C++ API 文档 AI整理

发表于 最近 更新于 最近
作者 text03
71~92 分钟 阅读

目录

  • 1. 快速开始
  • 2. 数据结构
    • 2.1 Image
    • 2.2 Mask
    • 2.3 KeyPoint
    • 2.4 Box
    • 2.5 RotatedBox
  • 3. 推理结果结构体
    • 3.1 BaseRes
    • 3.2 ClassifyRes
    • 3.3 DetectRes
    • 3.4 OBBRes
    • 3.5 SegmentRes
    • 3.6 PoseRes
  • 4. 推理配置
    • 4.1 InferOption
  • 5. 模型类
    • 5.1 BaseModel
    • 5.2 ClassifyModel
    • 5.3 DetectModel
    • 5.4 OBBModel
    • 5.5 SegmentModel
    • 5.6 PoseModel
  • 6. 约定与注意事项
    • 6.1 图像内存与像素格式
    • 6.2 批量推理
    • 6.3 线程与并发
    • 6.4 性能报告
    • 6.5 错误处理

1. 快速开始

以下示例展示“推理调用方式”。图像像素格式/数据类型的严格要求需要以你引擎/实现为准(见 6.1)。

1.1 检测(DetectModel)单张图像

#include "trtyolo.hpp"
#include <vector>

int main() {
    trtyolo::InferOption opt;
    opt.setDeviceId(0);
    opt.enableSwapRB();  // 如果输入来自 OpenCV(BGR) 且模型期望 RGB,可启用
    opt.setNormalizeParams({0.f, 0.f, 0.f}, {255.f, 255.f, 255.f}); // 示例:0~255 -> 0~1
    opt.setBorderValue(114.f); // 常见 letterbox padding 值(仅示例)

    trtyolo::DetectModel model("yolov8.engine", opt);

    // 假设你有一张 HWC 排列的图像数据 ptr(例如 cv::Mat.data)
    void* ptr = /* image data pointer */;
    int w = 640, h = 480, c = 3;
    size_t pitch = /* bytes per row, e.g. mat.step */;
    trtyolo::Image img(ptr, w, h, c, pitch);

    trtyolo::DetectRes res = model.predict(img);
    for (int i = 0; i < res.num; ++i) {
        const auto& box = res.boxes[i];
        int cls = res.classes[i];
        float score = res.scores[i];
        // box: left/top/right/bottom
    }
}

1.2 批量推理(多张图像)

std::vector<trtyolo::Image> imgs = { img1, img2, img3 };
std::vector<trtyolo::DetectRes> results = model.predict(imgs);

1.3 并发(clone)

auto worker = model.clone();  // 通常用于每个线程各持有一个实例
auto r = worker->predict(img);


2. 数据结构

2.1 Image

用于描述输入图像:只保存“指针 + 尺寸信息”,不拥有内存。

struct Image {
    void*  ptr;
    int    width;
    int    height;
    int    channels;
    size_t pitch;
};

字段说明

  • ptr:图像数据指针(Host 内存或 GPU 显存,取决于 InferOption 设置)

  • width / height:图像宽高(像素)

  • channels:通道数(常见为 3)

  • pitch:每行字节数(可包含 padding),用于处理非紧密排列的行跨度

构造函数

  • Image(void* data, int width, int height)

    • 适用于无 padding 的紧密排列数据(pitch 将由实现推断/设置)
  • Image(void* data, int width, int height, size_t pitch)

    • 适用于已知行跨度的情况(如 OpenCV 的 mat.step)
  • Image(void* data, int width, int height, int channels, size_t pitch)

    • 显式指定通道数和 pitch

输出

  • 支持 operator<< 便于打印调试:Image(width=..., height=..., channels=..., pitch=..., ptr=...)

2.2 Mask

用于分割任务的掩码输出。

struct Mask {
    std::vector<float> data;
    int width;
    int height;

    Mask(int width, int height);
};

说明

  • data:掩码数据(浮点数组)

  • width / height:掩码尺寸

  • Mask(width, height):会初始化尺寸并分配相应数据空间(具体填充由实现完成)


2.3 KeyPoint

用于姿态估计(Pose)的关键点输出。

struct KeyPoint {
    float x;
    float y;
    std::optional<float> conf;

    KeyPoint(float x, float y, std::optional<float> conf = std::nullopt);
};

说明

  • x / y:关键点坐标

  • conf:可选置信度(存在则输出,否则为空)


2.4 Box

轴对齐矩形框(AABB)。

struct Box {
    float left, top, right, bottom;

    Box(float left, float top, float right, float bottom);
    std::array<int, 4> xyxy() const;
};

说明

  • 坐标为 left, top, right, bottom

  • xyxy():返回 {left, top, right, bottom} 的 int 数组(通常用于绘制/后处理整型化)


2.5 RotatedBox

旋转矩形框(OBB),继承自 Box,增加角度。

struct RotatedBox : public Box {
    float theta; // 弧度,顺时针,从正 x 轴开始

    RotatedBox(float left, float top, float right, float bottom, float theta);
    std::array<int, 8> xyxyxyxy() const;
};

说明

  • theta:弧度制,顺时针方向,从正 x 轴开始测量

  • xyxyxyxy():返回四个顶点的整型坐标 {x1,y1,x2,y2,x3,y3,x4,y4}


3. 推理结果结构体

3.1 BaseRes

所有任务结果的基础字段。

struct BaseRes {
    int num;
    std::vector<int> classes;
    std::vector<float> scores;

    BaseRes() = default;
    BaseRes(int num, const std::vector<int>& classes, const std::vector<float>& scores);
};

说明

  • num:结果数量

  • classes:每个结果的类别 id

  • scores:每个结果的置信度/得分

  • 一般约定:classes.size() == scores.size() == num


3.2 ClassifyRes

分类结果(继承 BaseRes,无新增字段)。

struct ClassifyRes : public BaseRes {};


3.3 DetectRes

检测结果(AABB)。

struct DetectRes : public BaseRes {
    std::vector<Box> boxes;

    DetectRes() = default;
    DetectRes(int num,
              const std::vector<int>& classes,
              const std::vector<float>& scores,
              const std::vector<Box>& boxes);
};

说明

  • boxes:与 classes/scores 一一对应的矩形框

  • 一般约定:boxes.size() == num


3.4 OBBRes

旋转框检测结果(OBB)。

struct OBBRes : public BaseRes {
    std::vector<RotatedBox> boxes;

    OBBRes() = default;
    OBBRes(int num,
           const std::vector<int>& classes,
           const std::vector<float>& scores,
           const std::vector<RotatedBox>& boxes);
};


3.5 SegmentRes

实例分割结果:框 + 掩码。

struct SegmentRes : public BaseRes {
    std::vector<Box> boxes;
    std::vector<Mask> masks;

    SegmentRes() = default;
    SegmentRes(int num,
               const std::vector<int>& classes,
               const std::vector<float>& scores,
               const std::vector<Box>& boxes,
               const std::vector<Mask>& masks);
};

说明

  • masks:每个实例对应一个掩码(与 boxes/classes/scores 对齐)

  • 一般约定:boxes.size() == masks.size() == num


3.6 PoseRes

姿态估计:框 + 关键点列表。

struct PoseRes : public BaseRes {
    std::vector<Box> boxes;
    std::vector<std::vector<KeyPoint>> kpts;

    PoseRes() = default;
    PoseRes(int num,
            const std::vector<int>& classes,
            const std::vector<float>& scores,
            const std::vector<Box>& boxes,
            const std::vector<std::vector<KeyPoint>>& kpts);
};

说明

  • kpts[i]:第 i 个检测目标的关键点序列

  • 一般约定:boxes.size() == kpts.size() == num


4. 推理配置

4.1 InferOption

推理选项配置类(PImpl 隐藏实现)。

class InferOption {
public:
    InferOption();
    ~InferOption();

    void setDeviceId(int id);
    void enableCudaMem();
    void enableManagedMemory();
    void enablePerformanceReport();
    void enableSwapRB();
    void setBorderValue(float border_value);
    void setNormalizeParams(const std::vector<float>& mean, const std::vector<float>& std);
    void setInputDimensions(int width, int height);
};

方法说明

  • setDeviceId(int id)
    设置 GPU 设备号(如 0/1/2…)

  • enableCudaMem()
    指示推理输入数据位于 CUDA 显存中(通常意味着 Image::ptr 是 device pointer)

  • enableManagedMemory()
    启用统一内存(managed memory)

  • enablePerformanceReport()
    启用性能报告统计(用于 BaseModel::performanceReport())

  • enableSwapRB()
    启用通道交换(R/B 互换,常用于 BGR<->RGB 转换)

  • setBorderValue(float border_value)
    设置边界填充值(常用于 letterbox/padding 的填充值)

  • setNormalizeParams(mean, std)
    设置归一化参数(通常为逐通道 mean/std)。
    建议 mean.size() == std.size() == channels(常见 3)

  • setInputDimensions(int width, int height)
    固定输入宽高。未设置时表示输入宽高可变。
    注释提示:适用于输入宽高恒定的场景(如监控视频分析等)。


5. 模型类

5.1 BaseModel

所有模型的基类(PImpl 隐藏实现)。

class BaseModel {
public:
    BaseModel();
    ~BaseModel();

    explicit BaseModel(const std::string& trt_engine_file, const InferOption& infer_option);

    int batch() const;

    std::tuple<std::string, std::string, std::string> performanceReport();
};

方法说明

  • 构造:BaseModel(trt_engine_file, infer_option)

    • trt_engine_file:TensorRT 引擎文件路径(.engine 等)

    • infer_option:推理配置

  • int batch() const
    获取模型支持的 batch 大小(通常与引擎构建有关)

  • performanceReport()
    返回性能报告三元组:

    • 吞吐量(throughput)字符串

    • CPU 延迟字符串

    • GPU 延迟字符串

    仅在 InferOption::enablePerformanceReport() 启用后有意义

头文件里无参构造函数在 public 区域,但注释标明“仅在 clone 方法中使用”。建议按该意图使用。


5.2 ClassifyModel

分类模型。

class ClassifyModel : public BaseModel {
public:
    ClassifyModel();
    ~ClassifyModel();

    explicit ClassifyModel(const std::string& trt_engine_file, const InferOption& infer_option);

    std::unique_ptr<ClassifyModel> clone() const;

    ClassifyRes predict(const Image& image);
    std::vector<ClassifyRes> predict(const std::vector<Image>& images);
};


5.3 DetectModel

检测模型(AABB)。

class DetectModel : public BaseModel {
public:
    DetectModel();
    ~DetectModel();

    explicit DetectModel(const std::string& trt_engine_file, const InferOption& infer_option);

    std::unique_ptr<DetectModel> clone() const;

    DetectRes predict(const Image& image);
    std::vector<DetectRes> predict(const std::vector<Image>& images);
};


5.4 OBBModel

旋转框检测模型(OBB)。

class OBBModel : public BaseModel {
public:
    OBBModel();
    ~OBBModel();

    explicit OBBModel(const std::string& trt_engine_file, const InferOption& infer_option);

    std::unique_ptr<OBBModel> clone() const;

    OBBRes predict(const Image& image);
    std::vector<OBBRes> predict(const std::vector<Image>& images);
};


5.5 SegmentModel

分割模型。

class SegmentModel : public BaseModel {
public:
    SegmentModel();
    ~SegmentModel();

    explicit SegmentModel(const std::string& trt_engine_file, const InferOption& infer_option);

    std::unique_ptr<SegmentModel> clone() const;

    SegmentRes predict(const Image& image);
    std::vector<SegmentRes> predict(const std::vector<Image>& images);
};


5.6 PoseModel

姿态估计模型。

class PoseModel : public BaseModel {
public:
    PoseModel();
    ~PoseModel();

    explicit PoseModel(const std::string& trt_engine_file, const InferOption& infer_option);

    std::unique_ptr<PoseModel> clone() const;

    PoseRes predict(const Image& image);
    std::vector<PoseRes> predict(const std::vector<Image>& images);
};


6. 约定与注意事项

6.1 图像内存与像素格式

头文件只提供了 void* ptr + width/height/channels/pitch,没有在接口层声明:

  • 数据类型(uint8 / float / half)

  • 排列方式(HWC / CHW)

  • 色彩空间(BGR / RGB)

  • ptr 指向 host 还是 device

因此建议按以下方式理解,并以你的实现/引擎为准:

  • 默认情况(未 enableCudaMem):Image::ptr 很可能是 CPU 内存指针

  • enableCudaMem():Image::ptr 很可能应为 CUDA device pointer

  • enableSwapRB() 明确存在:说明库内部可能会对 R/B 通道做交换(通常与 OpenCV 的 BGR 输入有关)

  • setNormalizeParams():说明库内部可能会进行 (x - mean) / std 或类似归一化

实际输入预处理(resize/letterbox/归一化/颜色转换等)是此类库的关键点:如果你需要我把这些“约定”写成更准确的文档,请再提供对应的 .cpp 实现或 README。


6.2 批量推理

predict(const std::vector<Image>& images) 的输入图片数量与 batch() 的关系在头文件中未约束。

建议遵循:

  • 如果引擎是固定 batch:images.size() 与 batch() 保持一致

  • 如果引擎支持动态 batch:按实现支持的范围传入(至少不应超过 batch())


6.3 线程与并发

提供了 clone() 方法,通常意味着:

  • 同一个 Model 实例可能不适合跨线程并发调用

  • 建议“每个线程持有一个 clone 出来的实例”,以减少上下文/stream 争用


6.4 性能报告

  • 先调用 InferOption::enablePerformanceReport()

  • 再通过 BaseModel::performanceReport() 获取报告三元组(吞吐、CPU 延迟、GPU 延迟)

  • 报告以 std::string 形式返回,适合直接打印或上报


6.5 错误处理

头文件未说明错误返回方式(返回码/异常/日志)。常见可能性:

  • 构造模型时:引擎文件不可读、TensorRT 反序列化失败、设备不匹配等

  • 推理时:输入维度/内存类型不匹配等

建议你的调用侧:

  • 在模型构造、predict 周围增加异常捕获(如 try/catch)

  • 记录 engine 路径、输入尺寸、batch 数等关键上下文信息,便于排查


附:接口一览(便于检索)

  • trtyolo::Image

  • trtyolo::Mask

  • trtyolo::KeyPoint

  • trtyolo::Box

  • trtyolo::RotatedBox

  • trtyolo::BaseRes

  • trtyolo::ClassifyRes

  • trtyolo::DetectRes

  • trtyolo::OBBRes

  • trtyolo::SegmentRes

  • trtyolo::PoseRes

  • trtyolo::InferOption

  • trtyolo::BaseModel

  • trtyolo::ClassifyModel

  • trtyolo::DetectModel

  • trtyolo::OBBModel

  • trtyolo::SegmentModel

  • trtyolo::PoseModel

许可协议:  CC BY 4.0
分享

相关文章

下一篇

上一篇

OpenVR 驱动教程

最近更新

  • trtyolo C++ API 文档 AI整理
  • OpenVR 驱动教程
  • OpenVR Driver API 中文文档
  • 每周考研阅读 2021 Text 1

热门标签

Halo

目录

©2026 Text03's Blog. 保留部分权利。

鲁ICP备2025195077号-1 | 鲁公网安备37092302000179号

使用 Halo 主题 Chirpy