在上一篇文章中,我们已经带大家了解了多输入多输出(MIMO)能力的架构设计思路。
今天,小编将继续深入解析如何将架构设计真正落地到可运行代码,并带来一套可复用的核心实现。会介绍多输入多输出支持框架的关键组成部分。通过清晰的结构化设计、类型安全的接口抽象,为复杂的嵌入式 AI 模型建立一个高扩展性、高可维护性的基础底座。
下面,我们就将通过头文件设计、基础数据结构构建、生命周期管理等内容一步步展示一个完整的MIMO支持框架是如何搭建起来的。
话不多说,上代码!(代码超载预警)
头文件设计:构建类型安全的基础
首先,我们需要一个类型和接口定义完备、可扩展性强的头文件model.h。
这一部分为后续的MIMO管理、张量访问、预处理、模型统计等功能奠定了坚实基础。
#ifndefMODEL_H
#defineMODEL_H
#include"tensorflow/lite/c/common.h"
// =============================================================================
// 配置常量
// =============================================================================
#defineMAX_INPUT_TENSORS 8 // 最大输入张量数量
#defineMAX_OUTPUT_TENSORS 8 // 最大输出张量数量
#defineMAX_TENSOR_DIMS 6 // 最大张量维度数
#defineMODEL_NAME_MAX_LEN 64 // 模型名称最大长度
// =============================================================================
// 状态码定义
// =============================================================================
typedefenum{
kStatus_Success =0,
kStatus_Fail = 1,
kStatus_InvalidParam =2,
kStatus_OutOfRange =3,
kStatus_NotInitialized =4,
kStatus_InsufficientMemory =5
}status_t;
// =============================================================================
// 张量相关类型定义
// =============================================================================
typedefenum{
kTensorType_FLOAT32 =0,
kTensorType_UINT8 =1,
kTensorType_INT8 =2,
kTensorType_INT32 =3,
kTensorType_BOOL =4,
kTensorType_UNKNOWN =255
}tensor_type_t;
typedefstruct{
intsize; // 维度数量
int data[MAX_TENSOR_DIMS]; // 各维度的大小
}tensor_dims_t;
// 单个张量的完整信息
typedefstruct{
intindex; // 张量索引
tensor_dims_t dims; // 维度信息
tensor_type_t type; // 数据类型
uint8_t* data; // 数据指针
size_t size_bytes; // 数据大小(字节)
constchar* name; // 张量名称(可选)
}tensor_info_t;
// 多张量信息结构
typedefstruct{
intcount; // 张量数量
tensor_info_t tensors[MAX_INPUT_TENSORS]; // 张量信息数组
}multi_tensor_info_t;
// =============================================================================
// 模型统计信息
// =============================================================================
typedefstruct{
size_t arena_used_bytes; // 已使用的内存
size_t arena_total_bytes; // 总内存大小
int input_count; // 输入张量数量
int output_count; // 输出张量数量
constchar* model_name; // 模型名称
}model_stats_t;
// =============================================================================
// 核心接口声明
// =============================================================================
// 模型生命周期管理
status_tMODEL_Init(void);
status_tMODEL_Deinit(void);
status_tMODEL_RunInference(void);
// 模型信息查询
intMODEL_GetInputTensorCount(void);
intMODEL_GetOutputTensorCount(void);
status_tMODEL_GetModelStats(model_stats_t* stats);
constchar*MODEL_GetModelName(void);
// 单张量操作接口
uint8_t*MODEL_GetInputTensorData(intindex, tensor_dims_t* dims,tensor_type_t* type);
uint8_t*MODEL_GetOutputTensorData(intindex, tensor_dims_t* dims,tensor_type_t* type);
// 增强的单张量接口
status_tMODEL_GetInputTensorInfo(intindex, tensor_info_t* info);
status_tMODEL_GetOutputTensorInfo(intindex, tensor_info_t* info);
// 批量操作接口
status_t MODEL_GetAllInputTensors(multi_tensor_info_t* input_info);
status_t MODEL_GetAllOutputTensors(multi_tensor_info_t* output_info);
// 数据预处理接口
status_t MODEL_ConvertInput(inttensor_index,uint8_t* data,
const tensor_dims_t* dims,tensor_type_ttype);
// 工具函数
size_t MODEL_GetTensorSizeBytes(consttensor_dims_t* dims,tensor_type_ttype);
constchar* MODEL_GetTensorTypeName(tensor_type_ttype);
status_t MODEL_ValidateTensorDims(consttensor_dims_t* dims);
#endif// MODEL_H 核心实现:从设计到代码
接下来,进入到实际实现部分。为了提高代码可读性,整体实现拆分为以下模块:
全局变量与初始化
内部工具函数
生命周期管理(Init / Deinit / Invoke)
全局变量和初始化:
#include "tensorflow/lite/micro/kernels/micro_ops.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/micro/micro_op_resolver.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include"fsl_debug_console.h"
#include"model.h"
#include"model_data.h"
// =============================================================================
// 全局变量
// =============================================================================
staticconsttflite::Model* s_model =nullptr;
statictflite::MicroInterpreter* s_interpreter =nullptr;
staticbools_model_initialized =false;
// 张量内存区域 - 根据具体模型调整大小
staticuint8_ts_tensorArena[kTensorArenaSize] __ALIGNED(16);
// 外部函数声明
externtflite::MicroOpResolver &MODEL_GetOpsResolver();
// =============================================================================
// 内部辅助函数
// =============================================================================
// 获取数据类型的字节大小
staticsize_tGetTypeSize(tensor_type_ttype)
{
switch(type) {
case kTensorType_FLOAT32:
case kTensorType_INT32:
return 4;
case kTensorType_UINT8:
case kTensorType_INT8:
case kTensorType_BOOL:
return 1;
default:
return 0;
}
}
// TensorFlow Lite类型转换为我们的类型
statictensor_type_tConvertTfLiteType(TfLiteType tf_type)
{
switch(tf_type) {
case kTfLiteFloat32:
return kTensorType_FLOAT32;
case kTfLiteUInt8:
return kTensorType_UINT8;
case kTfLiteInt8:
return kTensorType_INT8;
case kTfLiteInt32:
return kTensorType_INT32;
case kTfLiteBool:
return kTensorType_BOOL;
default:
return kTensorType_UNKNOWN;
}
}
// 从TensorFlow Lite张量提取信息
staticstatus_tExtractTensorInfo(TfLiteTensor* tf_tensor, intindex,tensor_info_t* info)
{
if(tf_tensor == nullptr|| info ==nullptr) {
return kStatus_InvalidParam;
}
// 基本信息
info->index = index;
info->type = ConvertTfLiteType(tf_tensor->type);
info->data = tf_tensor->data.uint8;
if (info->type == kTensorType_UNKNOWN) {
PRINTF("Unsupported tensor type: %d
", tf_tensor->type);
return kStatus_Fail;
}
// 维度信息
info->dims.size = tf_tensor->dims->size;
if (info->dims.size > MAX_TENSOR_DIMS) {
PRINTF("Tensor dimensions exceed maximum: %d > %d
",
info->dims.size, MAX_TENSOR_DIMS);
return kStatus_OutOfRange;
}
size_t total_elements =1;
for(inti =0; i < info->dims.size; i ) {
info->dims.data[i] = tf_tensor->dims->data[i];
total_elements *= info->dims.data[i];
}
// 计算数据大小
info->size_bytes = total_elements *GetTypeSize(info->type);
// 张量名称(如果可用)
info->name = nullptr; // TensorFlow Lite Micro通常不保存名称
return kStatus_Success;
} 模型生命周期管理
这部分主要包括:
模型初始化(加载模型 / 创建解释器 / 分配张量内存)
模型反初始化
执行推理(Invoke)
//
模型生命周期管理
//
status_tMODEL_Init(void)
{
if (s_model_initialized) {
PRINTF("Model already initialized
");
return kStatus_Success;
}
// 加载模型
s_model= tflite::GetModel(model_data);
if (s_model->version()!=TFLITE_SCHEMA_VERSION) {
PRINTF("Model schema version %d not supported (expected %d)
",
s_model->version(),TFLITE_SCHEMA_VERSION);
return kStatus_Fail;
}
// 获取操作解析器
tflite::MicroOpResolverµ_op_resolver= MODEL_GetOpsResolver();
// 创建解释器
static tflite::MicroInterpreterstatic_interpreter(
s_model, micro_op_resolver, s_tensorArena, kTensorArenaSize);
s_interpreter= &static_interpreter;
// 分配张量内存
TfLiteStatus allocate_status=s_interpreter->AllocateTensors();
if (allocate_status!=kTfLiteOk) {
PRINTF("AllocateTensors() failed with status: %d
", allocate_status);
return kStatus_InsufficientMemory;
}
s_model_initialized=true;
// 打印模型信息
PRINTF("Model '%s' initialized successfully:
", MODEL_GetModelName());
PRINTF("- Input tensors: %d
", s_interpreter->inputs_size());
PRINTF("- Output tensors: %d
", s_interpreter->outputs_size());
PRINTF("- Arena used: %zu bytes
", s_interpreter->arena_used_bytes());
return kStatus_Success;
}
status_tMODEL_Deinit(void)
{
if (!s_model_initialized) {
return kStatus_NotInitialized;
}
// TensorFlow Lite Micro使用静态内存,无需显式释放
s_model= nullptr;
s_interpreter= nullptr;
s_model_initialized=false;
PRINTF("Model deinitialized
");
return kStatus_Success;
}
status_tMODEL_RunInference(void)
{
if (!s_model_initialized||s_interpreter==nullptr) {
PRINTF("Model not initialized
");
return kStatus_NotInitialized;
}
TfLiteStatus invoke_status=s_interpreter->Invoke();
if (invoke_status!=kTfLiteOk) {
PRINTF("Model inference failed with status: %d
", invoke_status);
return kStatus_Fail;
}
return kStatus_Success;
}
信息查询接口
包含:
输入/输出张量数量查询
模型统计信息读取
模型名称查询
//
模型信息查询
//
intMODEL_GetInputTensorCount(void)
{
if (!s_model_initialized || s_interpreter ==nullptr) {
return0;
}
return s_interpreter->inputs_size();
}
intMODEL_GetOutputTensorCount(void)
{
if (!s_model_initialized || s_interpreter ==nullptr) {
return0;
}
return s_interpreter->outputs_size();
}
status_t MODEL_GetModelStats(model_stats_t* stats)
{
if(stats == nullptr) {
return kStatus_InvalidParam;
}
if (!s_model_initialized || s_interpreter ==nullptr) {
return kStatus_NotInitialized;
}
stats->arena_used_bytes = s_interpreter->arena_used_bytes();
stats->arena_total_bytes = kTensorArenaSize;
stats->input_count = s_interpreter->inputs_size();
stats->output_count = s_interpreter->outputs_size();
stats->model_name =MODEL_GetModelName();
return kStatus_Success;
}
constchar*MODEL_GetModelName(void)
{
return MODEL_NAME;
}
下期预告
由于篇幅有限,本篇重点展示了:
头文件设计:类型安全、结构清晰
核心实现框架:生命周期管理 内部工具函数
基本模型信息查询接口
在下一篇(系列最终章)中,我们将重点讲解:
张量数据访问接口(Input/Output Data APIs)完整实现
批量张量操作的高效实现方案
更实际的代码示例与最佳实践
关注
5200文章
20511浏览量
334928免责声明:本文为转载,非本网原创内容,不代表本网观点。其原创性以及文中陈述文字和内容未经本站证实,对本文以及其中全部或者部分内容、文字的真实性、完整性、及时性本站不作任何保证或承诺,请读者仅作参考,并请自行核实相关内容。
如有疑问请发送邮件至:bangqikeconnect@gmail.com