Skip to content

错误处理与容错策略

本文档整理 AI 系统错误处理和容错的最佳实践,确保系统稳定可靠。

错误分类

txt
┌─────────────────────────────────────────────────────┐
│                   AI 系统错误分类                     │
├─────────────────────────────────────────────────────┤
│                                                     │
│  输入错误                                           │
│  ├── 格式错误                                       │
│  ├── 长度超限                                       │
│  ├── 内容违规                                       │
│  └── 编码问题                                       │
│                                                     │
│  API 错误                                           │
│  ├── 速率限制 (429)                                 │
│  ├── 服务不可用 (503)                               │
│  ├── 认证失败 (401)                                 │
│  ├── 额度超限                                       │
│  └── 超时                                           │
│                                                     │
│  模型错误                                           │
│  ├── 输出截断                                       │
│  ├── 格式不匹配                                     │
│  ├── 内容过滤                                       │
│  ├── 幻觉/错误信息                                  │
│  └── 拒答                                           │
│                                                     │
│  系统错误                                           │
│  ├── 网络故障                                       │
│  ├── 下游服务不可用                                 │
│  ├── 资源耗尽                                       │
│  └── 配置错误                                       │
│                                                     │
└─────────────────────────────────────────────────────┘

错误处理框架

统一错误类型

python
from dataclasses import dataclass
from typing import Optional, Any
from enum import Enum

class ErrorType(Enum):
    # 输入错误
    INVALID_INPUT = "invalid_input"
    INPUT_TOO_LONG = "input_too_long"
    CONTENT_FILTERED = "content_filtered"

    # API 错误
    RATE_LIMIT = "rate_limit"
    SERVICE_UNAVAILABLE = "service_unavailable"
    AUTH_FAILED = "auth_failed"
    QUOTA_EXCEEDED = "quota_exceeded"
    TIMEOUT = "timeout"

    # 模型错误
    OUTPUT_TRUNCATED = "output_truncated"
    FORMAT_ERROR = "format_error"
    MODEL_REFUSAL = "model_refusal"
    HALLUCINATION = "hallucination"

    # 系统错误
    NETWORK_ERROR = "network_error"
    DOWNSTREAM_ERROR = "downstream_error"
    RESOURCE_EXHAUSTED = "resource_exhausted"
    CONFIG_ERROR = "config_error"

@dataclass
class AIError:
    error_type: ErrorType
    message: str
    original_error: Optional[Exception] = None
    retryable: bool = False
    retry_after: Optional[int] = None  # 秒
    context: dict = None

    def __post_init__(self):
        if self.context is None:
            self.context = {}

    def is_retryable(self) -> bool:
        """判断是否可重试"""
        retryable_types = {
            ErrorType.RATE_LIMIT,
            ErrorType.SERVICE_UNAVAILABLE,
            ErrorType.TIMEOUT,
            ErrorType.NETWORK_ERROR,
            ErrorType.DOWNSTREAM_ERROR,
        }
        return self.error_type in retryable_types or self.retryable

    def get_retry_delay(self) -> int:
        """获取重试延迟"""
        if self.retry_after:
            return self.retry_after

        default_delays = {
            ErrorType.RATE_LIMIT: 60,
            ErrorType.SERVICE_UNAVAILABLE: 30,
            ErrorType.TIMEOUT: 5,
            ErrorType.NETWORK_ERROR: 5,
            ErrorType.DOWNSTREAM_ERROR: 10,
        }
        return default_delays.get(self.error_type, 5)

错误处理器

python
import asyncio
from typing import Callable, Optional, Any
from functools import wraps

class ErrorHandler:
    def __init__(
        self,
        max_retries: int = 3,
        base_delay: float = 1.0,
        max_delay: float = 60.0,
        exponential_base: float = 2.0,
        jitter: bool = True
    ):
        self.max_retries = max_retries
        self.base_delay = base_delay
        self.max_delay = max_delay
        self.exponential_base = exponential_base
        self.jitter = jitter

    def classify_error(self, error: Exception) -> AIError:
        """分类错误"""
        # 根据错误类型分类
        error_str = str(error).lower()

        # 速率限制
        if "rate limit" in error_str or "429" in error_str:
            return AIError(
                error_type=ErrorType.RATE_LIMIT,
                message="Rate limit exceeded",
                original_error=error,
                retryable=True,
                retry_after=self._extract_retry_after(error)
            )

        # 服务不可用
        if "503" in error_str or "service unavailable" in error_str:
            return AIError(
                error_type=ErrorType.SERVICE_UNAVAILABLE,
                message="Service temporarily unavailable",
                original_error=error,
                retryable=True
            )

        # 超时
        if "timeout" in error_str:
            return AIError(
                error_type=ErrorType.TIMEOUT,
                message="Request timed out",
                original_error=error,
                retryable=True
            )

        # 内容过滤
        if "content filter" in error_str or "content policy" in error_str:
            return AIError(
                error_type=ErrorType.CONTENT_FILTERED,
                message="Content filtered by safety system",
                original_error=error,
                retryable=False
            )

        # 认证失败
        if "401" in error_str or "unauthorized" in error_str:
            return AIError(
                error_type=ErrorType.AUTH_FAILED,
                message="Authentication failed",
                original_error=error,
                retryable=False
            )

        # 额度超限
        if "quota" in error_str or "insufficient" in error_str:
            return AIError(
                error_type=ErrorType.QUOTA_EXCEEDED,
                message="Quota exceeded",
                original_error=error,
                retryable=False
            )

        # 网络错误
        if "connection" in error_str or "network" in error_str:
            return AIError(
                error_type=ErrorType.NETWORK_ERROR,
                message="Network error",
                original_error=error,
                retryable=True
            )

        # 默认:未知错误
        return AIError(
            error_type=ErrorType.DOWNSTREAM_ERROR,
            message=f"Unknown error: {str(error)}",
            original_error=error,
            retryable=False
        )

    def _extract_retry_after(self, error: Exception) -> Optional[int]:
        """从错误中提取重试等待时间"""
        # 尝试从响应头或错误信息中提取
        if hasattr(error, 'response') and hasattr(error.response, 'headers'):
            retry_after = error.response.headers.get('Retry-After')
            if retry_after:
                try:
                    return int(retry_after)
                except ValueError:
                    pass
        return None

    def calculate_delay(self, attempt: int) -> float:
        """计算重试延迟"""
        delay = self.base_delay * (self.exponential_base ** attempt)
        delay = min(delay, self.max_delay)

        if self.jitter:
            import random
            delay = delay * (1 + random.random() * 0.1)

        return delay

    async def execute_with_retry(
        self,
        func: Callable,
        *args,
        **kwargs
    ) -> Any:
        """带重试的执行"""
        last_error = None

        for attempt in range(self.max_retries + 1):
            try:
                if asyncio.iscoroutinefunction(func):
                    return await func(*args, **kwargs)
                else:
                    return func(*args, **kwargs)

            except Exception as e:
                ai_error = self.classify_error(e)
                last_error = ai_error

                # 不可重试,直接抛出
                if not ai_error.is_retryable():
                    raise ai_error

                # 最后一次尝试,抛出错误
                if attempt == self.max_retries:
                    raise ai_error

                # 计算延迟
                delay = ai_error.get_retry_delay()
                delay = max(delay, self.calculate_delay(attempt))

                print(f"Attempt {attempt + 1} failed: {ai_error.message}. "
                      f"Retrying in {delay:.2f}s...")

                await asyncio.sleep(delay)

        raise last_error

# 使用示例
handler = ErrorHandler(max_retries=3)

async def call_model_with_retry(prompt: str) -> str:
    @handler.execute_with_retry
    async def _call():
        return await model.generate(prompt)

    return await _call()

降级策略

多级降级

python
from dataclasses import dataclass
from typing import Optional, Callable, Any
from enum import Enum

class DegradationLevel(Enum):
    FULL = "full"              # 完整功能
    REDUCED = "reduced"        # 降级功能
    MINIMAL = "minimal"        # 最小功能
    OFFLINE = "offline"        # 离线模式

@dataclass
class DegradationConfig:
    level: DegradationLevel
    fallback_model: Optional[str] = None
    max_tokens: Optional[int] = None
    timeout_ms: Optional[int] = None
    features_disabled: list[str] = None

class DegradationManager:
    def __init__(self):
        self.current_level = DegradationLevel.FULL
        self.configs = {
            DegradationLevel.FULL: DegradationConfig(
                level=DegradationLevel.FULL,
                fallback_model=None,
                max_tokens=4000,
                timeout_ms=30000
            ),
            DegradationLevel.REDUCED: DegradationConfig(
                level=DegradationLevel.REDUCED,
                fallback_model="gpt-3.5-turbo",
                max_tokens=2000,
                timeout_ms=15000,
                features_disabled=["advanced_reasoning"]
            ),
            DegradationLevel.MINIMAL: DegradationConfig(
                level=DegradationLevel.MINIMAL,
                fallback_model="gpt-3.5-turbo",
                max_tokens=500,
                timeout_ms=5000,
                features_disabled=["advanced_reasoning", "long_context", "streaming"]
            ),
            DegradationLevel.OFFLINE: DegradationConfig(
                level=DegradationLevel.OFFLINE,
                fallback_model=None,
                max_tokens=0,
                timeout_ms=0,
                features_disabled=["all"]
            ),
        }
        self.error_count = 0
        self.success_count = 0
        self.threshold = 5

    def get_config(self) -> DegradationConfig:
        """获取当前配置"""
        return self.configs[self.current_level]

    def record_success(self):
        """记录成功"""
        self.success_count += 1
        self._maybe_upgrade()

    def record_failure(self):
        """记录失败"""
        self.error_count += 1
        self._maybe_downgrade()

    def _maybe_downgrade(self):
        """可能降级"""
        if self.error_count >= self.threshold:
            levels = list(DegradationLevel)
            current_idx = levels.index(self.current_level)
            if current_idx < len(levels) - 1:
                self.current_level = levels[current_idx + 1]
                self.error_count = 0
                print(f"Degraded to {self.current_level.value} level")

    def _maybe_upgrade(self):
        """可能升级"""
        if self.success_count >= self.threshold * 2:
            levels = list(DegradationLevel)
            current_idx = levels.index(self.current_level)
            if current_idx > 0:
                self.current_level = levels[current_idx - 1]
                self.success_count = 0
                print(f"Upgraded to {self.current_level.value} level")

    async def execute_with_degradation(
        self,
        primary_func: Callable,
        fallback_func: Optional[Callable] = None,
        offline_func: Optional[Callable] = None
    ) -> Any:
        """带降级的执行"""
        config = self.get_config()

        if config.level == DegradationLevel.OFFLINE:
            if offline_func:
                return await offline_func()
            raise Exception("Service is in offline mode")

        try:
            result = await asyncio.wait_for(
                primary_func(),
                timeout=config.timeout_ms / 1000
            )
            self.record_success()
            return result

        except Exception as e:
            self.record_failure()

            # 尝试降级方案
            if fallback_func and config.level in [DegradationLevel.REDUCED, DegradationLevel.MINIMAL]:
                try:
                    return await fallback_func()
                except Exception:
                    pass

            # 尝试离线方案
            if offline_func:
                return await offline_func()

            raise

# 使用示例
degradation_manager = DegradationManager()

async def get_response(prompt: str) -> str:
    config = degradation_manager.get_config()

    async def primary():
        return await call_model("gpt-4", prompt, max_tokens=config.max_tokens)

    async def fallback():
        return await call_model("gpt-3.5-turbo", prompt, max_tokens=config.max_tokens)

    async def offline():
        return "抱歉,服务暂时不可用,请稍后重试。"

    return await degradation_manager.execute_with_degradation(
        primary, fallback, offline
    )

熔断器模式

python
from dataclasses import dataclass
from enum import Enum
from datetime import datetime, timedelta
import asyncio

class CircuitState(Enum):
    CLOSED = "closed"      # 正常状态
    OPEN = "open"          # 熔断状态
    HALF_OPEN = "half_open"  # 半开状态

@dataclass
class CircuitStats:
    failures: int = 0
    successes: int = 0
    last_failure: datetime = None
    last_success: datetime = None

class CircuitBreaker:
    def __init__(
        self,
        failure_threshold: int = 5,
        success_threshold: int = 3,
        timeout: timedelta = timedelta(seconds=60),
        half_open_max_calls: int = 3
    ):
        self.failure_threshold = failure_threshold
        self.success_threshold = success_threshold
        self.timeout = timeout
        self.half_open_max_calls = half_open_max_calls

        self.state = CircuitState.CLOSED
        self.stats = CircuitStats()
        self.half_open_calls = 0

    def is_allowed(self) -> bool:
        """检查是否允许调用"""
        if self.state == CircuitState.CLOSED:
            return True

        if self.state == CircuitState.OPEN:
            # 检查是否可以转为半开
            if datetime.now() - self.stats.last_failure >= self.timeout:
                self._transition_to(CircuitState.HALF_OPEN)
                return True
            return False

        if self.state == CircuitState.HALF_OPEN:
            # 半开状态下限制调用次数
            return self.half_open_calls < self.half_open_max_calls

        return False

    def record_success(self):
        """记录成功"""
        self.stats.successes += 1
        self.stats.last_success = datetime.now()

        if self.state == CircuitState.HALF_OPEN:
            self.half_open_calls += 1
            if self.stats.successes >= self.success_threshold:
                self._transition_to(CircuitState.CLOSED)

    def record_failure(self):
        """记录失败"""
        self.stats.failures += 1
        self.stats.last_failure = datetime.now()

        if self.state == CircuitState.CLOSED:
            if self.stats.failures >= self.failure_threshold:
                self._transition_to(CircuitState.OPEN)

        elif self.state == CircuitState.HALF_OPEN:
            self._transition_to(CircuitState.OPEN)

    def _transition_to(self, new_state: CircuitState):
        """状态转换"""
        old_state = self.state
        self.state = new_state

        if new_state == CircuitState.CLOSED:
            self.stats = CircuitStats()
            self.half_open_calls = 0

        elif new_state == CircuitState.OPEN:
            self.half_open_calls = 0

        elif new_state == CircuitState.HALF_OPEN:
            self.stats.successes = 0
            self.half_open_calls = 0

        print(f"Circuit breaker: {old_state.value} -> {new_state.value}")

    async def execute(self, func: Callable) -> Any:
        """带熔断的执行"""
        if not self.is_allowed():
            raise Exception("Circuit breaker is open")

        try:
            result = await func()
            self.record_success()
            return result
        except Exception as e:
            self.record_failure()
            raise

# 使用示例
circuit_breaker = CircuitBreaker(
    failure_threshold=5,
    success_threshold=3,
    timeout=timedelta(seconds=60)
)

async def call_with_circuit_breaker(prompt: str) -> str:
    return await circuit_breaker.execute(lambda: call_model("gpt-4", prompt))

限流与背压

令牌桶限流

python
import asyncio
from dataclasses import dataclass
from datetime import datetime

@dataclass
class TokenBucket:
    rate: float  # 令牌产生速率(令牌/秒)
    capacity: float  # 桶容量
    tokens: float = 0.0
    last_update: datetime = None

    def __post_init__(self):
        self.tokens = self.capacity
        self.last_update = datetime.now()

    def consume(self, tokens: float = 1.0) -> bool:
        """消费令牌"""
        self._refill()

        if self.tokens >= tokens:
            self.tokens -= tokens
            return True
        return False

    def wait_time(self, tokens: float = 1.0) -> float:
        """计算等待时间"""
        self._refill()

        if self.tokens >= tokens:
            return 0.0

        needed = tokens - self.tokens
        return needed / self.rate

    def _refill(self):
        """补充令牌"""
        now = datetime.now()
        elapsed = (now - self.last_update).total_seconds()
        self.tokens = min(self.capacity, self.tokens + elapsed * self.rate)
        self.last_update = now

class RateLimiter:
    def __init__(
        self,
        requests_per_minute: int = 60,
        tokens_per_minute: int = 100000
    ):
        self.request_bucket = TokenBucket(
            rate=requests_per_minute / 60.0,
            capacity=requests_per_minute / 6.0  # 10秒的量
        )
        self.token_bucket = TokenBucket(
            rate=tokens_per_minute / 60.0,
            capacity=tokens_per_minute / 6.0
        )

    async def acquire(self, tokens: int = 1000):
        """获取许可"""
        # 检查请求限制
        if not self.request_bucket.consume(1):
            wait = self.request_bucket.wait_time(1)
            await asyncio.sleep(wait)
            self.request_bucket.consume(1)

        # 检查 token 限制
        if not self.token_bucket.consume(tokens):
            wait = self.token_bucket.wait_time(tokens)
            await asyncio.sleep(wait)
            self.token_bucket.consume(tokens)

    async def execute(self, func: Callable, tokens: int = 1000) -> Any:
        """带限流的执行"""
        await self.acquire(tokens)
        return await func()

# 使用示例
limiter = RateLimiter(
    requests_per_minute=60,
    tokens_per_minute=100000
)

async def call_with_rate_limit(prompt: str) -> str:
    estimated_tokens = len(prompt.split()) * 2
    return await limiter.execute(
        lambda: call_model("gpt-4", prompt),
        tokens=estimated_tokens
    )

监控与告警

错误监控

python
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Dict, List
from collections import defaultdict
import json

@dataclass
class ErrorEvent:
    timestamp: datetime
    error_type: ErrorType
    model: str
    message: str
    retryable: bool
    latency_ms: int

class ErrorMonitor:
    def __init__(self, window_size: timedelta = timedelta(minutes=5)):
        self.window_size = window_size
        self.errors: Dict[str, List[ErrorEvent]] = defaultdict(list)
        self.alerts = []

    def record_error(
        self,
        error_type: ErrorType,
        model: str,
        message: str,
        retryable: bool,
        latency_ms: int
    ):
        """记录错误"""
        event = ErrorEvent(
            timestamp=datetime.now(),
            error_type=error_type,
            model=model,
            message=message,
            retryable=retryable,
            latency_ms=latency_ms
        )

        self.errors[model].append(event)
        self._clean_old_errors(model)
        self._check_alerts(model, event)

    def get_error_rate(self, model: str) -> float:
        """获取错误率"""
        errors = self.errors.get(model, [])
        if not errors:
            return 0.0

        recent = [e for e in errors if datetime.now() - e.timestamp < self.window_size]
        return len(recent) / self.window_size.total_seconds()

    def get_error_summary(self, model: str) -> dict:
        """获取错误摘要"""
        errors = self.errors.get(model, [])
        recent = [e for e in errors if datetime.now() - e.timestamp < self.window_size]

        summary = {
            "total_errors": len(recent),
            "error_rate": self.get_error_rate(model),
            "by_type": defaultdict(int),
            "retryable_count": 0,
            "avg_latency_ms": 0
        }

        if recent:
            for e in recent:
                summary["by_type"][e.error_type.value] += 1
                if e.retryable:
                    summary["retryable_count"] += 1

            summary["avg_latency_ms"] = sum(e.latency_ms for e in recent) / len(recent)

        return summary

    def _clean_old_errors(self, model: str):
        """清理旧错误"""
        cutoff = datetime.now() - self.window_size
        self.errors[model] = [
            e for e in self.errors[model]
            if e.timestamp >= cutoff
        ]

    def _check_alerts(self, model: str, event: ErrorEvent):
        """检查告警条件"""
        error_rate = self.get_error_rate(model)

        # 错误率告警
        if error_rate > 0.1:  # 每秒 0.1 个错误
            self._send_alert(
                level="warning",
                model=model,
                message=f"High error rate: {error_rate:.2f}/s"
            )

        # 特定错误告警
        if event.error_type == ErrorType.AUTH_FAILED:
            self._send_alert(
                level="critical",
                model=model,
                message="Authentication failed"
            )

        if event.error_type == ErrorType.QUOTA_EXCEEDED:
            self._send_alert(
                level="critical",
                model=model,
                message="Quota exceeded"
            )

    def _send_alert(self, level: str, model: str, message: str):
        """发送告警"""
        alert = {
            "timestamp": datetime.now().isoformat(),
            "level": level,
            "model": model,
            "message": message
        }
        self.alerts.append(alert)
        print(f"[{level.upper()}] {model}: {message}")

# 使用示例
monitor = ErrorMonitor()

async def call_with_monitoring(model: str, prompt: str) -> str:
    start_time = datetime.now()

    try:
        result = await call_model(model, prompt)
        return result

    except Exception as e:
        latency_ms = int((datetime.now() - start_time).total_seconds() * 1000)

        handler = ErrorHandler()
        ai_error = handler.classify_error(e)

        monitor.record_error(
            error_type=ai_error.error_type,
            model=model,
            message=ai_error.message,
            retryable=ai_error.is_retryable(),
            latency_ms=latency_ms
        )

        raise

最佳实践总结

错误处理清单

markdown
## 错误处理检查清单

### 必须项

- [ ] 实现统一的错误类型
- [ ] 实现错误分类和识别
- [ ] 实现重试机制(指数退避)
- [ ] 实现降级策略
- [ ] 实现熔断器

### 推荐项

- [ ] 实现限流机制
- [ ] 实现语义缓存
- [ ] 实现错误监控
- [ ] 实现告警机制
- [ ] 实现优雅降级

### 可选项

- [ ] 实现多级降级
- [ ] 实现自适应重试
- [ ] 实现成本感知降级
- [ ] 实现用户友好错误消息

### 避免事项

- [ ] 无限重试
- [ ] 忽略错误
- [ ] 暴露敏感信息
- [ ] 级联失败

错误响应策略

错误类型响应策略用户消息
RATE_LIMIT等待 + 重试"请求过于频繁,请稍后重试"
SERVICE_UNAVAILABLE降级 + 重试"服务暂时不可用,正在尝试其他方案"
AUTH_FAILED告警 + 停止"服务配置错误,请联系管理员"
QUOTA_EXCEEDED告警 + 降级"服务额度已用尽,请稍后重试"
TIMEOUT重试 + 降级"请求超时,正在重试..."
CONTENT_FILTERED拒绝 + 解释"请求内容不符合安全规范"
OUTPUT_TRUNCATED继续生成"响应过长,正在继续生成..."
FORMAT_ERROR重试 + 验证"响应格式异常,正在重试..."

监控指标

yaml
错误监控指标:
  错误率:
    - 按错误类型统计
    - 按模型统计
    - 按时间段统计

  响应时间:
    - P50/P95/P99 延迟
    - 超时率

  重试统计:
    - 重试次数分布
    - 重试成功率

  降级统计:
    - 降级触发次数
    - 降级成功率

  熔断统计:
    - 熔断触发次数
    - 熔断持续时间