Skip to content

模型微调最佳实践

本文档整理模型微调(Fine-tuning)的最佳实践。

微调概述

txt
┌─────────────────────────────────────────────────────┐
│                   微调决策框架                       │
├─────────────────────────────────────────────────────┤
│                                                     │
│  是否需要微调?                                     │
│       │                                             │
│       ├── 否 → 使用 Prompt Engineering / RAG        │
│       │                                             │
│       └── 是 → 选择微调方式                         │
│                 │                                   │
│                 ├── 全量微调                        │
│                 │   (所有参数更新)                  │
│                 │                                   │
│                 ├── 参数高效微调                    │
│                 │   (LoRA / QLoRA / Prefix Tuning)  │
│                 │                                   │
│                 └── 指令微调                        │
│                     (Instruction Tuning)            │
│                                                     │
└─────────────────────────────────────────────────────┘

微调决策

何时需要微调

python
from dataclasses import dataclass
from typing import List, Optional
from enum import Enum

class TaskType(Enum):
    CLASSIFICATION = "classification"
    EXTRACTION = "extraction"
    SUMMARIZATION = "summarization"
    GENERATION = "generation"
    CODE = "code"
    DOMAIN_SPECIFIC = "domain_specific"

@dataclass
class FineTuningDecision:
    needs_fine_tuning: bool
    reason: str
    recommended_approach: Optional[str]
    estimated_data_needed: Optional[int]
    estimated_cost: Optional[float]

class FineTuningAdvisor:
    """微调决策顾问"""

    def __init__(self):
        self.task_characteristics = {
            TaskType.CLASSIFICATION: {
                "prompt_engineering_success_rate": 0.85,
                "fine_tuning_success_rate": 0.95,
                "min_data_samples": 1000,
                "max_prompt_length": 512,
            },
            TaskType.EXTRACTION: {
                "prompt_engineering_success_rate": 0.75,
                "fine_tuning_success_rate": 0.92,
                "min_data_samples": 500,
                "max_prompt_length": 1024,
            },
            TaskType.SUMMARIZATION: {
                "prompt_engineering_success_rate": 0.80,
                "fine_tuning_success_rate": 0.90,
                "min_data_samples": 1000,
                "max_prompt_length": 2048,
            },
            TaskType.GENERATION: {
                "prompt_engineering_success_rate": 0.70,
                "fine_tuning_success_rate": 0.88,
                "min_data_samples": 2000,
                "max_prompt_length": 2048,
            },
            TaskType.CODE: {
                "prompt_engineering_success_rate": 0.75,
                "fine_tuning_success_rate": 0.92,
                "min_data_samples": 500,
                "max_prompt_length": 1024,
            },
            TaskType.DOMAIN_SPECIFIC: {
                "prompt_engineering_success_rate": 0.60,
                "fine_tuning_success_rate": 0.95,
                "min_data_samples": 1000,
                "max_prompt_length": 1024,
            },
        }

    def evaluate(
        self,
        task_type: TaskType,
        data_samples: int,
        prompt_engineering_tried: bool = False,
        prompt_engineering_results: Optional[float] = None,
        domain_specific_knowledge: bool = False,
        output_format_consistency: bool = True,
        cost_budget: Optional[float] = None
    ) -> FineTuningDecision:
        """评估是否需要微调"""
        chars = self.task_characteristics[task_type]

        # 判断是否需要微调
        reasons_for_fine_tuning = []
        reasons_against_fine_tuning = []

        # 数据量检查
        if data_samples < chars["min_data_samples"]:
            reasons_against_fine_tuning.append(
                f"数据量不足(需要至少 {chars['min_data_samples']} 样本,当前 {data_samples})"
            )

        # Prompt Engineering 效果检查
        if prompt_engineering_tried and prompt_engineering_results:
            if prompt_engineering_results >= 0.9:
                return FineTuningDecision(
                    needs_fine_tuning=False,
                    reason=f"Prompt Engineering 效果良好(准确率 {prompt_engineering_results:.0%})",
                    recommended_approach="prompt_engineering",
                    estimated_data_needed=0,
                    estimated_cost=0
                )
            elif prompt_engineering_results >= 0.8:
                reasons_against_fine_tuning.append(
                    "Prompt Engineering 效果尚可,可继续优化"
                )
            else:
                reasons_for_fine_tuning.append(
                    f"Prompt Engineering 效果不佳(准确率 {prompt_engineering_results:.0%})"
                )

        # 领域知识检查
        if domain_specific_knowledge:
            reasons_for_fine_tuning.append("需要领域特定知识")

        # 输出格式一致性检查
        if not output_format_consistency:
            reasons_for_fine_tuning.append("输出格式要求严格一致")

        # 成本检查
        estimated_cost = self._estimate_cost(task_type, data_samples)
        if cost_budget and estimated_cost > cost_budget:
            reasons_against_fine_tuning.append(
                f"超出预算(预计 {estimated_cost:.2f}$,预算 {cost_budget:.2f}$)"
            )

        # 最终决策
        needs_fine_tuning = (
            len(reasons_for_fine_tuning) > 0 and
            len(reasons_against_fine_tuning) == 0 and
            data_samples >= chars["min_data_samples"]
        )

        # 推荐方法
        if needs_fine_tuning:
            if data_samples >= 10000:
                recommended_approach = "full_fine_tuning"
            else:
                recommended_approach = "lora"
        else:
            recommended_approach = "prompt_engineering"

        return FineTuningDecision(
            needs_fine_tuning=needs_fine_tuning,
            reason=self._format_reason(reasons_for_fine_tuning, reasons_against_fine_tuning),
            recommended_approach=recommended_approach,
            estimated_data_needed=chars["min_data_samples"] if needs_fine_tuning else 0,
            estimated_cost=estimated_cost if needs_fine_tuning else 0
        )

    def _estimate_cost(self, task_type: TaskType, data_samples: int) -> float:
        """估算成本"""
        # 简化估算
        base_cost_per_1k_samples = {
            TaskType.CLASSIFICATION: 5.0,
            TaskType.EXTRACTION: 8.0,
            TaskType.SUMMARIZATION: 10.0,
            TaskType.GENERATION: 15.0,
            TaskType.CODE: 12.0,
            TaskType.DOMAIN_SPECIFIC: 15.0,
        }

        cost_per_1k = base_cost_per_1k_samples[task_type]
        return (data_samples / 1000) * cost_per_1k

    def _format_reason(self, for_ft: List[str], against_ft: List[str]) -> str:
        """格式化原因"""
        parts = []

        if for_ft:
            parts.append("支持微调: " + "; ".join(for_ft))

        if against_ft:
            parts.append("反对微调: " + "; ".join(against_ft))

        return " | ".join(parts)

# 使用示例
advisor = FineTuningAdvisor()

# 评估是否需要微调
decision = advisor.evaluate(
    task_type=TaskType.DOMAIN_SPECIFIC,
    data_samples=2000,
    prompt_engineering_tried=True,
    prompt_engineering_results=0.65,
    domain_specific_knowledge=True,
    output_format_consistency=True
)

print(f"需要微调: {decision.needs_fine_tuning}")
print(f"原因: {decision.reason}")
print(f"推荐方法: {decision.recommended_approach}")
print(f"预计成本: ${decision.estimated_cost:.2f}")

数据准备

数据收集与清洗

python
from dataclasses import dataclass
from typing import List, Dict, Any, Optional
import json
import random

@dataclass
class TrainingSample:
    """训练样本"""
    input_text: str
    output_text: str
    metadata: Dict[str, Any] = None

class DataPreparer:
    """数据准备器"""

    def __init__(
        self,
        min_input_length: int = 10,
        max_input_length: int = 4096,
        min_output_length: int = 10,
        max_output_length: int = 2048
    ):
        self.min_input_length = min_input_length
        self.max_input_length = max_input_length
        self.min_output_length = min_output_length
        self.max_output_length = max_output_length

    def prepare(
        self,
        raw_data: List[Dict[str, Any]],
        input_key: str = "input",
        output_key: str = "output",
        preprocess_fn: callable = None
    ) -> List[TrainingSample]:
        """准备训练数据"""
        samples = []

        for item in raw_data:
            # 提取输入输出
            input_text = item.get(input_key, "")
            output_text = item.get(output_key, "")

            # 预处理
            if preprocess_fn:
                input_text, output_text = preprocess_fn(input_text, output_text)

            # 验证长度
            if not self._validate_length(input_text, output_text):
                continue

            # 创建样本
            sample = TrainingSample(
                input_text=input_text.strip(),
                output_text=output_text.strip(),
                metadata=item.get("metadata")
            )

            samples.append(sample)

        return samples

    def _validate_length(self, input_text: str, output_text: str) -> bool:
        """验证长度"""
        if len(input_text) < self.min_input_length:
            return False
        if len(input_text) > self.max_input_length:
            return False
        if len(output_text) < self.min_output_length:
            return False
        if len(output_text) > self.max_output_length:
            return False
        return True

    def split(
        self,
        samples: List[TrainingSample],
        train_ratio: float = 0.8,
        val_ratio: float = 0.1,
        test_ratio: float = 0.1,
        shuffle: bool = True
    ) -> tuple[List[TrainingSample], List[TrainingSample], List[TrainingSample]]:
        """划分数据集"""
        if shuffle:
            random.shuffle(samples)

        total = len(samples)
        train_end = int(total * train_ratio)
        val_end = train_end + int(total * val_ratio)

        train = samples[:train_end]
        val = samples[train_end:val_end]
        test = samples[val_end:]

        return train, val, test

    def augment(
        self,
        samples: List[TrainingSample],
        methods: List[str] = None
    ) -> List[TrainingSample]:
        """数据增强"""
        if methods is None:
            methods = ["paraphrase", "shuffle"]

        augmented = list(samples)

        for sample in samples:
            for method in methods:
                if method == "paraphrase":
                    # 改写增强
                    paraphrased = self._paraphrase(sample)
                    if paraphrased:
                        augmented.append(paraphrased)

                elif method == "shuffle":
                    # 打乱增强(适用于列表类输出)
                    shuffled = self._shuffle(sample)
                    if shuffled:
                        augmented.append(shuffled)

        return augmented

    def _paraphrase(self, sample: TrainingSample) -> Optional[TrainingSample]:
        """改写增强"""
        # 简化实现
        # 实际应调用模型进行改写
        return None

    def _shuffle(self, sample: TrainingSample) -> Optional[TrainingSample]:
        """打乱增强"""
        # 简化实现
        return None

    def format_for_training(
        self,
        samples: List[TrainingSample],
        format_type: str = "alpaca"
    ) -> List[Dict[str, str]]:
        """格式化为训练格式"""
        formatted = []

        for sample in samples:
            if format_type == "alpaca":
                formatted.append({
                    "instruction": sample.input_text,
                    "input": "",
                    "output": sample.output_text
                })

            elif format_type == "chat":
                formatted.append({
                    "messages": [
                        {"role": "user", "content": sample.input_text},
                        {"role": "assistant", "content": sample.output_text}
                    ]
                })

            elif format_type == "completion":
                formatted.append({
                    "prompt": sample.input_text,
                    "completion": sample.output_text
                })

        return formatted

    def save(
        self,
        samples: List[Dict[str, str]],
        output_path: str,
        format: str = "jsonl"
    ):
        """保存数据"""
        if format == "jsonl":
            with open(output_path, 'w', encoding='utf-8') as f:
                for sample in samples:
                    f.write(json.dumps(sample, ensure_ascii=False) + '\n')

        elif format == "json":
            with open(output_path, 'w', encoding='utf-8') as f:
                json.dump(samples, f, ensure_ascii=False, indent=2)

# 使用示例
preparer = DataPreparer()

# 准备数据
raw_data = [
    {"input": "翻译:Hello", "output": "你好"},
    {"input": "翻译:World", "output": "世界"},
    # ... 更多数据
]

samples = preparer.prepare(raw_data)
train, val, test = preparer.split(samples)

# 格式化
formatted = preparer.format_for_training(train, format_type="chat")

# 保存
preparer.save(formatted, "train.jsonl")

LoRA 微调

python
from dataclasses import dataclass
from typing import Optional, List
import json

@dataclass
class LoRAConfig:
    """LoRA 配置"""
    # 模型配置
    base_model: str
    output_dir: str

    # LoRA 配置
    lora_r: int = 8
    lora_alpha: int = 16
    lora_dropout: float = 0.05
    lora_target_modules: List[str] = None

    # 训练配置
    num_train_epochs: int = 3
    per_device_train_batch_size: int = 4
    per_device_eval_batch_size: int = 4
    gradient_accumulation_steps: int = 4
    learning_rate: float = 2e-4
    weight_decay: float = 0.01
    warmup_ratio: float = 0.1

    # 优化配置
    bf16: bool = True
    fp16: bool = False
    gradient_checkpointing: bool = True
    optim: str = "adamw_torch"

    # 日志配置
    logging_steps: int = 10
    save_steps: int = 500
    eval_steps: int = 500

    def __post_init__(self):
        if self.lora_target_modules is None:
            self.lora_target_modules = ["q_proj", "v_proj"]

class LoRATrainer:
    """LoRA 训练器"""

    def __init__(self, config: LoRAConfig):
        self.config = config

    def prepare_model(self):
        """准备模型"""
        code = f"""
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model, TaskType

# 加载基础模型
model = AutoModelForCausalLM.from_pretrained(
    "{self.config.base_model}",
    torch_dtype="auto",
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("{self.config.base_model}")

# LoRA 配置
lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r={self.config.lora_r},
    lora_alpha={self.config.lora_alpha},
    lora_dropout={self.config.lora_dropout},
    target_modules={self.config.lora_target_modules}
)

# 应用 LoRA
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
"""
        return code

    def prepare_training_script(self) -> str:
        """生成训练脚本"""
        script = f"""
from transformers import TrainingArguments, Trainer
from datasets import load_dataset

# 训练参数
training_args = TrainingArguments(
    output_dir="{self.config.output_dir}",
    num_train_epochs={self.config.num_train_epochs},
    per_device_train_batch_size={self.config.per_device_train_batch_size},
    per_device_eval_batch_size={self.config.per_device_eval_batch_size},
    gradient_accumulation_steps={self.config.gradient_accumulation_steps},
    learning_rate={self.config.learning_rate},
    weight_decay={self.config.weight_decay},
    warmup_ratio={self.config.warmup_ratio},
    bf16={self.config.bf16},
    fp16={self.config.fp16},
    gradient_checkpointing={self.config.gradient_checkpointing},
    optim="{self.config.optim}",
    logging_steps={self.config.logging_steps},
    save_steps={self.config.save_steps},
    eval_steps={self.config.eval_steps},
    evaluation_strategy="steps",
    save_strategy="steps",
    load_best_model_at_end=True,
)

# 加载数据集
train_dataset = load_dataset("json", data_files="train.jsonl", split="train")
eval_dataset = load_dataset("json", data_files="val.jsonl", split="train")

# 创建训练器
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
)

# 开始训练
trainer.train()

# 保存模型
trainer.save_model("{self.config.output_dir}")
tokenizer.save_pretrained("{self.config.output_dir}")
"""
        return script

    def estimate_resources(self) -> dict:
        """估算资源需求"""
        # 基于模型大小估算
        model_sizes = {
            "llama-7b": {"params": 7, "vram_gb": 16},
            "llama-13b": {"params": 13, "vram_gb": 26},
            "llama-70b": {"params": 70, "vram_gb": 140},
            "qwen-7b": {"params": 7, "vram_gb": 16},
            "qwen-14b": {"params": 14, "vram_gb": 28},
        }

        # 简化估算
        base_size = model_sizes.get("llama-7b", {"params": 7, "vram_gb": 16})

        # LoRA 减少显存使用
        lora_vram = base_size["vram_gb"] * 0.3  # LoRA 使用约 30% 显存

        # 考虑批量大小
        batch_vram = self.config.per_device_train_batch_size * 0.5
        total_vram = lora_vram + batch_vram

        # 估算训练时间(假设每 1000 样本 1 分钟)
        estimated_time_minutes = 1000 / 1000 * self.config.num_train_epochs

        return {
            "vram_gb": total_vram,
            "estimated_time_minutes": estimated_time_minutes,
            "trainable_params_percent": (self.config.lora_r / 4096) * 100,  # 粗略估算
        }

    def validate_config(self) -> List[str]:
        """验证配置"""
        errors = []

        if self.config.lora_r <= 0:
            errors.append("lora_r 必须大于 0")

        if self.config.learning_rate <= 0 or self.config.learning_rate > 1:
            errors.append("learning_rate 应该在 (0, 1) 范围内")

        if self.config.num_train_epochs <= 0:
            errors.append("num_train_epochs 必须大于 0")

        if self.config.per_device_train_batch_size <= 0:
            errors.append("per_device_train_batch_size 必须大于 0")

        return errors

# 使用示例
config = LoRAConfig(
    base_model="meta-llama/Llama-2-7b-hf",
    output_dir="./lora-output",
    lora_r=8,
    lora_alpha=16,
    num_train_epochs=3,
    per_device_train_batch_size=4,
    learning_rate=2e-4
)

trainer = LoRATrainer(config)

# 验证配置
errors = trainer.validate_config()
if errors:
    print("配置错误:", errors)
else:
    print("配置有效")

# 估算资源
resources = trainer.estimate_resources()
print(f"显存需求: {resources['vram_gb']:.1f} GB")
print(f"预计训练时间: {resources['estimated_time_minutes']:.1f} 分钟")
print(f"可训练参数比例: {resources['trainable_params_percent']:.2f}%")

# 生成训练脚本
script = trainer.prepare_training_script()
print(script)

训练监控

python
from dataclasses import dataclass
from typing import List, Dict, Any, Optional
from datetime import datetime
import json
import matplotlib.pyplot as plt

@dataclass
class TrainingMetrics:
    """训练指标"""
    step: int
    epoch: float
    loss: float
    learning_rate: float
    eval_loss: Optional[float] = None
    eval_accuracy: Optional[float] = None
    timestamp: datetime = None

class TrainingMonitor:
    """训练监控器"""

    def __init__(self):
        self.metrics: List[TrainingMetrics] = []
        self.best_loss: float = float('inf')
        self.best_step: int = 0
        self.no_improve_count: int = 0

    def record(
        self,
        step: int,
        epoch: float,
        loss: float,
        learning_rate: float,
        eval_loss: float = None,
        eval_accuracy: float = None
    ):
        """记录指标"""
        metric = TrainingMetrics(
            step=step,
            epoch=epoch,
            loss=loss,
            learning_rate=learning_rate,
            eval_loss=eval_loss,
            eval_accuracy=eval_accuracy,
            timestamp=datetime.now()
        )

        self.metrics.append(metric)

        # 更新最佳
        if eval_loss and eval_loss < self.best_loss:
            self.best_loss = eval_loss
            self.best_step = step
            self.no_improve_count = 0
        elif eval_loss:
            self.no_improve_count += 1

    def should_stop(self, patience: int = 3) -> bool:
        """判断是否应该早停"""
        return self.no_improve_count >= patience

    def get_summary(self) -> Dict[str, Any]:
        """获取训练摘要"""
        if not self.metrics:
            return {}

        train_losses = [m.loss for m in self.metrics]
        eval_losses = [m.eval_loss for m in self.metrics if m.eval_loss]

        return {
            "total_steps": self.metrics[-1].step,
            "total_epochs": self.metrics[-1].epoch,
            "final_train_loss": train_losses[-1],
            "min_train_loss": min(train_losses),
            "final_eval_loss": eval_losses[-1] if eval_losses else None,
            "min_eval_loss": min(eval_losses) if eval_losses else None,
            "best_step": self.best_step,
            "best_loss": self.best_loss,
            "total_time": (self.metrics[-1].timestamp - self.metrics[0].timestamp).total_seconds(),
        }

    def plot(self, output_path: str = None):
        """绘制训练曲线"""
        if not self.metrics:
            return

        steps = [m.step for m in self.metrics]
        train_losses = [m.loss for m in self.metrics]
        eval_losses = [m.eval_loss for m in self.metrics if m.eval_loss]
        eval_steps = [m.step for m in self.metrics if m.eval_loss]

        fig, axes = plt.subplots(1, 2, figsize=(12, 4))

        # Loss 曲线
        axes[0].plot(steps, train_losses, label='Train Loss')
        if eval_losses:
            axes[0].plot(eval_steps, eval_losses, label='Eval Loss')
        axes[0].set_xlabel('Step')
        axes[0].set_ylabel('Loss')
        axes[0].set_title('Training Loss')
        axes[0].legend()
        axes[0].grid(True)

        # Learning Rate 曲线
        learning_rates = [m.learning_rate for m in self.metrics]
        axes[1].plot(steps, learning_rates)
        axes[1].set_xlabel('Step')
        axes[1].set_ylabel('Learning Rate')
        axes[1].set_title('Learning Rate Schedule')
        axes[1].grid(True)

        plt.tight_layout()

        if output_path:
            plt.savefig(output_path)
        else:
            plt.show()

    def save_log(self, output_path: str):
        """保存训练日志"""
        log = {
            "summary": self.get_summary(),
            "metrics": [
                {
                    "step": m.step,
                    "epoch": m.epoch,
                    "loss": m.loss,
                    "learning_rate": m.learning_rate,
                    "eval_loss": m.eval_loss,
                    "eval_accuracy": m.eval_accuracy,
                    "timestamp": m.timestamp.isoformat() if m.timestamp else None
                }
                for m in self.metrics
            ]
        }

        with open(output_path, 'w') as f:
            json.dump(log, f, indent=2)

# 使用示例
monitor = TrainingMonitor()

# 模拟训练过程
for step in range(0, 1000, 10):
    # 模拟损失下降
    train_loss = 2.0 - step * 0.001 + random.random() * 0.1
    eval_loss = 2.1 - step * 0.001 + random.random() * 0.15

    monitor.record(
        step=step,
        epoch=step / 200,
        loss=train_loss,
        learning_rate=2e-4 * (1 - step / 1000),
        eval_loss=eval_loss if step % 100 == 0 else None
    )

    # 检查早停
    if monitor.should_stop(patience=5):
        print(f"早停于 step {step}")
        break

# 获取摘要
summary = monitor.get_summary()
print(f"最终损失: {summary['final_train_loss']:.4f}")
print(f"最佳损失: {summary['min_eval_loss']:.4f}")

# 绘制曲线
monitor.plot("training_curves.png")

# 保存日志
monitor.save_log("training_log.json")

模型评估

python
from dataclasses import dataclass
from typing import List, Dict, Any, Callable
import json

@dataclass
class EvaluationResult:
    """评估结果"""
    model_name: str
    task: str
    metrics: Dict[str, float]
    samples: List[Dict[str, Any]]
    timestamp: str

class ModelEvaluator:
    """模型评估器"""

    def __init__(self):
        self.evaluations: List[EvaluationResult] = []

    def evaluate(
        self,
        model,
        test_data: List[TrainingSample],
        task: str,
        metrics: List[str] = None
    ) -> EvaluationResult:
        """评估模型"""
        if metrics is None:
            metrics = ["accuracy", "f1", "exact_match"]

        predictions = []
        references = []

        # 生成预测
        for sample in test_data:
            pred = model.generate(sample.input_text)
            predictions.append(pred)
            references.append(sample.output_text)

        # 计算指标
        results = {}
        for metric in metrics:
            if metric == "accuracy":
                results["accuracy"] = self._accuracy(predictions, references)
            elif metric == "f1":
                results["f1"] = self._f1_score(predictions, references)
            elif metric == "exact_match":
                results["exact_match"] = self._exact_match(predictions, references)
            elif metric == "bleu":
                results["bleu"] = self._bleu(predictions, references)
            elif metric == "rouge":
                results["rouge"] = self._rouge(predictions, references)

        # 构建样本
        samples = [
            {
                "input": test_data[i].input_text,
                "reference": references[i],
                "prediction": predictions[i],
                "correct": predictions[i] == references[i]
            }
            for i in range(len(test_data))
        ]

        result = EvaluationResult(
            model_name=model.name if hasattr(model, 'name') else "unknown",
            task=task,
            metrics=results,
            samples=samples,
            timestamp=datetime.now().isoformat()
        )

        self.evaluations.append(result)

        return result

    def _accuracy(self, predictions: List[str], references: List[str]) -> float:
        """计算准确率"""
        correct = sum(p == r for p, r in zip(predictions, references))
        return correct / len(predictions)

    def _f1_score(self, predictions: List[str], references: List[str]) -> float:
        """计算 F1 分数"""
        # 简化实现
        # 实际应该计算 token 级别的 F1
        return self._accuracy(predictions, references)

    def _exact_match(self, predictions: List[str], references: List[str]) -> float:
        """计算精确匹配"""
        return self._accuracy(predictions, references)

    def _bleu(self, predictions: List[str], references: List[str]) -> float:
        """计算 BLEU 分数"""
        # 简化实现
        # 实际应该使用 nltk 或 sacrebleu
        return 0.0

    def _rouge(self, predictions: List[str], references: List[str]) -> float:
        """计算 ROUGE 分数"""
        # 简化实现
        # 实际应该使用 rouge-score
        return 0.0

    def compare(
        self,
        baseline_model,
        tuned_model,
        test_data: List[TrainingSample],
        task: str
    ) -> Dict[str, Any]:
        """比较模型"""
        baseline_result = self.evaluate(baseline_model, test_data, task)
        tuned_result = self.evaluate(tuned_model, test_data, task)

        comparison = {
            "baseline": {
                "model": baseline_result.model_name,
                "metrics": baseline_result.metrics
            },
            "tuned": {
                "model": tuned_result.model_name,
                "metrics": tuned_result.metrics
            },
            "improvement": {}
        }

        for metric in baseline_result.metrics:
            baseline_value = baseline_result.metrics[metric]
            tuned_value = tuned_result.metrics[metric]

            if baseline_value > 0:
                improvement = (tuned_value - baseline_value) / baseline_value * 100
            else:
                improvement = 0

            comparison["improvement"][metric] = {
                "baseline": baseline_value,
                "tuned": tuned_value,
                "improvement_percent": improvement
            }

        return comparison

    def generate_report(self, output_path: str):
        """生成评估报告"""
        report = {
            "total_evaluations": len(self.evaluations),
            "evaluations": [
                {
                    "model": e.model_name,
                    "task": e.task,
                    "metrics": e.metrics,
                    "timestamp": e.timestamp
                }
                for e in self.evaluations
            ]
        }

        with open(output_path, 'w') as f:
            json.dump(report, f, indent=2)

# 使用示例
evaluator = ModelEvaluator()

# 评估微调后的模型
result = evaluator.evaluate(
    model=tuned_model,
    test_data=test_samples,
    task="classification",
    metrics=["accuracy", "f1"]
)

print(f"准确率: {result.metrics['accuracy']:.2%}")
print(f"F1 分数: {result.metrics['f1']:.2%}")

# 比较模型
comparison = evaluator.compare(
    baseline_model=base_model,
    tuned_model=tuned_model,
    test_data=test_samples,
    task="classification"
)

print(f"准确率提升: {comparison['improvement']['accuracy']['improvement_percent']:.1f}%")

# 生成报告
evaluator.generate_report("evaluation_report.json")

最佳实践总结

微调检查清单

markdown
## 微调检查清单

### 准备阶段

- [ ] 明确任务目标
- [ ] 评估是否需要微调
- [ ] 确定数据量是否足够
- [ ] 选择合适的微调方式

### 数据准备

- [ ] 收集高质量数据
- [ ] 清洗和预处理数据
- [ ] 划分训练/验证/测试集
- [ ] 数据增强(如需要)

### 训练配置

- [ ] 选择合适的基础模型
- [ ] 设置合理的超参数
- [ ] 配置学习率调度
- [ ] 设置早停策略

### 训练监控

- [ ] 监控训练损失
- [ ] 监控验证损失
- [ ] 监控学习率变化
- [ ] 监控显存使用

### 评估验证

- [ ] 在测试集上评估
- [ ] 与基线模型对比
- [ ] 分析错误案例
- [ ] 评估部署成本

### 部署阶段

- [ ] 模型压缩优化
- [ ] 推理性能测试
- [ ] A/B 测试验证
- [ ] 监控线上效果

微调方式选择

方式适用场景数据量成本效果
Prompt Engineering通用任务
Few-shot Learning快速原型很小
Instruction Tuning指令遵循
LoRA特定任务
QLoRA资源受限
全量微调特定领域最高