利用UnslothUnsloth训练自己的第一款R1自主推理模型
Unsloth简介
Unsloth是一款实现LLM高效训练与微调的开源库,它完全兼容 HuggingFace 生态系统(包括Hub、transformers、PEFT、TRL)。该库由Unsloth团队以及开源社区积极开发维护,支持大多数NVIDIA GPU,并且可以使用TRL库中的所有训练器套件(SFTTrainer, DPOTrainer, PPOTrainer, GRPOTrainer)。
Unsloth与GRPO
Unsloth 在推理能力训练方面取得了重大突破。通过优化Group Relative Policy Optimization (GRPO)训练过程,Unsloth团队成功将显存占用降低了80%,这意味着现在只需要7GB显存就能训练出具有自主推理能力的模型。这一突破使得在消费级显卡上训练推理模型成为可能。
具体来说,如果你拥有15GB显存的GPU,就可以将任何参数量在15B以下的模型(如Llama 3.1、Phi-4、Mistral或Qwen2.5等)转化为具备推理能力的模型。更令人兴奋的是,即使只有7GB显存,你也可以使用较小的模型(如Qwen2.5 1.5B)来实现类似效果。这与之前需要双A100(160GB显存)的要求相比,是一个革命性的进步。
GRPO的另一个重要特性是它能够帮助模型自主学习分配更多的思考时间,而无需人工反馈。现在,Unsloth不仅支持完整的微调,还支持QLoRA和LoRA等参数高效微调方法。对于那些只有输入输出数据(如问题和答案),但缺少推理过程的场景,GRPO可以自动生成合理的推理链路。这使得它在法律、医疗等需要严谨推理过程的专业领域特别有价值。
关于Unsloth支持GRPO,训练R1推理模型的细节,请参考官方博客 - https://unsloth.ai/blog/r1-reasoning。
Unsloth训练R1推理模型实战
1. 环境准备
首先需要安装必要的依赖包:
!pip install unsloth vllm
!pip install --upgrade pillow
!pip install git+https://github.com/huggingface/trl.git@e95f9fb74a3c3647b86f251b7e230ec51c64b72b
在开始训练之前,我们需要导入Unsloth的核心组件并进行初始化。使用 PatchFastRL 来追加 GRPO 和其他 RL 算法:
from unsloth import FastLanguageModel, PatchFastRL
PatchFastRL("GRPO", FastLanguageModel)
2. 模型加载与配置
基础参数设置
max_seq_length = 512 # 最大序列长度
lora_rank = 32 # LoRA秩,越大模型越"智能"但训练更慢
加载预训练模型
使用以下配置加载 Llama 3.1 8B Instruct 模型:
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "meta-llama/meta-Llama-3.1-8B-Instruct",
max_seq_length = max_seq_length,
load_in_4bit = True, # 使用4位量化节省显存
fast_inference = True, # 启用vLLM加速推理
max_lora_rank = lora_rank,
gpu_memory_utilization = 0.6, # GPU显存使用率
)
PEFT配置
配置参数高效微调(PEFT)模型:
model = FastLanguageModel.get_peft_model(
model,
r = lora_rank,
target_modules = [
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
lora_alpha = lora_rank,
use_gradient_checkpointing = "unsloth",
random_state = 3407,
)
3. 数据集准备
为了训练推理能力,我们需要准备合适的数据集和提示模板:
import re
from datasets import load_dataset, Dataset
# Load and prep dataset
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""
XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""
def extract_xml_answer(text: str) -> str:
answer = text.split("<answer>")[-1]
answer = answer.split("</answer>")[0]
return answer.strip()
def extract_hash_answer(text: str) -> str | None:
if "####" not in text:
return None
return text.split("####")[1].strip()
# uncomment middle messages for 1-shot prompting
def get_gsm8k_questions(split = "train") -> Dataset:
data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
data = data.map(lambda x: { # type: ignore
'prompt': [
{'role': 'system', 'content': SYSTEM_PROMPT},
{'role': 'user', 'content': x['question']}
],
'answer': extract_hash_answer(x['answer'])
}) # type: ignore
return data # type: ignore
dataset = get_gsm8k_questions()
我们使用GSM8K数据集作为训练数据,并定义了相应的数据处理函数:
def get_gsm8k_questions(split = "train") -> Dataset:
data = load_dataset('openai/gsm8k', 'main')[split]
data = data.map(lambda x: {
'prompt': [
{'role': 'system', 'content': SYSTEM_PROMPT},
{'role': 'user', 'content': x['question']}
],
'answer': extract_hash_answer(x['answer'])
})
return data
4. 奖励函数设计
为了引导模型学习正确的推理过程,我们设计了多个奖励函数:
-
正确性奖励 (correctness_reward_func):
- 评估模型输出的答案是否正确
-
格式奖励 (strict_format_reward_func 和 soft_format_reward_func):
- 检查输出是否符合预定义的XML格式
- strict版本要求完全匹配,soft版本允许一定的灵活性
-
XML结构奖励 (xmlcount_reward_func):
- 检查XML标签的完整性和正确性
- 对每个正确的标签给予部分分数
# Reward functions
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
responses = [completion[0]['content'] for completion in completions]
q = prompts[0][-1]['content']
extracted_responses = [extract_xml_answer(r) for r in responses]
print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
def int_reward_func(completions, **kwargs)->list[float]:responses=[completion[0]['content']forcompletionincompletions]extracted_responses=[extract_xml_answer(r)forrinresponses]return[0.5ifr.isdigit()else0.0forrinextracted_responses]defstrict_format_reward_func(completions,** kwargs) -> list[float]:
"""Reward function that checks if the completion has a specific format."""
pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in matches]
def soft_format_reward_func(completions, **kwargs) -> list[float]:
"""Reward function that checks if the completion has a specific format."""
pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in matches]
def count_xml(text) -> float:
count = 0.0
if text.count("<reasoning>\n") == 1:
count += 0.125
if text.count("\n</reasoning>\n") == 1:
count += 0.125
if text.count("\n<answer>\n") == 1:
count += 0.125
count -= len(text.split("\n</answer>\n")[-1])*0.001
if text.count("\n</answer>") == 1:
count += 0.125
count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
return count
def xmlcount_reward_func(completions, **kwargs) -> list[float]:
contents = [completion[0]["content"] for completion in completions]
return [count_xml(c) for c in contents]
5. 训练配置
使用GRPO训练器进行模型训练,主要配置包括:
training_args = GRPOConfig(
use_vllm = True,
learning_rate = 5e-6,
warmup_ratio = 0.1,
lr_scheduler_type = "cosine",
optim = "paged_adamw_8bit",
per_device_train_batch_size = 1,
num_generations = 6,
max_steps = 250,
output_dir = "outputs",
)
关键参数说明:
- learning_rate:采用较小的学习率(5e-6)确保稳定训练
- warmup_ratio:使用10%的步骤进行预热
- lr_scheduler_type:采用余弦衰减的学习率调度策略
- optim:使用8位量化的AdamW优化器节省内存
6. 训练过程
使用配置好的GRPO训练器开始训练:
trainer = GRPOTrainer(
model = model,
processing_class = tokenizer,
reward_funcs = [
xmlcount_reward_func,
soft_format_reward_func,
strict_format_reward_func,
int_reward_func,
correctness_reward_func,
],
args = training_args,
train_dataset = dataset,
)
trainer.train()
7. 模型推理与保存
推理测试
训练完成后,我们可以使用以下代码测试模型:
text = tokenizer.apply_chat_template([
{"role" : "user", "content" : "Calculate pi."},
], tokenize = False, add_generation_prompt = True)
from vllm import SamplingParams
sampling_params = SamplingParams(
temperature = 0.8,
top_p = 0.95,
max_tokens = 1024,
)
output = model.fast_generate(
[text],
sampling_params = sampling_params,
lora_request = None,
)[0].outputs[0].text
模型保存
支持多种保存格式:
- LoRA权重保存:
model.save_lora("grpo_saved_lora")
- 完整模型保存:
- 16位精度保存
- 4位量化保存
- 仅保存LoRA适配器
- GGUF格式转换: 支持多种量化选项:
- q8_0:快速转换,较高资源占用
- q4_k_m:推荐使用,平衡性能和资源
- q5_k_m:较高精度选项
结语
以上步骤,完成了一个具有推理能力的R1模型的训练。整个过程展示了如何利用Unsloth工具链高效地训练大语言模型,以及如何通过精心设计的奖励函数来引导模型学习推理能力。
值得注意的是,训练效果与训练时间、数据质量以及奖励函数的设计都密切相关。在实际应用中,可能需要根据具体需求调整这些参数和配置。