jwt_access.go (2561B)
1 package generates 2 3 import ( 4 "context" 5 "encoding/base64" 6 "strings" 7 "time" 8 9 "github.com/superseriousbusiness/oauth2/v4" 10 "github.com/superseriousbusiness/oauth2/v4/errors" 11 "github.com/golang-jwt/jwt" 12 "github.com/google/uuid" 13 ) 14 15 // JWTAccessClaims jwt claims 16 type JWTAccessClaims struct { 17 jwt.StandardClaims 18 } 19 20 // Valid claims verification 21 func (a *JWTAccessClaims) Valid() error { 22 if time.Unix(a.ExpiresAt, 0).Before(time.Now()) { 23 return errors.ErrInvalidAccessToken 24 } 25 return nil 26 } 27 28 // NewJWTAccessGenerate create to generate the jwt access token instance 29 func NewJWTAccessGenerate(kid string, key []byte, method jwt.SigningMethod) *JWTAccessGenerate { 30 return &JWTAccessGenerate{ 31 SignedKeyID: kid, 32 SignedKey: key, 33 SignedMethod: method, 34 } 35 } 36 37 // JWTAccessGenerate generate the jwt access token 38 type JWTAccessGenerate struct { 39 SignedKeyID string 40 SignedKey []byte 41 SignedMethod jwt.SigningMethod 42 } 43 44 // Token based on the UUID generated token 45 func (a *JWTAccessGenerate) Token(ctx context.Context, data *oauth2.GenerateBasic, isGenRefresh bool) (string, string, error) { 46 claims := &JWTAccessClaims{ 47 StandardClaims: jwt.StandardClaims{ 48 Audience: data.Client.GetID(), 49 Subject: data.UserID, 50 ExpiresAt: data.TokenInfo.GetAccessCreateAt().Add(data.TokenInfo.GetAccessExpiresIn()).Unix(), 51 }, 52 } 53 54 token := jwt.NewWithClaims(a.SignedMethod, claims) 55 if a.SignedKeyID != "" { 56 token.Header["kid"] = a.SignedKeyID 57 } 58 var key interface{} 59 if a.isEs() { 60 v, err := jwt.ParseECPrivateKeyFromPEM(a.SignedKey) 61 if err != nil { 62 return "", "", err 63 } 64 key = v 65 } else if a.isRsOrPS() { 66 v, err := jwt.ParseRSAPrivateKeyFromPEM(a.SignedKey) 67 if err != nil { 68 return "", "", err 69 } 70 key = v 71 } else if a.isHs() { 72 key = a.SignedKey 73 } else { 74 return "", "", errors.New("unsupported sign method") 75 } 76 77 access, err := token.SignedString(key) 78 if err != nil { 79 return "", "", err 80 } 81 refresh := "" 82 83 if isGenRefresh { 84 t := uuid.NewSHA1(uuid.Must(uuid.NewRandom()), []byte(access)).String() 85 refresh = base64.URLEncoding.EncodeToString([]byte(t)) 86 refresh = strings.ToUpper(strings.TrimRight(refresh, "=")) 87 } 88 89 return access, refresh, nil 90 } 91 92 func (a *JWTAccessGenerate) isEs() bool { 93 return strings.HasPrefix(a.SignedMethod.Alg(), "ES") 94 } 95 96 func (a *JWTAccessGenerate) isRsOrPS() bool { 97 isRs := strings.HasPrefix(a.SignedMethod.Alg(), "RS") 98 isPs := strings.HasPrefix(a.SignedMethod.Alg(), "PS") 99 return isRs || isPs 100 } 101 102 func (a *JWTAccessGenerate) isHs() bool { 103 return strings.HasPrefix(a.SignedMethod.Alg(), "HS") 104 }