Skip to content

实时消息系统设计

💬 大规模即时通讯架构

消息系统核心设计

Q1: 设计一个支持千万用户的实时消息系统?如何保证消息的可靠性和顺序性?

难度: ⭐⭐⭐⭐⭐

答案: 大规模实时消息系统需要解决消息投递、存储、同步等核心问题,同时保证高可用和低延迟。

1. 整体系统架构:

mermaid
graph TB
    A[客户端App] --> B[WebSocket网关]
    B --> C[消息路由服务]
    
    C --> D[会话管理服务]
    C --> E[消息投递服务]
    C --> F[离线消息服务]
    
    G[推送服务] --> H[APNs/FCM]
    
    E --> I[消息存储服务]
    F --> I
    
    I --> J[(HBase消息存储)]
    D --> K[(Redis会话缓存)]
    
    L[消息队列 Kafka] --> E
    L --> F
    L --> G
    
    M[用户状态服务] --> N[(Redis在线状态)]
    O[群组管理服务] --> P[(MySQL群组信息)]

2. WebSocket连接管理:

连接网关实现:

java
@Component
public class WebSocketGateway {
    
    private final ConcurrentHashMap<String, Channel> userChannels = new ConcurrentHashMap<>();
    private final ConnectionManager connectionManager;
    
    @EventListener
    public void handleWebSocketConnect(WebSocketConnectEvent event) {
        String userId = extractUserId(event);
        Channel channel = event.getChannel();
        
        // 1. 用户认证
        if (!authenticateUser(userId, event.getToken())) {
            channel.close();
            return;
        }
        
        // 2. 连接管理
        connectionManager.addConnection(userId, channel);
        
        // 3. 更新在线状态
        userStatusService.setUserOnline(userId, getServerId());
        
        // 4. 处理重复连接
        handleDuplicateConnections(userId, channel);
        
        // 5. 拉取离线消息
        pullOfflineMessages(userId);
        
        log.info("User {} connected from {}", userId, channel.remoteAddress());
    }
    
    @EventListener  
    public void handleWebSocketDisconnect(WebSocketDisconnectEvent event) {
        String userId = extractUserId(event);
        Channel channel = event.getChannel();
        
        // 1. 移除连接
        connectionManager.removeConnection(userId, channel);
        
        // 2. 更新离线状态
        userStatusService.setUserOffline(userId);
        
        // 3. 清理心跳定时器
        heartbeatManager.removeHeartbeat(userId);
        
        log.info("User {} disconnected", userId);
    }
    
    private void handleDuplicateConnections(String userId, Channel newChannel) {
        // 处理同一用户多端登录
        List<Channel> existingChannels = connectionManager.getUserChannels(userId);
        
        if (!existingChannels.isEmpty()) {
            // 根据策略处理:踢出旧连接 或 允许多端同时在线
            if (shouldKickOutOldConnections()) {
                existingChannels.forEach(Channel::close);
                connectionManager.clearUserConnections(userId);
            }
        }
        
        connectionManager.addConnection(userId, newChannel);
    }
    
    private void pullOfflineMessages(String userId) {
        CompletableFuture.runAsync(() -> {
            try {
                List<OfflineMessage> offlineMessages = offlineMessageService.getOfflineMessages(userId);
                
                for (OfflineMessage msg : offlineMessages) {
                    sendMessageToUser(userId, msg.toRealtimeMessage());
                }
                
                // 标记离线消息已投递
                offlineMessageService.markDelivered(userId, 
                    offlineMessages.stream().map(OfflineMessage::getId).collect(Collectors.toList()));
                    
            } catch (Exception e) {
                log.error("Pull offline messages failed for user: {}", userId, e);
            }
        });
    }
}

@Component
public class ConnectionManager {
    
    // 用户ID -> 连接通道映射
    private final ConcurrentHashMap<String, Set<Channel>> userConnections = new ConcurrentHashMap<>();
    // 通道ID -> 用户ID映射
    private final ConcurrentHashMap<String, String> channelUsers = new ConcurrentHashMap<>();
    
    public void addConnection(String userId, Channel channel) {
        userConnections.computeIfAbsent(userId, k -> ConcurrentHashMap.newKeySet()).add(channel);
        channelUsers.put(channel.id().asShortText(), userId);
        
        // 启动心跳检测
        startHeartbeat(userId, channel);
    }
    
    public void removeConnection(String userId, Channel channel) {
        Set<Channel> channels = userConnections.get(userId);
        if (channels != null) {
            channels.remove(channel);
            if (channels.isEmpty()) {
                userConnections.remove(userId);
            }
        }
        channelUsers.remove(channel.id().asShortText());
    }
    
    public List<Channel> getUserChannels(String userId) {
        return userConnections.getOrDefault(userId, Collections.emptySet())
            .stream()
            .filter(Channel::isActive)
            .collect(Collectors.toList());
    }
    
    public boolean isUserOnline(String userId) {
        return !getUserChannels(userId).isEmpty();
    }
    
    // 心跳机制
    private void startHeartbeat(String userId, Channel channel) {
        ScheduledFuture<?> heartbeat = channel.eventLoop().scheduleAtFixedRate(() -> {
            if (channel.isActive()) {
                sendPing(channel);
            }
        }, 30, 30, TimeUnit.SECONDS);
        
        // 存储心跳任务,用于清理
        channel.attr(AttributeKey.valueOf("heartbeat")).set(heartbeat);
    }
}

3. 消息路由和投递:

消息投递核心逻辑:

java
@Service
public class MessageDeliveryService {
    
    @Autowired private ConnectionManager connectionManager;
    @Autowired private OfflineMessageService offlineMessageService;
    @Autowired private MessageStorageService messageStorage;
    @Autowired private PushNotificationService pushService;
    
    public DeliveryResult deliverMessage(Message message) {
        try {
            // 1. 消息持久化
            messageStorage.saveMessage(message);
            
            // 2. 根据消息类型选择投递策略
            DeliveryResult result = switch (message.getType()) {
                case SINGLE_CHAT -> deliverP2PMessage(message);
                case GROUP_CHAT -> deliverGroupMessage(message);
                case SYSTEM_MESSAGE -> deliverSystemMessage(message);
                default -> DeliveryResult.fail("Unknown message type");
            };
            
            // 3. 投递结果统计
            recordDeliveryMetrics(message, result);
            
            return result;
            
        } catch (Exception e) {
            log.error("Message delivery failed", e);
            return DeliveryResult.fail("Delivery failed: " + e.getMessage());
        }
    }
    
    private DeliveryResult deliverP2PMessage(Message message) {
        String recipientId = message.getRecipientId();
        
        // 1. 检查接收方在线状态
        if (connectionManager.isUserOnline(recipientId)) {
            return deliverToOnlineUser(recipientId, message);
        } else {
            return deliverToOfflineUser(recipientId, message);
        }
    }
    
    private DeliveryResult deliverToOnlineUser(String userId, Message message) {
        List<Channel> channels = connectionManager.getUserChannels(userId);
        
        if (channels.isEmpty()) {
            // 用户刚下线,转为离线投递
            return deliverToOfflineUser(userId, message);
        }
        
        boolean delivered = false;
        for (Channel channel : channels) {
            try {
                // 构造实时消息
                RealtimeMessage rtMessage = RealtimeMessage.builder()
                    .messageId(message.getId())
                    .fromUserId(message.getSenderId())
                    .toUserId(userId)
                    .content(message.getContent())
                    .messageType(message.getType())
                    .timestamp(message.getTimestamp())
                    .build();
                
                // 发送消息
                channel.writeAndFlush(new TextWebSocketFrame(JSON.toJSONString(rtMessage)));
                delivered = true;
                
                // 记录投递状态
                messageStorage.updateDeliveryStatus(message.getId(), DeliveryStatus.DELIVERED);
                
            } catch (Exception e) {
                log.error("Failed to deliver message to channel: {}", channel.id(), e);
                // 移除无效连接
                connectionManager.removeConnection(userId, channel);
            }
        }
        
        return delivered ? DeliveryResult.success() : deliverToOfflineUser(userId, message);
    }
    
    private DeliveryResult deliverToOfflineUser(String userId, Message message) {
        // 1. 存储离线消息
        OfflineMessage offlineMsg = OfflineMessage.builder()
            .userId(userId)
            .messageId(message.getId())
            .senderId(message.getSenderId())
            .content(message.getContent())
            .messageType(message.getType())
            .timestamp(message.getTimestamp())
            .status(OfflineMessageStatus.PENDING)
            .build();
            
        offlineMessageService.saveOfflineMessage(offlineMsg);
        
        // 2. 发送推送通知
        PushNotification notification = PushNotification.builder()
            .userId(userId)
            .title(getSenderDisplayName(message.getSenderId()))
            .content(message.getContent())
            .data(Map.of("messageId", message.getId(), "type", message.getType()))
            .build();
            
        pushService.sendNotification(notification);
        
        // 3. 记录离线投递状态
        messageStorage.updateDeliveryStatus(message.getId(), DeliveryStatus.OFFLINE_STORED);
        
        return DeliveryResult.success();
    }
    
    private DeliveryResult deliverGroupMessage(Message message) {
        String groupId = message.getGroupId();
        
        // 1. 获取群组成员
        List<String> members = groupService.getGroupMembers(groupId);
        
        // 2. 过滤发送者
        members = members.stream()
            .filter(memberId -> !memberId.equals(message.getSenderId()))
            .collect(Collectors.toList());
        
        // 3. 批量投递
        List<CompletableFuture<DeliveryResult>> deliveryTasks = members.stream()
            .map(memberId -> CompletableFuture.supplyAsync(() -> {
                // 创建单独的消息副本给每个成员
                Message memberMessage = message.toBuilder()
                    .recipientId(memberId)
                    .build();
                return deliverP2PMessage(memberMessage);
            }))
            .collect(Collectors.toList());
        
        // 4. 等待所有投递完成
        CompletableFuture<Void> allDeliveries = CompletableFuture.allOf(
            deliveryTasks.toArray(new CompletableFuture[0])
        );
        
        try {
            allDeliveries.get(5, TimeUnit.SECONDS); // 5秒超时
            
            // 统计投递结果
            long successCount = deliveryTasks.stream()
                .mapToLong(future -> {
                    try {
                        return future.get().isSuccess() ? 1 : 0;
                    } catch (Exception e) {
                        return 0;
                    }
                }).sum();
            
            return DeliveryResult.success("Delivered to " + successCount + "/" + members.size() + " members");
            
        } catch (TimeoutException e) {
            log.warn("Group message delivery timeout for group: {}", groupId);
            return DeliveryResult.partial("Some deliveries may still be in progress");
        } catch (Exception e) {
            log.error("Group message delivery failed for group: {}", groupId, e);
            return DeliveryResult.fail("Group delivery failed");
        }
    }
}

4. 消息存储和同步:

消息存储服务:

java
@Service
public class MessageStorageService {
    
    @Autowired private HBaseTemplate hbaseTemplate;
    @Autowired private RedisTemplate redisTemplate;
    
    // HBase表结构:
    // Row Key: userId_timestamp_messageId
    // Column Family: msg
    // Columns: from, to, content, type, status, group_id
    
    public void saveMessage(Message message) {
        String rowKey = buildRowKey(message.getRecipientId(), message.getTimestamp(), message.getId());
        
        // 1. 保存到HBase
        Put put = new Put(Bytes.toBytes(rowKey));
        put.addColumn(Bytes.toBytes("msg"), Bytes.toBytes("from"), Bytes.toBytes(message.getSenderId()));
        put.addColumn(Bytes.toBytes("msg"), Bytes.toBytes("to"), Bytes.toBytes(message.getRecipientId()));
        put.addColumn(Bytes.toBytes("msg"), Bytes.toBytes("content"), Bytes.toBytes(message.getContent()));
        put.addColumn(Bytes.toBytes("msg"), Bytes.toBytes("type"), Bytes.toBytes(message.getType().name()));
        put.addColumn(Bytes.toBytes("msg"), Bytes.toBytes("timestamp"), Bytes.toBytes(message.getTimestamp()));
        
        if (message.getGroupId() != null) {
            put.addColumn(Bytes.toBytes("msg"), Bytes.toBytes("group_id"), Bytes.toBytes(message.getGroupId()));
        }
        
        hbaseTemplate.execute("messages", table -> {
            table.put(put);
            return null;
        });
        
        // 2. 缓存最近消息到Redis
        cacheRecentMessage(message);
    }
    
    public List<Message> getMessageHistory(String userId1, String userId2, long startTime, long endTime, int limit) {
        // 构造扫描范围
        String startRowKey = buildRowKey(userId1, startTime, "");
        String stopRowKey = buildRowKey(userId1, endTime + 1, "");
        
        return hbaseTemplate.execute("messages", table -> {
            Scan scan = new Scan();
            scan.setStartRow(Bytes.toBytes(startRowKey));
            scan.setStopRow(Bytes.toBytes(stopRowKey));
            scan.setReversed(true); // 倒序扫描,获取最新消息
            scan.setLimit(limit);
            
            // 过滤条件:只获取与userId2的对话
            FilterList filters = new FilterList(FilterList.Operator.MUST_PASS_ALL);
            filters.addFilter(new SingleColumnValueFilter(
                Bytes.toBytes("msg"), 
                Bytes.toBytes("from"), 
                CompareFilter.CompareOp.EQUAL, 
                Bytes.toBytes(userId2)
            ));
            scan.setFilter(filters);
            
            List<Message> messages = new ArrayList<>();
            try (ResultScanner scanner = table.getScanner(scan)) {
                for (Result result : scanner) {
                    messages.add(parseMessage(result));
                }
            }
            return messages;
        });
    }
    
    private void cacheRecentMessage(Message message) {
        // 为每个用户缓存最近100条消息
        String cacheKey = "recent_messages:" + message.getRecipientId();
        
        // 使用有序集合存储,按时间戳排序
        redisTemplate.opsForZSet().add(cacheKey, JSON.toJSONString(message), message.getTimestamp());
        
        // 保留最近100条
        redisTemplate.opsForZSet().removeRange(cacheKey, 0, -101);
        
        // 设置过期时间
        redisTemplate.expire(cacheKey, Duration.ofDays(7));
    }
    
    private String buildRowKey(String userId, long timestamp, String messageId) {
        // 使用反转时间戳确保最新消息在前
        long reversedTimestamp = Long.MAX_VALUE - timestamp;
        return String.format("%s_%020d_%s", userId, reversedTimestamp, messageId);
    }
}

5. 消息顺序和一致性保证:

消息序列号机制:

java
@Service
public class MessageSequenceService {
    
    @Autowired private RedisTemplate redisTemplate;
    
    // 为每个会话维护消息序列号
    public long generateSequenceNumber(String conversationId) {
        String key = "msg_seq:" + conversationId;
        return redisTemplate.opsForValue().increment(key);
    }
    
    public void ensureMessageOrder(String conversationId, long expectedSeq, Message message) {
        String lockKey = "msg_order_lock:" + conversationId;
        
        // 分布式锁保证消息顺序
        redisTemplate.execute(new SessionCallback<Object>() {
            @Override
            public Object execute(RedisOperations operations) throws DataAccessException {
                operations.multi();
                
                // 检查序列号
                Long currentSeq = getCurrentSequence(conversationId);
                if (currentSeq != null && expectedSeq <= currentSeq) {
                    operations.discard();
                    throw new MessageOrderException("Duplicate or out-of-order message");
                }
                
                // 保存消息和序列号
                saveMessageWithSequence(message, expectedSeq);
                updateSequence(conversationId, expectedSeq);
                
                return operations.exec();
            }
        });
    }
    
    // 客户端消息同步
    public MessageSyncResult syncMessages(String userId, long lastSyncSeq) {
        List<Message> missedMessages = new ArrayList<>();
        
        // 获取用户所有会话
        List<String> conversations = getUserConversations(userId);
        
        for (String conversationId : conversations) {
            // 获取该会话中用户未同步的消息
            List<Message> messages = getMessagesAfterSequence(conversationId, lastSyncSeq, userId);
            missedMessages.addAll(messages);
        }
        
        // 按时间戳排序
        missedMessages.sort(Comparator.comparing(Message::getTimestamp));
        
        return MessageSyncResult.builder()
            .messages(missedMessages)
            .currentSyncSeq(getCurrentMaxSequence(userId))
            .build();
    }
}

// 客户端消息确认机制
@Service
public class MessageAckService {
    
    public void handleMessageAck(String userId, String messageId) {
        // 1. 更新消息状态为已读
        messageStorageService.updateMessageStatus(messageId, MessageStatus.READ);
        
        // 2. 通知发送者消息已读
        Message originalMessage = messageStorageService.getMessage(messageId);
        if (originalMessage != null) {
            sendReadReceipt(originalMessage.getSenderId(), messageId);
        }
        
        // 3. 更新用户的已读序列号
        updateUserReadSequence(userId, originalMessage.getSequenceNumber());
    }
    
    private void sendReadReceipt(String senderId, String messageId) {
        ReadReceiptMessage receipt = ReadReceiptMessage.builder()
            .type("read_receipt")
            .messageId(messageId)
            .readTime(System.currentTimeMillis())
            .build();
        
        List<Channel> senderChannels = connectionManager.getUserChannels(senderId);
        for (Channel channel : senderChannels) {
            channel.writeAndFlush(new TextWebSocketFrame(JSON.toJSONString(receipt)));
        }
    }
}

面试要点:

  • 理解大规模消息系统的连接管理和负载均衡
  • 掌握消息的可靠投递和离线存储机制
  • 了解消息顺序性和一致性的保证方法
  • 掌握消息同步和状态管理的实现策略

这个实时消息系统设计涵盖了WebSocket连接管理、消息路由投递、存储同步、顺序一致性等核心功能,展示了大规模即时通讯系统的完整技术方案。

正在精进