错误处理与容错策略
本文档整理 AI 系统错误处理和容错的最佳实践,确保系统稳定可靠。
错误分类
txt
┌─────────────────────────────────────────────────────┐
│ AI 系统错误分类 │
├─────────────────────────────────────────────────────┤
│ │
│ 输入错误 │
│ ├── 格式错误 │
│ ├── 长度超限 │
│ ├── 内容违规 │
│ └── 编码问题 │
│ │
│ API 错误 │
│ ├── 速率限制 (429) │
│ ├── 服务不可用 (503) │
│ ├── 认证失败 (401) │
│ ├── 额度超限 │
│ └── 超时 │
│ │
│ 模型错误 │
│ ├── 输出截断 │
│ ├── 格式不匹配 │
│ ├── 内容过滤 │
│ ├── 幻觉/错误信息 │
│ └── 拒答 │
│ │
│ 系统错误 │
│ ├── 网络故障 │
│ ├── 下游服务不可用 │
│ ├── 资源耗尽 │
│ └── 配置错误 │
│ │
└─────────────────────────────────────────────────────┘错误处理框架
统一错误类型
python
from dataclasses import dataclass
from typing import Optional, Any
from enum import Enum
class ErrorType(Enum):
# 输入错误
INVALID_INPUT = "invalid_input"
INPUT_TOO_LONG = "input_too_long"
CONTENT_FILTERED = "content_filtered"
# API 错误
RATE_LIMIT = "rate_limit"
SERVICE_UNAVAILABLE = "service_unavailable"
AUTH_FAILED = "auth_failed"
QUOTA_EXCEEDED = "quota_exceeded"
TIMEOUT = "timeout"
# 模型错误
OUTPUT_TRUNCATED = "output_truncated"
FORMAT_ERROR = "format_error"
MODEL_REFUSAL = "model_refusal"
HALLUCINATION = "hallucination"
# 系统错误
NETWORK_ERROR = "network_error"
DOWNSTREAM_ERROR = "downstream_error"
RESOURCE_EXHAUSTED = "resource_exhausted"
CONFIG_ERROR = "config_error"
@dataclass
class AIError:
error_type: ErrorType
message: str
original_error: Optional[Exception] = None
retryable: bool = False
retry_after: Optional[int] = None # 秒
context: dict = None
def __post_init__(self):
if self.context is None:
self.context = {}
def is_retryable(self) -> bool:
"""判断是否可重试"""
retryable_types = {
ErrorType.RATE_LIMIT,
ErrorType.SERVICE_UNAVAILABLE,
ErrorType.TIMEOUT,
ErrorType.NETWORK_ERROR,
ErrorType.DOWNSTREAM_ERROR,
}
return self.error_type in retryable_types or self.retryable
def get_retry_delay(self) -> int:
"""获取重试延迟"""
if self.retry_after:
return self.retry_after
default_delays = {
ErrorType.RATE_LIMIT: 60,
ErrorType.SERVICE_UNAVAILABLE: 30,
ErrorType.TIMEOUT: 5,
ErrorType.NETWORK_ERROR: 5,
ErrorType.DOWNSTREAM_ERROR: 10,
}
return default_delays.get(self.error_type, 5)错误处理器
python
import asyncio
from typing import Callable, Optional, Any
from functools import wraps
class ErrorHandler:
def __init__(
self,
max_retries: int = 3,
base_delay: float = 1.0,
max_delay: float = 60.0,
exponential_base: float = 2.0,
jitter: bool = True
):
self.max_retries = max_retries
self.base_delay = base_delay
self.max_delay = max_delay
self.exponential_base = exponential_base
self.jitter = jitter
def classify_error(self, error: Exception) -> AIError:
"""分类错误"""
# 根据错误类型分类
error_str = str(error).lower()
# 速率限制
if "rate limit" in error_str or "429" in error_str:
return AIError(
error_type=ErrorType.RATE_LIMIT,
message="Rate limit exceeded",
original_error=error,
retryable=True,
retry_after=self._extract_retry_after(error)
)
# 服务不可用
if "503" in error_str or "service unavailable" in error_str:
return AIError(
error_type=ErrorType.SERVICE_UNAVAILABLE,
message="Service temporarily unavailable",
original_error=error,
retryable=True
)
# 超时
if "timeout" in error_str:
return AIError(
error_type=ErrorType.TIMEOUT,
message="Request timed out",
original_error=error,
retryable=True
)
# 内容过滤
if "content filter" in error_str or "content policy" in error_str:
return AIError(
error_type=ErrorType.CONTENT_FILTERED,
message="Content filtered by safety system",
original_error=error,
retryable=False
)
# 认证失败
if "401" in error_str or "unauthorized" in error_str:
return AIError(
error_type=ErrorType.AUTH_FAILED,
message="Authentication failed",
original_error=error,
retryable=False
)
# 额度超限
if "quota" in error_str or "insufficient" in error_str:
return AIError(
error_type=ErrorType.QUOTA_EXCEEDED,
message="Quota exceeded",
original_error=error,
retryable=False
)
# 网络错误
if "connection" in error_str or "network" in error_str:
return AIError(
error_type=ErrorType.NETWORK_ERROR,
message="Network error",
original_error=error,
retryable=True
)
# 默认:未知错误
return AIError(
error_type=ErrorType.DOWNSTREAM_ERROR,
message=f"Unknown error: {str(error)}",
original_error=error,
retryable=False
)
def _extract_retry_after(self, error: Exception) -> Optional[int]:
"""从错误中提取重试等待时间"""
# 尝试从响应头或错误信息中提取
if hasattr(error, 'response') and hasattr(error.response, 'headers'):
retry_after = error.response.headers.get('Retry-After')
if retry_after:
try:
return int(retry_after)
except ValueError:
pass
return None
def calculate_delay(self, attempt: int) -> float:
"""计算重试延迟"""
delay = self.base_delay * (self.exponential_base ** attempt)
delay = min(delay, self.max_delay)
if self.jitter:
import random
delay = delay * (1 + random.random() * 0.1)
return delay
async def execute_with_retry(
self,
func: Callable,
*args,
**kwargs
) -> Any:
"""带重试的执行"""
last_error = None
for attempt in range(self.max_retries + 1):
try:
if asyncio.iscoroutinefunction(func):
return await func(*args, **kwargs)
else:
return func(*args, **kwargs)
except Exception as e:
ai_error = self.classify_error(e)
last_error = ai_error
# 不可重试,直接抛出
if not ai_error.is_retryable():
raise ai_error
# 最后一次尝试,抛出错误
if attempt == self.max_retries:
raise ai_error
# 计算延迟
delay = ai_error.get_retry_delay()
delay = max(delay, self.calculate_delay(attempt))
print(f"Attempt {attempt + 1} failed: {ai_error.message}. "
f"Retrying in {delay:.2f}s...")
await asyncio.sleep(delay)
raise last_error
# 使用示例
handler = ErrorHandler(max_retries=3)
async def call_model_with_retry(prompt: str) -> str:
@handler.execute_with_retry
async def _call():
return await model.generate(prompt)
return await _call()降级策略
多级降级
python
from dataclasses import dataclass
from typing import Optional, Callable, Any
from enum import Enum
class DegradationLevel(Enum):
FULL = "full" # 完整功能
REDUCED = "reduced" # 降级功能
MINIMAL = "minimal" # 最小功能
OFFLINE = "offline" # 离线模式
@dataclass
class DegradationConfig:
level: DegradationLevel
fallback_model: Optional[str] = None
max_tokens: Optional[int] = None
timeout_ms: Optional[int] = None
features_disabled: list[str] = None
class DegradationManager:
def __init__(self):
self.current_level = DegradationLevel.FULL
self.configs = {
DegradationLevel.FULL: DegradationConfig(
level=DegradationLevel.FULL,
fallback_model=None,
max_tokens=4000,
timeout_ms=30000
),
DegradationLevel.REDUCED: DegradationConfig(
level=DegradationLevel.REDUCED,
fallback_model="gpt-3.5-turbo",
max_tokens=2000,
timeout_ms=15000,
features_disabled=["advanced_reasoning"]
),
DegradationLevel.MINIMAL: DegradationConfig(
level=DegradationLevel.MINIMAL,
fallback_model="gpt-3.5-turbo",
max_tokens=500,
timeout_ms=5000,
features_disabled=["advanced_reasoning", "long_context", "streaming"]
),
DegradationLevel.OFFLINE: DegradationConfig(
level=DegradationLevel.OFFLINE,
fallback_model=None,
max_tokens=0,
timeout_ms=0,
features_disabled=["all"]
),
}
self.error_count = 0
self.success_count = 0
self.threshold = 5
def get_config(self) -> DegradationConfig:
"""获取当前配置"""
return self.configs[self.current_level]
def record_success(self):
"""记录成功"""
self.success_count += 1
self._maybe_upgrade()
def record_failure(self):
"""记录失败"""
self.error_count += 1
self._maybe_downgrade()
def _maybe_downgrade(self):
"""可能降级"""
if self.error_count >= self.threshold:
levels = list(DegradationLevel)
current_idx = levels.index(self.current_level)
if current_idx < len(levels) - 1:
self.current_level = levels[current_idx + 1]
self.error_count = 0
print(f"Degraded to {self.current_level.value} level")
def _maybe_upgrade(self):
"""可能升级"""
if self.success_count >= self.threshold * 2:
levels = list(DegradationLevel)
current_idx = levels.index(self.current_level)
if current_idx > 0:
self.current_level = levels[current_idx - 1]
self.success_count = 0
print(f"Upgraded to {self.current_level.value} level")
async def execute_with_degradation(
self,
primary_func: Callable,
fallback_func: Optional[Callable] = None,
offline_func: Optional[Callable] = None
) -> Any:
"""带降级的执行"""
config = self.get_config()
if config.level == DegradationLevel.OFFLINE:
if offline_func:
return await offline_func()
raise Exception("Service is in offline mode")
try:
result = await asyncio.wait_for(
primary_func(),
timeout=config.timeout_ms / 1000
)
self.record_success()
return result
except Exception as e:
self.record_failure()
# 尝试降级方案
if fallback_func and config.level in [DegradationLevel.REDUCED, DegradationLevel.MINIMAL]:
try:
return await fallback_func()
except Exception:
pass
# 尝试离线方案
if offline_func:
return await offline_func()
raise
# 使用示例
degradation_manager = DegradationManager()
async def get_response(prompt: str) -> str:
config = degradation_manager.get_config()
async def primary():
return await call_model("gpt-4", prompt, max_tokens=config.max_tokens)
async def fallback():
return await call_model("gpt-3.5-turbo", prompt, max_tokens=config.max_tokens)
async def offline():
return "抱歉,服务暂时不可用,请稍后重试。"
return await degradation_manager.execute_with_degradation(
primary, fallback, offline
)熔断器模式
python
from dataclasses import dataclass
from enum import Enum
from datetime import datetime, timedelta
import asyncio
class CircuitState(Enum):
CLOSED = "closed" # 正常状态
OPEN = "open" # 熔断状态
HALF_OPEN = "half_open" # 半开状态
@dataclass
class CircuitStats:
failures: int = 0
successes: int = 0
last_failure: datetime = None
last_success: datetime = None
class CircuitBreaker:
def __init__(
self,
failure_threshold: int = 5,
success_threshold: int = 3,
timeout: timedelta = timedelta(seconds=60),
half_open_max_calls: int = 3
):
self.failure_threshold = failure_threshold
self.success_threshold = success_threshold
self.timeout = timeout
self.half_open_max_calls = half_open_max_calls
self.state = CircuitState.CLOSED
self.stats = CircuitStats()
self.half_open_calls = 0
def is_allowed(self) -> bool:
"""检查是否允许调用"""
if self.state == CircuitState.CLOSED:
return True
if self.state == CircuitState.OPEN:
# 检查是否可以转为半开
if datetime.now() - self.stats.last_failure >= self.timeout:
self._transition_to(CircuitState.HALF_OPEN)
return True
return False
if self.state == CircuitState.HALF_OPEN:
# 半开状态下限制调用次数
return self.half_open_calls < self.half_open_max_calls
return False
def record_success(self):
"""记录成功"""
self.stats.successes += 1
self.stats.last_success = datetime.now()
if self.state == CircuitState.HALF_OPEN:
self.half_open_calls += 1
if self.stats.successes >= self.success_threshold:
self._transition_to(CircuitState.CLOSED)
def record_failure(self):
"""记录失败"""
self.stats.failures += 1
self.stats.last_failure = datetime.now()
if self.state == CircuitState.CLOSED:
if self.stats.failures >= self.failure_threshold:
self._transition_to(CircuitState.OPEN)
elif self.state == CircuitState.HALF_OPEN:
self._transition_to(CircuitState.OPEN)
def _transition_to(self, new_state: CircuitState):
"""状态转换"""
old_state = self.state
self.state = new_state
if new_state == CircuitState.CLOSED:
self.stats = CircuitStats()
self.half_open_calls = 0
elif new_state == CircuitState.OPEN:
self.half_open_calls = 0
elif new_state == CircuitState.HALF_OPEN:
self.stats.successes = 0
self.half_open_calls = 0
print(f"Circuit breaker: {old_state.value} -> {new_state.value}")
async def execute(self, func: Callable) -> Any:
"""带熔断的执行"""
if not self.is_allowed():
raise Exception("Circuit breaker is open")
try:
result = await func()
self.record_success()
return result
except Exception as e:
self.record_failure()
raise
# 使用示例
circuit_breaker = CircuitBreaker(
failure_threshold=5,
success_threshold=3,
timeout=timedelta(seconds=60)
)
async def call_with_circuit_breaker(prompt: str) -> str:
return await circuit_breaker.execute(lambda: call_model("gpt-4", prompt))限流与背压
令牌桶限流
python
import asyncio
from dataclasses import dataclass
from datetime import datetime
@dataclass
class TokenBucket:
rate: float # 令牌产生速率(令牌/秒)
capacity: float # 桶容量
tokens: float = 0.0
last_update: datetime = None
def __post_init__(self):
self.tokens = self.capacity
self.last_update = datetime.now()
def consume(self, tokens: float = 1.0) -> bool:
"""消费令牌"""
self._refill()
if self.tokens >= tokens:
self.tokens -= tokens
return True
return False
def wait_time(self, tokens: float = 1.0) -> float:
"""计算等待时间"""
self._refill()
if self.tokens >= tokens:
return 0.0
needed = tokens - self.tokens
return needed / self.rate
def _refill(self):
"""补充令牌"""
now = datetime.now()
elapsed = (now - self.last_update).total_seconds()
self.tokens = min(self.capacity, self.tokens + elapsed * self.rate)
self.last_update = now
class RateLimiter:
def __init__(
self,
requests_per_minute: int = 60,
tokens_per_minute: int = 100000
):
self.request_bucket = TokenBucket(
rate=requests_per_minute / 60.0,
capacity=requests_per_minute / 6.0 # 10秒的量
)
self.token_bucket = TokenBucket(
rate=tokens_per_minute / 60.0,
capacity=tokens_per_minute / 6.0
)
async def acquire(self, tokens: int = 1000):
"""获取许可"""
# 检查请求限制
if not self.request_bucket.consume(1):
wait = self.request_bucket.wait_time(1)
await asyncio.sleep(wait)
self.request_bucket.consume(1)
# 检查 token 限制
if not self.token_bucket.consume(tokens):
wait = self.token_bucket.wait_time(tokens)
await asyncio.sleep(wait)
self.token_bucket.consume(tokens)
async def execute(self, func: Callable, tokens: int = 1000) -> Any:
"""带限流的执行"""
await self.acquire(tokens)
return await func()
# 使用示例
limiter = RateLimiter(
requests_per_minute=60,
tokens_per_minute=100000
)
async def call_with_rate_limit(prompt: str) -> str:
estimated_tokens = len(prompt.split()) * 2
return await limiter.execute(
lambda: call_model("gpt-4", prompt),
tokens=estimated_tokens
)监控与告警
错误监控
python
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Dict, List
from collections import defaultdict
import json
@dataclass
class ErrorEvent:
timestamp: datetime
error_type: ErrorType
model: str
message: str
retryable: bool
latency_ms: int
class ErrorMonitor:
def __init__(self, window_size: timedelta = timedelta(minutes=5)):
self.window_size = window_size
self.errors: Dict[str, List[ErrorEvent]] = defaultdict(list)
self.alerts = []
def record_error(
self,
error_type: ErrorType,
model: str,
message: str,
retryable: bool,
latency_ms: int
):
"""记录错误"""
event = ErrorEvent(
timestamp=datetime.now(),
error_type=error_type,
model=model,
message=message,
retryable=retryable,
latency_ms=latency_ms
)
self.errors[model].append(event)
self._clean_old_errors(model)
self._check_alerts(model, event)
def get_error_rate(self, model: str) -> float:
"""获取错误率"""
errors = self.errors.get(model, [])
if not errors:
return 0.0
recent = [e for e in errors if datetime.now() - e.timestamp < self.window_size]
return len(recent) / self.window_size.total_seconds()
def get_error_summary(self, model: str) -> dict:
"""获取错误摘要"""
errors = self.errors.get(model, [])
recent = [e for e in errors if datetime.now() - e.timestamp < self.window_size]
summary = {
"total_errors": len(recent),
"error_rate": self.get_error_rate(model),
"by_type": defaultdict(int),
"retryable_count": 0,
"avg_latency_ms": 0
}
if recent:
for e in recent:
summary["by_type"][e.error_type.value] += 1
if e.retryable:
summary["retryable_count"] += 1
summary["avg_latency_ms"] = sum(e.latency_ms for e in recent) / len(recent)
return summary
def _clean_old_errors(self, model: str):
"""清理旧错误"""
cutoff = datetime.now() - self.window_size
self.errors[model] = [
e for e in self.errors[model]
if e.timestamp >= cutoff
]
def _check_alerts(self, model: str, event: ErrorEvent):
"""检查告警条件"""
error_rate = self.get_error_rate(model)
# 错误率告警
if error_rate > 0.1: # 每秒 0.1 个错误
self._send_alert(
level="warning",
model=model,
message=f"High error rate: {error_rate:.2f}/s"
)
# 特定错误告警
if event.error_type == ErrorType.AUTH_FAILED:
self._send_alert(
level="critical",
model=model,
message="Authentication failed"
)
if event.error_type == ErrorType.QUOTA_EXCEEDED:
self._send_alert(
level="critical",
model=model,
message="Quota exceeded"
)
def _send_alert(self, level: str, model: str, message: str):
"""发送告警"""
alert = {
"timestamp": datetime.now().isoformat(),
"level": level,
"model": model,
"message": message
}
self.alerts.append(alert)
print(f"[{level.upper()}] {model}: {message}")
# 使用示例
monitor = ErrorMonitor()
async def call_with_monitoring(model: str, prompt: str) -> str:
start_time = datetime.now()
try:
result = await call_model(model, prompt)
return result
except Exception as e:
latency_ms = int((datetime.now() - start_time).total_seconds() * 1000)
handler = ErrorHandler()
ai_error = handler.classify_error(e)
monitor.record_error(
error_type=ai_error.error_type,
model=model,
message=ai_error.message,
retryable=ai_error.is_retryable(),
latency_ms=latency_ms
)
raise最佳实践总结
错误处理清单
markdown
## 错误处理检查清单
### 必须项
- [ ] 实现统一的错误类型
- [ ] 实现错误分类和识别
- [ ] 实现重试机制(指数退避)
- [ ] 实现降级策略
- [ ] 实现熔断器
### 推荐项
- [ ] 实现限流机制
- [ ] 实现语义缓存
- [ ] 实现错误监控
- [ ] 实现告警机制
- [ ] 实现优雅降级
### 可选项
- [ ] 实现多级降级
- [ ] 实现自适应重试
- [ ] 实现成本感知降级
- [ ] 实现用户友好错误消息
### 避免事项
- [ ] 无限重试
- [ ] 忽略错误
- [ ] 暴露敏感信息
- [ ] 级联失败错误响应策略
| 错误类型 | 响应策略 | 用户消息 |
|---|---|---|
| RATE_LIMIT | 等待 + 重试 | "请求过于频繁,请稍后重试" |
| SERVICE_UNAVAILABLE | 降级 + 重试 | "服务暂时不可用,正在尝试其他方案" |
| AUTH_FAILED | 告警 + 停止 | "服务配置错误,请联系管理员" |
| QUOTA_EXCEEDED | 告警 + 降级 | "服务额度已用尽,请稍后重试" |
| TIMEOUT | 重试 + 降级 | "请求超时,正在重试..." |
| CONTENT_FILTERED | 拒绝 + 解释 | "请求内容不符合安全规范" |
| OUTPUT_TRUNCATED | 继续生成 | "响应过长,正在继续生成..." |
| FORMAT_ERROR | 重试 + 验证 | "响应格式异常,正在重试..." |
监控指标
yaml
错误监控指标:
错误率:
- 按错误类型统计
- 按模型统计
- 按时间段统计
响应时间:
- P50/P95/P99 延迟
- 超时率
重试统计:
- 重试次数分布
- 重试成功率
降级统计:
- 降级触发次数
- 降级成功率
熔断统计:
- 熔断触发次数
- 熔断持续时间