Skip to content

多模态处理最佳实践

本文档整理多模态(文本、图像、音频、视频)AI 处理的最佳实践。

多模态概述

txt
┌─────────────────────────────────────────────────────┐
│                   多模态处理架构                     │
├─────────────────────────────────────────────────────┤
│                                                     │
│  输入层                                             │
│  ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐                   │
│  │文本 │ │图像 │ │音频 │ │视频 │                   │
│  └──┬──┘ └──┬──┘ └──┬──┘ └──┬──┘                   │
│     │       │       │       │                       │
│     ↓       ↓       ↓       ↓                       │
│  编码层                                             │
│  ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐                   │
│  │文本 │ │图像 │ │音频 │ │视频 │                   │
│  │编码 │ │编码 │ │编码 │ │编码 │                   │
│  └──┬──┘ └──┬──┘ └──┬──┘ └──┬──┘                   │
│     │       │       │       │                       │
│     └───────┴───────┴───────┘                       │
│                 │                                   │
│                 ↓                                   │
│  融合层                                             │
│  ┌─────────────────────────────┐                   │
│  │    多模态特征融合            │                   │
│  │  Attention │ Cross-Modal    │                   │
│  └──────────────┬──────────────┘                   │
│                 │                                   │
│                 ↓                                   │
│  输出层                                             │
│  ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐                   │
│  │文本 │ │图像 │ │音频 │ │视频 │                   │
│  │生成 │ │生成 │ │生成 │ │生成 │                   │
│  └─────┘ └─────┘ └─────┘ └─────┘                   │
│                                                     │
└─────────────────────────────────────────────────────┘

图像处理

图像输入处理

python
from dataclasses import dataclass
from typing import List, Optional, Tuple
from enum import Enum
import base64
from io import BytesIO

class ImageFormat(Enum):
    JPEG = "jpeg"
    PNG = "png"
    WEBP = "webp"
    GIF = "gif"

@dataclass
class ImageInput:
    """图像输入"""
    data: bytes
    format: ImageFormat
    width: int
    height: int
    caption: Optional[str] = None

    def to_base64(self) -> str:
        """转换为 Base64"""
        return base64.b64encode(self.data).decode('utf-8')

    def to_data_url(self) -> str:
        """转换为 Data URL"""
        return f"data:image/{self.format.value};base64,{self.to_base64()}"

class ImageProcessor:
    """图像处理器"""

    def __init__(
        self,
        max_width: int = 2048,
        max_height: int = 2048,
        max_size_mb: int = 20,
        preferred_format: ImageFormat = ImageFormat.JPEG
    ):
        self.max_width = max_width
        self.max_height = max_height
        self.max_size_mb = max_size_mb
        self.preferred_format = preferred_format

    def process(self, image: ImageInput) -> ImageInput:
        """处理图像"""
        # 检查尺寸
        if image.width > self.max_width or image.height > self.max_height:
            image = self._resize(image)

        # 检查大小
        if len(image.data) > self.max_size_mb * 1024 * 1024:
            image = self._compress(image)

        return image

    def _resize(self, image: ImageInput) -> ImageInput:
        """调整尺寸"""
        from PIL import Image

        img = Image.open(BytesIO(image.data))

        # 计算缩放比例
        ratio = min(self.max_width / image.width, self.max_height / image.height)
        new_width = int(image.width * ratio)
        new_height = int(image.height * ratio)

        # 调整尺寸
        img_resized = img.resize((new_width, new_height), Image.LANCZOS)

        # 转换为字节
        buffer = BytesIO()
        img_resized.save(buffer, format=image.format.value.upper())

        return ImageInput(
            data=buffer.getvalue(),
            format=image.format,
            width=new_width,
            height=new_height,
            caption=image.caption
        )

    def _compress(self, image: ImageInput) -> ImageInput:
        """压缩图像"""
        from PIL import Image

        img = Image.open(BytesIO(image.data))

        # 逐步降低质量直到满足大小要求
        quality = 95
        while quality > 50:
            buffer = BytesIO()
            img.save(buffer, format="JPEG", quality=quality)
            if len(buffer.getvalue()) <= self.max_size_mb * 1024 * 1024:
                return ImageInput(
                    data=buffer.getvalue(),
                    format=ImageFormat.JPEG,
                    width=image.width,
                    height=image.height,
                    caption=image.caption
                )
            quality -= 5

        return image

    def create_grid(
        self,
        images: List[ImageInput],
        grid_size: Tuple[int, int] = None
    ) -> ImageInput:
        """创建图像网格"""
        from PIL import Image

        if not images:
            raise ValueError("No images provided")

        # 自动计算网格大小
        if grid_size is None:
            import math
            n = len(images)
            cols = math.ceil(math.sqrt(n))
            rows = math.ceil(n / cols)
            grid_size = (rows, cols)

        rows, cols = grid_size

        # 计算单元格大小
        cell_width = max(img.width for img in images)
        cell_height = max(img.height for img in images)

        # 创建网格
        grid = Image.new('RGB', (cols * cell_width, rows * cell_height))

        for i, image in enumerate(images):
            row = i // cols
            col = i % cols

            img = Image.open(BytesIO(image.data))

            # 居中放置
            x = col * cell_width + (cell_width - img.width) // 2
            y = row * cell_height + (cell_height - img.height) // 2

            grid.paste(img, (x, y))

        # 转换为字节
        buffer = BytesIO()
        grid.save(buffer, format="JPEG", quality=90)

        return ImageInput(
            data=buffer.getvalue(),
            format=ImageFormat.JPEG,
            width=cols * cell_width,
            height=rows * cell_height
        )

# 使用示例
processor = ImageProcessor(max_width=1024, max_height=1024)

# 处理单张图像
with open("image.jpg", "rb") as f:
    image_data = f.read()

image = ImageInput(
    data=image_data,
    format=ImageFormat.JPEG,
    width=4000,
    height=3000
)

processed = processor.process(image)
print(f"处理后的尺寸: {processed.width}x{processed.height}")

图像理解

python
from typing import List, Dict, Any, Optional
from dataclasses import dataclass

@dataclass
class BoundingBox:
    """边界框"""
    x1: float
    y1: float
    x2: float
    y2: float
    label: str
    confidence: float

@dataclass
class ImageAnalysis:
    """图像分析结果"""
    description: str
    objects: List[BoundingBox]
    text_content: List[str]
    colors: List[str]
    tags: List[str]
    confidence: float

class ImageUnderstanding:
    """图像理解"""

    def __init__(self, model_client):
        self.model_client = model_client

    async def describe(
        self,
        image: ImageInput,
        detail_level: str = "medium"
    ) -> str:
        """描述图像"""
        prompts = {
            "brief": "用一句话描述这张图片。",
            "medium": "详细描述这张图片的内容,包括主要对象、场景和氛围。",
            "detailed": "非常详细地描述这张图片,包括:1) 主要对象及其位置、颜色、大小;2) 背景和环境;3) 图像的整体风格和氛围;4) 任何有趣的细节。"
        }

        prompt = prompts.get(detail_level, prompts["medium"])

        # 调用多模态模型
        response = await self.model_client.generate(
            messages=[
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": prompt},
                        {"type": "image_url", "image_url": {"url": image.to_data_url()}}
                    ]
                }
            ]
        )

        return response.content

    async def extract_text(self, image: ImageInput) -> List[str]:
        """提取图像中的文字"""
        response = await self.model_client.generate(
            messages=[
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": "提取图像中的所有文字,按阅读顺序列出。如果没有文字,返回空列表。"},
                        {"type": "image_url", "image_url": {"url": image.to_data_url()}}
                    ]
                }
            ]
        )

        # 解析文字列表
        text_content = response.content.strip().split('\n')
        return [t.strip() for t in text_content if t.strip()]

    async def detect_objects(
        self,
        image: ImageInput,
        object_types: List[str] = None
    ) -> List[BoundingBox]:
        """检测物体"""
        object_prompt = ""
        if object_types:
            object_prompt = f"特别关注以下类型的物体:{', '.join(object_types)}。"

        response = await self.model_client.generate(
            messages=[
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "text",
                            "text": f"""检测图像中的所有物体。{object_prompt}

对于每个物体,提供:
1. 物体名称
2. 边界框坐标(x1, y1, x2, y2,范围 0-1000)
3. 置信度(0-1)

输出格式:
[物体名称] [x1] [y1] [x2] [y2] [置信度]"""
                        },
                        {"type": "image_url", "image_url": {"url": image.to_data_url()}}
                    ]
                }
            ]
        )

        # 解析边界框
        boxes = []
        for line in response.content.strip().split('\n'):
            parts = line.strip().split()
            if len(parts) >= 6:
                try:
                    box = BoundingBox(
                        label=parts[0],
                        x1=float(parts[1]),
                        y1=float(parts[2]),
                        x2=float(parts[3]),
                        y2=float(parts[4]),
                        confidence=float(parts[5])
                    )
                    boxes.append(box)
                except ValueError:
                    continue

        return boxes

    async def analyze(
        self,
        image: ImageInput,
        tasks: List[str] = None
    ) -> ImageAnalysis:
        """综合分析"""
        if tasks is None:
            tasks = ["describe", "objects", "text", "colors", "tags"]

        results = {}

        if "describe" in tasks:
            results["description"] = await self.describe(image, "detailed")

        if "objects" in tasks:
            results["objects"] = await self.detect_objects(image)

        if "text" in tasks:
            results["text"] = await self.extract_text(image)

        return ImageAnalysis(
            description=results.get("description", ""),
            objects=results.get("objects", []),
            text_content=results.get("text", []),
            colors=[],
            tags=[],
            confidence=0.0
        )

# 使用示例
async def analyze_image():
    processor = ImageProcessor()
    understanding = ImageUnderstanding(model_client)

    # 加载和处理图像
    with open("photo.jpg", "rb") as f:
        image = ImageInput(
            data=f.read(),
            format=ImageFormat.JPEG,
            width=2000,
            height=1500
        )

    processed = processor.process(image)

    # 综合分析
    analysis = await understanding.analyze(processed)

    print(f"描述: {analysis.description}")
    print(f"物体: {[obj.label for obj in analysis.objects]}")
    print(f"文字: {analysis.text_content}")

音频处理

音频输入处理

python
from dataclasses import dataclass
from typing import List, Optional
from enum import Enum
import wave
import struct

class AudioFormat(Enum):
    WAV = "wav"
    MP3 = "mp3"
    FLAC = "flac"
    OGG = "ogg"

@dataclass
class AudioInput:
    """音频输入"""
    data: bytes
    format: AudioFormat
    sample_rate: int
    channels: int
    duration_seconds: float

    def to_base64(self) -> str:
        """转换为 Base64"""
        import base64
        return base64.b64encode(self.data).decode('utf-8')

class AudioProcessor:
    """音频处理器"""

    def __init__(
        self,
        max_duration_seconds: int = 300,
        target_sample_rate: int = 16000,
        target_channels: int = 1
    ):
        self.max_duration_seconds = max_duration_seconds
        self.target_sample_rate = target_sample_rate
        self.target_channels = target_channels

    def process(self, audio: AudioInput) -> AudioInput:
        """处理音频"""
        # 检查时长
        if audio.duration_seconds > self.max_duration_seconds:
            audio = self._trim(audio, 0, self.max_duration_seconds)

        # 重采样
        if audio.sample_rate != self.target_sample_rate:
            audio = self._resample(audio, self.target_sample_rate)

        # 转换声道
        if audio.channels != self.target_channels:
            audio = self._convert_channels(audio, self.target_channels)

        return audio

    def _trim(
        self,
        audio: AudioInput,
        start: float,
        end: float
    ) -> AudioInput:
        """裁剪音频"""
        # 简化实现
        return audio

    def _resample(
        self,
        audio: AudioInput,
        target_sample_rate: int
    ) -> AudioInput:
        """重采样"""
        # 简化实现
        return audio

    def _convert_channels(
        self,
        audio: AudioInput,
        target_channels: int
    ) -> AudioInput:
        """转换声道"""
        # 简化实现
        return audio

    def split_by_silence(
        self,
        audio: AudioInput,
        min_silence_duration: float = 0.5,
        silence_threshold_db: float = -40
    ) -> List[AudioInput]:
        """按静音分割"""
        # 简化实现
        return [audio]

# 使用示例
processor = AudioProcessor(max_duration_seconds=60)

# 处理音频
audio = AudioInput(
    data=b"...",
    format=AudioFormat.WAV,
    sample_rate=44100,
    channels=2,
    duration_seconds=120
)

processed = processor.process(audio)
print(f"处理后: {processed.duration_seconds}秒, {processed.sample_rate}Hz, {processed.channels}声道")

语音识别与合成

python
from dataclasses import dataclass
from typing import List, Optional
from enum import Enum

class LanguageCode(Enum):
    ZH_CN = "zh-CN"
    EN_US = "en-US"
    JA_JP = "ja-JP"
    KO_KR = "ko-KR"

@dataclass
class TranscriptionResult:
    """转录结果"""
    text: str
    segments: List[dict]
    language: str
    confidence: float

@dataclass
class SynthesisResult:
    """合成结果"""
    audio_data: bytes
    format: AudioFormat
    sample_rate: int
    duration_seconds: float

class SpeechProcessor:
    """语音处理器"""

    def __init__(self, asr_client, tts_client):
        self.asr_client = asr_client  # 语音识别客户端
        self.tts_client = tts_client  # 语音合成客户端

    async def transcribe(
        self,
        audio: AudioInput,
        language: LanguageCode = LanguageCode.ZH_CN,
        enable_diarization: bool = False
    ) -> TranscriptionResult:
        """语音转文字"""
        response = await self.asr_client.transcribe(
            audio=audio.data,
            language=language.value,
            enable_diarization=enable_diarization
        )

        return TranscriptionResult(
            text=response.text,
            segments=response.segments,
            language=response.language,
            confidence=response.confidence
        )

    async def synthesize(
        self,
        text: str,
        voice: str = "default",
        language: LanguageCode = LanguageCode.ZH_CN,
        speed: float = 1.0
    ) -> SynthesisResult:
        """文字转语音"""
        response = await self.tts_client.synthesize(
            text=text,
            voice=voice,
            language=language.value,
            speed=speed
        )

        return SynthesisResult(
            audio_data=response.audio,
            format=AudioFormat.MP3,
            sample_rate=24000,
            duration_seconds=len(response.audio) / 24000 / 2  # 估算
        )

    async def translate_speech(
        self,
        audio: AudioInput,
        source_language: LanguageCode,
        target_language: LanguageCode
    ) -> TranscriptionResult:
        """语音翻译"""
        # 先转录
        transcription = await self.transcribe(audio, source_language)

        # 翻译
        translated_text = await self._translate(
            transcription.text,
            source_language.value,
            target_language.value
        )

        return TranscriptionResult(
            text=translated_text,
            segments=transcription.segments,
            language=target_language.value,
            confidence=transcription.confidence
        )

    async def _translate(
        self,
        text: str,
        source_lang: str,
        target_lang: str
    ) -> str:
        """翻译文本"""
        # 简化实现
        return text

# 使用示例
async def process_speech():
    processor = SpeechProcessor(asr_client, tts_client)

    # 语音转文字
    transcription = await processor.transcribe(
        audio=processed_audio,
        language=LanguageCode.ZH_CN
    )
    print(f"转录: {transcription.text}")

    # 文字转语音
    synthesis = await processor.synthesize(
        text="你好,世界",
        voice="female-1",
        language=LanguageCode.ZH_CN
    )
    print(f"合成长度: {synthesis.duration_seconds}秒")

视频处理

视频分析

python
from dataclasses import dataclass
from typing import List, Optional
import subprocess
import json

@dataclass
class VideoInfo:
    """视频信息"""
    duration_seconds: float
    width: int
    height: int
    fps: float
    codec: str
    bitrate: int
    has_audio: bool
    audio_codec: Optional[str]

@dataclass
class VideoFrame:
    """视频帧"""
    timestamp: float
    image: ImageInput
    description: Optional[str] = None

@dataclass
class VideoAnalysis:
    """视频分析结果"""
    info: VideoInfo
    key_frames: List[VideoFrame]
    transcript: Optional[str]
    summary: str
    tags: List[str]

class VideoProcessor:
    """视频处理器"""

    def __init__(
        self,
        max_duration_seconds: int = 600,
        max_width: int = 1920,
        max_height: int = 1080,
        frame_sample_interval: float = 1.0
    ):
        self.max_duration_seconds = max_duration_seconds
        self.max_width = max_width
        self.max_height = max_height
        self.frame_sample_interval = frame_sample_interval

    def get_info(self, video_path: str) -> VideoInfo:
        """获取视频信息"""
        cmd = [
            "ffprobe",
            "-v", "quiet",
            "-print_format", "json",
            "-show_format",
            "-show_streams",
            video_path
        ]

        result = subprocess.run(cmd, capture_output=True, text=True)
        data = json.loads(result.stdout)

        video_stream = None
        audio_stream = None

        for stream in data.get("streams", []):
            if stream.get("codec_type") == "video" and not video_stream:
                video_stream = stream
            elif stream.get("codec_type") == "audio" and not audio_stream:
                audio_stream = stream

        duration = float(data.get("format", {}).get("duration", 0))

        return VideoInfo(
            duration_seconds=duration,
            width=int(video_stream.get("width", 0)),
            height=int(video_stream.get("height", 0)),
            fps=eval(video_stream.get("r_frame_rate", "0/1")),
            codec=video_stream.get("codec_name", ""),
            bitrate=int(data.get("format", {}).get("bit_rate", 0)),
            has_audio=audio_stream is not None,
            audio_codec=audio_stream.get("codec_name") if audio_stream else None
        )

    def extract_frames(
        self,
        video_path: str,
        interval: float = None,
        max_frames: int = 30
    ) -> List[VideoFrame]:
        """提取关键帧"""
        if interval is None:
            interval = self.frame_sample_interval

        info = self.get_info(video_path)
        frames = []

        # 计算需要提取的帧数
        num_frames = min(int(info.duration_seconds / interval), max_frames)

        for i in range(num_frames):
            timestamp = i * interval

            # 使用 ffmpeg 提取帧
            output_path = f"/tmp/frame_{i}.jpg"
            cmd = [
                "ffmpeg",
                "-ss", str(timestamp),
                "-i", video_path,
                "-vframes", "1",
                "-q:v", "2",
                output_path
            ]

            subprocess.run(cmd, capture_output=True)

            # 读取帧图像
            with open(output_path, "rb") as f:
                frame_data = f.read()

            frames.append(VideoFrame(
                timestamp=timestamp,
                image=ImageInput(
                    data=frame_data,
                    format=ImageFormat.JPEG,
                    width=info.width,
                    height=info.height
                )
            ))

        return frames

    def extract_audio(
        self,
        video_path: str,
        output_path: str
    ) -> AudioInput:
        """提取音频"""
        cmd = [
            "ffmpeg",
            "-i", video_path,
            "-vn",
            "-acodec", "pcm_s16le",
            "-ar", "16000",
            "-ac", "1",
            output_path
        ]

        subprocess.run(cmd, capture_output=True)

        with open(output_path, "rb") as f:
            audio_data = f.read()

        return AudioInput(
            data=audio_data,
            format=AudioFormat.WAV,
            sample_rate=16000,
            channels=1,
            duration_seconds=self.get_info(video_path).duration_seconds
        )

class VideoUnderstanding:
    """视频理解"""

    def __init__(self, video_processor: VideoProcessor, image_understanding: ImageUnderstanding, speech_processor: SpeechProcessor):
        self.video_processor = video_processor
        self.image_understanding = image_understanding
        self.speech_processor = speech_processor

    async def analyze(
        self,
        video_path: str,
        extract_audio: bool = True,
        analyze_frames: bool = True,
        max_frames: int = 10
    ) -> VideoAnalysis:
        """综合分析视频"""
        info = self.video_processor.get_info(video_path)

        key_frames = []
        transcript = None
        summaries = []

        # 提取和分析关键帧
        if analyze_frames:
            frames = self.video_processor.extract_frames(
                video_path,
                max_frames=max_frames
            )

            for frame in frames:
                description = await self.image_understanding.describe(
                    frame.image,
                    detail_level="brief"
                )
                frame.description = description
                key_frames.append(frame)
                summaries.append(f"[{frame.timestamp:.1f}s] {description}")

        # 提取和转录音频
        if extract_audio and info.has_audio:
            audio = self.video_processor.extract_audio(
                video_path,
                "/tmp/audio.wav"
            )
            transcription = await self.speech_processor.transcribe(audio)
            transcript = transcription.text

        # 生成总结
        if summaries:
            summary_prompt = f"""根据以下关键帧描述,总结这个视频的主要内容:

{chr(10).join(summaries)}

音频转录(如果有):
{transcript or '无'}

请提供一个简洁的总结:"""
            # 调用模型生成总结
            summary = summary_prompt  # 简化实现

        return VideoAnalysis(
            info=info,
            key_frames=key_frames,
            transcript=transcript,
            summary=summary if 'summary' in dir() else "",
            tags=[]
        )

# 使用示例
async def analyze_video():
    video_processor = VideoProcessor()
    understanding = VideoUnderstanding(
        video_processor,
        image_understanding,
        speech_processor
    )

    analysis = await understanding.analyze(
        "video.mp4",
        extract_audio=True,
        analyze_frames=True,
        max_frames=10
    )

    print(f"视频时长: {analysis.info.duration_seconds}秒")
    print(f"关键帧: {len(analysis.key_frames)}个")
    print(f"转录: {analysis.transcript[:100] if analysis.transcript else '无'}...")
    print(f"总结: {analysis.summary}")

多模态融合

跨模态检索

python
from typing import List, Dict, Any, Tuple
from dataclasses import dataclass
import numpy as np

@dataclass
class MultiModalItem:
    """多模态项目"""
    id: str
    text: Optional[str] = None
    image: Optional[ImageInput] = None
    audio: Optional[AudioInput] = None
    video_path: Optional[str] = None
    metadata: Dict[str, Any] = None

@dataclass
class SearchResult:
    """搜索结果"""
    item: MultiModalItem
    score: float
    matched_modalities: List[str]

class MultiModalRetriever:
    """多模态检索器"""

    def __init__(self, text_encoder, image_encoder, audio_encoder):
        self.text_encoder = text_encoder
        self.image_encoder = image_encoder
        self.audio_encoder = audio_encoder

        self.items: List[MultiModalItem] = []
        self.text_embeddings: Dict[str, np.ndarray] = {}
        self.image_embeddings: Dict[str, np.ndarray] = {}
        self.audio_embeddings: Dict[str, np.ndarray] = {}

    def index(self, items: List[MultiModalItem]):
        """索引项目"""
        for item in items:
            self.items.append(item)

            # 编码文本
            if item.text:
                self.text_embeddings[item.id] = self.text_encoder.encode(item.text)

            # 编码图像
            if item.image:
                self.image_embeddings[item.id] = self.image_encoder.encode(item.image.data)

            # 编码音频
            if item.audio:
                self.audio_embeddings[item.id] = self.audio_encoder.encode(item.audio.data)

    def search(
        self,
        query: MultiModalItem,
        top_k: int = 10,
        weights: Dict[str, float] = None
    ) -> List[SearchResult]:
        """搜索"""
        if weights is None:
            weights = {"text": 0.4, "image": 0.4, "audio": 0.2}

        # 编码查询
        query_text_embedding = None
        query_image_embedding = None
        query_audio_embedding = None

        if query.text:
            query_text_embedding = self.text_encoder.encode(query.text)

        if query.image:
            query_image_embedding = self.image_encoder.encode(query.image.data)

        if query.audio:
            query_audio_embedding = self.audio_encoder.encode(query.audio.data)

        # 计算相似度
        scores: Dict[str, Tuple[float, List[str]]] = {}

        for item in self.items:
            total_score = 0.0
            matched_modalities = []

            # 文本相似度
            if query_text_embedding is not None and item.id in self.text_embeddings:
                similarity = self._cosine_similarity(
                    query_text_embedding,
                    self.text_embeddings[item.id]
                )
                total_score += similarity * weights["text"]
                matched_modalities.append("text")

            # 图像相似度
            if query_image_embedding is not None and item.id in self.image_embeddings:
                similarity = self._cosine_similarity(
                    query_image_embedding,
                    self.image_embeddings[item.id]
                )
                total_score += similarity * weights["image"]
                matched_modalities.append("image")

            # 音频相似度
            if query_audio_embedding is not None and item.id in self.audio_embeddings:
                similarity = self._cosine_similarity(
                    query_audio_embedding,
                    self.audio_embeddings[item.id]
                )
                total_score += similarity * weights["audio"]
                matched_modalities.append("audio")

            scores[item.id] = (total_score, matched_modalities)

        # 排序
        sorted_scores = sorted(
            scores.items(),
            key=lambda x: x[1][0],
            reverse=True
        )

        # 构建结果
        results = []
        for item_id, (score, modalities) in sorted_scores[:top_k]:
            item = next(i for i in self.items if i.id == item_id)
            results.append(SearchResult(
                item=item,
                score=score,
                matched_modalities=modalities
            ))

        return results

    def _cosine_similarity(self, a: np.ndarray, b: np.ndarray) -> float:
        """计算余弦相似度"""
        return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))

# 使用示例
async def multi_modal_search():
    retriever = MultiModalRetriever(text_encoder, image_encoder, audio_encoder)

    # 索引项目
    items = [
        MultiModalItem(
            id="item1",
            text="一只猫在草地上",
            image=cat_image,
            metadata={"source": "user1"}
        ),
        MultiModalItem(
            id="item2",
            text="狗在海边玩耍",
            image=dog_image,
            metadata={"source": "user2"}
        ),
    ]

    retriever.index(items)

    # 搜索
    query = MultiModalItem(
        text="可爱的动物",
        image=query_image
    )

    results = retriever.search(query, top_k=5)

    for result in results:
        print(f"ID: {result.item.id}, Score: {result.score:.2f}")
        print(f"匹配模态: {result.matched_modalities}")

最佳实践总结

多模态处理检查清单

markdown
## 多模态处理检查清单

### 图像处理

- [ ] 图像尺寸限制
- [ ] 图像大小压缩
- [ ] 格式转换
- [ ] 图像质量检查
- [ ] 敏感内容过滤

### 音频处理

- [ ] 音频时长限制
- [ ] 采样率标准化
- [ ] 声道转换
- [ ] 音频质量检查
- [ ] 静音检测

### 视频处理

- [ ] 视频时长限制
- [ ] 关键帧提取
- [ ] 音频分离
- [ ] 视频压缩
- [ ] 格式兼容性

### 融合策略

- [ ] 模态权重设置
- [ ] 特征对齐
- [ ] 时序同步
- [ ] 缺失模态处理
- [ ] 结果融合

### 性能优化

- [ ] 批量处理
- [ ] 异步编码
- [ ] 缓存策略
- [ ] 模型量化
- [ ] GPU 加速