Skip to content

Context上下文管理详解 - Golang并发编程面试题

Context是Go语言中用于管理请求生命周期、传递取消信号和截止时间的核心机制。本章深入探讨Context的设计原理、使用模式和最佳实践。

📋 重点面试题

面试题 1:Context的基本概念和使用

难度级别:⭐⭐⭐⭐
考察范围:并发控制/生命周期管理
技术标签context cancellation timeout deadline value passing

详细解答

1. Context基本概念和类型

点击查看完整代码实现
点击查看完整代码实现
go
package main

import (
    "context"
    "fmt"
    "sync"
    "time"
)

func demonstrateContextBasics() {
    fmt.Println("=== Context基本概念 ===")
    
    /*
    Context主要用途:
    1. 取消传播:将取消信号传播到整个调用链
    2. 超时控制:设置操作的超时时间
    3. 截止时间:设置绝对的截止时间
    4. 值传递:在调用链中传递请求级别的数据
    
    Context类型:
    - Background:根上下文,永不取消
    - TODO:占位符上下文,用于不确定的情况
    - WithCancel:可手动取消的上下文
    - WithTimeout:带超时的上下文
    - WithDeadline:带截止时间的上下文
    - WithValue:携带值的上下文
    */
    
    // 演示不同类型的Context
    demonstrateContextTypes()
    
    // 演示Context的取消传播
    demonstrateCancellationPropagation()
    
    // 演示Context的超时控制
    demonstrateTimeoutControl()
    
    // 演示Context传递值
    demonstrateValuePassing()
}

func demonstrateContextTypes() {
    fmt.Println("\n--- Context类型演示 ---")
    
    // 1. Background Context
    bgCtx := context.Background()
    fmt.Printf("Background Context: %T\n", bgCtx)
    
    // 2. TODO Context
    todoCtx := context.TODO()
    fmt.Printf("TODO Context: %T\n", todoCtx)
    
    // 3. WithCancel Context
    cancelCtx, cancel := context.WithCancel(bgCtx)
    fmt.Printf("Cancel Context: %T\n", cancelCtx)
    defer cancel()
    
    // 4. WithTimeout Context
    timeoutCtx, timeoutCancel := context.WithTimeout(bgCtx, 1*time.Second)
    fmt.Printf("Timeout Context: %T\n", timeoutCtx)
    defer timeoutCancel()
    
    // 5. WithDeadline Context
    deadline := time.Now().Add(2 * time.Second)
    deadlineCtx, deadlineCancel := context.WithDeadline(bgCtx, deadline)
    fmt.Printf("Deadline Context: %T\n", deadlineCtx)
    defer deadlineCancel()
    
    // 6. WithValue Context
    valueCtx := context.WithValue(bgCtx, "user_id", "12345")
    fmt.Printf("Value Context: %T\n", valueCtx)
    
    // 检查Context状态
    fmt.Printf("Cancel Context Done: %v\n", cancelCtx.Done() != nil)
    fmt.Printf("Timeout Context Deadline: %v\n", timeoutCtx.Deadline())
    fmt.Printf("Value Context Value: %v\n", valueCtx.Value("user_id"))
}

func demonstrateCancellationPropagation() {
    fmt.Println("\n--- 取消传播演示 ---")
    
    // 创建可取消的根上下文
    rootCtx, rootCancel := context.WithCancel(context.Background())
    defer rootCancel()
    
    var wg sync.WaitGroup
    
    // 启动多层嵌套的goroutine
    wg.Add(1)
    go func() {
        defer wg.Done()
        level1Worker(rootCtx, "Worker-1")
    }()
    
    wg.Add(1)
    go func() {
        defer wg.Done()
        level1Worker(rootCtx, "Worker-2")
    }()
    
    // 让worker运行一段时间
    time.Sleep(500 * time.Millisecond)
    
    // 取消根上下文,观察取消传播
    fmt.Println("取消根上下文...")
    rootCancel()
    
    wg.Wait()
    fmt.Println("所有worker已完成")
}

func level1Worker(ctx context.Context, name string) {
    fmt.Printf("%s 启动\n", name)
    defer fmt.Printf("%s 退出\n", name)
    
    // 创建子上下文
    childCtx, cancel := context.WithCancel(ctx)
    defer cancel()
    
    var wg sync.WaitGroup
    
    // 启动子worker
    for i := 0; i < 2; i++ {
        wg.Add(1)
        go func(id int) {
            defer wg.Done()
            level2Worker(childCtx, fmt.Sprintf("%s-Sub%d", name, id))
        }(i)
    }
    
    // 等待取消信号或子worker完成
    select {
    case <-ctx.Done():
        fmt.Printf("%s 收到取消信号: %v\n", name, ctx.Err())
    }
    
    wg.Wait()
}

func level2Worker(ctx context.Context, name string) {
    fmt.Printf("  %s 启动\n", name)
    defer fmt.Printf("  %s 退出\n", name)
    
    ticker := time.NewTicker(100 * time.Millisecond)
    defer ticker.Stop()
    
    for {
        select {
        case <-ctx.Done():
            fmt.Printf("  %s 收到取消信号: %v\n", name, ctx.Err())
            return
            
        case <-ticker.C:
            fmt.Printf("  %s 工作中...\n", name)
        }
    }
}

func demonstrateTimeoutControl() {
    fmt.Println("\n--- 超时控制演示 ---")
    
    // 测试不同的超时场景
    testCases := []struct {
        name    string
        timeout time.Duration
        work    time.Duration
    }{
        {"快速任务", 200 * time.Millisecond, 100 * time.Millisecond},
        {"慢速任务", 200 * time.Millisecond, 300 * time.Millisecond},
        {"边界任务", 200 * time.Millisecond, 200 * time.Millisecond},
    }
    
    for _, tc := range testCases {
        fmt.Printf("\n测试: %s\n", tc.name)
        
        ctx, cancel := context.WithTimeout(context.Background(), tc.timeout)
        
        result := make(chan string, 1)
        
        // 启动工作goroutine
        go func() {
            defer cancel()
            
            // 模拟工作
            select {
            case <-time.After(tc.work):
                result <- "工作完成"
                
            case <-ctx.Done():
                result <- fmt.Sprintf("工作被取消: %v", ctx.Err())
                return
            }
        }()
        
        // 等待结果或超时
        select {
        case res := <-result:
            fmt.Printf("结果: %s\n", res)
            
        case <-ctx.Done():
            fmt.Printf("上下文超时: %v\n", ctx.Err())
        }
        
        cancel()
    }
}

func demonstrateValuePassing() {
    fmt.Println("\n--- 值传递演示 ---")
    
    // 定义键类型(避免字符串键冲突)
    type contextKey string
    
    const (
        userIDKey    contextKey = "user_id"
        requestIDKey contextKey = "request_id"
        traceIDKey   contextKey = "trace_id"
    )
    
    // 创建带值的上下文
    ctx := context.Background()
    ctx = context.WithValue(ctx, userIDKey, "user-123")
    ctx = context.WithValue(ctx, requestIDKey, "req-456")
    ctx = context.WithValue(ctx, traceIDKey, "trace-789")
    
    // 模拟请求处理流程
    handleRequest(ctx)
}

func handleRequest(ctx context.Context) {
    fmt.Println("处理请求...")
    
    // 从上下文获取值
    userID := ctx.Value(contextKey("user_id"))
    requestID := ctx.Value(contextKey("request_id"))
    traceID := ctx.Value(contextKey("trace_id"))
    
    fmt.Printf("用户ID: %v\n", userID)
    fmt.Printf("请求ID: %v\n", requestID)
    fmt.Printf("跟踪ID: %v\n", traceID)
    
    // 传递给下一层
    processData(ctx)
}

func processData(ctx context.Context) {
    fmt.Println("处理数据...")
    
    // 获取用户信息
    if userID, ok := ctx.Value(contextKey("user_id")).(string); ok {
        fmt.Printf("为用户 %s 处理数据\n", userID)
    }
    
    // 调用外部服务
    callExternalService(ctx)
}

func callExternalService(ctx context.Context) {
    fmt.Println("调用外部服务...")
    
    // 为外部调用设置超时
    serviceCtx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
    defer cancel()
    
    // 模拟外部服务调用
    done := make(chan bool)
    
    go func() {
        // 模拟服务响应时间
        time.Sleep(50 * time.Millisecond)
        done <- true
    }()
    
    select {
    case <-done:
        fmt.Println("外部服务调用成功")
        
    case <-serviceCtx.Done():
        fmt.Printf("外部服务调用超时: %v\n", serviceCtx.Err())
    }
}

:::

面试题 2:Context的高级用法和最佳实践

难度级别:⭐⭐⭐⭐⭐
考察范围:高级并发模式/最佳实践
技术标签context patterns request scoped data graceful shutdown middleware

详细解答

1. Context的高级使用模式

点击查看完整代码实现
点击查看完整代码实现
go
func demonstrateAdvancedContextPatterns() {
    fmt.Println("\n=== Context高级使用模式 ===")
    
    // 模式1:请求级别的数据管理
    demonstrateRequestScopedData()
    
    // 模式2:优雅关闭模式
    demonstrateGracefulShutdown()
    
    // 模式3:中间件模式
    demonstrateMiddlewarePattern()
    
    // 模式4:扇出扇入模式
    demonstrateFanOutFanIn()
}

func demonstrateRequestScopedData() {
    fmt.Println("\n--- 请求级别数据管理 ---")
    
    // 请求上下文构建器
    type RequestContext struct {
        UserID    string
        RequestID string
        TraceID   string
        StartTime time.Time
        Logger    *Logger
    }
    
    type Logger struct {
        prefix string
    }
    
    func (l *Logger) Info(msg string) {
        fmt.Printf("[INFO] %s: %s\n", l.prefix, msg)
    }
    
    func (l *Logger) Error(msg string, err error) {
        fmt.Printf("[ERROR] %s: %s - %v\n", l.prefix, msg, err)
    }
    
    // 构建请求上下文
    buildRequestContext := func(userID string) context.Context {
        reqCtx := &RequestContext{
            UserID:    userID,
            RequestID: fmt.Sprintf("req-%d", time.Now().UnixNano()),
            TraceID:   fmt.Sprintf("trace-%d", time.Now().UnixNano()),
            StartTime: time.Now(),
            Logger:    &Logger{prefix: fmt.Sprintf("[%s]", userID)},
        }
        
        ctx := context.Background()
        ctx = context.WithValue(ctx, "request_context", reqCtx)
        
        return ctx
    }
    
    // 获取请求上下文
    getRequestContext := func(ctx context.Context) *RequestContext {
        if reqCtx, ok := ctx.Value("request_context").(*RequestContext); ok {
            return reqCtx
        }
        return nil
    }
    
    // 模拟HTTP请求处理
    handleHTTPRequest := func(userID string) {
        ctx := buildRequestContext(userID)
        reqCtx := getRequestContext(ctx)
        
        reqCtx.Logger.Info("开始处理请求")
        
        // 处理业务逻辑
        if err := businessLogic(ctx); err != nil {
            reqCtx.Logger.Error("业务逻辑处理失败", err)
        } else {
            reqCtx.Logger.Info("请求处理成功")
        }
        
        // 记录请求耗时
        duration := time.Since(reqCtx.StartTime)
        reqCtx.Logger.Info(fmt.Sprintf("请求处理耗时: %v", duration))
    }
    
    businessLogic := func(ctx context.Context) error {
        reqCtx := getRequestContext(ctx)
        reqCtx.Logger.Info("执行业务逻辑")
        
        // 模拟数据库查询
        if err := queryDatabase(ctx); err != nil {
            return err
        }
        
        // 模拟外部API调用
        if err := callExternalAPI(ctx); err != nil {
            return err
        }
        
        return nil
    }
    
    queryDatabase := func(ctx context.Context) error {
        reqCtx := getRequestContext(ctx)
        reqCtx.Logger.Info("查询数据库")
        
        // 设置数据库查询超时
        dbCtx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
        defer cancel()
        
        // 模拟数据库查询
        time.Sleep(50 * time.Millisecond)
        
        select {
        case <-dbCtx.Done():
            return fmt.Errorf("数据库查询超时")
        default:
            reqCtx.Logger.Info("数据库查询成功")
            return nil
        }
    }
    
    callExternalAPI := func(ctx context.Context) error {
        reqCtx := getRequestContext(ctx)
        reqCtx.Logger.Info("调用外部API")
        
        // 模拟API调用
        time.Sleep(30 * time.Millisecond)
        reqCtx.Logger.Info("外部API调用成功")
        
        return nil
    }
    
    // 模拟多个并发请求
    var wg sync.WaitGroup
    users := []string{"alice", "bob", "charlie"}
    
    for _, user := range users {
        wg.Add(1)
        go func(userID string) {
            defer wg.Done()
            handleHTTPRequest(userID)
        }(user)
    }
    
    wg.Wait()
}

func demonstrateGracefulShutdown() {
    fmt.Println("\n--- 优雅关闭模式 ---")
    
    // 服务器管理器
    type Server struct {
        ctx    context.Context
        cancel context.CancelFunc
        wg     sync.WaitGroup
    }
    
    func NewServer() *Server {
        ctx, cancel := context.WithCancel(context.Background())
        return &Server{
            ctx:    ctx,
            cancel: cancel,
        }
    }
    
    func (s *Server) Start() {
        fmt.Println("服务器启动")
        
        // 启动HTTP服务器
        s.wg.Add(1)
        go s.runHTTPServer()
        
        // 启动后台工作者
        s.wg.Add(1)
        go s.runBackgroundWorker()
        
        // 启动定时任务
        s.wg.Add(1)
        go s.runScheduledTasks()
    }
    
    func (s *Server) Shutdown() {
        fmt.Println("开始优雅关闭...")
        s.cancel()
        s.wg.Wait()
        fmt.Println("服务器已完全关闭")
    }
    
    func (s *Server) runHTTPServer() {
        defer s.wg.Done()
        defer fmt.Println("HTTP服务器已停止")
        
        ticker := time.NewTicker(100 * time.Millisecond)
        defer ticker.Stop()
        
        for {
            select {
            case <-s.ctx.Done():
                fmt.Println("HTTP服务器收到关闭信号")
                return
                
            case <-ticker.C:
                // 模拟处理HTTP请求
                fmt.Println("处理HTTP请求...")
            }
        }
    }
    
    func (s *Server) runBackgroundWorker() {
        defer s.wg.Done()
        defer fmt.Println("后台工作者已停止")
        
        ticker := time.NewTicker(200 * time.Millisecond)
        defer ticker.Stop()
        
        for {
            select {
            case <-s.ctx.Done():
                fmt.Println("后台工作者收到关闭信号")
                // 完成当前正在处理的任务
                s.finishCurrentTasks()
                return
                
            case <-ticker.C:
                // 模拟后台任务
                fmt.Println("执行后台任务...")
            }
        }
    }
    
    func (s *Server) runScheduledTasks() {
        defer s.wg.Done()
        defer fmt.Println("定时任务已停止")
        
        ticker := time.NewTicker(300 * time.Millisecond)
        defer ticker.Stop()
        
        for {
            select {
            case <-s.ctx.Done():
                fmt.Println("定时任务收到关闭信号")
                return
                
            case <-ticker.C:
                // 模拟定时任务
                fmt.Println("执行定时任务...")
            }
        }
    }
    
    func (s *Server) finishCurrentTasks() {
        fmt.Println("完成当前正在处理的任务...")
        time.Sleep(50 * time.Millisecond)
        fmt.Println("当前任务已完成")
    }
    
    // 演示优雅关闭
    server := NewServer()
    server.Start()
    
    // 运行一段时间后关闭
    time.Sleep(1 * time.Second)
    server.Shutdown()
}

func demonstrateMiddlewarePattern() {
    fmt.Println("\n--- 中间件模式 ---")
    
    // 中间件类型定义
    type Handler func(ctx context.Context) error
    type Middleware func(Handler) Handler
    
    // 日志中间件
    loggingMiddleware := func(next Handler) Handler {
        return func(ctx context.Context) error {
            start := time.Now()
            fmt.Printf("[LOG] 请求开始 - %v\n", start)
            
            err := next(ctx)
            
            duration := time.Since(start)
            if err != nil {
                fmt.Printf("[LOG] 请求失败 - 耗时: %v, 错误: %v\n", duration, err)
            } else {
                fmt.Printf("[LOG] 请求成功 - 耗时: %v\n", duration)
            }
            
            return err
        }
    }
    
    // 超时中间件
    timeoutMiddleware := func(timeout time.Duration) Middleware {
        return func(next Handler) Handler {
            return func(ctx context.Context) error {
                timeoutCtx, cancel := context.WithTimeout(ctx, timeout)
                defer cancel()
                
                done := make(chan error, 1)
                
                go func() {
                    done <- next(timeoutCtx)
                }()
                
                select {
                case err := <-done:
                    return err
                case <-timeoutCtx.Done():
                    return fmt.Errorf("请求超时: %v", timeoutCtx.Err())
                }
            }
        }
    }
    
    // 重试中间件
    retryMiddleware := func(maxRetries int) Middleware {
        return func(next Handler) Handler {
            return func(ctx context.Context) error {
                var lastErr error
                
                for i := 0; i <= maxRetries; i++ {
                    if i > 0 {
                        fmt.Printf("[RETRY] 第 %d 次重试\n", i)
                        time.Sleep(time.Duration(i*100) * time.Millisecond)
                    }
                    
                    if err := next(ctx); err != nil {
                        lastErr = err
                        continue
                    }
                    
                    return nil
                }
                
                return fmt.Errorf("重试 %d 次后仍然失败: %v", maxRetries, lastErr)
            }
        }
    }
    
    // 业务处理器
    businessHandler := func(ctx context.Context) error {
        fmt.Println("[BUSINESS] 执行业务逻辑")
        
        // 模拟随机失败
        if time.Now().UnixNano()%3 == 0 {
            return fmt.Errorf("业务逻辑执行失败")
        }
        
        // 模拟处理时间
        time.Sleep(50 * time.Millisecond)
        fmt.Println("[BUSINESS] 业务逻辑执行成功")
        return nil
    }
    
    // 组合中间件
    handler := loggingMiddleware(
        timeoutMiddleware(200 * time.Millisecond)(
            retryMiddleware(2)(
                businessHandler,
            ),
        ),
    )
    
    // 执行处理器
    ctx := context.Background()
    if err := handler(ctx); err != nil {
        fmt.Printf("最终处理失败: %v\n", err)
    }
}

func demonstrateFanOutFanIn() {
    fmt.Println("\n--- 扇出扇入模式 ---")
    
    // 扇出:将任务分发给多个worker
    fanOut := func(ctx context.Context, input <-chan int, numWorkers int) []<-chan int {
        outputs := make([]<-chan int, numWorkers)
        
        for i := 0; i < numWorkers; i++ {
            output := make(chan int)
            outputs[i] = output
            
            go func(workerID int, out chan<- int) {
                defer close(out)
                
                for {
                    select {
                    case task, ok := <-input:
                        if !ok {
                            fmt.Printf("Worker %d: 输入通道关闭\n", workerID)
                            return
                        }
                        
                        // 处理任务
                        result := processTask(ctx, task, workerID)
                        
                        select {
                        case out <- result:
                            fmt.Printf("Worker %d: 处理任务 %d -> %d\n", workerID, task, result)
                        case <-ctx.Done():
                            fmt.Printf("Worker %d: 上下文取消\n", workerID)
                            return
                        }
                        
                    case <-ctx.Done():
                        fmt.Printf("Worker %d: 上下文取消\n", workerID)
                        return
                    }
                }
            }(i, output)
        }
        
        return outputs
    }
    
    // 扇入:将多个worker的结果合并
    fanIn := func(ctx context.Context, inputs ...<-chan int) <-chan int {
        output := make(chan int)
        var wg sync.WaitGroup
        
        for i, input := range inputs {
            wg.Add(1)
            go func(inputID int, in <-chan int) {
                defer wg.Done()
                
                for {
                    select {
                    case result, ok := <-in:
                        if !ok {
                            fmt.Printf("扇入 %d: 输入通道关闭\n", inputID)
                            return
                        }
                        
                        select {
                        case output <- result:
                            fmt.Printf("扇入 %d: 转发结果 %d\n", inputID, result)
                        case <-ctx.Done():
                            fmt.Printf("扇入 %d: 上下文取消\n", inputID)
                            return
                        }
                        
                    case <-ctx.Done():
                        fmt.Printf("扇入 %d: 上下文取消\n", inputID)
                        return
                    }
                }
            }(i, input)
        }
        
        go func() {
            wg.Wait()
            close(output)
            fmt.Println("扇入: 所有输入处理完成,关闭输出通道")
        }()
        
        return output
    }
    
    processTask := func(ctx context.Context, task, workerID int) int {
        // 检查上下文是否已取消
        select {
        case <-ctx.Done():
            return -1
        default:
        }
        
        // 模拟任务处理时间
        time.Sleep(time.Duration(50+workerID*10) * time.Millisecond)
        
        return task * 2
    }
    
    // 创建带超时的上下文
    ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
    defer cancel()
    
    // 创建输入通道
    input := make(chan int, 10)
    
    // 发送任务
    go func() {
        defer close(input)
        for i := 1; i <= 10; i++ {
            select {
            case input <- i:
                fmt.Printf("发送任务: %d\n", i)
            case <-ctx.Done():
                fmt.Println("任务发送被取消")
                return
            }
        }
    }()
    
    // 扇出到3个worker
    workerOutputs := fanOut(ctx, input, 3)
    
    // 扇入结果
    finalOutput := fanIn(ctx, workerOutputs...)
    
    // 收集最终结果
    var results []int
    for {
        select {
        case result, ok := <-finalOutput:
            if !ok {
                fmt.Printf("收集完成,总共收到 %d 个结果: %v\n", len(results), results)
                return
            }
            results = append(results, result)
            
        case <-ctx.Done():
            fmt.Printf("结果收集超时,已收到 %d 个结果: %v\n", len(results), results)
            return
        }
    }
}

func main() {
    demonstrateContextBasics()
    demonstrateAdvancedContextPatterns()
}

:::

🎯 核心知识点总结

Context基础要点

  1. 取消传播: Context可以将取消信号传播到整个调用链
  2. 超时控制: WithTimeout和WithDeadline提供超时控制
  3. 值传递: WithValue在请求范围内传递数据
  4. 组合使用: 不同类型的Context可以组合使用

Context类型要点

  1. Background: 根上下文,通常用于main函数和测试
  2. TODO: 占位符上下文,用于不确定使用哪种上下文的情况
  3. WithCancel: 提供手动取消功能
  4. WithTimeout/WithDeadline: 提供时间控制功能

最佳实践要点

  1. 参数传递: Context应该作为函数的第一个参数
  2. 不要存储: 不要将Context存储在结构体中
  3. 及时取消: 使用defer确保cancel函数被调用
  4. 合理使用Value: 只用于请求级别的数据,不要滥用

高级模式要点

  1. 中间件模式: 使用Context在中间件之间传递数据
  2. 优雅关闭: 使用Context协调多个组件的关闭
  3. 扇出扇入: 在并发模式中使用Context控制生命周期
  4. 请求追踪: 使用Context传递请求ID、用户ID等追踪信息

🔍 面试准备建议

  1. 理解设计原理: 深入理解Context的设计理念和使用场景
  2. 掌握使用模式: 熟练使用各种Context创建和组合方式
  3. 避免常见错误: 了解Context使用中的常见陷阱
  4. 实践应用: 在实际项目中正确使用Context
  5. 性能考虑: 理解Context对性能的影响和优化方法

正在精进