Skip to content

工作窃取算法详解 - Golang运行时机制面试题

工作窃取算法是Go调度器实现负载均衡的核心机制。本章深入探讨工作窃取的设计原理、实现细节和性能优化策略。

📋 重点面试题

面试题 1:工作窃取算法原理和实现

难度级别:⭐⭐⭐⭐⭐
考察范围:调度算法/负载均衡
技术标签work stealing load balancing scheduling queue management performance

详细解答

1. 工作窃取算法基本原理

go
package main

import (
    "fmt"
    "math/rand"
    "runtime"
    "sync"
    "sync/atomic"
    "time"
)

func demonstrateWorkStealing() {
    fmt.Println("=== 工作窃取算法原理演示 ===")
    
    /*
    工作窃取算法核心概念:
    
    1. 每个P维护本地队列:
       - 优先从本地队列获取G
       - 本地队列为空时才尝试窃取
       - 使用双端队列(deque)结构
    
    2. 窃取策略:
       - 随机选择目标P进行窃取
       - 从目标P的队列尾部窃取(LIFO)
       - 一次窃取一半的工作(批量窃取)
    
    3. 负载均衡效果:
       - 忙碌的P分担工作给空闲的P
       - 减少P之间的负载不均衡
       - 提高整体系统吞吐量
    
    4. 实现细节:
       - 无锁队列操作(CAS原子操作)
       - 避免ABA问题
       - 处理队列边界条件
    */
    
    // 演示工作窃取的效果
    demonstrateWorkStealingEffect()
    
    // 模拟工作窃取算法
    demonstrateWorkStealingSimulation()
    
    // 分析负载均衡性能
    demonstrateLoadBalancingAnalysis()
    
    // 演示窃取策略对比
    demonstrateStealingStrategies()
}

func demonstrateWorkStealingEffect() {
    fmt.Println("\n--- 工作窃取效果演示 ---")
    
    // 创建不均衡的工作负载
    numWorkers := runtime.NumCPU()
    workloads := make([][]int, numWorkers)
    
    // 前半部分worker分配大量任务,后半部分分配少量任务
    for i := 0; i < numWorkers; i++ {
        if i < numWorkers/2 {
            // 重负载worker:1000个任务
            workloads[i] = make([]int, 1000)
            for j := range workloads[i] {
                workloads[i][j] = j + i*1000
            }
        } else {
            // 轻负载worker:10个任务
            workloads[i] = make([]int, 10)
            for j := range workloads[i] {
                workloads[i][j] = j + i*1000
            }
        }
    }
    
    fmt.Printf("初始工作分配:\n")
    for i, workload := range workloads {
        fmt.Printf("  Worker %d: %d 个任务\n", i, len(workload))
    }
    
    // 模拟无工作窃取的执行(顺序执行)
    fmt.Println("\n无工作窃取执行:")
    start := time.Now()
    var wg sync.WaitGroup
    
    for i, workload := range workloads {
        wg.Add(1)
        go func(workerID int, tasks []int) {
            defer wg.Done()
            
            startTime := time.Now()
            
            for _, task := range tasks {
                // 模拟任务处理
                processTask(task)
            }
            
            duration := time.Since(startTime)
            fmt.Printf("  Worker %d 完成,耗时: %v\n", workerID, duration)
        }(i, workload)
    }
    
    wg.Wait()
    totalTime := time.Since(start)
    fmt.Printf("总执行时间: %v\n", totalTime)
    
    // 模拟工作窃取算法
    fmt.Println("\n使用工作窃取算法:")
    stealer := NewWorkStealer(numWorkers)
    
    // 将任务添加到工作窃取器
    for i, workload := range workloads {
        for _, task := range workload {
            stealer.AddTask(i, task)
        }
    }
    
    start = time.Now()
    stealer.Execute()
    stealingTime := time.Since(start)
    
    fmt.Printf("工作窃取总时间: %v\n", stealingTime)
    fmt.Printf("性能提升: %.2fx\n", float64(totalTime)/float64(stealingTime))
}

func processTask(task int) {
    // 模拟任务处理时间(0.1-1ms)
    delay := time.Duration(100+task%900) * time.Microsecond
    time.Sleep(delay)
}

// 工作窃取器实现
type WorkStealer struct {
    queues    []*WorkQueue
    workers   int
    completed int64
    total     int64
}

type WorkQueue struct {
    tasks []int
    head  int64
    tail  int64
    mutex sync.Mutex
}

func NewWorkStealer(numWorkers int) *WorkStealer {
    ws := &WorkStealer{
        queues:  make([]*WorkQueue, numWorkers),
        workers: numWorkers,
    }
    
    for i := range ws.queues {
        ws.queues[i] = &WorkQueue{
            tasks: make([]int, 0, 1000),
        }
    }
    
    return ws
}

func (ws *WorkStealer) AddTask(workerID, task int) {
    ws.queues[workerID].Push(task)
    atomic.AddInt64(&ws.total, 1)
}

func (wq *WorkQueue) Push(task int) {
    wq.mutex.Lock()
    defer wq.mutex.Unlock()
    
    wq.tasks = append(wq.tasks, task)
    atomic.AddInt64(&wq.tail, 1)
}

func (wq *WorkQueue) Pop() (int, bool) {
    wq.mutex.Lock()
    defer wq.mutex.Unlock()
    
    head := atomic.LoadInt64(&wq.head)
    tail := atomic.LoadInt64(&wq.tail)
    
    if head >= tail {
        return 0, false
    }
    
    task := wq.tasks[head]
    atomic.AddInt64(&wq.head, 1)
    return task, true
}

func (wq *WorkQueue) Steal() []int {
    wq.mutex.Lock()
    defer wq.mutex.Unlock()
    
    head := atomic.LoadInt64(&wq.head)
    tail := atomic.LoadInt64(&wq.tail)
    size := tail - head
    
    if size <= 1 {
        return nil
    }
    
    // 窃取一半的任务
    stealCount := size / 2
    newTail := tail - stealCount
    
    stolen := make([]int, stealCount)
    copy(stolen, wq.tasks[newTail:tail])
    
    atomic.StoreInt64(&wq.tail, newTail)
    return stolen
}

func (wq *WorkQueue) Size() int64 {
    head := atomic.LoadInt64(&wq.head)
    tail := atomic.LoadInt64(&wq.tail)
    return tail - head
}

func (ws *WorkStealer) Execute() {
    var wg sync.WaitGroup
    
    for i := 0; i < ws.workers; i++ {
        wg.Add(1)
        go ws.worker(i, &wg)
    }
    
    wg.Wait()
}

func (ws *WorkStealer) worker(workerID int, wg *sync.WaitGroup) {
    defer wg.Done()
    
    processed := 0
    stolen := 0
    
    for {
        // 1. 尝试从本地队列获取任务
        if task, ok := ws.queues[workerID].Pop(); ok {
            processTask(task)
            atomic.AddInt64(&ws.completed, 1)
            processed++
            continue
        }
        
        // 2. 本地队列为空,尝试窃取
        stealTarget := rand.Intn(ws.workers)
        if stealTarget == workerID {
            continue
        }
        
        if stolenTasks := ws.queues[stealTarget].Steal(); stolenTasks != nil {
            for _, task := range stolenTasks {
                processTask(task)
                atomic.AddInt64(&ws.completed, 1)
                stolen++
            }
            continue
        }
        
        // 3. 检查是否所有任务完成
        if atomic.LoadInt64(&ws.completed) >= atomic.LoadInt64(&ws.total) {
            break
        }
        
        // 4. 短暂休息,避免busy waiting
        time.Sleep(100 * time.Microsecond)
    }
    
    fmt.Printf("  Worker %d: 处理 %d 个本地任务, 窃取 %d 个任务\n", 
        workerID, processed, stolen)
}

func demonstrateWorkStealingSimulation() {
    fmt.Println("\n--- 工作窃取算法模拟 ---")
    
    // 模拟更详细的工作窃取过程
    sim := NewSchedulingSimulator(4)
    
    // 创建不同类型的工作负载
    workloadTypes := []struct {
        name      string
        pattern   func(int) []Task
        intensity string
    }{
        {"均匀分布", createUniformWorkload, "轻度"},
        {"倾斜分布", createSkewedWorkload, "中度"},
        {"突发分布", createBurstWorkload, "重度"},
    }
    
    for _, workload := range workloadTypes {
        fmt.Printf("\n--- %s工作负载 (%s) ---\n", workload.name, workload.intensity)
        
        // 重置模拟器
        sim.Reset()
        
        // 添加任务
        for i := 0; i < 4; i++ {
            tasks := workload.pattern(i)
            for _, task := range tasks {
                sim.AddTask(i, task)
            }
        }
        
        // 运行模拟
        stats := sim.Run()
        stats.Print()
    }
}

type Task struct {
    ID       int
    Duration time.Duration
    WorkerID int
}

type SchedulingSimulator struct {
    processors []*Processor
    numP       int
    taskID     int64
}

type Processor struct {
    id           int
    localQueue   []Task
    processTime  time.Duration
    stolenTasks  int
    stealAttempts int
    processed    int
}

type SimulationStats struct {
    TotalTime      time.Duration
    ProcessorStats []ProcessorStat
    LoadBalance    float64
    StealEfficiency float64
}

type ProcessorStat struct {
    ID              int
    ProcessedTasks  int
    StolenTasks     int
    StealAttempts   int
    ProcessTime     time.Duration
    Utilization     float64
}

func NewSchedulingSimulator(numP int) *SchedulingSimulator {
    sim := &SchedulingSimulator{
        processors: make([]*Processor, numP),
        numP:       numP,
    }
    
    for i := range sim.processors {
        sim.processors[i] = &Processor{
            id:         i,
            localQueue: make([]Task, 0),
        }
    }
    
    return sim
}

func (sim *SchedulingSimulator) Reset() {
    for _, p := range sim.processors {
        p.localQueue = p.localQueue[:0]
        p.processTime = 0
        p.stolenTasks = 0
        p.stealAttempts = 0
        p.processed = 0
    }
    sim.taskID = 0
}

func (sim *SchedulingSimulator) AddTask(processorID int, task Task) {
    task.ID = int(atomic.AddInt64(&sim.taskID, 1))
    task.WorkerID = processorID
    sim.processors[processorID].localQueue = append(
        sim.processors[processorID].localQueue, task)
}

func (sim *SchedulingSimulator) Run() *SimulationStats {
    start := time.Now()
    
    var wg sync.WaitGroup
    for i := range sim.processors {
        wg.Add(1)
        go sim.runProcessor(i, &wg)
    }
    
    wg.Wait()
    totalTime := time.Since(start)
    
    return sim.calculateStats(totalTime)
}

func (sim *SchedulingSimulator) runProcessor(processorID int, wg *sync.WaitGroup) {
    defer wg.Done()
    
    p := sim.processors[processorID]
    
    for {
        // 处理本地队列任务
        if len(p.localQueue) > 0 {
            task := p.localQueue[0]
            p.localQueue = p.localQueue[1:]
            
            // 模拟任务执行
            time.Sleep(task.Duration)
            p.processTime += task.Duration
            p.processed++
            continue
        }
        
        // 尝试窃取任务
        stolen := sim.stealTask(processorID)
        if stolen {
            continue
        }
        
        // 检查所有队列是否为空
        if sim.allQueuesEmpty() {
            break
        }
        
        // 短暂等待
        time.Sleep(10 * time.Microsecond)
    }
}

func (sim *SchedulingSimulator) stealTask(processorID int) bool {
    p := sim.processors[processorID]
    p.stealAttempts++
    
    // 随机选择窃取目标
    target := rand.Intn(sim.numP)
    if target == processorID {
        return false
    }
    
    targetP := sim.processors[target]
    if len(targetP.localQueue) <= 1 {
        return false
    }
    
    // 窃取一半任务
    stealCount := len(targetP.localQueue) / 2
    stolenTasks := targetP.localQueue[len(targetP.localQueue)-stealCount:]
    targetP.localQueue = targetP.localQueue[:len(targetP.localQueue)-stealCount]
    
    // 添加到本地队列
    p.localQueue = append(p.localQueue, stolenTasks...)
    p.stolenTasks += stealCount
    
    return true
}

func (sim *SchedulingSimulator) allQueuesEmpty() bool {
    for _, p := range sim.processors {
        if len(p.localQueue) > 0 {
            return false
        }
    }
    return true
}

func (sim *SchedulingSimulator) calculateStats(totalTime time.Duration) *SimulationStats {
    stats := &SimulationStats{
        TotalTime:      totalTime,
        ProcessorStats: make([]ProcessorStat, sim.numP),
    }
    
    totalTasks := 0
    totalStolenTasks := 0
    totalStealAttempts := 0
    maxProcessTime := time.Duration(0)
    
    for i, p := range sim.processors {
        stats.ProcessorStats[i] = ProcessorStat{
            ID:              p.id,
            ProcessedTasks:  p.processed,
            StolenTasks:     p.stolenTasks,
            StealAttempts:   p.stealAttempts,
            ProcessTime:     p.processTime,
            Utilization:     float64(p.processTime) / float64(totalTime),
        }
        
        totalTasks += p.processed
        totalStolenTasks += p.stolenTasks
        totalStealAttempts += p.stealAttempts
        
        if p.processTime > maxProcessTime {
            maxProcessTime = p.processTime
        }
    }
    
    // 计算负载均衡度(标准差)
    avgTasks := float64(totalTasks) / float64(sim.numP)
    variance := 0.0
    for _, p := range sim.processors {
        diff := float64(p.processed) - avgTasks
        variance += diff * diff
    }
    stats.LoadBalance = 1.0 / (1.0 + variance/float64(sim.numP))
    
    // 计算窃取效率
    if totalStealAttempts > 0 {
        stats.StealEfficiency = float64(totalStolenTasks) / float64(totalStealAttempts)
    }
    
    return stats
}

func (stats *SimulationStats) Print() {
    fmt.Printf("模拟结果:\n")
    fmt.Printf("  总执行时间: %v\n", stats.TotalTime)
    fmt.Printf("  负载均衡度: %.3f\n", stats.LoadBalance)
    fmt.Printf("  窃取效率: %.3f\n", stats.StealEfficiency)
    
    fmt.Printf("  处理器统计:\n")
    for _, p := range stats.ProcessorStats {
        fmt.Printf("    P%d: 处理=%d, 窃取=%d, 尝试=%d, 利用率=%.2f%%\n",
            p.ID, p.ProcessedTasks, p.StolenTasks, p.StealAttempts, p.Utilization*100)
    }
}

// 工作负载生成函数
func createUniformWorkload(processorID int) []Task {
    // 每个处理器相同数量的任务
    tasks := make([]Task, 100)
    for i := range tasks {
        tasks[i] = Task{
            Duration: time.Duration(100+rand.Intn(50)) * time.Microsecond,
        }
    }
    return tasks
}

func createSkewedWorkload(processorID int) []Task {
    // 不均匀分布:前面的处理器任务更多
    var count int
    if processorID == 0 {
        count = 200
    } else if processorID == 1 {
        count = 150
    } else {
        count = 50
    }
    
    tasks := make([]Task, count)
    for i := range tasks {
        tasks[i] = Task{
            Duration: time.Duration(100+rand.Intn(100)) * time.Microsecond,
        }
    }
    return tasks
}

func createBurstWorkload(processorID int) []Task {
    // 突发负载:某个处理器有大量任务
    var count int
    if processorID == 0 {
        count = 300 // 大量任务
    } else {
        count = 10  // 很少任务
    }
    
    tasks := make([]Task, count)
    for i := range tasks {
        tasks[i] = Task{
            Duration: time.Duration(50+rand.Intn(150)) * time.Microsecond,
        }
    }
    return tasks
}

func demonstrateLoadBalancingAnalysis() {
    fmt.Println("\n--- 负载均衡性能分析 ---")
    
    // 对比不同调度策略的效果
    strategies := []struct {
        name     string
        strategy SchedulingStrategy
    }{
        {"无负载均衡", &NoBalancingStrategy{}},
        {"工作窃取", &WorkStealingStrategy{}},
        {"全局队列", &GlobalQueueStrategy{}},
    }
    
    for _, strategy := range strategies {
        fmt.Printf("\n--- %s策略 ---\n", strategy.name)
        
        analyzer := NewLoadBalanceAnalyzer(strategy.strategy)
        metrics := analyzer.Analyze()
        metrics.Print()
    }
}

type SchedulingStrategy interface {
    Schedule(tasks []Task, numWorkers int) time.Duration
    GetMetrics() *PerformanceMetrics
}

type PerformanceMetrics struct {
    TotalTime       time.Duration
    MaxWorkerTime   time.Duration
    MinWorkerTime   time.Duration
    LoadImbalance   float64
    Throughput      float64
    WorkerUtilization []float64
}

func (pm *PerformanceMetrics) Print() {
    fmt.Printf("  总时间: %v\n", pm.TotalTime)
    fmt.Printf("  最大工作时间: %v\n", pm.MaxWorkerTime)
    fmt.Printf("  最小工作时间: %v\n", pm.MinWorkerTime)
    fmt.Printf("  负载不平衡度: %.3f\n", pm.LoadImbalance)
    fmt.Printf("  吞吐量: %.0f tasks/sec\n", pm.Throughput)
    fmt.Printf("  平均利用率: %.2f%%\n", average(pm.WorkerUtilization)*100)
}

func average(values []float64) float64 {
    sum := 0.0
    for _, v := range values {
        sum += v
    }
    return sum / float64(len(values))
}

type LoadBalanceAnalyzer struct {
    strategy SchedulingStrategy
}

func NewLoadBalanceAnalyzer(strategy SchedulingStrategy) *LoadBalanceAnalyzer {
    return &LoadBalanceAnalyzer{strategy: strategy}
}

func (lba *LoadBalanceAnalyzer) Analyze() *PerformanceMetrics {
    // 生成测试任务
    tasks := generateTestTasks(1000)
    
    // 执行调度
    duration := lba.strategy.Schedule(tasks, 4)
    
    // 获取指标
    metrics := lba.strategy.GetMetrics()
    metrics.TotalTime = duration
    
    return metrics
}

func generateTestTasks(count int) []Task {
    tasks := make([]Task, count)
    for i := range tasks {
        tasks[i] = Task{
            ID:       i,
            Duration: time.Duration(50+rand.Intn(200)) * time.Microsecond,
        }
    }
    return tasks
}

// 策略实现(简化版本)
type NoBalancingStrategy struct {
    metrics *PerformanceMetrics
}

func (nbs *NoBalancingStrategy) Schedule(tasks []Task, numWorkers int) time.Duration {
    // 简单平均分配,不考虑负载均衡
    start := time.Now()
    
    var wg sync.WaitGroup
    chunkSize := len(tasks) / numWorkers
    workerTimes := make([]time.Duration, numWorkers)
    
    for i := 0; i < numWorkers; i++ {
        wg.Add(1)
        go func(workerID int) {
            defer wg.Done()
            
            startIdx := workerID * chunkSize
            endIdx := startIdx + chunkSize
            if workerID == numWorkers-1 {
                endIdx = len(tasks)
            }
            
            workerStart := time.Now()
            for j := startIdx; j < endIdx; j++ {
                time.Sleep(tasks[j].Duration)
            }
            workerTimes[workerID] = time.Since(workerStart)
        }(i)
    }
    
    wg.Wait()
    totalTime := time.Since(start)
    
    // 计算指标
    nbs.calculateMetrics(workerTimes, len(tasks), totalTime)
    
    return totalTime
}

func (nbs *NoBalancingStrategy) GetMetrics() *PerformanceMetrics {
    return nbs.metrics
}

func (nbs *NoBalancingStrategy) calculateMetrics(workerTimes []time.Duration, totalTasks int, totalTime time.Duration) {
    nbs.metrics = &PerformanceMetrics{
        WorkerUtilization: make([]float64, len(workerTimes)),
    }
    
    maxTime := time.Duration(0)
    minTime := time.Duration(^uint64(0) >> 1) // 最大值
    
    for i, workerTime := range workerTimes {
        if workerTime > maxTime {
            maxTime = workerTime
        }
        if workerTime < minTime {
            minTime = workerTime
        }
        
        nbs.metrics.WorkerUtilization[i] = float64(workerTime) / float64(totalTime)
    }
    
    nbs.metrics.MaxWorkerTime = maxTime
    nbs.metrics.MinWorkerTime = minTime
    nbs.metrics.LoadImbalance = float64(maxTime-minTime) / float64(maxTime)
    nbs.metrics.Throughput = float64(totalTasks) / totalTime.Seconds()
}

type WorkStealingStrategy struct {
    metrics *PerformanceMetrics
}

func (wss *WorkStealingStrategy) Schedule(tasks []Task, numWorkers int) time.Duration {
    // 使用前面实现的工作窃取算法
    stealer := NewWorkStealer(numWorkers)
    
    // 初始分配任务
    for i, task := range tasks {
        workerID := i % numWorkers
        stealer.AddTask(workerID, int(task.Duration/time.Microsecond))
    }
    
    start := time.Now()
    stealer.Execute()
    totalTime := time.Since(start)
    
    // 简化的指标计算
    wss.metrics = &PerformanceMetrics{
        TotalTime:         totalTime,
        Throughput:        float64(len(tasks)) / totalTime.Seconds(),
        LoadImbalance:     0.1, // 工作窃取通常有更好的负载均衡
        WorkerUtilization: make([]float64, numWorkers),
    }
    
    for i := range wss.metrics.WorkerUtilization {
        wss.metrics.WorkerUtilization[i] = 0.85 + rand.Float64()*0.1 // 模拟高利用率
    }
    
    return totalTime
}

func (wss *WorkStealingStrategy) GetMetrics() *PerformanceMetrics {
    return wss.metrics
}

type GlobalQueueStrategy struct {
    metrics *PerformanceMetrics
}

func (gqs *GlobalQueueStrategy) Schedule(tasks []Task, numWorkers int) time.Duration {
    // 全局队列策略(需要锁同步)
    start := time.Now()
    
    var mu sync.Mutex
    taskQueue := make([]Task, len(tasks))
    copy(taskQueue, tasks)
    
    var wg sync.WaitGroup
    for i := 0; i < numWorkers; i++ {
        wg.Add(1)
        go func() {
            defer wg.Done()
            
            for {
                mu.Lock()
                if len(taskQueue) == 0 {
                    mu.Unlock()
                    break
                }
                
                task := taskQueue[0]
                taskQueue = taskQueue[1:]
                mu.Unlock()
                
                time.Sleep(task.Duration)
            }
        }()
    }
    
    wg.Wait()
    totalTime := time.Since(start)
    
    // 全局队列通常有更多锁竞争,性能较差
    gqs.metrics = &PerformanceMetrics{
        TotalTime:         totalTime,
        Throughput:        float64(len(tasks)) / totalTime.Seconds(),
        LoadImbalance:     0.05, // 负载均衡好但有锁开销
        WorkerUtilization: make([]float64, numWorkers),
    }
    
    for i := range gqs.metrics.WorkerUtilization {
        gqs.metrics.WorkerUtilization[i] = 0.7 + rand.Float64()*0.15 // 模拟锁竞争影响
    }
    
    return totalTime
}

func (gqs *GlobalQueueStrategy) GetMetrics() *PerformanceMetrics {
    return gqs.metrics
}

func demonstrateStealingStrategies() {
    fmt.Println("\n--- 不同窃取策略对比 ---")
    
    strategies := []struct {
        name        string
        description string
        efficiency  float64
    }{
        {"随机窃取", "随机选择目标P,实现简单但可能不均匀", 0.6},
        {"轮询窃取", "按顺序遍历所有P,保证公平性", 0.7},
        {"负载感知窃取", "优先从负载最重的P窃取", 0.8},
        {"局部性窃取", "优先从相邻的P窃取,提高缓存局部性", 0.75},
    }
    
    for _, strategy := range strategies {
        fmt.Printf("\n%s:\n", strategy.name)
        fmt.Printf("  描述: %s\n", strategy.description)
        fmt.Printf("  效率: %.1f%%\n", strategy.efficiency*100)
    }
    
    fmt.Println("\nGo运行时使用的策略:")
    fmt.Println("  - 主要使用随机窃取策略")
    fmt.Println("  - 每次随机选择一个P进行窃取")
    fmt.Println("  - 从目标P的队列尾部窃取(LIFO)")
    fmt.Println("  - 一次窃取一半的工作量")
    fmt.Println("  - 无锁实现,使用CAS操作")
}

func main() {
    demonstrateWorkStealing()
}
go
func demonstrateWorkStealingOptimization() {
    fmt.Println("\n=== 工作窃取性能优化 ===")
    
    // 优化技术1:队列结构优化
    demonstrateQueueOptimization()
    
    // 优化技术2:窃取策略优化
    demonstrateStealingOptimization()
    
    // 优化技术3:NUMA感知调度
    demonstrateNUMAAwareScheduling()
    
    // 优化技术4:动态调整策略
    demonstrateDynamicTuning()
}

func demonstrateQueueOptimization() {
    fmt.Println("\n--- 队列结构优化 ---")
    
    /*
    队列优化技术:
    
    1. 双端队列(Deque)设计:
       - 本地操作使用队列头(LIFO)
       - 窃取操作使用队列尾(FIFO)
       - 减少缓存一致性开销
    
    2. 无锁队列实现:
       - 使用原子操作(CAS)
       - 避免锁竞争
       - 处理ABA问题
    
    3. 内存布局优化:
       - 缓存行对齐
       - 避免false sharing
       - 预分配内存
    */
    
    // 对比不同队列实现的性能
    queueTypes := []struct {
        name  string
        queue WorkQueueInterface
    }{
        {"简单队列", NewSimpleQueue()},
        {"无锁队列", NewLockFreeQueue()},
        {"优化队列", NewOptimizedQueue()},
    }
    
    const numOperations = 100000
    
    for _, qt := range queueTypes {
        fmt.Printf("\n测试 %s:\n", qt.name)
        
        // 测试单线程性能
        start := time.Now()
        for i := 0; i < numOperations; i++ {
            qt.queue.Push(i)
        }
        for i := 0; i < numOperations; i++ {
            qt.queue.Pop()
        }
        singleThreadTime := time.Since(start)
        
        fmt.Printf("  单线程耗时: %v\n", singleThreadTime)
        
        // 测试并发性能
        start = time.Now()
        var wg sync.WaitGroup
        
        // 生产者
        wg.Add(1)
        go func() {
            defer wg.Done()
            for i := 0; i < numOperations; i++ {
                qt.queue.Push(i)
            }
        }()
        
        // 消费者
        wg.Add(1)
        go func() {
            defer wg.Done()
            consumed := 0
            for consumed < numOperations {
                if _, ok := qt.queue.Pop(); ok {
                    consumed++
                }
                time.Sleep(1 * time.Microsecond)
            }
        }()
        
        wg.Wait()
        concurrentTime := time.Since(start)
        
        fmt.Printf("  并发耗时: %v\n", concurrentTime)
        fmt.Printf("  并发效率: %.2fx\n", float64(singleThreadTime)/float64(concurrentTime))
    }
}

type WorkQueueInterface interface {
    Push(item int)
    Pop() (int, bool)
    Size() int
}

// 简单队列实现(使用互斥锁)
type SimpleQueue struct {
    items []int
    mutex sync.Mutex
}

func NewSimpleQueue() *SimpleQueue {
    return &SimpleQueue{
        items: make([]int, 0, 1000),
    }
}

func (sq *SimpleQueue) Push(item int) {
    sq.mutex.Lock()
    defer sq.mutex.Unlock()
    sq.items = append(sq.items, item)
}

func (sq *SimpleQueue) Pop() (int, bool) {
    sq.mutex.Lock()
    defer sq.mutex.Unlock()
    
    if len(sq.items) == 0 {
        return 0, false
    }
    
    item := sq.items[0]
    sq.items = sq.items[1:]
    return item, true
}

func (sq *SimpleQueue) Size() int {
    sq.mutex.Lock()
    defer sq.mutex.Unlock()
    return len(sq.items)
}

// 无锁队列实现
type LockFreeQueue struct {
    items []int
    head  int64
    tail  int64
}

func NewLockFreeQueue() *LockFreeQueue {
    return &LockFreeQueue{
        items: make([]int, 100000), // 预分配足够大的空间
    }
}

func (lfq *LockFreeQueue) Push(item int) {
    for {
        tail := atomic.LoadInt64(&lfq.tail)
        next := (tail + 1) % int64(len(lfq.items))
        
        if next == atomic.LoadInt64(&lfq.head) {
            // 队列满,扩容或等待
            time.Sleep(1 * time.Microsecond)
            continue
        }
        
        lfq.items[tail] = item
        if atomic.CompareAndSwapInt64(&lfq.tail, tail, next) {
            break
        }
    }
}

func (lfq *LockFreeQueue) Pop() (int, bool) {
    for {
        head := atomic.LoadInt64(&lfq.head)
        tail := atomic.LoadInt64(&lfq.tail)
        
        if head == tail {
            return 0, false
        }
        
        item := lfq.items[head]
        next := (head + 1) % int64(len(lfq.items))
        
        if atomic.CompareAndSwapInt64(&lfq.head, head, next) {
            return item, true
        }
    }
}

func (lfq *LockFreeQueue) Size() int {
    head := atomic.LoadInt64(&lfq.head)
    tail := atomic.LoadInt64(&lfq.tail)
    
    if tail >= head {
        return int(tail - head)
    }
    return int(int64(len(lfq.items)) - head + tail)
}

// 优化队列实现(缓存行对齐)
type OptimizedQueue struct {
    // 使用padding避免false sharing
    head     int64
    _        [7]int64 // 缓存行填充
    tail     int64
    _        [7]int64 // 缓存行填充
    items    []int
    capacity int64
}

func NewOptimizedQueue() *OptimizedQueue {
    capacity := int64(100000)
    return &OptimizedQueue{
        items:    make([]int, capacity),
        capacity: capacity,
    }
}

func (oq *OptimizedQueue) Push(item int) {
    for {
        tail := atomic.LoadInt64(&oq.tail)
        next := (tail + 1) % oq.capacity
        
        if next == atomic.LoadInt64(&oq.head) {
            time.Sleep(1 * time.Microsecond)
            continue
        }
        
        oq.items[tail] = item
        if atomic.CompareAndSwapInt64(&oq.tail, tail, next) {
            break
        }
    }
}

func (oq *OptimizedQueue) Pop() (int, bool) {
    for {
        head := atomic.LoadInt64(&oq.head)
        tail := atomic.LoadInt64(&oq.tail)
        
        if head == tail {
            return 0, false
        }
        
        item := oq.items[head]
        next := (head + 1) % oq.capacity
        
        if atomic.CompareAndSwapInt64(&oq.head, head, next) {
            return item, true
        }
    }
}

func (oq *OptimizedQueue) Size() int {
    head := atomic.LoadInt64(&oq.head)
    tail := atomic.LoadInt64(&oq.tail)
    
    if tail >= head {
        return int(tail - head)
    }
    return int(oq.capacity - head + tail)
}

func demonstrateStealingOptimization() {
    fmt.Println("\n--- 窃取策略优化 ---")
    
    /*
    窃取策略优化:
    
    1. 自适应窃取频率:
       - 根据负载动态调整窃取频率
       - 避免过度窃取导致的开销
    
    2. 批量窃取优化:
       - 一次窃取多个任务
       - 减少窃取操作的频率
    
    3. 窃取目标选择:
       - 负载感知的目标选择
       - 避免从空队列窃取
    
    4. 窃取失败处理:
       - 退避策略避免busy waiting
       - 动态调整重试间隔
    */
    
    optimizations := []struct {
        name        string
        description string
        apply       func(*AdvancedWorkStealer)
    }{
        {
            "基础窃取",
            "固定频率随机窃取",
            func(ws *AdvancedWorkStealer) {
                ws.SetStealingPolicy(&BasicStealingPolicy{})
            },
        },
        {
            "自适应窃取",
            "根据负载调整窃取频率",
            func(ws *AdvancedWorkStealer) {
                ws.SetStealingPolicy(&AdaptiveStealingPolicy{})
            },
        },
        {
            "负载感知窃取",
            "优先从高负载队列窃取",
            func(ws *AdvancedWorkStealer) {
                ws.SetStealingPolicy(&LoadAwareStealingPolicy{})
            },
        },
        {
            "混合优化窃取",
            "结合多种优化策略",
            func(ws *AdvancedWorkStealer) {
                ws.SetStealingPolicy(&HybridStealingPolicy{})
            },
        },
    }
    
    for _, opt := range optimizations {
        fmt.Printf("\n--- %s ---\n", opt.name)
        fmt.Printf("描述: %s\n", opt.description)
        
        // 创建工作窃取器
        ws := NewAdvancedWorkStealer(4)
        opt.apply(ws)
        
        // 添加不均衡的工作负载
        for i := 0; i < 4; i++ {
            var taskCount int
            if i == 0 {
                taskCount = 200 // 重负载
            } else {
                taskCount = 20  // 轻负载
            }
            
            for j := 0; j < taskCount; j++ {
                ws.AddTask(i, j)
            }
        }
        
        // 执行并测量性能
        start := time.Now()
        stats := ws.Execute()
        duration := time.Since(start)
        
        fmt.Printf("执行时间: %v\n", duration)
        fmt.Printf("负载均衡度: %.3f\n", stats.LoadBalance)
        fmt.Printf("窃取效率: %.3f\n", stats.StealEfficiency)
        fmt.Printf("平均队列长度: %.1f\n", stats.AvgQueueLength)
    }
}

// 高级工作窃取器
type AdvancedWorkStealer struct {
    processors []*AdvancedProcessor
    policy     StealingPolicy
    numP       int
}

type AdvancedProcessor struct {
    id            int
    queue         *OptimizedQueue
    processed     int64
    stolen        int64
    stealAttempts int64
    stealFails    int64
}

type StealingPolicy interface {
    ShouldSteal(processor *AdvancedProcessor, allProcessors []*AdvancedProcessor) bool
    SelectTarget(processor *AdvancedProcessor, allProcessors []*AdvancedProcessor) int
    GetBatchSize(targetQueue *OptimizedQueue) int
}

type AdvancedStats struct {
    LoadBalance     float64
    StealEfficiency float64
    AvgQueueLength  float64
    TotalProcessed  int64
    TotalStolen     int64
}

func NewAdvancedWorkStealer(numP int) *AdvancedWorkStealer {
    processors := make([]*AdvancedProcessor, numP)
    for i := range processors {
        processors[i] = &AdvancedProcessor{
            id:    i,
            queue: NewOptimizedQueue(),
        }
    }
    
    return &AdvancedWorkStealer{
        processors: processors,
        numP:       numP,
        policy:     &BasicStealingPolicy{},
    }
}

func (aws *AdvancedWorkStealer) SetStealingPolicy(policy StealingPolicy) {
    aws.policy = policy
}

func (aws *AdvancedWorkStealer) AddTask(processorID, task int) {
    aws.processors[processorID].queue.Push(task)
}

func (aws *AdvancedWorkStealer) Execute() *AdvancedStats {
    var wg sync.WaitGroup
    
    for i := range aws.processors {
        wg.Add(1)
        go aws.runAdvancedProcessor(i, &wg)
    }
    
    wg.Wait()
    return aws.calculateAdvancedStats()
}

func (aws *AdvancedWorkStealer) runAdvancedProcessor(processorID int, wg *sync.WaitGroup) {
    defer wg.Done()
    
    processor := aws.processors[processorID]
    
    for {
        // 尝试处理本地任务
        if task, ok := processor.queue.Pop(); ok {
            // 模拟任务处理
            time.Sleep(time.Duration(task%100) * time.Microsecond)
            atomic.AddInt64(&processor.processed, 1)
            continue
        }
        
        // 检查是否应该尝试窃取
        if aws.policy.ShouldSteal(processor, aws.processors) {
            target := aws.policy.SelectTarget(processor, aws.processors)
            if target != -1 && target != processorID {
                if aws.stealFromTarget(processor, target) {
                    continue
                }
            }
        }
        
        // 检查是否所有队列都空了
        if aws.allQueuesEmpty() {
            break
        }
        
        // 退避等待
        time.Sleep(10 * time.Microsecond)
    }
}

func (aws *AdvancedWorkStealer) stealFromTarget(stealer *AdvancedProcessor, targetID int) bool {
    atomic.AddInt64(&stealer.stealAttempts, 1)
    
    target := aws.processors[targetID]
    batchSize := aws.policy.GetBatchSize(target.queue)
    
    stolenCount := 0
    for i := 0; i < batchSize; i++ {
        if task, ok := target.queue.Pop(); ok {
            stealer.queue.Push(task)
            stolenCount++
        } else {
            break
        }
    }
    
    if stolenCount > 0 {
        atomic.AddInt64(&stealer.stolen, int64(stolenCount))
        return true
    }
    
    atomic.AddInt64(&stealer.stealFails, 1)
    return false
}

func (aws *AdvancedWorkStealer) allQueuesEmpty() bool {
    for _, processor := range aws.processors {
        if processor.queue.Size() > 0 {
            return false
        }
    }
    return true
}

func (aws *AdvancedWorkStealer) calculateAdvancedStats() *AdvancedStats {
    stats := &AdvancedStats{}
    
    totalProcessed := int64(0)
    totalStolen := int64(0)
    totalAttempts := int64(0)
    processorLoads := make([]int64, aws.numP)
    
    for i, processor := range aws.processors {
        processed := atomic.LoadInt64(&processor.processed)
        stolen := atomic.LoadInt64(&processor.stolen)
        attempts := atomic.LoadInt64(&processor.stealAttempts)
        
        totalProcessed += processed
        totalStolen += stolen
        totalAttempts += attempts
        processorLoads[i] = processed
    }
    
    // 计算负载均衡度
    avgLoad := float64(totalProcessed) / float64(aws.numP)
    variance := 0.0
    for _, load := range processorLoads {
        diff := float64(load) - avgLoad
        variance += diff * diff
    }
    stats.LoadBalance = 1.0 / (1.0 + variance/float64(aws.numP))
    
    // 计算窃取效率
    if totalAttempts > 0 {
        stats.StealEfficiency = float64(totalStolen) / float64(totalAttempts)
    }
    
    stats.TotalProcessed = totalProcessed
    stats.TotalStolen = totalStolen
    
    return stats
}

// 不同的窃取策略实现
type BasicStealingPolicy struct{}

func (bsp *BasicStealingPolicy) ShouldSteal(processor *AdvancedProcessor, allProcessors []*AdvancedProcessor) bool {
    return processor.queue.Size() == 0
}

func (bsp *BasicStealingPolicy) SelectTarget(processor *AdvancedProcessor, allProcessors []*AdvancedProcessor) int {
    return rand.Intn(len(allProcessors))
}

func (bsp *BasicStealingPolicy) GetBatchSize(targetQueue *OptimizedQueue) int {
    return 1
}

type AdaptiveStealingPolicy struct {
    lastStealTime map[int]time.Time
}

func (asp *AdaptiveStealingPolicy) ShouldSteal(processor *AdvancedProcessor, allProcessors []*AdvancedProcessor) bool {
    if processor.queue.Size() > 0 {
        return false
    }
    
    if asp.lastStealTime == nil {
        asp.lastStealTime = make(map[int]time.Time)
    }
    
    // 自适应间隔:根据最近的窃取成功率调整
    lastTime, exists := asp.lastStealTime[processor.id]
    if !exists {
        asp.lastStealTime[processor.id] = time.Now()
        return true
    }
    
    elapsed := time.Since(lastTime)
    stealAttempts := atomic.LoadInt64(&processor.stealAttempts)
    stealFails := atomic.LoadInt64(&processor.stealFails)
    
    var interval time.Duration
    if stealAttempts > 0 {
        failRate := float64(stealFails) / float64(stealAttempts)
        // 失败率高时增加间隔
        interval = time.Duration(failRate*1000) * time.Microsecond
    } else {
        interval = 100 * time.Microsecond
    }
    
    if elapsed > interval {
        asp.lastStealTime[processor.id] = time.Now()
        return true
    }
    
    return false
}

func (asp *AdaptiveStealingPolicy) SelectTarget(processor *AdvancedProcessor, allProcessors []*AdvancedProcessor) int {
    return rand.Intn(len(allProcessors))
}

func (asp *AdaptiveStealingPolicy) GetBatchSize(targetQueue *OptimizedQueue) int {
    size := targetQueue.Size()
    if size > 10 {
        return size / 2
    }
    return 1
}

type LoadAwareStealingPolicy struct{}

func (lasp *LoadAwareStealingPolicy) ShouldSteal(processor *AdvancedProcessor, allProcessors []*AdvancedProcessor) bool {
    return processor.queue.Size() == 0
}

func (lasp *LoadAwareStealingPolicy) SelectTarget(processor *AdvancedProcessor, allProcessors []*AdvancedProcessor) int {
    maxLoad := -1
    targetID := -1
    
    for i, p := range allProcessors {
        if i == processor.id {
            continue
        }
        
        load := p.queue.Size()
        if load > maxLoad {
            maxLoad = load
            targetID = i
        }
    }
    
    return targetID
}

func (lasp *LoadAwareStealingPolicy) GetBatchSize(targetQueue *OptimizedQueue) int {
    size := targetQueue.Size()
    if size > 4 {
        return size / 2
    }
    return 1
}

type HybridStealingPolicy struct {
    adaptive   *AdaptiveStealingPolicy
    loadAware  *LoadAwareStealingPolicy
}

func (hsp *HybridStealingPolicy) ShouldSteal(processor *AdvancedProcessor, allProcessors []*AdvancedProcessor) bool {
    if hsp.adaptive == nil {
        hsp.adaptive = &AdaptiveStealingPolicy{}
    }
    return hsp.adaptive.ShouldSteal(processor, allProcessors)
}

func (hsp *HybridStealingPolicy) SelectTarget(processor *AdvancedProcessor, allProcessors []*AdvancedProcessor) int {
    if hsp.loadAware == nil {
        hsp.loadAware = &LoadAwareStealingPolicy{}
    }
    return hsp.loadAware.SelectTarget(processor, allProcessors)
}

func (hsp *HybridStealingPolicy) GetBatchSize(targetQueue *OptimizedQueue) int {
    if hsp.adaptive == nil {
        hsp.adaptive = &AdaptiveStealingPolicy{}
    }
    return hsp.adaptive.GetBatchSize(targetQueue)
}

func demonstrateNUMAAwareScheduling() {
    fmt.Println("\n--- NUMA感知调度优化 ---")
    
    /*
    NUMA感知调度优化:
    
    1. 处理器亲和性:
       - 尽量在同一NUMA节点内调度
       - 减少跨节点内存访问
    
    2. 内存局部性:
       - 任务数据就近分配
       - 减少远程内存访问延迟
    
    3. 窃取层次化:
       - 优先从同节点的P窃取
       - 其次从相邻节点窃取
    */
    
    fmt.Println("NUMA感知优化策略:")
    fmt.Println("  1. 节点内优先窃取")
    fmt.Println("  2. 跨节点窃取惩罚")
    fmt.Println("  3. 内存亲和性优化")
    fmt.Println("  4. 任务迁移最小化")
    
    // 模拟NUMA拓扑
    numaTopology := map[int][]int{
        0: {0, 1},    // 节点0包含处理器0,1
        1: {2, 3},    // 节点1包含处理器2,3
    }
    
    fmt.Printf("\nNUMA拓扑: %v\n", numaTopology)
    
    // 模拟不同的调度策略
    fmt.Println("\n调度策略对比:")
    fmt.Println("  - 随机调度: 忽略NUMA拓扑")
    fmt.Println("  - NUMA感知: 考虑节点亲和性")
    
    // 简化的性能影响模拟
    fmt.Println("\n性能影响 (模拟):")
    fmt.Println("  - 本地内存访问: 100ns")
    fmt.Println("  - 远程内存访问: 200ns")
    fmt.Println("  - 跨节点窃取延迟: +50%")
}

func demonstrateDynamicTuning() {
    fmt.Println("\n--- 动态调整策略 ---")
    
    /*
    动态调整策略:
    
    1. 运行时监控:
       - 窃取成功率
       - 队列长度分布
       - 处理器利用率
    
    2. 自适应参数:
       - 窃取频率
       - 批量大小
       - 退避策略
    
    3. 负载预测:
       - 基于历史数据预测负载
       - 主动进行负载均衡
    */
    
    tuningParams := []struct {
        name        string
        description string
        impact      string
    }{
        {
            "窃取频率",
            "根据队列状态动态调整窃取尝试频率",
            "减少无效窃取,提高效率",
        },
        {
            "批量大小",
            "根据目标队列大小调整一次窃取的任务数量",
            "平衡窃取开销和负载均衡效果",
        },
        {
            "退避策略",
            "窃取失败后的等待时间动态调整",
            "避免busy waiting,减少CPU浪费",
        },
        {
            "目标选择",
            "智能选择窃取目标,避免无效尝试",
            "提高窃取成功率,减少系统开销",
        },
    }
    
    for _, param := range tuningParams {
        fmt.Printf("\n%s:\n", param.name)
        fmt.Printf("  策略: %s\n", param.description)
        fmt.Printf("  影响: %s\n", param.impact)
    }
    
    fmt.Println("\n动态调整算法:")
    fmt.Println("  1. 收集运行时统计信息")
    fmt.Println("  2. 分析性能瓶颈")
    fmt.Println("  3. 调整相关参数")
    fmt.Println("  4. 评估调整效果")
    fmt.Println("  5. 持续优化迭代")
}

func main() {
    demonstrateWorkStealing()
    demonstrateWorkStealingOptimization()
}

🎯 核心知识点总结

工作窃取基本原理要点

  1. 负载均衡: 自动将工作从忙碌的P转移到空闲的P
  2. 双端队列: 本地操作用头部(LIFO),窃取用尾部(FIFO)
  3. 随机选择: 简单有效的目标P选择策略
  4. 批量窃取: 一次窃取一半任务,减少窃取频率

算法实现要点

  1. 无锁设计: 使用CAS操作避免锁竞争
  2. 队列管理: 维护head/tail指针,处理边界条件
  3. 窃取时机: 本地队列为空时才尝试窃取
  4. 失败处理: 合理的退避策略避免busy waiting

性能优化要点

  1. 队列优化: 缓存行对齐、避免false sharing
  2. 策略优化: 自适应频率、负载感知、批量窃取
  3. NUMA感知: 考虑处理器亲和性和内存局部性
  4. 动态调优: 运行时监控和参数自适应调整

负载均衡效果要点

  1. 吞吐量提升: 减少处理器空闲时间
  2. 响应延迟: 避免任务在某个队列中长时间等待
  3. 系统伸缩性: 支持动态添加/移除处理器
  4. 资源利用率: 提高整体系统资源利用效率

🔍 面试准备建议

  1. 理解核心算法: 深入掌握工作窃取的设计原理和实现细节
  2. 分析性能特征: 了解不同场景下工作窃取的效果
  3. 掌握优化技术: 熟悉各种性能优化策略和适用场景
  4. 实践经验: 在并发程序中观察和分析负载均衡效果
  5. 系统思维: 理解工作窃取在整个调度系统中的作用

正在精进