import argparse import ctypes import enum import os import threading from typing import Optional, Sequence, Tuple import numpy as np # Define constants from the header CPU0 = (1 << 0) # 0x01 CPU1 = (1 << 1) # 0x02 CPU2 = (1 << 2) # 0x04 CPU3 = (1 << 3) # 0x08 CPU4 = (1 << 4) # 0x10 CPU5 = (1 << 5) # 0x20 CPU6 = (1 << 6) # 0x40 CPU7 = (1 << 7) # 0x80 # --- Enums --- class LLMCallState(enum.IntEnum): RKLLM_RUN_NORMAL = 0 RKLLM_RUN_WAITING = 1 RKLLM_RUN_FINISH = 2 RKLLM_RUN_ERROR = 3 class RKLLMInputType(enum.IntEnum): RKLLM_INPUT_PROMPT = 0 RKLLM_INPUT_TOKEN = 1 RKLLM_INPUT_EMBED = 2 RKLLM_INPUT_MULTIMODAL = 3 class RKLLMInferMode(enum.IntEnum): RKLLM_INFER_GENERATE = 0 RKLLM_INFER_GET_LAST_HIDDEN_LAYER = 1 RKLLM_INFER_GET_LOGITS = 2 # --- Structures --- class RKLLMExtendParam(ctypes.Structure): base_domain_id: ctypes.c_int32 embed_flash: ctypes.c_int8 enabled_cpus_num: ctypes.c_int8 enabled_cpus_mask: ctypes.c_uint32 n_batch: ctypes.c_uint8 use_cross_attn: ctypes.c_int8 reserved: ctypes.c_uint8 * 104 _fields_ = [ ("base_domain_id", ctypes.c_int32), # 基础域ID ("embed_flash", ctypes.c_int8), # 是否从闪存查询词嵌入向量(1启用,0禁用) ("enabled_cpus_num", ctypes.c_int8), # 推理启用的CPU数量 ("enabled_cpus_mask", ctypes.c_uint32), # 指示启用哪些CPU的位掩码 ("n_batch", ctypes.c_uint8), # 一次前向传播中并发处理的输入样本数,设置>1启用批量推理,默认为1 ("use_cross_attn", ctypes.c_int8), # 是否启用交叉注意力(非零启用,0禁用) ("reserved", ctypes.c_uint8 * 104) # 保留字段 ] class RKLLMParam(ctypes.Structure): model_path: ctypes.c_char_p max_context_len: ctypes.c_int32 max_new_tokens: ctypes.c_int32 top_k: ctypes.c_int32 n_keep: ctypes.c_int32 top_p: ctypes.c_float temperature: ctypes.c_float repeat_penalty: ctypes.c_float frequency_penalty: ctypes.c_float presence_penalty: ctypes.c_float mirostat: ctypes.c_int32 mirostat_tau: ctypes.c_float mirostat_eta: ctypes.c_float skip_special_token: ctypes.c_bool is_async: ctypes.c_bool img_start: ctypes.c_char_p img_end: ctypes.c_char_p img_content: ctypes.c_char_p extend_param: RKLLMExtendParam _fields_ = [ ("model_path", ctypes.c_char_p), # 模型文件路径 ("max_context_len", ctypes.c_int32), # 上下文窗口最大token数 ("max_new_tokens", ctypes.c_int32), # 最大生成新token数 ("top_k", ctypes.c_int32), # Top-K采样参数 ("n_keep", ctypes.c_int32), # 上下文窗口移动时保留的kv缓存数量 ("top_p", ctypes.c_float), # Top-P(nucleus)采样参数 ("temperature", ctypes.c_float), # 采样温度,影响token选择的随机性 ("repeat_penalty", ctypes.c_float), # 重复token惩罚 ("frequency_penalty", ctypes.c_float), # 频繁token惩罚 ("presence_penalty", ctypes.c_float), # 输入中已存在token的惩罚 ("mirostat", ctypes.c_int32), # Mirostat采样策略标志(0表示禁用) ("mirostat_tau", ctypes.c_float), # Mirostat采样Tau参数 ("mirostat_eta", ctypes.c_float), # Mirostat采样Eta参数 ("skip_special_token", ctypes.c_bool), # 是否跳过特殊token ("is_async", ctypes.c_bool), # 是否异步推理 ("img_start", ctypes.c_char_p), # 多模态输入中图像的起始位置 ("img_end", ctypes.c_char_p), # 多模态输入中图像的结束位置 ("img_content", ctypes.c_char_p), # 图像内容指针 ("extend_param", RKLLMExtendParam) # 扩展参数 ] class RKLLMLoraAdapter(ctypes.Structure): lora_adapter_path: ctypes.c_char_p lora_adapter_name: ctypes.c_char_p scale: ctypes.c_float _fields_ = [ ("lora_adapter_path", ctypes.c_char_p), ("lora_adapter_name", ctypes.c_char_p), ("scale", ctypes.c_float) ] class RKLLMEmbedInput(ctypes.Structure): embed: ctypes.POINTER(ctypes.c_float) n_tokens: ctypes.c_size_t _fields_ = [ ("embed", ctypes.POINTER(ctypes.c_float)), ("n_tokens", ctypes.c_size_t) ] class RKLLMTokenInput(ctypes.Structure): input_ids: ctypes.POINTER(ctypes.c_int32) n_tokens: ctypes.c_size_t _fields_ = [ ("input_ids", ctypes.POINTER(ctypes.c_int32)), ("n_tokens", ctypes.c_size_t) ] class RKLLMMultiModelInput(ctypes.Structure): prompt: ctypes.c_char_p image_embed: ctypes.POINTER(ctypes.c_float) n_image_tokens: ctypes.c_size_t n_image: ctypes.c_size_t image_width: ctypes.c_size_t image_height: ctypes.c_size_t _fields_ = [ ("prompt", ctypes.c_char_p), ("image_embed", ctypes.POINTER(ctypes.c_float)), ("n_image_tokens", ctypes.c_size_t), ("n_image", ctypes.c_size_t), ("image_width", ctypes.c_size_t), ("image_height", ctypes.c_size_t) ] class RKLLMCrossAttnParam(ctypes.Structure): """ 交叉注意力参数结构体 该结构体用于在解码器中执行交叉注意力时使用。 它提供编码器输出(键/值缓存)、位置索引和注意力掩码。 - encoder_k_cache必须存储在连续内存中,布局为: [num_layers][num_tokens][num_kv_heads][head_dim] - encoder_v_cache必须存储在连续内存中,布局为: [num_layers][num_kv_heads][head_dim][num_tokens] """ encoder_k_cache: ctypes.POINTER(ctypes.c_float) encoder_v_cache: ctypes.POINTER(ctypes.c_float) encoder_mask: ctypes.POINTER(ctypes.c_float) encoder_pos: ctypes.POINTER(ctypes.c_int32) num_tokens: ctypes.c_int _fields_ = [ ("encoder_k_cache", ctypes.POINTER(ctypes.c_float)), # 编码器键缓存指针(大小:num_layers * num_tokens * num_kv_heads * head_dim) ("encoder_v_cache", ctypes.POINTER(ctypes.c_float)), # 编码器值缓存指针(大小:num_layers * num_kv_heads * head_dim * num_tokens) ("encoder_mask", ctypes.POINTER(ctypes.c_float)), # 编码器注意力掩码指针(大小:num_tokens的数组) ("encoder_pos", ctypes.POINTER(ctypes.c_int32)), # 编码器token位置指针(大小:num_tokens的数组) ("num_tokens", ctypes.c_int) # 编码器序列中的token数量 ] class RKLLMPerfStat(ctypes.Structure): """ 性能统计结构体 用于保存预填充和生成阶段的性能统计信息。 """ prefill_time_ms: ctypes.c_float prefill_tokens: ctypes.c_int generate_time_ms: ctypes.c_float generate_tokens: ctypes.c_int memory_usage_mb: ctypes.c_float _fields_ = [ ("prefill_time_ms", ctypes.c_float), # 预填充阶段总耗时(毫秒) ("prefill_tokens", ctypes.c_int), # 预填充阶段处理的token数量 ("generate_time_ms", ctypes.c_float), # 生成阶段总耗时(毫秒) ("generate_tokens", ctypes.c_int), # 生成阶段处理的token数量 ("memory_usage_mb", ctypes.c_float) # 推理期间VmHWM常驻内存使用量(MB) ] class _RKLLMInputUnion(ctypes.Union): prompt_input: ctypes.c_char_p embed_input: RKLLMEmbedInput token_input: RKLLMTokenInput multimodal_input: RKLLMMultiModelInput _fields_ = [ ("prompt_input", ctypes.c_char_p), ("embed_input", RKLLMEmbedInput), ("token_input", RKLLMTokenInput), ("multimodal_input", RKLLMMultiModelInput) ] class RKLLMInput(ctypes.Structure): """ LLM输入结构体 通过联合体表示不同类型的LLM输入。 """ role: ctypes.c_char_p enable_thinking: ctypes.c_bool input_type: ctypes.c_int _union_data: _RKLLMInputUnion _fields_ = [ ("role", ctypes.c_char_p), # 消息角色:"user"(用户输入)、"tool"(函数结果) ("enable_thinking", ctypes.c_bool), # 控制Qwen3模型是否启用"思考模式" ("input_type", ctypes.c_int), # 枚举类型,指定输入类型(如prompt、token、embed、multimodal) ("_union_data", _RKLLMInputUnion) # 联合体数据 ] # Properties to make accessing union members easier @property def prompt_input(self) -> bytes: # Assuming c_char_p maps to bytes if self.input_type == RKLLMInputType.RKLLM_INPUT_PROMPT: return self._union_data.prompt_input raise AttributeError("Not a prompt input") @prompt_input.setter def prompt_input(self, value: bytes): # Assuming c_char_p maps to bytes if self.input_type == RKLLMInputType.RKLLM_INPUT_PROMPT: self._union_data.prompt_input = value else: raise AttributeError("Not a prompt input") @property def embed_input(self) -> RKLLMEmbedInput: if self.input_type == RKLLMInputType.RKLLM_INPUT_EMBED: return self._union_data.embed_input raise AttributeError("Not an embed input") @embed_input.setter def embed_input(self, value: RKLLMEmbedInput): if self.input_type == RKLLMInputType.RKLLM_INPUT_EMBED: self._union_data.embed_input = value else: raise AttributeError("Not an embed input") @property def token_input(self) -> RKLLMTokenInput: if self.input_type == RKLLMInputType.RKLLM_INPUT_TOKEN: return self._union_data.token_input raise AttributeError("Not a token input") @token_input.setter def token_input(self, value: RKLLMTokenInput): if self.input_type == RKLLMInputType.RKLLM_INPUT_TOKEN: self._union_data.token_input = value else: raise AttributeError("Not a token input") @property def multimodal_input(self) -> RKLLMMultiModelInput: if self.input_type == RKLLMInputType.RKLLM_INPUT_MULTIMODAL: return self._union_data.multimodal_input raise AttributeError("Not a multimodal input") @multimodal_input.setter def multimodal_input(self, value: RKLLMMultiModelInput): if self.input_type == RKLLMInputType.RKLLM_INPUT_MULTIMODAL: self._union_data.multimodal_input = value else: raise AttributeError("Not a multimodal input") class RKLLMLoraParam(ctypes.Structure): # For inference lora_adapter_name: ctypes.c_char_p _fields_ = [ ("lora_adapter_name", ctypes.c_char_p) ] class RKLLMPromptCacheParam(ctypes.Structure): # For inference save_prompt_cache: ctypes.c_int # bool-like prompt_cache_path: ctypes.c_char_p _fields_ = [ ("save_prompt_cache", ctypes.c_int), # bool-like ("prompt_cache_path", ctypes.c_char_p) ] class RKLLMInferParam(ctypes.Structure): mode: ctypes.c_int lora_params: ctypes.POINTER(RKLLMLoraParam) prompt_cache_params: ctypes.POINTER(RKLLMPromptCacheParam) keep_history: ctypes.c_int # bool-like _fields_ = [ ("mode", ctypes.c_int), # Enum will be passed as int, changed RKLLMInferMode to ctypes.c_int ("lora_params", ctypes.POINTER(RKLLMLoraParam)), ("prompt_cache_params", ctypes.POINTER(RKLLMPromptCacheParam)), ("keep_history", ctypes.c_int) # bool-like ] class RKLLMResultLastHiddenLayer(ctypes.Structure): hidden_states: ctypes.POINTER(ctypes.c_float) embd_size: ctypes.c_int num_tokens: ctypes.c_int _fields_ = [ ("hidden_states", ctypes.POINTER(ctypes.c_float)), ("embd_size", ctypes.c_int), ("num_tokens", ctypes.c_int) ] class RKLLMResultLogits(ctypes.Structure): logits: ctypes.POINTER(ctypes.c_float) vocab_size: ctypes.c_int num_tokens: ctypes.c_int _fields_ = [ ("logits", ctypes.POINTER(ctypes.c_float)), ("vocab_size", ctypes.c_int), ("num_tokens", ctypes.c_int) ] class RKLLMResult(ctypes.Structure): """ LLM推理结果结构体 表示LLM推理的结果,包含生成的文本、token ID、隐藏层状态、logits和性能统计。 """ text: ctypes.c_char_p token_id: ctypes.c_int32 last_hidden_layer: RKLLMResultLastHiddenLayer logits: RKLLMResultLogits perf: RKLLMPerfStat _fields_ = [ ("text", ctypes.c_char_p), # 生成的文本结果 ("token_id", ctypes.c_int32), # 生成的token ID ("last_hidden_layer", RKLLMResultLastHiddenLayer), # 最后一层的隐藏状态(如果请求的话) ("logits", RKLLMResultLogits), # 模型输出的logits ("perf", RKLLMPerfStat) # 性能统计(预填充和生成) ] # --- Typedefs --- LLMHandle = ctypes.c_void_p # --- Callback Function Type --- LLMResultCallback = ctypes.CFUNCTYPE( ctypes.c_int, # 返回类型:int,表示处理状态 ctypes.POINTER(RKLLMResult), # LLM结果指针 ctypes.c_void_p, # 用户数据指针 ctypes.c_int # LLM调用状态(LLMCallState枚举值) ) """ 回调函数类型定义 用于处理LLM结果的回调函数。 参数: - result: 指向LLM结果的指针 - userdata: 回调的用户数据指针 - state: LLM调用状态(例如:完成、错误) 返回值: - 0: 正常继续推理 - 1: 暂停推理。如果用户想要修改或干预结果(例如编辑输出、注入新提示), 返回1以暂停当前推理。稍后,使用更新的内容调用rkllm_run来恢复推理。 """ class RKLLMRuntime: def __init__(self, library_path="./librkllmrt.so"): try: self.lib = ctypes.CDLL(library_path) except OSError as e: raise OSError(f"Failed to load RKLLM library from {library_path}. " f"Ensure it's in your LD_LIBRARY_PATH or provide the full path. Error: {e}") self._setup_functions() self.llm_handle = LLMHandle() self._c_callback = None # To keep the callback object alive self._user_callback = None def _setup_functions(self): # RKLLMParam rkllm_createDefaultParam(); self.lib.rkllm_createDefaultParam.restype = RKLLMParam self.lib.rkllm_createDefaultParam.argtypes = [] # int rkllm_init(LLMHandle* handle, RKLLMParam* param, LLMResultCallback callback); self.lib.rkllm_init.restype = ctypes.c_int self.lib.rkllm_init.argtypes = [ ctypes.POINTER(LLMHandle), ctypes.POINTER(RKLLMParam), LLMResultCallback ] # int rkllm_load_lora(LLMHandle handle, RKLLMLoraAdapter* lora_adapter); self.lib.rkllm_load_lora.restype = ctypes.c_int self.lib.rkllm_load_lora.argtypes = [LLMHandle, ctypes.POINTER(RKLLMLoraAdapter)] # int rkllm_load_prompt_cache(LLMHandle handle, const char* prompt_cache_path); self.lib.rkllm_load_prompt_cache.restype = ctypes.c_int self.lib.rkllm_load_prompt_cache.argtypes = [LLMHandle, ctypes.c_char_p] # int rkllm_release_prompt_cache(LLMHandle handle); self.lib.rkllm_release_prompt_cache.restype = ctypes.c_int self.lib.rkllm_release_prompt_cache.argtypes = [LLMHandle] # int rkllm_destroy(LLMHandle handle); self.lib.rkllm_destroy.restype = ctypes.c_int self.lib.rkllm_destroy.argtypes = [LLMHandle] # int rkllm_run(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata); self.lib.rkllm_run.restype = ctypes.c_int self.lib.rkllm_run.argtypes = [ LLMHandle, ctypes.POINTER(RKLLMInput), ctypes.POINTER(RKLLMInferParam), ctypes.c_void_p # userdata ] # int rkllm_run_async(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata); # Assuming async also takes userdata for the callback context self.lib.rkllm_run_async.restype = ctypes.c_int self.lib.rkllm_run_async.argtypes = [ LLMHandle, ctypes.POINTER(RKLLMInput), ctypes.POINTER(RKLLMInferParam), ctypes.c_void_p # userdata ] # int rkllm_abort(LLMHandle handle); self.lib.rkllm_abort.restype = ctypes.c_int self.lib.rkllm_abort.argtypes = [LLMHandle] # int rkllm_is_running(LLMHandle handle); self.lib.rkllm_is_running.restype = ctypes.c_int # 0 if running, non-zero otherwise self.lib.rkllm_is_running.argtypes = [LLMHandle] # int rkllm_clear_kv_cache(LLMHandle handle, int keep_system_prompt, int* start_pos, int* end_pos); self.lib.rkllm_clear_kv_cache.restype = ctypes.c_int self.lib.rkllm_clear_kv_cache.argtypes = [ LLMHandle, ctypes.c_int, ctypes.POINTER(ctypes.c_int), # start_pos ctypes.POINTER(ctypes.c_int) # end_pos ] # int rkllm_get_kv_cache_size(LLMHandle handle, int* cache_sizes); self.lib.rkllm_get_kv_cache_size.restype = ctypes.c_int self.lib.rkllm_get_kv_cache_size.argtypes = [LLMHandle, ctypes.POINTER(ctypes.c_int)] # int rkllm_set_chat_template(LLMHandle handle, const char* system_prompt, const char* prompt_prefix, const char* prompt_postfix); self.lib.rkllm_set_chat_template.restype = ctypes.c_int self.lib.rkllm_set_chat_template.argtypes = [ LLMHandle, ctypes.c_char_p, ctypes.c_char_p, ctypes.c_char_p ] # int rkllm_set_function_tools(LLMHandle handle, const char* system_prompt, const char* tools, const char* tool_response_str); self.lib.rkllm_set_function_tools.restype = ctypes.c_int self.lib.rkllm_set_function_tools.argtypes = [ LLMHandle, ctypes.c_char_p, # system_prompt ctypes.c_char_p, # tools ctypes.c_char_p # tool_response_str ] # int rkllm_set_cross_attn_params(LLMHandle handle, RKLLMCrossAttnParam* cross_attn_params); self.lib.rkllm_set_cross_attn_params.restype = ctypes.c_int self.lib.rkllm_set_cross_attn_params.argtypes = [LLMHandle, ctypes.POINTER(RKLLMCrossAttnParam)] def create_default_param(self) -> RKLLMParam: """Creates a default RKLLMParam structure.""" return self.lib.rkllm_createDefaultParam() def init(self, param: RKLLMParam, callback_func) -> int: """ Initializes the LLM. :param param: RKLLMParam structure. :param callback_func: A Python function that matches the signature: def my_callback(result_ptr, userdata_ptr, state_enum): result = result_ptr.contents # RKLLMResult # Process result # userdata can be retrieved if passed during run, or ignored # state = LLMCallState(state_enum) :return: 0 for success, non-zero for failure. """ if not callable(callback_func): raise ValueError("callback_func must be a callable Python function.") self._user_callback = callback_func # Keep a reference to the ctypes callback object to prevent it from being garbage collected. # Always register a trampoline so we can swap the Python-level handler when needed. self._c_callback = LLMResultCallback(self._callback_trampoline) ret = self.lib.rkllm_init(ctypes.byref(self.llm_handle), ctypes.byref(param), self._c_callback) if ret != 0: raise RuntimeError(f"rkllm_init failed with error code {ret}") return ret def load_lora(self, lora_adapter: RKLLMLoraAdapter) -> int: """Loads a Lora adapter.""" ret = self.lib.rkllm_load_lora(self.llm_handle, ctypes.byref(lora_adapter)) if ret != 0: raise RuntimeError(f"rkllm_load_lora failed with error code {ret}") return ret def load_prompt_cache(self, prompt_cache_path: str) -> int: """Loads a prompt cache from a file.""" c_path = prompt_cache_path.encode('utf-8') ret = self.lib.rkllm_load_prompt_cache(self.llm_handle, c_path) if ret != 0: raise RuntimeError(f"rkllm_load_prompt_cache failed for {prompt_cache_path} with error code {ret}") return ret def release_prompt_cache(self) -> int: """Releases the prompt cache from memory.""" ret = self.lib.rkllm_release_prompt_cache(self.llm_handle) if ret != 0: raise RuntimeError(f"rkllm_release_prompt_cache failed with error code {ret}") return ret def destroy(self) -> int: """Destroys the LLM instance and releases resources.""" if self.llm_handle and self.llm_handle.value: # Check if handle is not NULL ret = self.lib.rkllm_destroy(self.llm_handle) self.llm_handle = LLMHandle() # Reset handle if ret != 0: # Don't raise here as it might be called in __del__ print(f"Warning: rkllm_destroy failed with error code {ret}") return ret return 0 # Already destroyed or not initialized def run(self, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata=None) -> int: """Runs an LLM inference task synchronously.""" # userdata can be a ctypes.py_object if you want to pass Python objects, # then cast to c_void_p. Or simply None. if userdata is not None: # Store the userdata object to keep it alive during the call self._userdata_ref = userdata c_userdata = ctypes.cast(ctypes.pointer(ctypes.py_object(userdata)), ctypes.c_void_p) else: c_userdata = None ret = self.lib.rkllm_run(self.llm_handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), c_userdata) if ret != 0: raise RuntimeError(f"rkllm_run failed with error code {ret}") return ret def run_async(self, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata=None) -> int: """Runs an LLM inference task asynchronously.""" if userdata is not None: # Store the userdata object to keep it alive during the call self._userdata_ref = userdata c_userdata = ctypes.cast(ctypes.pointer(ctypes.py_object(userdata)), ctypes.c_void_p) else: c_userdata = None ret = self.lib.rkllm_run_async(self.llm_handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), c_userdata) if ret != 0: raise RuntimeError(f"rkllm_run_async failed with error code {ret}") return ret def abort(self) -> int: """Aborts an ongoing LLM task.""" ret = self.lib.rkllm_abort(self.llm_handle) if ret != 0: raise RuntimeError(f"rkllm_abort failed with error code {ret}") return ret def is_running(self) -> bool: """Checks if an LLM task is currently running. Returns True if running.""" # The C API returns 0 if running, non-zero otherwise. # This is a bit counter-intuitive for a boolean "is_running". return self.lib.rkllm_is_running(self.llm_handle) == 0 def clear_kv_cache(self, keep_system_prompt: bool, start_pos: list = None, end_pos: list = None) -> int: """ 清除键值缓存 此函数用于清除部分或全部KV缓存。 参数: - keep_system_prompt: 是否在缓存中保留系统提示(True保留,False清除) 如果提供了特定范围[start_pos, end_pos),此标志将被忽略 - start_pos: 要清除的KV缓存范围的起始位置数组(包含),每个批次一个 - end_pos: 要清除的KV缓存范围的结束位置数组(不包含),每个批次一个 如果start_pos和end_pos都设置为None,将清除整个缓存,keep_system_prompt将生效 如果start_pos[i] < end_pos[i],只有指定的范围会被清除,keep_system_prompt将被忽略 注意:start_pos或end_pos只有在keep_history == 0且生成已通过在回调中返回1暂停时才有效 返回:0表示缓存清除成功,非零表示失败 """ # 准备C数组参数 c_start_pos = None c_end_pos = None if start_pos is not None and end_pos is not None: if len(start_pos) != len(end_pos): raise ValueError("start_pos和end_pos数组长度必须相同") # 创建C数组 c_start_pos = (ctypes.c_int * len(start_pos))(*start_pos) c_end_pos = (ctypes.c_int * len(end_pos))(*end_pos) ret = self.lib.rkllm_clear_kv_cache( self.llm_handle, ctypes.c_int(1 if keep_system_prompt else 0), c_start_pos, c_end_pos ) if ret != 0: raise RuntimeError(f"rkllm_clear_kv_cache失败,错误代码:{ret}") return ret def set_chat_template(self, system_prompt: str, prompt_prefix: str, prompt_postfix: str) -> int: """Sets the chat template for the LLM.""" c_system = system_prompt.encode('utf-8') if system_prompt else b"" c_prefix = prompt_prefix.encode('utf-8') if prompt_prefix else b"" c_postfix = prompt_postfix.encode('utf-8') if prompt_postfix else b"" ret = self.lib.rkllm_set_chat_template(self.llm_handle, c_system, c_prefix, c_postfix) if ret != 0: raise RuntimeError(f"rkllm_set_chat_template failed with error code {ret}") return ret def get_kv_cache_size(self, n_batch: int) -> list: """ 获取给定LLM句柄的键值缓存当前大小 此函数返回当前存储在模型KV缓存中的位置总数。 参数: - n_batch: 批次数量,用于确定返回数组的大小 返回: - list: 每个批次的缓存大小列表 """ # 预分配数组以存储每个批次的缓存大小 cache_sizes = (ctypes.c_int * n_batch)() ret = self.lib.rkllm_get_kv_cache_size(self.llm_handle, cache_sizes) if ret != 0: raise RuntimeError(f"rkllm_get_kv_cache_size失败,错误代码:{ret}") # 转换为Python列表 return [cache_sizes[i] for i in range(n_batch)] def set_function_tools(self, system_prompt: str, tools: str, tool_response_str: str) -> int: """ 为LLM设置函数调用配置,包括系统提示、工具定义和工具响应token 参数: - system_prompt: 定义语言模型上下文或行为的系统提示 - tools: JSON格式的字符串,定义可用的函数,包括它们的名称、描述和参数 - tool_response_str: 用于识别对话中函数调用结果的唯一标签。它作为标记标签, 允许分词器将工具输出与正常对话轮次分开识别 返回:0表示配置设置成功,非零表示错误 """ c_system = system_prompt.encode('utf-8') if system_prompt else b"" c_tools = tools.encode('utf-8') if tools else b"" c_tool_response = tool_response_str.encode('utf-8') if tool_response_str else b"" ret = self.lib.rkllm_set_function_tools(self.llm_handle, c_system, c_tools, c_tool_response) if ret != 0: raise RuntimeError(f"rkllm_set_function_tools失败,错误代码:{ret}") return ret def set_cross_attn_params(self, cross_attn_params: RKLLMCrossAttnParam) -> int: """ 为LLM解码器设置交叉注意力参数 参数: - cross_attn_params: 包含用于交叉注意力的编码器相关输入数据的结构体 (详见RKLLMCrossAttnParam说明) 返回:0表示参数设置成功,非零表示错误 """ ret = self.lib.rkllm_set_cross_attn_params(self.llm_handle, ctypes.byref(cross_attn_params)) if ret != 0: raise RuntimeError(f"rkllm_set_cross_attn_params失败,错误代码:{ret}") return ret def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.destroy() def __del__(self): self.destroy() # Ensure resources are freed if object is garbage collected def _callback_trampoline(self, result_ptr, userdata_ptr, state_enum): """ Bridge callback that forwards to the currently active Python handler. This keeps the C callback pointer stable while allowing per-call overrides. """ handler = self._user_callback if handler is None: return 0 try: return handler(result_ptr, userdata_ptr, state_enum) except Exception as exc: # Avoid propagating exceptions through the C callback boundary. print(f"[rkllm_binding] Callback raised an exception: {exc}") return 0 def forward_embed( self, embeds: np.ndarray, *, keep_history: bool = False, timeout: Optional[float] = None, return_last_only: bool = False, ) -> np.ndarray: """ Run a single forward pass with embedding input and return the last hidden layer. Args: embeds: Float32 embeddings shaped (T, H) or (1, T, H). Batch>1 is not supported. keep_history: When False, KV cache will be cleared after the call. When True, cache is kept; call clear_kv_cache() manually if needed. timeout: Optional timeout (seconds) for waiting on the callback. return_last_only: If True, return the last token vector shape (H,). Returns: np.ndarray containing hidden states (T, H) or the last token (H,). """ if embeds is None: raise ValueError("embeds must not be None.") np_embeds = np.asarray(embeds, dtype=np.float32) if np_embeds.ndim == 3: if np_embeds.shape[0] != 1: raise ValueError("Only batch size 1 is supported for forward_embed.") num_tokens = np_embeds.shape[1] flat = np_embeds.reshape(-1) elif np_embeds.ndim == 2: num_tokens = np_embeds.shape[0] flat = np_embeds.reshape(-1) else: raise ValueError("embeds must have shape (T, H) or (1, T, H).") flat = np.ascontiguousarray(flat, dtype=np.float32) embed_buffer = (ctypes.c_float * flat.size)(*flat) rk_input = RKLLMInput() rk_input.input_type = RKLLMInputType.RKLLM_INPUT_EMBED embed_input = RKLLMEmbedInput() embed_input.embed = embed_buffer embed_input.n_tokens = num_tokens rk_input._union_data.embed_input = embed_input infer_params = RKLLMInferParam() infer_params.mode = RKLLMInferMode.RKLLM_INFER_GET_LAST_HIDDEN_LAYER infer_params.keep_history = 1 if keep_history else 0 infer_params.lora_params = None infer_params.prompt_cache_params = None done = threading.Event() result_holder = {"hidden": None, "error": None} def _capture_hidden(result_ptr, userdata_ptr, state_enum): state = LLMCallState(state_enum) if state == LLMCallState.RKLLM_RUN_ERROR: result_holder["error"] = "RKLLM reported an error state." done.set() return 0 if not result_ptr: result_holder["error"] = "Empty result pointer received." done.set() return 0 result = result_ptr.contents if result.last_hidden_layer.hidden_states and result.last_hidden_layer.embd_size > 0: hidden = np.ctypeslib.as_array( result.last_hidden_layer.hidden_states, shape=(1, result.last_hidden_layer.num_tokens, result.last_hidden_layer.embd_size), ).copy() result_holder["hidden"] = hidden[-1].copy() if return_last_only else hidden done.set() return 1 # Pause further work; we already have the hidden states. if state == LLMCallState.RKLLM_RUN_FINISH: done.set() return 0 previous_callback = self._user_callback self._user_callback = _capture_hidden try: self.run(rk_input, infer_params) if not done.wait(timeout): raise TimeoutError("forward_embed timed out waiting for hidden states.") finally: self._user_callback = previous_callback if result_holder["error"]: raise RuntimeError(result_holder["error"]) if result_holder["hidden"] is None: raise RuntimeError("forward_embed did not receive hidden states.") try: if not keep_history: self.clear_kv_cache(True) except Exception: # Cache clearing best-effort; keep the forward result usable even if clearing fails. pass return result_holder["hidden"] # --- Demo CLI --- def _cli_parse_arguments() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Demo application showcasing rkllm_binding usage." ) parser.add_argument( "model", help="Path to the .rkllm model file used for inference." ) parser.add_argument( "--lib", default="./librkllmrt.so", help="Path to librkllmrt.so. Defaults to ./librkllmrt.so." ) # Core generation parameters parser.add_argument("--max-context-len", type=int, default=512, help="Maximum context length.") parser.add_argument("--max-new-tokens", type=int, default=128, help="Maximum number of new tokens to generate.") parser.add_argument("--top-k", type=int, default=1, help="Top-K sampling parameter.") parser.add_argument("--top-p", type=float, default=0.0, help="Top-P (nucleus) sampling parameter.") parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature.") parser.add_argument("--repeat-penalty", type=float, default=1.1, help="Penalty applied to repeated tokens.") parser.add_argument("--n-keep", type=int, default=0, help="Number of tokens to keep when context slides.") parser.add_argument("--mirostat", type=int, default=0, help="Enable Mirostat sampling (0 disables).") parser.add_argument("--mirostat-tau", type=float, default=5.0, help="Mirostat tau parameter.") parser.add_argument("--mirostat-eta", type=float, default=0.1, help="Mirostat eta parameter.") parser.add_argument( "--skip-special-token", action="store_true", help="Skip special tokens when generating output." ) # Input management parser.add_argument( "--input-type", choices=("prompt", "token", "multimodal"), default="prompt", help="Select prompt, raw token, or multimodal (image + prompt) input." ) parser.add_argument("--prompt", help="Prompt text to send to the model.") parser.add_argument("--prompt-file", help="Path to a UTF-8 text file containing the prompt.") parser.add_argument( "--token-ids", type=int, nargs="+", help="Raw token IDs (space separated). Only valid when --input-type token." ) parser.add_argument("--role", default="user", help="Role metadata for the input message (e.g., user/system).") parser.add_argument( "--enable-thinking", action="store_true", help="Enable thinking mode for supported models." ) parser.add_argument("--image", help="Path to an image file used when --input-type multimodal.") parser.add_argument("--vision-encoder", help="Path to the ONNX vision encoder model.") parser.add_argument( "--encoder-provider", help="Comma separated ONNX Runtime providers (e.g., 'CPUExecutionProvider')." ) parser.add_argument( "--encoder-threads", type=int, help="Thread count hint for ONNX Runtime session." ) parser.add_argument( "--encoder-input-shape", help="Override encoder input spatial size as HxW or H,W (e.g., 392x392)." ) parser.add_argument( "--norm", choices=("imagenet", "divide_255", "divide_128_sub_1"), default="imagenet", help="Image normalization preset." ) parser.add_argument( "--norm-mean", type=float, nargs=3, metavar=("R", "G", "B"), help="Override normalization mean (RGB order)." ) parser.add_argument( "--norm-std", type=float, nargs=3, metavar=("R", "G", "B"), help="Override normalization std (RGB order)." ) parser.add_argument( "--image-background", type=int, nargs=3, metavar=("R", "G", "B"), default=(128, 128, 128), help="Background color used when padding image to target size." ) parser.add_argument("--img-start-token", help="Override image start token string passed to the model.") parser.add_argument("--img-end-token", help="Override image end token string passed to the model.") parser.add_argument("--img-content-token", help="Override image content token string passed to the model.") # Inference options parser.add_argument( "--mode", choices=("generate", "hidden", "logits"), default="generate", help="Inference mode: generate tokens, return last hidden layer, or logits." ) parser.add_argument( "--no-keep-history", action="store_true", help="Do not keep dialogue history on the device." ) # Output options parser.add_argument( "--stream", action="store_true", default=True, help="Stream tokens to stdout as they arrive from the callback." ) parser.add_argument( "--hide-stats", action="store_true", help="Suppress performance statistics after inference." ) args = parser.parse_args() if args.prompt and args.prompt_file: parser.error("Arguments --prompt and --prompt-file cannot be used together.") if args.input_type == "prompt": if not args.prompt and not args.prompt_file: parser.error("Provide --prompt or --prompt-file when --input-type is prompt.") if args.token_ids: parser.error("--token-ids is only valid when --input-type token.") elif args.input_type == "token": if not args.token_ids: parser.error("--token-ids is required when --input-type token.") if args.prompt or args.prompt_file: parser.error("--prompt/--prompt-file cannot be combined with --input-type token.") else: # multimodal if args.token_ids: parser.error("--token-ids cannot be used with --input-type multimodal.") if not args.prompt and not args.prompt_file: parser.error("Provide --prompt or --prompt-file when --input-type is multimodal.") if not args.image: parser.error("--image is required when --input-type multimodal.") if not args.vision_encoder: parser.error("--vision-encoder is required when --input-type multimodal.") if args.image_background: for component in args.image_background: if component < 0 or component > 255: parser.error("--image-background values must be in the range [0, 255].") return args def _load_prompt_from_args(args: argparse.Namespace) -> str: if args.prompt: return args.prompt if args.prompt_file: try: with open(args.prompt_file, "r", encoding="utf-8") as fp: return fp.read() except OSError as exc: raise RuntimeError(f"Failed to read prompt file '{args.prompt_file}': {exc}") from exc raise RuntimeError("Prompt text is required but not provided.") def _mode_to_enum(mode: str) -> int: mapping = { "generate": RKLLMInferMode.RKLLM_INFER_GENERATE, "hidden": RKLLMInferMode.RKLLM_INFER_GET_LAST_HIDDEN_LAYER, "logits": RKLLMInferMode.RKLLM_INFER_GET_LOGITS, } return mapping[mode] def _parse_hw_string(value: str) -> Tuple[int, int]: separators = ("x", "X", ",", " ") token = value.strip() for sep in separators: if sep in token: parts = [p for p in token.split(sep) if p] break else: parts = [token] if len(parts) != 2: raise ValueError(f"Unable to parse height/width from '{value}'. Expected format like 392x392.") try: height = int(parts[0]) width = int(parts[1]) except ValueError as exc: raise ValueError(f"Height/width must be integers, got '{value}'.") from exc if height <= 0 or width <= 0: raise ValueError("Height and width must be positive integers.") return height, width def _infer_hw_from_onnx_shape(shape: Sequence) -> Tuple[Optional[int], Optional[int]]: if shape is None or len(shape) < 4: return None, None height = shape[-2] width = shape[-1] if isinstance(height, str) or height is None: height = None if isinstance(width, str) or width is None: width = None return height, width def _parse_providers(provider_str: Optional[str]) -> Optional[list]: if not provider_str: return None providers = [item.strip() for item in provider_str.split(",") if item.strip()] return providers or None def _load_vision_encoder_session(encoder_path: str, providers: Optional[list], threads: Optional[int]): try: import onnxruntime as ort except ImportError as exc: raise RuntimeError("onnxruntime is required for multimodal inference. Please install onnxruntime.") from exc sess_options = ort.SessionOptions() if threads and threads > 0: sess_options.intra_op_num_threads = threads try: if providers: session = ort.InferenceSession(encoder_path, sess_options=sess_options, providers=providers) else: session = ort.InferenceSession(encoder_path, sess_options=sess_options) except Exception as exc: raise RuntimeError(f"Failed to load vision encoder '{encoder_path}': {exc}") from exc return session def _letterbox_resize(image, target_hw: Tuple[int, int], background_color: Sequence[int]): try: import cv2 import numpy as np except ImportError as exc: raise RuntimeError("OpenCV (cv2) and numpy are required for multimodal preprocessing.") from exc target_h, target_w = target_hw if image.ndim != 3 or image.shape[2] != 3: raise RuntimeError("Expected RGB image with 3 channels.") src_h, src_w = image.shape[:2] if src_h == 0 or src_w == 0: raise RuntimeError("Loaded image has invalid dimensions.") scale = min(target_w / src_w, target_h / src_h) resized_w = max(1, int(round(src_w * scale))) resized_h = max(1, int(round(src_h * scale))) resized = cv2.resize(image, (resized_w, resized_h), interpolation=cv2.INTER_LINEAR) canvas = np.full((target_h, target_w, 3), background_color, dtype=resized.dtype) top = (target_h - resized_h) // 2 left = (target_w - resized_w) // 2 canvas[top:top + resized_h, left:left + resized_w] = resized return canvas, resized_h, resized_w def _normalize_image(image, method: str, mean: Optional[Sequence[float]], std: Optional[Sequence[float]]): import numpy as np img = image.astype(np.float32) mean_arr = np.array(mean, dtype=np.float32) if mean else None std_arr = np.array(std, dtype=np.float32) if std else None if method == "imagenet": img = img / 255.0 if mean_arr is None: mean_arr = np.array([0.48145466, 0.4578275, 0.40821073], dtype=np.float32) if std_arr is None: std_arr = np.array([0.26862954, 0.26130258, 0.27577711], dtype=np.float32) img = (img - mean_arr) / std_arr elif method == "divide_255": img = img / 255.0 if mean_arr is not None: img = img - mean_arr if std_arr is not None: img = img / std_arr elif method == "divide_128_sub_1": img = img / 128.0 - 1.0 if mean_arr is not None: img = img - mean_arr if std_arr is not None: img = img / std_arr else: raise RuntimeError(f"Unsupported normalization method '{method}'.") return img def _encode_image_to_embedding( session, image_path: str, input_name: str, output_name: str, target_hw: Tuple[int, int], background_color: Sequence[int], norm_method: str, norm_mean: Optional[Sequence[float]], norm_std: Optional[Sequence[float]] ): try: import cv2 import numpy as np except ImportError as exc: raise RuntimeError("OpenCV (cv2) and numpy are required for multimodal preprocessing.") from exc image = cv2.imread(image_path, cv2.IMREAD_COLOR) if image is None: raise RuntimeError(f"Failed to read image from '{image_path}'.") image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) padded, resized_h, resized_w = _letterbox_resize(image, target_hw, background_color) normalized = _normalize_image(padded, norm_method, norm_mean, norm_std) tensor = np.transpose(normalized, (2, 0, 1)) # HWC -> CHW tensor = np.expand_dims(tensor, axis=0) # Add batch dimension tensor = np.ascontiguousarray(tensor, dtype=np.float32) try: output_list = session.run([output_name], {input_name: tensor}) except Exception as exc: raise RuntimeError(f"Vision encoder inference failed: {exc}") from exc if not output_list: raise RuntimeError("Vision encoder returned no outputs.") embedding = output_list[0] if embedding.ndim == 3: if embedding.shape[0] != 1: raise RuntimeError("Vision encoder output batch dimension must be 1 for a single image.") n_tokens = embedding.shape[1] elif embedding.ndim == 2: n_tokens = embedding.shape[0] else: raise RuntimeError(f"Unsupported vision encoder output shape {embedding.shape}.") flat_embedding = embedding.reshape(-1).astype(np.float32, copy=False) flat_embedding = np.ascontiguousarray(flat_embedding) return flat_embedding, n_tokens, target_hw if __name__ == "__main__": import os os.environ["RKLLM_LOG_LEVEL"] = "1" args = _cli_parse_arguments() prompt_text = None if args.input_type == "prompt": prompt_text = _load_prompt_from_args(args) token_id_array = None token_input_struct = None generated_chunks = [] perf_snapshot = { "prefill_tokens": 0, "prefill_time_ms": 0.0, "generate_tokens": 0, "generate_time_ms": 0.0, "memory_usage_mb": 0.0, } def demo_callback(result_ptr, userdata_ptr, state_enum): state = LLMCallState(state_enum) result = result_ptr.contents current_text = "" if result.text: current_text = result.text.decode("utf-8", errors="ignore") generated_chunks.append(current_text) if args.stream and current_text: print(current_text, end="", flush=True) perf_snapshot.update( prefill_tokens=result.perf.prefill_tokens, prefill_time_ms=result.perf.prefill_time_ms, generate_tokens=result.perf.generate_tokens, generate_time_ms=result.perf.generate_time_ms, memory_usage_mb=result.perf.memory_usage_mb, ) if state == LLMCallState.RKLLM_RUN_ERROR: print("\n[Callback] 推理过程中出现错误。") return 0 try: with RKLLMRuntime(library_path=args.lib) as rk_llm: params = rk_llm.create_default_param() params.model_path = os.path.abspath(args.model).encode("utf-8") params.max_context_len = args.max_context_len params.max_new_tokens = args.max_new_tokens params.top_k = args.top_k params.top_p = float(args.top_p) params.temperature = float(args.temperature) params.repeat_penalty = float(args.repeat_penalty) params.n_keep = args.n_keep params.mirostat = args.mirostat params.mirostat_tau = float(args.mirostat_tau) params.mirostat_eta = float(args.mirostat_eta) params.skip_special_token = bool(args.skip_special_token) params.is_async = False rk_llm.init(params, demo_callback) rk_input = RKLLMInput() rk_input.role = args.role.encode("utf-8") rk_input.enable_thinking = bool(args.enable_thinking) if args.input_type == "prompt": rk_input.input_type = RKLLMInputType.RKLLM_INPUT_PROMPT rk_input._union_data.prompt_input = prompt_text.encode("utf-8") else: rk_input.input_type = RKLLMInputType.RKLLM_INPUT_TOKEN token_id_array = (ctypes.c_int32 * len(args.token_ids))(*args.token_ids) token_input_struct = RKLLMTokenInput() token_input_struct.input_ids = token_id_array token_input_struct.n_tokens = len(args.token_ids) rk_input._union_data.token_input = token_input_struct infer_params = RKLLMInferParam() infer_params.mode = _mode_to_enum(args.mode) infer_params.keep_history = 0 if args.no_keep_history else 1 infer_params.lora_params = None infer_params.prompt_cache_params = None if args.stream: print("=== Streaming Output ===") rk_llm.run(rk_input, infer_params) except OSError as exc: print(f"无法加载 RKLLM 运行时库:{exc}") except RuntimeError as exc: print(f"推理失败:{exc}") except Exception as exc: print(f"发生未预期的错误:{exc}") else: if args.stream: print() # Ensure newline after streaming output final_text = "".join(generated_chunks) if final_text: print("=== 生成结果 ===") print(final_text) else: print("未收到生成文本。") if not args.hide_stats: print("=== 性能统计 ===") print( f"预填充: {perf_snapshot['prefill_tokens']} tokens / {perf_snapshot['prefill_time_ms']:.2f} ms" ) print( f"生成: {perf_snapshot['generate_tokens']} tokens / {perf_snapshot['generate_time_ms']:.2f} ms" ) print(f"最大常驻内存: {perf_snapshot['memory_usage_mb']:.2f} MB")