推理缓存(Prompt Cache)
推理缓存(Prompt Cache)是一项优化技术,通过缓存重复使用的提示词和模型响应,显著减少API调用延迟和成本,提高应用性能。
概述
推理缓存的优势:
- 减少重复计算
- 降低API调用成本
- 提高响应速度
- 改善用户体验
- 支持高频场景
工作原理
- 缓存键生成:基于输入内容生成唯一键
- 缓存查找:检查是否已有缓存结果
- 缓存命中:直接返回缓存结果
- 缓存未命中:调用API并缓存结果
- 缓存更新:定期清理和更新缓存
基本用法
简单缓存实现
python
import openai
import hashlib
import json
import time
from typing import Optional, Dict, Any
client = openai.OpenAI(
api_key="your_api_key",
base_url="https://realmrouter.cn/v1"
)
class SimplePromptCache:
def __init__(self, ttl: int = 3600):
self.cache: Dict[str, Dict[str, Any]] = {}
self.ttl = ttl # 缓存过期时间(秒)
def _generate_key(self, model: str, messages: list, **kwargs) -> str:
"""生成缓存键"""
cache_data = {
"model": model,
"messages": messages,
"kwargs": {k: v for k, v in kwargs.items() if k != "stream"}
}
cache_str = json.dumps(cache_data, sort_keys=True)
return hashlib.md5(cache_str.encode()).hexdigest()
def get(self, model: str, messages: list, **kwargs) -> Optional[str]:
"""获取缓存结果"""
key = self._generate_key(model, messages, **kwargs)
if key in self.cache:
cached_data = self.cache[key]
if time.time() - cached_data["timestamp"] < self.ttl:
print(f"缓存命中: {key[:8]}...")
return cached_data["response"]
else:
# 缓存过期,删除
del self.cache[key]
return None
def set(self, model: str, messages: list, response: str, **kwargs):
"""设置缓存"""
key = self._generate_key(model, messages, **kwargs)
self.cache[key] = {
"response": response,
"timestamp": time.time()
}
print(f"缓存设置: {key[:8]}...")
# 使用缓存
cache = SimplePromptCache(ttl=1800) # 30分钟过期
def cached_chat_completion(model: str, messages: list, **kwargs):
"""带缓存的聊天完成"""
# 尝试从缓存获取
cached_response = cache.get(model, messages, **kwargs)
if cached_response:
return cached_response
# 缓存未命中,调用API
response = client.chat.completions.create(
model=model,
messages=messages,
**kwargs
)
result = response.choices[0].message.content
# 缓存结果
cache.set(model, messages, result, **kwargs)
return result
# 使用示例
messages = [
{"role": "user", "content": "请解释什么是机器学习"}
]
# 第一次调用(会调用API)
response1 = cached_chat_completion("gpt-3.5-turbo", messages)
print("第一次响应:", response1[:50] + "...")
# 第二次调用(从缓存获取)
response2 = cached_chat_completion("gpt-3.5-turbo", messages)
print("第二次响应:", response2[:50] + "...")高级缓存实现
python
import redis
import pickle
from dataclasses import dataclass
from typing import Optional, List
@dataclass
class CacheConfig:
redis_host: str = "localhost"
redis_port: int = 6379
redis_db: int = 0
ttl: int = 3600
max_size: int = 10000
class AdvancedPromptCache:
def __init__(self, config: CacheConfig):
self.config = config
self.redis_client = redis.Redis(
host=config.redis_host,
port=config.redis_port,
db=config.redis_db,
decode_responses=False
)
def _generate_key(self, model: str, messages: list, **kwargs) -> str:
"""生成更精确的缓存键"""
cache_data = {
"model": model,
"messages": self._normalize_messages(messages),
"temperature": kwargs.get("temperature", 1.0),
"max_tokens": kwargs.get("max_tokens"),
"top_p": kwargs.get("top_p", 1.0)
}
cache_str = json.dumps(cache_data, sort_keys=True, ensure_ascii=False)
return f"prompt_cache:{hashlib.sha256(cache_str.encode()).hexdigest()}"
def _normalize_messages(self, messages: list) -> list:
"""标准化消息格式"""
normalized = []
for msg in messages:
normalized_msg = {
"role": msg["role"],
"content": msg["content"].strip()
}
normalized.append(normalized_msg)
return normalized
def get(self, model: str, messages: list, **kwargs) -> Optional[str]:
"""获取缓存结果"""
key = self._generate_key(model, messages, **kwargs)
try:
cached_data = self.redis_client.get(key)
if cached_data:
data = pickle.loads(cached_data)
print(f"Redis缓存命中: {key[:16]}...")
return data["response"]
except Exception as e:
print(f"Redis获取失败: {e}")
return None
def set(self, model: str, messages: list, response: str, **kwargs):
"""设置缓存"""
key = self._generate_key(model, messages, **kwargs)
try:
data = {
"response": response,
"timestamp": time.time(),
"model": model
}
# 检查缓存大小限制
if self._get_cache_size() >= self.config.max_size:
self._evict_oldest()
self.redis_client.setex(
key,
self.config.ttl,
pickle.dumps(data)
)
print(f"Redis缓存设置: {key[:16]}...")
except Exception as e:
print(f"Redis设置失败: {e}")
def _get_cache_size(self) -> int:
"""获取当前缓存大小"""
try:
return len(self.redis_client.keys("prompt_cache:*"))
except:
return 0
def _evict_oldest(self):
"""删除最旧的缓存项"""
try:
keys = self.redis_client.keys("prompt_cache:*")
if keys:
# 获取最旧的键并删除
oldest_key = min(keys, key=lambda k: self.redis_client.ttl(k))
self.redis_client.delete(oldest_key)
except Exception as e:
print(f"缓存清理失败: {e}")
def clear(self):
"""清空所有缓存"""
try:
keys = self.redis_client.keys("prompt_cache:*")
if keys:
self.redis_client.delete(*keys)
print("缓存已清空")
except Exception as e:
print(f"缓存清空失败: {e}")
# 使用高级缓存
config = CacheConfig(ttl=7200, max_size=5000)
advanced_cache = AdvancedPromptCache(config)缓存策略
1. 基于内容的缓存
python
class ContentBasedCache:
"""基于内容相似度的缓存"""
def __init__(self, similarity_threshold: float = 0.9):
self.cache = {}
self.similarity_threshold = similarity_threshold
def _calculate_similarity(self, text1: str, text2: str) -> float:
"""计算文本相似度(简化版)"""
from difflib import SequenceMatcher
return SequenceMatcher(None, text1, text2).ratio()
def find_similar(self, content: str) -> Optional[str]:
"""查找相似内容的缓存"""
for cached_key, cached_data in self.cache.items():
similarity = self._calculate_similarity(
content,
cached_data["original_content"]
)
if similarity >= self.similarity_threshold:
print(f"找到相似缓存 (相似度: {similarity:.2f})")
return cached_data["response"]
return None
def set(self, content: str, response: str):
"""设置缓存"""
cache_key = hashlib.md5(content.encode()).hexdigest()
self.cache[cache_key] = {
"original_content": content,
"response": response,
"timestamp": time.time()
}2. 分层缓存
python
class TieredCache:
"""分层缓存:内存 + Redis"""
def __init__(self, memory_size: int = 1000, redis_ttl: int = 3600):
self.memory_cache = {} # 内存缓存
self.memory_size = memory_size
self.redis_client = redis.Redis(decode_responses=False)
self.redis_ttl = redis_ttl
def get(self, key: str) -> Optional[str]:
# 先查内存缓存
if key in self.memory_cache:
print("内存缓存命中")
return self.memory_cache[key]["response"]
# 再查Redis缓存
try:
cached_data = self.redis_client.get(key)
if cached_data:
data = pickle.loads(cached_data)
# 提升到内存缓存
self._promote_to_memory(key, data)
print("Redis缓存命中")
return data["response"]
except:
pass
return None
def set(self, key: str, response: str):
# 设置到内存缓存
self._set_memory(key, response)
# 设置到Redis缓存
try:
data = {"response": response, "timestamp": time.time()}
self.redis_client.setex(key, self.redis_ttl, pickle.dumps(data))
except:
pass
def _promote_to_memory(self, key: str, data: dict):
"""提升到内存缓存"""
if len(self.memory_cache) >= self.memory_size:
# LRU淘汰
oldest_key = min(
self.memory_cache.keys(),
key=lambda k: self.memory_cache[k]["timestamp"]
)
del self.memory_cache[oldest_key]
self.memory_cache[key] = {
"response": data["response"],
"timestamp": time.time()
}
def _set_memory(self, key: str, response: str):
"""设置内存缓存"""
if len(self.memory_cache) >= self.memory_size:
oldest_key = min(
self.memory_cache.keys(),
key=lambda k: self.memory_cache[k]["timestamp"]
)
del self.memory_cache[oldest_key]
self.memory_cache[key] = {
"response": response,
"timestamp": time.time()
}实际应用场景
1. 聊天机器人
python
class ChatBotCache:
"""聊天机器人专用缓存"""
def __init__(self):
self.cache = AdvancedPromptCache(CacheConfig())
self.conversation_cache = {} # 对话上下文缓存
def get_response(self, user_id: str, message: str, conversation_history: list) -> str:
"""获取机器人响应"""
# 构建完整消息
messages = conversation_history + [
{"role": "user", "content": message}
]
# 尝试从缓存获取
cached_response = self.cache.get("gpt-3.5-turbo", messages)
if cached_response:
return cached_response
# 调用API
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=messages,
temperature=0.7
)
result = response.choices[0].message.content
# 缓存结果
self.cache.set("gpt-3.5-turbo", messages, result)
# 更新对话历史
if user_id not in self.conversation_cache:
self.conversation_cache[user_id] = []
self.conversation_cache[user_id].extend([
{"role": "user", "content": message},
{"role": "assistant", "content": result}
])
# 保持对话历史在合理长度
if len(self.conversation_cache[user_id]) > 10:
self.conversation_cache[user_id] = self.conversation_cache[user_id][-10:]
return result
# 使用示例
chatbot = ChatBotCache()
response = chatbot.get_response("user123", "你好,今天天气怎么样?", [])2. 文档问答系统
python
class DocumentQACache:
"""文档问答系统缓存"""
def __init__(self):
self.cache = AdvancedPromptCache(CacheConfig())
self.document_cache = {} # 文档内容缓存
def answer_question(self, document_id: str, question: str) -> str:
"""回答文档相关问题"""
# 获取文档内容(带缓存)
document_content = self._get_document(document_id)
# 构建提示词
messages = [
{
"role": "system",
"content": "你是一个文档问答助手,请基于以下文档内容回答问题。"
},
{
"role": "user",
"content": f"""
文档内容:
{document_content}
问题:{question}
请基于文档内容回答问题。
"""
}
]
# 尝试从缓存获取
cached_response = self.cache.get("gpt-4", messages)
if cached_response:
return cached_response
# 调用API
response = client.chat.completions.create(
model="gpt-4",
messages=messages,
temperature=0.1
)
result = response.choices[0].message.content
# 缓存结果
self.cache.set("gpt-4", messages, result)
return result
def _get_document(self, document_id: str) -> str:
"""获取文档内容(带缓存)"""
if document_id in self.document_cache:
return self.document_cache[document_id]
# 这里应该是实际的文档获取逻辑
document_content = f"这是文档 {document_id} 的内容..."
# 缓存文档内容
self.document_cache[document_id] = document_content
return document_content性能监控
缓存命中率统计
python
class CacheMetrics:
"""缓存性能指标"""
def __init__(self):
self.total_requests = 0
self.cache_hits = 0
self.cache_misses = 0
self.api_calls = 0
def record_request(self, is_hit: bool):
"""记录请求"""
self.total_requests += 1
if is_hit:
self.cache_hits += 1
else:
self.cache_misses += 1
self.api_calls += 1
def get_hit_rate(self) -> float:
"""获取命中率"""
if self.total_requests == 0:
return 0.0
return self.cache_hits / self.total_requests
def get_stats(self) -> dict:
"""获取统计信息"""
return {
"total_requests": self.total_requests,
"cache_hits": self.cache_hits,
"cache_misses": self.cache_misses,
"api_calls": self.api_calls,
"hit_rate": self.get_hit_rate(),
"cost_savings": f"{(self.cache_hits / self.total_requests * 100):.1f}%" if self.total_requests > 0 else "0%"
}
class MonitoredCache:
"""带监控的缓存"""
def __init__(self):
self.cache = AdvancedPromptCache(CacheConfig())
self.metrics = CacheMetrics()
def get(self, model: str, messages: list, **kwargs) -> Optional[str]:
result = self.cache.get(model, messages, **kwargs)
self.metrics.record_request(result is not None)
return result
def set(self, model: str, messages: list, response: str, **kwargs):
self.cache.set(model, messages, response, **kwargs)
def print_stats(self):
"""打印统计信息"""
stats = self.metrics.get_stats()
print("=== 缓存性能统计 ===")
print(f"总请求数: {stats['total_requests']}")
print(f"缓存命中: {stats['cache_hits']}")
print(f"缓存未命中: {stats['cache_misses']}")
print(f"API调用: {stats['api_calls']}")
print(f"命中率: {stats['hit_rate']:.2%}")
print(f"成本节省: {stats['cost_savings']}")
# 使用示例
monitored_cache = MonitoredCache()
# 模拟多次调用
for i in range(10):
messages = [{"role": "user", "content": f"测试消息 {i % 3}"}] # 重复消息
response = monitored_cache.get("gpt-3.5-turbo", messages)
if not response:
# 模拟API调用
response = f"API响应 {i}"
monitored_cache.set("gpt-3.5-turbo", messages, response)
monitored_cache.print_stats()最佳实践
1. 缓存键设计
- 包含所有影响输出的参数
- 使用标准化格式
- 避免包含时间戳等变化信息
2. 缓存策略
- 根据业务需求设置合适的TTL
- 实现分层缓存提高性能
- 定期清理过期缓存
3. 错误处理
- 缓存失败时降级到API调用
- 记录缓存异常日志
- 实现缓存预热机制
4. 性能优化
- 使用内存缓存提高速度
- 实现LRU淘汰策略
- 监控缓存命中率
限制和注意事项
- 内存使用:缓存会占用内存资源
- 数据一致性:缓存可能返回过时数据
- 存储成本:大规模缓存需要存储成本
- 复杂性:增加了系统复杂性
- 网络依赖:Redis缓存依赖网络连接