如何设计推荐系统
一、问题描述
1.1 业务背景
推荐系统是互联网产品的核心引擎,广泛应用于:
- 内容平台:抖音、快手、今日头条、B站
- 电商平台:淘宝、京东、拼多多、亚马逊
- 音乐视频:网易云音乐、Spotify、Netflix、YouTube
- 社交平台:微博、小红书、Instagram
1.2 核心目标
业务目标:
- 提高用户时长:增加用户停留时间
- 提升转化率:提高点击率(CTR)、购买率(CVR)
- 增强用户粘性:降低流失率
- 发现长尾内容:让优质内容被发现
技术目标:
- 准确性:推荐用户感兴趣的内容
- 多样性:避免推荐单一类型
- 新鲜度:及时推荐最新内容
- 实时性:响应用户兴趣变化
1.3 技术挑战
冷启动问题:
- 新用户无历史行为
- 新物品无交互数据
- 新系统无足够训练数据
数据稀疏性:
- 海量物品,用户只交互极少部分
- 用户-物品交互矩阵极度稀疏
实时性要求:
- 毫秒级响应
- 实时捕捉用户兴趣变化
- 新内容快速进入推荐池
规模挑战:
- 亿级用户
- 千万级物品
- 百亿级交互数据
1.4 面试考察点
- 推荐算法:协同过滤、矩阵分解、深度学习
- 系统架构:召回、排序、重排三层架构
- 工程实现:特征工程、模型训练、线上服务
- 冷启动解决:如何处理新用户、新物品
- 效果评估:离线指标、在线ABTest
二、需求分析
2.1 功能性需求
| 需求 | 描述 | 优先级 |
|---|---|---|
| FR1 | 首页推荐流 | P0 |
| FR2 | 相关推荐(看了又看) | P0 |
| FR3 | 个性化推荐 | P0 |
| FR4 | 实时推荐 | P1 |
| FR5 | 多样性控制 | P1 |
| FR6 | 新内容推荐 | P1 |
| FR7 | 推荐解释 | P2 |
2.2 非功能性需求
性能需求:
- 推荐延迟:<100ms(P99)
- 推荐QPS:10万+
- 候选物品:百万级
- 最终推荐:10-100个
准确性需求:
- 点击率(CTR):>3%
- 转化率(CVR):>0.5%
- 用户停留时长:>30分钟/天
多样性需求:
- 同类物品占比:<40%
- 新物品占比:>10%
- 来源多样性:多个召回策略
2.3 业务约束
- 合规要求:不推荐违规、敏感内容
- 商业目标:平衡用户体验和商业变现
- 内容质量:优质内容优先推荐
- 新人扶持:给予新创作者曝光机会
三、技术选型
3.1 推荐算法对比
| 算法 | 原理 | 优点 | 缺点 | 适用场景 |
|---|---|---|---|---|
| UserCF | 相似用户推荐 | 发现新兴趣 | 计算量大 | 用户少、物品多 |
| ItemCF | 相似物品推荐 | 实时性好 | 泛化能力弱 | 物品少、用户多 |
| 矩阵分解 | 隐向量表示 | 准确率高 | 冷启动差 | 协同过滤基础 |
| FM/FFM | 特征组合 | 特征交叉 | 模型简单 | CTR预估 |
| Wide&Deep | 记忆+泛化 | 兼顾两者 | 训练复杂 | 工业界主流 |
| DIN | 注意力机制 | 捕捉兴趣变化 | 计算量大 | 序列推荐 |
| 双塔模型 | 用户/物品编码 | 高效召回 | 交叉不足 | 召回层(推荐) |
3.2 架构选型
三层架构(推荐):
召回层(Recall) 排序层(Ranking) 重排层(Re-Ranking)
├─ 协同过滤 ├─ 特征工程 ├─ 多样性
├─ 内容推荐 ├─ 模型预测 ├─ 新鲜度
├─ 热门推荐 └─ 打分排序 └─ 业务规则
├─ 深度学习召回
└─ ...
↓ 数千个候选 ↓ Top 100 ↓ Top 10-203.3 技术栈
| 组件 | 技术选型 | 作用 |
|---|---|---|
| 离线训练 | Spark、TensorFlow | 模型训练 |
| 实时计算 | Flink、Storm | 实时特征 |
| 特征存储 | Redis、HBase | 特征缓存 |
| 向量检索 | FAISS、Milvus | 向量召回 |
| 模型服务 | TensorFlow Serving | 在线预测 |
| 实验平台 | ABTest系统 | 效果评估 |
| 数据仓库 | Hive、ClickHouse | 数据存储 |
四、架构设计
4.1 系统架构图
mermaid
graph TB
subgraph 用户端
A[移动端/Web]
end
subgraph 在线服务
B[推荐API]
B --> C[召回服务]
B --> D[排序服务]
B --> E[重排服务]
end
subgraph 召回层
C --> F1[协同过滤召回]
C --> F2[内容召回]
C --> F3[热门召回]
C --> F4[深度召回]
end
subgraph 特征服务
G[实时特征]
H[离线特征]
I[用户画像]
end
subgraph 排序层
D --> G
D --> H
D --> I
D --> J[排序模型]
end
subgraph 重排层
E --> K[多样性]
E --> L[新鲜度]
E --> M[业务规则]
end
subgraph 离线计算
N[Spark训练]
O[Flink实时]
end
subgraph 存储层
P[Redis特征]
Q[HBase历史]
R[FAISS向量]
end
A --> B
F1 --> P
F2 --> P
F4 --> R
N --> P
N --> R
O --> P4.2 推荐流程
mermaid
sequenceDiagram
participant U as 用户
participant API as 推荐API
participant Recall as 召回服务
participant Rank as 排序服务
participant ReRank as 重排服务
participant Redis as Redis
participant Model as 模型服务
U->>API: 请求推荐(user_id)
API->>Recall: 召回请求
par 多路召回
Recall->>Redis: 协同过滤召回
Recall->>Redis: 内容召回
Recall->>Redis: 热门召回
end
Recall-->>API: 候选集(1000个)
API->>Rank: 排序请求
Rank->>Redis: 获取实时特征
Rank->>Model: 模型预测
Model-->>Rank: 预测分数
Rank-->>API: 排序结果(Top 100)
API->>ReRank: 重排请求
ReRank->>ReRank: 多样性、新鲜度
ReRank-->>API: 最终结果(Top 10)
API-->>U: 返回推荐列表4.3 数据模型
用户画像(Redis)
json
{
"user_id": 123456,
"demographics": {
"age": 25,
"gender": "M",
"city": "Beijing"
},
"interests": [
{"category": "科技", "score": 0.8},
{"category": "游戏", "score": 0.6}
],
"behavior": {
"click_categories": ["科技", "游戏", "数码"],
"recent_items": [1001, 1002, 1003]
}
}物品特征(Redis)
json
{
"item_id": 1001,
"title": "iPhone 15 Pro评测",
"category": "数码",
"tags": ["手机", "评测", "苹果"],
"quality_score": 0.9,
"embedding": [0.1, 0.2, ...], // 向量
"stats": {
"ctr": 0.05,
"cvr": 0.01,
"views": 10000
}
}交互记录(HBase)
RowKey: user_id_timestamp
Columns:
- item_id: 1001
- action: click/like/share/purchase
- timestamp: 1704067200
- context: {device, location, ...}五、核心实现
5.1 召回层实现
协同过滤召回(Python)
python
import numpy as np
from scipy.sparse import csr_matrix
from sklearn.metrics.pairwise import cosine_similarity
class ItemCFRecaller:
"""基于物品的协同过滤召回"""
def __init__(self, topk=50):
self.topk = topk
self.item_sim_matrix = None
def fit(self, user_item_matrix):
"""
训练物品相似度矩阵
:param user_item_matrix: 用户-物品交互矩阵 (n_users, n_items)
"""
# 计算物品相似度(基于共同用户)
# 转置:(n_items, n_users)
item_matrix = user_item_matrix.T
# 计算余弦相似度
self.item_sim_matrix = cosine_similarity(item_matrix)
# 对角线置0(自己和自己不计算)
np.fill_diagonal(self.item_sim_matrix, 0)
def recall(self, user_history, n_recall=100):
"""
召回候选物品
:param user_history: 用户历史交互物品ID列表
:param n_recall: 召回数量
:return: 候选物品及分数
"""
if not user_history:
return []
# 累计相似物品的分数
scores = np.zeros(self.item_sim_matrix.shape[0])
for item_id in user_history:
if item_id < self.item_sim_matrix.shape[0]:
scores += self.item_sim_matrix[item_id]
# 排除已交互物品
scores[user_history] = -np.inf
# Top N
top_indices = np.argsort(scores)[::-1][:n_recall]
top_scores = scores[top_indices]
return list(zip(top_indices, top_scores))
class ContentRecaller:
"""基于内容的召回"""
def __init__(self, item_embeddings):
"""
:param item_embeddings: 物品向量 {item_id: embedding}
"""
self.item_embeddings = item_embeddings
def build_user_profile(self, user_history, decay=0.9):
"""
构建用户画像(历史物品的加权平均)
:param user_history: [(item_id, timestamp), ...]
:param decay: 时间衰减系数
"""
if not user_history:
return None
# 按时间排序
user_history = sorted(user_history, key=lambda x: x[1])
# 加权平均
user_embedding = np.zeros_like(
self.item_embeddings[user_history[0][0]]
)
total_weight = 0
for i, (item_id, timestamp) in enumerate(user_history):
weight = decay ** (len(user_history) - i - 1)
user_embedding += weight * self.item_embeddings[item_id]
total_weight += weight
user_embedding /= total_weight
return user_embedding
def recall(self, user_profile, candidate_items, n_recall=100):
"""
召回相似物品
"""
if user_profile is None:
return []
# 计算相似度
similarities = []
for item_id in candidate_items:
embedding = self.item_embeddings[item_id]
sim = np.dot(user_profile, embedding) / (
np.linalg.norm(user_profile) * np.linalg.norm(embedding)
)
similarities.append((item_id, sim))
# 排序
similarities.sort(key=lambda x: x[1], reverse=True)
return similarities[:n_recall]
class HotRecaller:
"""热门推荐召回"""
def __init__(self, redis_client):
self.redis = redis_client
def update_hot_items(self, time_window='1h'):
"""
更新热门物品(定时任务)
:param time_window: 时间窗口
"""
# 统计时间窗口内的点击、点赞等
hot_key = f'hot_items:{time_window}'
# 这里简化,实际从日志聚合
# 使用有序集合存储:score = 热度分数
# self.redis.zadd(hot_key, {item_id: score})
pass
def recall(self, n_recall=100, exclude_items=None):
"""
召回热门物品
"""
exclude_items = exclude_items or []
hot_key = 'hot_items:1h'
# 获取热门物品
hot_items = self.redis.zrevrange(
hot_key, 0, n_recall + len(exclude_items) - 1,
withscores=True
)
# 过滤已交互
results = [
(int(item_id), score)
for item_id, score in hot_items
if int(item_id) not in exclude_items
]
return results[:n_recall]5.2 排序层实现
Wide & Deep模型(TensorFlow)
python
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
class WideAndDeepModel:
"""Wide & Deep排序模型"""
def __init__(self, config):
self.config = config
self.model = None
def build_model(self):
"""构建模型"""
# ===== 输入层 =====
# 类别特征(Wide侧)
user_id = layers.Input(shape=(1,), name='user_id')
item_id = layers.Input(shape=(1,), name='item_id')
category = layers.Input(shape=(1,), name='category')
# 数值特征(Wide侧)
user_age = layers.Input(shape=(1,), name='user_age')
item_price = layers.Input(shape=(1,), name='item_price')
# 向量特征(Deep侧)
user_embedding = layers.Input(
shape=(128,), name='user_embedding'
)
item_embedding = layers.Input(
shape=(128,), name='item_embedding'
)
# ===== Wide部分(线性模型) =====
# 类别特征one-hot
user_id_onehot = layers.Embedding(
input_dim=1000000, output_dim=1
)(user_id)
item_id_onehot = layers.Embedding(
input_dim=100000, output_dim=1
)(item_id)
category_onehot = layers.Embedding(
input_dim=100, output_dim=1
)(category)
# Wide特征拼接
wide_features = layers.concatenate([
layers.Flatten()(user_id_onehot),
layers.Flatten()(item_id_onehot),
layers.Flatten()(category_onehot),
user_age,
item_price
])
# Wide输出
wide_output = layers.Dense(
1, activation=None, name='wide_output'
)(wide_features)
# ===== Deep部分(深度神经网络) =====
# Embedding特征拼接
deep_features = layers.concatenate([
user_embedding,
item_embedding,
user_age,
item_price
])
# DNN层
deep = layers.Dense(256, activation='relu')(deep_features)
deep = layers.BatchNormalization()(deep)
deep = layers.Dropout(0.3)(deep)
deep = layers.Dense(128, activation='relu')(deep)
deep = layers.BatchNormalization()(deep)
deep = layers.Dropout(0.3)(deep)
deep = layers.Dense(64, activation='relu')(deep)
# Deep输出
deep_output = layers.Dense(
1, activation=None, name='deep_output'
)(deep)
# ===== Wide & Deep融合 =====
output = layers.add([wide_output, deep_output])
output = layers.Activation('sigmoid', name='output')(output)
# 构建模型
self.model = keras.Model(
inputs=[
user_id, item_id, category,
user_age, item_price,
user_embedding, item_embedding
],
outputs=output
)
# 编译
self.model.compile(
optimizer='adam',
loss='binary_crossentropy',
metrics=['auc', 'accuracy']
)
return self.model
def train(self, train_data, val_data, epochs=10):
"""训练模型"""
# 回调
callbacks = [
keras.callbacks.EarlyStopping(
patience=3, restore_best_weights=True
),
keras.callbacks.ReduceLROnPlateau(
factor=0.5, patience=2
),
keras.callbacks.ModelCheckpoint(
'wide_deep_best.h5', save_best_only=True
)
]
# 训练
history = self.model.fit(
train_data,
validation_data=val_data,
epochs=epochs,
callbacks=callbacks,
verbose=1
)
return history
def predict(self, features):
"""预测"""
return self.model.predict(features)5.3 特征工程
python
class RealtimeFeatureService:
"""实时特征服务"""
def __init__(self, redis_client):
self.redis = redis_client
def get_user_realtime_features(self, user_id):
"""获取用户实时特征"""
# 最近1小时行为
recent_clicks = self.redis.lrange(
f'user:{user_id}:clicks:1h', 0, -1
)
# 统计特征
features = {
'click_count_1h': len(recent_clicks),
'active_minutes_1h': self._calc_active_time(recent_clicks),
'category_dist_1h': self._calc_category_dist(recent_clicks)
}
# 实时偏好
realtime_prefs = self.redis.hgetall(
f'user:{user_id}:realtime_prefs'
)
features['realtime_interests'] = realtime_prefs
return features
def get_item_realtime_features(self, item_id):
"""获取物品实时特征"""
# 实时统计(最近1小时)
stats_key = f'item:{item_id}:stats:1h'
stats = self.redis.hgetall(stats_key)
features = {
'ctr_1h': float(stats.get('clicks', 0)) / max(float(stats.get('views', 1)), 1),
'like_rate_1h': float(stats.get('likes', 0)) / max(float(stats.get('clicks', 1)), 1),
'share_rate_1h': float(stats.get('shares', 0)) / max(float(stats.get('clicks', 1)), 1)
}
return features
def get_context_features(self, context):
"""获取上下文特征"""
return {
'hour': context.get('hour'),
'day_of_week': context.get('day_of_week'),
'device': context.get('device'),
'network': context.get('network'), # wifi/4g/5g
'location': context.get('location')
}java
@Service
public class OfflineFeatureService {
@Autowired
private HBaseTemplate hbaseTemplate;
/**
* 获取用户离线特征
*/
public Map<String, Object> getUserOfflineFeatures(Long userId) {
Map<String, Object> features = new HashMap<>();
// 从HBase获取
String rowKey = "user:" + userId;
Result result = hbaseTemplate.get("user_features", rowKey);
// 基础特征
features.put("age", Bytes.toInt(result.getValue("basic", "age")));
features.put("gender", Bytes.toString(result.getValue("basic", "gender")));
features.put("city", Bytes.toString(result.getValue("basic", "city")));
// 统计特征(最近30天)
features.put("click_count_30d",
Bytes.toLong(result.getValue("stats", "click_count_30d")));
features.put("like_count_30d",
Bytes.toLong(result.getValue("stats", "like_count_30d")));
features.put("purchase_count_30d",
Bytes.toLong(result.getValue("stats", "purchase_count_30d")));
// 偏好特征
String interests = Bytes.toString(result.getValue("prefs", "interests"));
features.put("interests", JSON.parseArray(interests));
// 用户embedding
byte[] embedding = result.getValue("embedding", "user_vec");
features.put("user_embedding", bytesToFloatArray(embedding));
return features;
}
/**
* 获取物品离线特征
*/
public Map<String, Object> getItemOfflineFeatures(Long itemId) {
Map<String, Object> features = new HashMap<>();
String rowKey = "item:" + itemId;
Result result = hbaseTemplate.get("item_features", rowKey);
// 基本信息
features.put("category",
Bytes.toString(result.getValue("basic", "category")));
features.put("tags",
Bytes.toString(result.getValue("basic", "tags")));
features.put("create_time",
Bytes.toLong(result.getValue("basic", "create_time")));
// 统计特征
features.put("total_views",
Bytes.toLong(result.getValue("stats", "total_views")));
features.put("total_clicks",
Bytes.toLong(result.getValue("stats", "total_clicks")));
features.put("avg_ctr",
Bytes.toDouble(result.getValue("stats", "avg_ctr")));
features.put("quality_score",
Bytes.toDouble(result.getValue("stats", "quality_score")));
// 物品embedding
byte[] embedding = result.getValue("embedding", "item_vec");
features.put("item_embedding", bytesToFloatArray(embedding));
return features;
}
}六、冷启动解决
6.1 新用户冷启动
python
class ColdStartHandler:
"""冷启动处理器"""
def handle_new_user(self, user_id, user_info):
"""
新用户冷启动策略
"""
strategies = []
# 1. 基于人口统计学推荐
if user_info.get('age') and user_info.get('gender'):
strategies.append(
self.demographic_based_recall(user_info)
)
# 2. 热门推荐
strategies.append(
self.hot_recall(time_window='7d', limit=50)
)
# 3. 编辑精选
strategies.append(
self.editorial_recall(limit=20)
)
# 4. 探索推荐(多样性)
strategies.append(
self.exploration_recall(limit=30)
)
# 合并去重
return self.merge_and_deduplicate(strategies)
def demographic_based_recall(self, user_info):
"""基于人口统计学召回"""
age_group = self._get_age_group(user_info['age'])
gender = user_info['gender']
# 查询相同年龄段、性别用户喜欢的物品
similar_users = self.find_similar_demographic_users(
age_group, gender
)
# 聚合他们喜欢的物品
items = self.aggregate_user_preferences(similar_users)
return items[:50]
def handle_new_item(self, item_id, item_info):
"""
新物品冷启动策略
"""
# 1. 内容特征提取
content_features = self.extract_content_features(item_info)
# 2. 找到相似物品
similar_items = self.find_similar_items(content_features)
# 3. 推荐给喜欢相似物品的用户
target_users = self.find_users_liked_similar_items(
similar_items
)
# 4. 探索流量(随机曝光)
exploration_users = self.sample_exploration_users(
sample_rate=0.01 # 1%流量
)
return target_users + exploration_users七、效果评估
7.1 离线评估指标
python
from sklearn.metrics import roc_auc_score, log_loss
import numpy as np
class OfflineEvaluator:
"""离线评估"""
@staticmethod
def calc_auc(y_true, y_pred):
"""AUC"""
return roc_auc_score(y_true, y_pred)
@staticmethod
def calc_logloss(y_true, y_pred):
"""LogLoss"""
return log_loss(y_true, y_pred)
@staticmethod
def calc_precision_at_k(y_true, y_pred, k=10):
"""Precision@K"""
# y_true: [1, 0, 1, 0, ...]
# y_pred: [0.9, 0.1, 0.8, ...]
# Top K
top_k_indices = np.argsort(y_pred)[::-1][:k]
relevant_in_topk = sum(y_true[i] for i in top_k_indices)
return relevant_in_topk / k
@staticmethod
def calc_recall_at_k(y_true, y_pred, k=10):
"""Recall@K"""
top_k_indices = np.argsort(y_pred)[::-1][:k]
relevant_in_topk = sum(y_true[i] for i in top_k_indices)
total_relevant = sum(y_true)
return relevant_in_topk / max(total_relevant, 1)
@staticmethod
def calc_ndcg_at_k(y_true, y_pred, k=10):
"""NDCG@K(归一化折损累计增益)"""
def dcg_at_k(r, k):
r = np.asfarray(r)[:k]
return np.sum(r / np.log2(np.arange(2, r.size + 2)))
top_k_indices = np.argsort(y_pred)[::-1][:k]
r = [y_true[i] for i in top_k_indices]
dcg = dcg_at_k(r, k)
idcg = dcg_at_k(sorted(y_true, reverse=True), k)
return dcg / idcg if idcg > 0 else 07.2 在线ABTest
python
class ABTestService:
"""ABTest服务"""
def __init__(self, redis_client):
self.redis = redis_client
def assign_bucket(self, user_id, experiment_id):
"""
分配实验桶
:return: 'control' or 'treatment'
"""
# 根据user_id哈希分桶
hash_value = hash(f"{user_id}_{experiment_id}")
bucket_id = hash_value % 100
# 50%对照组,50%实验组
return 'treatment' if bucket_id < 50 else 'control'
def track_metric(self, user_id, experiment_id, bucket, metric, value):
"""记录指标"""
key = f'ab_test:{experiment_id}:{bucket}:{metric}'
self.redis.lpush(key, value)
def get_results(self, experiment_id):
"""获取实验结果"""
results = {}
for bucket in ['control', 'treatment']:
# CTR
clicks_key = f'ab_test:{experiment_id}:{bucket}:clicks'
views_key = f'ab_test:{experiment_id}:{bucket}:views'
clicks = self.redis.llen(clicks_key)
views = self.redis.llen(views_key)
ctr = clicks / max(views, 1)
results[bucket] = {
'ctr': ctr,
'clicks': clicks,
'views': views
}
# 显著性检验
p_value = self._significance_test(
results['control'], results['treatment']
)
results['p_value'] = p_value
return results八、性能优化
8.1 召回优化
- 向量索引:使用FAISS加速向量检索
- 多路并行:多个召回策略并行执行
- 预计算:提前计算热门、相似物品
- 缓存:Redis缓存召回结果
8.2 排序优化
- 模型蒸馏:大模型->小模型
- 特征裁剪:去除低重要性特征
- 批量预测:批处理提高吞吐
- GPU加速:使用GPU推理
8.3 性能数据
| 指标 | 目标 | 实际 |
|---|---|---|
| 推荐延迟P99 | <100ms | 85ms |
| 推荐QPS | 10万+ | 12万 |
| 召回候选数 | 1000+ | 1200 |
| 排序数量 | 100 | 100 |
| 点击率CTR | >3% | 3.5% |
九、面试要点
9.1 常见追问
Q1: 召回和排序的区别?
A:
- 召回:快速从海量物品中筛选出候选集(1000+),注重召回率和性能
- 排序:对候选集精排,注重准确率,可以用复杂模型
- 关系:召回追求"不漏",排序追求"精准"
Q2: 如何解决推荐多样性问题?
A:
- 召回多样性:多种召回策略(协同、内容、热门)
- 类别打散:同类物品不连续出现
- 时间衰减:历史推荐过的降权
- DPP算法:行列式点过程保证多样性
- 探索与利用:EE策略,10%流量探索新内容
Q3: 推荐系统如何实时更新?
A:
- 实时特征:Flink计算实时行为特征
- 在线学习:增量更新模型参数
- 快速索引更新:新物品快速进入召回池
- 热更新:模型热加载,不停服务
9.2 扩展知识
十、总结
推荐系统是算法和工程的结合,核心要点:
- 三层架构:召回->排序->重排,平衡性能和准确性
- 算法演进:从协同过滤到深度学习
- 冷启动:多策略组合解决新用户、新物品
- 效果评估:离线指标 + 在线ABTest
- 工程优化:缓存、并行、模型压缩
面试中要能说清楚架构设计、算法选型、工程实现、效果评估等全链路。
推荐阅读:
- 《推荐系统实践》- 项亮
- 《深度学习推荐系统》- 王喆
- YouTube DNN论文
- Wide & Deep论文
