recovery.go (5207B)
1 // Copyright 2014 Manu Martinez-Almeida. All rights reserved. 2 // Use of this source code is governed by a MIT style 3 // license that can be found in the LICENSE file. 4 5 package gin 6 7 import ( 8 "bytes" 9 "errors" 10 "fmt" 11 "io" 12 "log" 13 "net" 14 "net/http" 15 "net/http/httputil" 16 "os" 17 "runtime" 18 "strings" 19 "time" 20 ) 21 22 var ( 23 dunno = []byte("???") 24 centerDot = []byte("·") 25 dot = []byte(".") 26 slash = []byte("/") 27 ) 28 29 // RecoveryFunc defines the function passable to CustomRecovery. 30 type RecoveryFunc func(c *Context, err any) 31 32 // Recovery returns a middleware that recovers from any panics and writes a 500 if there was one. 33 func Recovery() HandlerFunc { 34 return RecoveryWithWriter(DefaultErrorWriter) 35 } 36 37 // CustomRecovery returns a middleware that recovers from any panics and calls the provided handle func to handle it. 38 func CustomRecovery(handle RecoveryFunc) HandlerFunc { 39 return RecoveryWithWriter(DefaultErrorWriter, handle) 40 } 41 42 // RecoveryWithWriter returns a middleware for a given writer that recovers from any panics and writes a 500 if there was one. 43 func RecoveryWithWriter(out io.Writer, recovery ...RecoveryFunc) HandlerFunc { 44 if len(recovery) > 0 { 45 return CustomRecoveryWithWriter(out, recovery[0]) 46 } 47 return CustomRecoveryWithWriter(out, defaultHandleRecovery) 48 } 49 50 // CustomRecoveryWithWriter returns a middleware for a given writer that recovers from any panics and calls the provided handle func to handle it. 51 func CustomRecoveryWithWriter(out io.Writer, handle RecoveryFunc) HandlerFunc { 52 var logger *log.Logger 53 if out != nil { 54 logger = log.New(out, "\n\n\x1b[31m", log.LstdFlags) 55 } 56 return func(c *Context) { 57 defer func() { 58 if err := recover(); err != nil { 59 // Check for a broken connection, as it is not really a 60 // condition that warrants a panic stack trace. 61 var brokenPipe bool 62 if ne, ok := err.(*net.OpError); ok { 63 var se *os.SyscallError 64 if errors.As(ne, &se) { 65 seStr := strings.ToLower(se.Error()) 66 if strings.Contains(seStr, "broken pipe") || 67 strings.Contains(seStr, "connection reset by peer") { 68 brokenPipe = true 69 } 70 } 71 } 72 if logger != nil { 73 stack := stack(3) 74 httpRequest, _ := httputil.DumpRequest(c.Request, false) 75 headers := strings.Split(string(httpRequest), "\r\n") 76 for idx, header := range headers { 77 current := strings.Split(header, ":") 78 if current[0] == "Authorization" { 79 headers[idx] = current[0] + ": *" 80 } 81 } 82 headersToStr := strings.Join(headers, "\r\n") 83 if brokenPipe { 84 logger.Printf("%s\n%s%s", err, headersToStr, reset) 85 } else if IsDebugging() { 86 logger.Printf("[Recovery] %s panic recovered:\n%s\n%s\n%s%s", 87 timeFormat(time.Now()), headersToStr, err, stack, reset) 88 } else { 89 logger.Printf("[Recovery] %s panic recovered:\n%s\n%s%s", 90 timeFormat(time.Now()), err, stack, reset) 91 } 92 } 93 if brokenPipe { 94 // If the connection is dead, we can't write a status to it. 95 c.Error(err.(error)) //nolint: errcheck 96 c.Abort() 97 } else { 98 handle(c, err) 99 } 100 } 101 }() 102 c.Next() 103 } 104 } 105 106 func defaultHandleRecovery(c *Context, _ any) { 107 c.AbortWithStatus(http.StatusInternalServerError) 108 } 109 110 // stack returns a nicely formatted stack frame, skipping skip frames. 111 func stack(skip int) []byte { 112 buf := new(bytes.Buffer) // the returned data 113 // As we loop, we open files and read them. These variables record the currently 114 // loaded file. 115 var lines [][]byte 116 var lastFile string 117 for i := skip; ; i++ { // Skip the expected number of frames 118 pc, file, line, ok := runtime.Caller(i) 119 if !ok { 120 break 121 } 122 // Print this much at least. If we can't find the source, it won't show. 123 fmt.Fprintf(buf, "%s:%d (0x%x)\n", file, line, pc) 124 if file != lastFile { 125 data, err := os.ReadFile(file) 126 if err != nil { 127 continue 128 } 129 lines = bytes.Split(data, []byte{'\n'}) 130 lastFile = file 131 } 132 fmt.Fprintf(buf, "\t%s: %s\n", function(pc), source(lines, line)) 133 } 134 return buf.Bytes() 135 } 136 137 // source returns a space-trimmed slice of the n'th line. 138 func source(lines [][]byte, n int) []byte { 139 n-- // in stack trace, lines are 1-indexed but our array is 0-indexed 140 if n < 0 || n >= len(lines) { 141 return dunno 142 } 143 return bytes.TrimSpace(lines[n]) 144 } 145 146 // function returns, if possible, the name of the function containing the PC. 147 func function(pc uintptr) []byte { 148 fn := runtime.FuncForPC(pc) 149 if fn == nil { 150 return dunno 151 } 152 name := []byte(fn.Name()) 153 // The name includes the path name to the package, which is unnecessary 154 // since the file name is already included. Plus, it has center dots. 155 // That is, we see 156 // runtime/debug.*T·ptrmethod 157 // and want 158 // *T.ptrmethod 159 // Also the package path might contain dot (e.g. code.google.com/...), 160 // so first eliminate the path prefix 161 if lastSlash := bytes.LastIndex(name, slash); lastSlash >= 0 { 162 name = name[lastSlash+1:] 163 } 164 if period := bytes.Index(name, dot); period >= 0 { 165 name = name[period+1:] 166 } 167 name = bytes.ReplaceAll(name, centerDot, dot) 168 return name 169 } 170 171 // timeFormat returns a customized time string for logger. 172 func timeFormat(t time.Time) string { 173 return t.Format("2006/01/02 - 15:04:05") 174 }