gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

recovery.go (5207B)


      1 // Copyright 2014 Manu Martinez-Almeida. All rights reserved.
      2 // Use of this source code is governed by a MIT style
      3 // license that can be found in the LICENSE file.
      4 
      5 package gin
      6 
      7 import (
      8 	"bytes"
      9 	"errors"
     10 	"fmt"
     11 	"io"
     12 	"log"
     13 	"net"
     14 	"net/http"
     15 	"net/http/httputil"
     16 	"os"
     17 	"runtime"
     18 	"strings"
     19 	"time"
     20 )
     21 
     22 var (
     23 	dunno     = []byte("???")
     24 	centerDot = []byte("·")
     25 	dot       = []byte(".")
     26 	slash     = []byte("/")
     27 )
     28 
     29 // RecoveryFunc defines the function passable to CustomRecovery.
     30 type RecoveryFunc func(c *Context, err any)
     31 
     32 // Recovery returns a middleware that recovers from any panics and writes a 500 if there was one.
     33 func Recovery() HandlerFunc {
     34 	return RecoveryWithWriter(DefaultErrorWriter)
     35 }
     36 
     37 // CustomRecovery returns a middleware that recovers from any panics and calls the provided handle func to handle it.
     38 func CustomRecovery(handle RecoveryFunc) HandlerFunc {
     39 	return RecoveryWithWriter(DefaultErrorWriter, handle)
     40 }
     41 
     42 // RecoveryWithWriter returns a middleware for a given writer that recovers from any panics and writes a 500 if there was one.
     43 func RecoveryWithWriter(out io.Writer, recovery ...RecoveryFunc) HandlerFunc {
     44 	if len(recovery) > 0 {
     45 		return CustomRecoveryWithWriter(out, recovery[0])
     46 	}
     47 	return CustomRecoveryWithWriter(out, defaultHandleRecovery)
     48 }
     49 
     50 // CustomRecoveryWithWriter returns a middleware for a given writer that recovers from any panics and calls the provided handle func to handle it.
     51 func CustomRecoveryWithWriter(out io.Writer, handle RecoveryFunc) HandlerFunc {
     52 	var logger *log.Logger
     53 	if out != nil {
     54 		logger = log.New(out, "\n\n\x1b[31m", log.LstdFlags)
     55 	}
     56 	return func(c *Context) {
     57 		defer func() {
     58 			if err := recover(); err != nil {
     59 				// Check for a broken connection, as it is not really a
     60 				// condition that warrants a panic stack trace.
     61 				var brokenPipe bool
     62 				if ne, ok := err.(*net.OpError); ok {
     63 					var se *os.SyscallError
     64 					if errors.As(ne, &se) {
     65 						seStr := strings.ToLower(se.Error())
     66 						if strings.Contains(seStr, "broken pipe") ||
     67 							strings.Contains(seStr, "connection reset by peer") {
     68 							brokenPipe = true
     69 						}
     70 					}
     71 				}
     72 				if logger != nil {
     73 					stack := stack(3)
     74 					httpRequest, _ := httputil.DumpRequest(c.Request, false)
     75 					headers := strings.Split(string(httpRequest), "\r\n")
     76 					for idx, header := range headers {
     77 						current := strings.Split(header, ":")
     78 						if current[0] == "Authorization" {
     79 							headers[idx] = current[0] + ": *"
     80 						}
     81 					}
     82 					headersToStr := strings.Join(headers, "\r\n")
     83 					if brokenPipe {
     84 						logger.Printf("%s\n%s%s", err, headersToStr, reset)
     85 					} else if IsDebugging() {
     86 						logger.Printf("[Recovery] %s panic recovered:\n%s\n%s\n%s%s",
     87 							timeFormat(time.Now()), headersToStr, err, stack, reset)
     88 					} else {
     89 						logger.Printf("[Recovery] %s panic recovered:\n%s\n%s%s",
     90 							timeFormat(time.Now()), err, stack, reset)
     91 					}
     92 				}
     93 				if brokenPipe {
     94 					// If the connection is dead, we can't write a status to it.
     95 					c.Error(err.(error)) //nolint: errcheck
     96 					c.Abort()
     97 				} else {
     98 					handle(c, err)
     99 				}
    100 			}
    101 		}()
    102 		c.Next()
    103 	}
    104 }
    105 
    106 func defaultHandleRecovery(c *Context, _ any) {
    107 	c.AbortWithStatus(http.StatusInternalServerError)
    108 }
    109 
    110 // stack returns a nicely formatted stack frame, skipping skip frames.
    111 func stack(skip int) []byte {
    112 	buf := new(bytes.Buffer) // the returned data
    113 	// As we loop, we open files and read them. These variables record the currently
    114 	// loaded file.
    115 	var lines [][]byte
    116 	var lastFile string
    117 	for i := skip; ; i++ { // Skip the expected number of frames
    118 		pc, file, line, ok := runtime.Caller(i)
    119 		if !ok {
    120 			break
    121 		}
    122 		// Print this much at least.  If we can't find the source, it won't show.
    123 		fmt.Fprintf(buf, "%s:%d (0x%x)\n", file, line, pc)
    124 		if file != lastFile {
    125 			data, err := os.ReadFile(file)
    126 			if err != nil {
    127 				continue
    128 			}
    129 			lines = bytes.Split(data, []byte{'\n'})
    130 			lastFile = file
    131 		}
    132 		fmt.Fprintf(buf, "\t%s: %s\n", function(pc), source(lines, line))
    133 	}
    134 	return buf.Bytes()
    135 }
    136 
    137 // source returns a space-trimmed slice of the n'th line.
    138 func source(lines [][]byte, n int) []byte {
    139 	n-- // in stack trace, lines are 1-indexed but our array is 0-indexed
    140 	if n < 0 || n >= len(lines) {
    141 		return dunno
    142 	}
    143 	return bytes.TrimSpace(lines[n])
    144 }
    145 
    146 // function returns, if possible, the name of the function containing the PC.
    147 func function(pc uintptr) []byte {
    148 	fn := runtime.FuncForPC(pc)
    149 	if fn == nil {
    150 		return dunno
    151 	}
    152 	name := []byte(fn.Name())
    153 	// The name includes the path name to the package, which is unnecessary
    154 	// since the file name is already included.  Plus, it has center dots.
    155 	// That is, we see
    156 	//	runtime/debug.*T·ptrmethod
    157 	// and want
    158 	//	*T.ptrmethod
    159 	// Also the package path might contain dot (e.g. code.google.com/...),
    160 	// so first eliminate the path prefix
    161 	if lastSlash := bytes.LastIndex(name, slash); lastSlash >= 0 {
    162 		name = name[lastSlash+1:]
    163 	}
    164 	if period := bytes.Index(name, dot); period >= 0 {
    165 		name = name[period+1:]
    166 	}
    167 	name = bytes.ReplaceAll(name, centerDot, dot)
    168 	return name
    169 }
    170 
    171 // timeFormat returns a customized time string for logger.
    172 func timeFormat(t time.Time) string {
    173 	return t.Format("2006/01/02 - 15:04:05")
    174 }