返回博客2025年2月4日

利用UnslothUnsloth训练自己的第一款R1自主推理模型

AIUnslothGRPO

Unsloth简介

Unsloth是一款实现LLM高效训练与微调的开源库,它完全兼容 HuggingFace 生态系统(包括Hub、transformers、PEFT、TRL)。该库由Unsloth团队以及开源社区积极开发维护,支持大多数NVIDIA GPU,并且可以使用TRL库中的所有训练器套件(SFTTrainer, DPOTrainer, PPOTrainer, GRPOTrainer)。

Unsloth与GRPO

Unsloth 在推理能力训练方面取得了重大突破。通过优化Group Relative Policy Optimization (GRPO)训练过程,Unsloth团队成功将显存占用降低了80%,这意味着现在只需要7GB显存就能训练出具有自主推理能力的模型。这一突破使得在消费级显卡上训练推理模型成为可能。

具体来说,如果你拥有15GB显存的GPU,就可以将任何参数量在15B以下的模型(如Llama 3.1、Phi-4、Mistral或Qwen2.5等)转化为具备推理能力的模型。更令人兴奋的是,即使只有7GB显存,你也可以使用较小的模型(如Qwen2.5 1.5B)来实现类似效果。这与之前需要双A100(160GB显存)的要求相比,是一个革命性的进步。

GRPO的另一个重要特性是它能够帮助模型自主学习分配更多的思考时间,而无需人工反馈。现在,Unsloth不仅支持完整的微调,还支持QLoRA和LoRA等参数高效微调方法。对于那些只有输入输出数据(如问题和答案),但缺少推理过程的场景,GRPO可以自动生成合理的推理链路。这使得它在法律、医疗等需要严谨推理过程的专业领域特别有价值。

关于Unsloth支持GRPO,训练R1推理模型的细节,请参考官方博客 - https://unsloth.ai/blog/r1-reasoning

Unsloth训练R1推理模型实战

1. 环境准备

首先需要安装必要的依赖包:

!pip install unsloth vllm
!pip install --upgrade pillow
!pip install git+https://github.com/huggingface/trl.git@e95f9fb74a3c3647b86f251b7e230ec51c64b72b

在开始训练之前,我们需要导入Unsloth的核心组件并进行初始化。使用 PatchFastRL 来追加 GRPO 和其他 RL 算法:

from unsloth import FastLanguageModel, PatchFastRL
PatchFastRL("GRPO", FastLanguageModel)

2. 模型加载与配置

基础参数设置

max_seq_length = 512  # 最大序列长度
lora_rank = 32        # LoRA秩,越大模型越"智能"但训练更慢

加载预训练模型

使用以下配置加载 Llama 3.1 8B Instruct 模型:

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "meta-llama/meta-Llama-3.1-8B-Instruct",
    max_seq_length = max_seq_length,
    load_in_4bit = True,           # 使用4位量化节省显存
    fast_inference = True,         # 启用vLLM加速推理
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.6,  # GPU显存使用率
)

PEFT配置

配置参数高效微调(PEFT)模型:

model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank,
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha = lora_rank,
    use_gradient_checkpointing = "unsloth",
    random_state = 3407,
)

3. 数据集准备

为了训练推理能力,我们需要准备合适的数据集和提示模板:

import re
from datasets import load_dataset, Dataset

# Load and prep dataset
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

# uncomment middle messages for 1-shot prompting
def get_gsm8k_questions(split = "train") -> Dataset:
    data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
    data = data.map(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer'])
    }) # type: ignore
    return data # type: ignore

dataset = get_gsm8k_questions()

我们使用GSM8K数据集作为训练数据,并定义了相应的数据处理函数:

def get_gsm8k_questions(split = "train") -> Dataset:
    data = load_dataset('openai/gsm8k', 'main')[split]
    data = data.map(lambda x: {
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer'])
    })
    return data

4. 奖励函数设计

为了引导模型学习正确的推理过程,我们设计了多个奖励函数:

  1. 正确性奖励 (correctness_reward_func):

    • 评估模型输出的答案是否正确
  2. 格式奖励 (strict_format_reward_func 和 soft_format_reward_func):

    • 检查输出是否符合预定义的XML格式
    • strict版本要求完全匹配,soft版本允许一定的灵活性
  3. XML结构奖励 (xmlcount_reward_func):

    • 检查XML标签的完整性和正确性
    • 对每个正确的标签给予部分分数
# Reward functions
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

def int_reward_func(completions, **kwargs)->list[float]:responses=[completion[0]['content']forcompletionincompletions]extracted_responses=[extract_xml_answer(r)forrinresponses]return[0.5ifr.isdigit()else0.0forrinextracted_responses]defstrict_format_reward_func(completions,** kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1:
        count += 0.125
    if text.count("\n</reasoning>\n") == 1:
        count += 0.125
    if text.count("\n<answer>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</answer>\n")[-1])*0.001
    if text.count("\n</answer>") == 1:
        count += 0.125
        count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

5. 训练配置

使用GRPO训练器进行模型训练,主要配置包括:

training_args = GRPOConfig(
    use_vllm = True,
    learning_rate = 5e-6,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "paged_adamw_8bit",
    per_device_train_batch_size = 1,
    num_generations = 6,
    max_steps = 250,
    output_dir = "outputs",
)

关键参数说明:

  • learning_rate:采用较小的学习率(5e-6)确保稳定训练
  • warmup_ratio:使用10%的步骤进行预热
  • lr_scheduler_type:采用余弦衰减的学习率调度策略
  • optim:使用8位量化的AdamW优化器节省内存

6. 训练过程

使用配置好的GRPO训练器开始训练:

trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        xmlcount_reward_func,
        soft_format_reward_func,
        strict_format_reward_func,
        int_reward_func,
        correctness_reward_func,
    ],
    args = training_args,
    train_dataset = dataset,
)
trainer.train()

7. 模型推理与保存

推理测试

训练完成后,我们可以使用以下代码测试模型:

text = tokenizer.apply_chat_template([
    {"role" : "user", "content" : "Calculate pi."},
], tokenize = False, add_generation_prompt = True)

from vllm import SamplingParams
sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 1024,
)
output = model.fast_generate(
    [text],
    sampling_params = sampling_params,
    lora_request = None,
)[0].outputs[0].text

模型保存

支持多种保存格式:

  1. LoRA权重保存
model.save_lora("grpo_saved_lora")
  1. 完整模型保存
  • 16位精度保存
  • 4位量化保存
  • 仅保存LoRA适配器
  1. GGUF格式转换: 支持多种量化选项:
  • q8_0:快速转换,较高资源占用
  • q4_k_m:推荐使用,平衡性能和资源
  • q5_k_m:较高精度选项

结语

以上步骤,完成了一个具有推理能力的R1模型的训练。整个过程展示了如何利用Unsloth工具链高效地训练大语言模型,以及如何通过精心设计的奖励函数来引导模型学习推理能力。

值得注意的是,训练效果与训练时间、数据质量以及奖励函数的设计都密切相关。在实际应用中,可能需要根据具体需求调整这些参数和配置。

准备开始了吗?

先简单说明目标,我会给出最合适的沟通方式。