Skip to content

推理缓存(Prompt Cache)

推理缓存(Prompt Cache)是一项优化技术,通过缓存重复使用的提示词和模型响应,显著减少API调用延迟和成本,提高应用性能。

概述

推理缓存的优势:

  • 减少重复计算
  • 降低API调用成本
  • 提高响应速度
  • 改善用户体验
  • 支持高频场景

工作原理

  1. 缓存键生成:基于输入内容生成唯一键
  2. 缓存查找:检查是否已有缓存结果
  3. 缓存命中:直接返回缓存结果
  4. 缓存未命中:调用API并缓存结果
  5. 缓存更新:定期清理和更新缓存

基本用法

简单缓存实现

python
import openai
import hashlib
import json
import time
from typing import Optional, Dict, Any

client = openai.OpenAI(
    api_key="your_api_key",
    base_url="https://realmrouter.cn/v1"
)

class SimplePromptCache:
    def __init__(self, ttl: int = 3600):
        self.cache: Dict[str, Dict[str, Any]] = {}
        self.ttl = ttl  # 缓存过期时间(秒)
    
    def _generate_key(self, model: str, messages: list, **kwargs) -> str:
        """生成缓存键"""
        cache_data = {
            "model": model,
            "messages": messages,
            "kwargs": {k: v for k, v in kwargs.items() if k != "stream"}
        }
        cache_str = json.dumps(cache_data, sort_keys=True)
        return hashlib.md5(cache_str.encode()).hexdigest()
    
    def get(self, model: str, messages: list, **kwargs) -> Optional[str]:
        """获取缓存结果"""
        key = self._generate_key(model, messages, **kwargs)
        
        if key in self.cache:
            cached_data = self.cache[key]
            if time.time() - cached_data["timestamp"] < self.ttl:
                print(f"缓存命中: {key[:8]}...")
                return cached_data["response"]
            else:
                # 缓存过期,删除
                del self.cache[key]
        
        return None
    
    def set(self, model: str, messages: list, response: str, **kwargs):
        """设置缓存"""
        key = self._generate_key(model, messages, **kwargs)
        self.cache[key] = {
            "response": response,
            "timestamp": time.time()
        }
        print(f"缓存设置: {key[:8]}...")

# 使用缓存
cache = SimplePromptCache(ttl=1800)  # 30分钟过期

def cached_chat_completion(model: str, messages: list, **kwargs):
    """带缓存的聊天完成"""
    # 尝试从缓存获取
    cached_response = cache.get(model, messages, **kwargs)
    if cached_response:
        return cached_response
    
    # 缓存未命中,调用API
    response = client.chat.completions.create(
        model=model,
        messages=messages,
        **kwargs
    )
    
    result = response.choices[0].message.content
    
    # 缓存结果
    cache.set(model, messages, result, **kwargs)
    
    return result

# 使用示例
messages = [
    {"role": "user", "content": "请解释什么是机器学习"}
]

# 第一次调用(会调用API)
response1 = cached_chat_completion("gpt-3.5-turbo", messages)
print("第一次响应:", response1[:50] + "...")

# 第二次调用(从缓存获取)
response2 = cached_chat_completion("gpt-3.5-turbo", messages)
print("第二次响应:", response2[:50] + "...")

高级缓存实现

python
import redis
import pickle
from dataclasses import dataclass
from typing import Optional, List

@dataclass
class CacheConfig:
    redis_host: str = "localhost"
    redis_port: int = 6379
    redis_db: int = 0
    ttl: int = 3600
    max_size: int = 10000

class AdvancedPromptCache:
    def __init__(self, config: CacheConfig):
        self.config = config
        self.redis_client = redis.Redis(
            host=config.redis_host,
            port=config.redis_port,
            db=config.redis_db,
            decode_responses=False
        )
        
    def _generate_key(self, model: str, messages: list, **kwargs) -> str:
        """生成更精确的缓存键"""
        cache_data = {
            "model": model,
            "messages": self._normalize_messages(messages),
            "temperature": kwargs.get("temperature", 1.0),
            "max_tokens": kwargs.get("max_tokens"),
            "top_p": kwargs.get("top_p", 1.0)
        }
        cache_str = json.dumps(cache_data, sort_keys=True, ensure_ascii=False)
        return f"prompt_cache:{hashlib.sha256(cache_str.encode()).hexdigest()}"
    
    def _normalize_messages(self, messages: list) -> list:
        """标准化消息格式"""
        normalized = []
        for msg in messages:
            normalized_msg = {
                "role": msg["role"],
                "content": msg["content"].strip()
            }
            normalized.append(normalized_msg)
        return normalized
    
    def get(self, model: str, messages: list, **kwargs) -> Optional[str]:
        """获取缓存结果"""
        key = self._generate_key(model, messages, **kwargs)
        
        try:
            cached_data = self.redis_client.get(key)
            if cached_data:
                data = pickle.loads(cached_data)
                print(f"Redis缓存命中: {key[:16]}...")
                return data["response"]
        except Exception as e:
            print(f"Redis获取失败: {e}")
        
        return None
    
    def set(self, model: str, messages: list, response: str, **kwargs):
        """设置缓存"""
        key = self._generate_key(model, messages, **kwargs)
        
        try:
            data = {
                "response": response,
                "timestamp": time.time(),
                "model": model
            }
            
            # 检查缓存大小限制
            if self._get_cache_size() >= self.config.max_size:
                self._evict_oldest()
            
            self.redis_client.setex(
                key,
                self.config.ttl,
                pickle.dumps(data)
            )
            print(f"Redis缓存设置: {key[:16]}...")
            
        except Exception as e:
            print(f"Redis设置失败: {e}")
    
    def _get_cache_size(self) -> int:
        """获取当前缓存大小"""
        try:
            return len(self.redis_client.keys("prompt_cache:*"))
        except:
            return 0
    
    def _evict_oldest(self):
        """删除最旧的缓存项"""
        try:
            keys = self.redis_client.keys("prompt_cache:*")
            if keys:
                # 获取最旧的键并删除
                oldest_key = min(keys, key=lambda k: self.redis_client.ttl(k))
                self.redis_client.delete(oldest_key)
        except Exception as e:
            print(f"缓存清理失败: {e}")
    
    def clear(self):
        """清空所有缓存"""
        try:
            keys = self.redis_client.keys("prompt_cache:*")
            if keys:
                self.redis_client.delete(*keys)
            print("缓存已清空")
        except Exception as e:
            print(f"缓存清空失败: {e}")

# 使用高级缓存
config = CacheConfig(ttl=7200, max_size=5000)
advanced_cache = AdvancedPromptCache(config)

缓存策略

1. 基于内容的缓存

python
class ContentBasedCache:
    """基于内容相似度的缓存"""
    
    def __init__(self, similarity_threshold: float = 0.9):
        self.cache = {}
        self.similarity_threshold = similarity_threshold
    
    def _calculate_similarity(self, text1: str, text2: str) -> float:
        """计算文本相似度(简化版)"""
        from difflib import SequenceMatcher
        return SequenceMatcher(None, text1, text2).ratio()
    
    def find_similar(self, content: str) -> Optional[str]:
        """查找相似内容的缓存"""
        for cached_key, cached_data in self.cache.items():
            similarity = self._calculate_similarity(
                content, 
                cached_data["original_content"]
            )
            if similarity >= self.similarity_threshold:
                print(f"找到相似缓存 (相似度: {similarity:.2f})")
                return cached_data["response"]
        
        return None
    
    def set(self, content: str, response: str):
        """设置缓存"""
        cache_key = hashlib.md5(content.encode()).hexdigest()
        self.cache[cache_key] = {
            "original_content": content,
            "response": response,
            "timestamp": time.time()
        }

2. 分层缓存

python
class TieredCache:
    """分层缓存:内存 + Redis"""
    
    def __init__(self, memory_size: int = 1000, redis_ttl: int = 3600):
        self.memory_cache = {}  # 内存缓存
        self.memory_size = memory_size
        self.redis_client = redis.Redis(decode_responses=False)
        self.redis_ttl = redis_ttl
    
    def get(self, key: str) -> Optional[str]:
        # 先查内存缓存
        if key in self.memory_cache:
            print("内存缓存命中")
            return self.memory_cache[key]["response"]
        
        # 再查Redis缓存
        try:
            cached_data = self.redis_client.get(key)
            if cached_data:
                data = pickle.loads(cached_data)
                # 提升到内存缓存
                self._promote_to_memory(key, data)
                print("Redis缓存命中")
                return data["response"]
        except:
            pass
        
        return None
    
    def set(self, key: str, response: str):
        # 设置到内存缓存
        self._set_memory(key, response)
        
        # 设置到Redis缓存
        try:
            data = {"response": response, "timestamp": time.time()}
            self.redis_client.setex(key, self.redis_ttl, pickle.dumps(data))
        except:
            pass
    
    def _promote_to_memory(self, key: str, data: dict):
        """提升到内存缓存"""
        if len(self.memory_cache) >= self.memory_size:
            # LRU淘汰
            oldest_key = min(
                self.memory_cache.keys(),
                key=lambda k: self.memory_cache[k]["timestamp"]
            )
            del self.memory_cache[oldest_key]
        
        self.memory_cache[key] = {
            "response": data["response"],
            "timestamp": time.time()
        }
    
    def _set_memory(self, key: str, response: str):
        """设置内存缓存"""
        if len(self.memory_cache) >= self.memory_size:
            oldest_key = min(
                self.memory_cache.keys(),
                key=lambda k: self.memory_cache[k]["timestamp"]
            )
            del self.memory_cache[oldest_key]
        
        self.memory_cache[key] = {
            "response": response,
            "timestamp": time.time()
        }

实际应用场景

1. 聊天机器人

python
class ChatBotCache:
    """聊天机器人专用缓存"""
    
    def __init__(self):
        self.cache = AdvancedPromptCache(CacheConfig())
        self.conversation_cache = {}  # 对话上下文缓存
    
    def get_response(self, user_id: str, message: str, conversation_history: list) -> str:
        """获取机器人响应"""
        
        # 构建完整消息
        messages = conversation_history + [
            {"role": "user", "content": message}
        ]
        
        # 尝试从缓存获取
        cached_response = self.cache.get("gpt-3.5-turbo", messages)
        if cached_response:
            return cached_response
        
        # 调用API
        response = client.chat.completions.create(
            model="gpt-3.5-turbo",
            messages=messages,
            temperature=0.7
        )
        
        result = response.choices[0].message.content
        
        # 缓存结果
        self.cache.set("gpt-3.5-turbo", messages, result)
        
        # 更新对话历史
        if user_id not in self.conversation_cache:
            self.conversation_cache[user_id] = []
        
        self.conversation_cache[user_id].extend([
            {"role": "user", "content": message},
            {"role": "assistant", "content": result}
        ])
        
        # 保持对话历史在合理长度
        if len(self.conversation_cache[user_id]) > 10:
            self.conversation_cache[user_id] = self.conversation_cache[user_id][-10:]
        
        return result

# 使用示例
chatbot = ChatBotCache()
response = chatbot.get_response("user123", "你好,今天天气怎么样?", [])

2. 文档问答系统

python
class DocumentQACache:
    """文档问答系统缓存"""
    
    def __init__(self):
        self.cache = AdvancedPromptCache(CacheConfig())
        self.document_cache = {}  # 文档内容缓存
    
    def answer_question(self, document_id: str, question: str) -> str:
        """回答文档相关问题"""
        
        # 获取文档内容(带缓存)
        document_content = self._get_document(document_id)
        
        # 构建提示词
        messages = [
            {
                "role": "system",
                "content": "你是一个文档问答助手,请基于以下文档内容回答问题。"
            },
            {
                "role": "user",
                "content": f"""
                文档内容:
                {document_content}
                
                问题:{question}
                
                请基于文档内容回答问题。
                """
            }
        ]
        
        # 尝试从缓存获取
        cached_response = self.cache.get("gpt-4", messages)
        if cached_response:
            return cached_response
        
        # 调用API
        response = client.chat.completions.create(
            model="gpt-4",
            messages=messages,
            temperature=0.1
        )
        
        result = response.choices[0].message.content
        
        # 缓存结果
        self.cache.set("gpt-4", messages, result)
        
        return result
    
    def _get_document(self, document_id: str) -> str:
        """获取文档内容(带缓存)"""
        if document_id in self.document_cache:
            return self.document_cache[document_id]
        
        # 这里应该是实际的文档获取逻辑
        document_content = f"这是文档 {document_id} 的内容..."
        
        # 缓存文档内容
        self.document_cache[document_id] = document_content
        
        return document_content

性能监控

缓存命中率统计

python
class CacheMetrics:
    """缓存性能指标"""
    
    def __init__(self):
        self.total_requests = 0
        self.cache_hits = 0
        self.cache_misses = 0
        self.api_calls = 0
    
    def record_request(self, is_hit: bool):
        """记录请求"""
        self.total_requests += 1
        if is_hit:
            self.cache_hits += 1
        else:
            self.cache_misses += 1
            self.api_calls += 1
    
    def get_hit_rate(self) -> float:
        """获取命中率"""
        if self.total_requests == 0:
            return 0.0
        return self.cache_hits / self.total_requests
    
    def get_stats(self) -> dict:
        """获取统计信息"""
        return {
            "total_requests": self.total_requests,
            "cache_hits": self.cache_hits,
            "cache_misses": self.cache_misses,
            "api_calls": self.api_calls,
            "hit_rate": self.get_hit_rate(),
            "cost_savings": f"{(self.cache_hits / self.total_requests * 100):.1f}%" if self.total_requests > 0 else "0%"
        }

class MonitoredCache:
    """带监控的缓存"""
    
    def __init__(self):
        self.cache = AdvancedPromptCache(CacheConfig())
        self.metrics = CacheMetrics()
    
    def get(self, model: str, messages: list, **kwargs) -> Optional[str]:
        result = self.cache.get(model, messages, **kwargs)
        self.metrics.record_request(result is not None)
        return result
    
    def set(self, model: str, messages: list, response: str, **kwargs):
        self.cache.set(model, messages, response, **kwargs)
    
    def print_stats(self):
        """打印统计信息"""
        stats = self.metrics.get_stats()
        print("=== 缓存性能统计 ===")
        print(f"总请求数: {stats['total_requests']}")
        print(f"缓存命中: {stats['cache_hits']}")
        print(f"缓存未命中: {stats['cache_misses']}")
        print(f"API调用: {stats['api_calls']}")
        print(f"命中率: {stats['hit_rate']:.2%}")
        print(f"成本节省: {stats['cost_savings']}")

# 使用示例
monitored_cache = MonitoredCache()

# 模拟多次调用
for i in range(10):
    messages = [{"role": "user", "content": f"测试消息 {i % 3}"}]  # 重复消息
    response = monitored_cache.get("gpt-3.5-turbo", messages)
    if not response:
        # 模拟API调用
        response = f"API响应 {i}"
        monitored_cache.set("gpt-3.5-turbo", messages, response)

monitored_cache.print_stats()

最佳实践

1. 缓存键设计

  • 包含所有影响输出的参数
  • 使用标准化格式
  • 避免包含时间戳等变化信息

2. 缓存策略

  • 根据业务需求设置合适的TTL
  • 实现分层缓存提高性能
  • 定期清理过期缓存

3. 错误处理

  • 缓存失败时降级到API调用
  • 记录缓存异常日志
  • 实现缓存预热机制

4. 性能优化

  • 使用内存缓存提高速度
  • 实现LRU淘汰策略
  • 监控缓存命中率

限制和注意事项

  1. 内存使用:缓存会占用内存资源
  2. 数据一致性:缓存可能返回过时数据
  3. 存储成本:大规模缓存需要存储成本
  4. 复杂性:增加了系统复杂性
  5. 网络依赖:Redis缓存依赖网络连接

基于 MIT 许可发布 厦门界云聚算网络科技有限公司