gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

signing.go (8883B)


      1 package httpsig
      2 
      3 import (
      4 	"bytes"
      5 	"crypto"
      6 	"crypto/rand"
      7 	"encoding/base64"
      8 	"fmt"
      9 	"net/http"
     10 	"net/textproto"
     11 	"strconv"
     12 	"strings"
     13 )
     14 
     15 const (
     16 	// Signature Parameters
     17 	keyIdParameter            = "keyId"
     18 	algorithmParameter        = "algorithm"
     19 	headersParameter          = "headers"
     20 	signatureParameter        = "signature"
     21 	prefixSeparater           = " "
     22 	parameterKVSeparater      = "="
     23 	parameterValueDelimiter   = "\""
     24 	parameterSeparater        = ","
     25 	headerParameterValueDelim = " "
     26 	// RequestTarget specifies to include the http request method and
     27 	// entire URI in the signature. Pass it as a header to NewSigner.
     28 	RequestTarget = "(request-target)"
     29 	createdKey    = "created"
     30 	expiresKey    = "expires"
     31 	dateHeader    = "date"
     32 
     33 	// Signature String Construction
     34 	headerFieldDelimiter   = ": "
     35 	headersDelimiter       = "\n"
     36 	headerValueDelimiter   = ", "
     37 	requestTargetSeparator = " "
     38 )
     39 
     40 var defaultHeaders = []string{dateHeader}
     41 
     42 var _ Signer = &macSigner{}
     43 
     44 type macSigner struct {
     45 	m            macer
     46 	makeDigest   bool
     47 	dAlgo        DigestAlgorithm
     48 	headers      []string
     49 	targetHeader SignatureScheme
     50 	prefix       string
     51 	created      int64
     52 	expires      int64
     53 }
     54 
     55 func (m *macSigner) SignRequest(pKey crypto.PrivateKey, pubKeyId string, r *http.Request, body []byte) error {
     56 	if body != nil {
     57 		err := addDigest(r, m.dAlgo, body)
     58 		if err != nil {
     59 			return err
     60 		}
     61 	}
     62 	s, err := m.signatureString(r)
     63 	if err != nil {
     64 		return err
     65 	}
     66 	enc, err := m.signSignature(pKey, s)
     67 	if err != nil {
     68 		return err
     69 	}
     70 	setSignatureHeader(r.Header, string(m.targetHeader), m.prefix, pubKeyId, m.m.String(), enc, m.headers, m.created, m.expires)
     71 	return nil
     72 }
     73 
     74 func (m *macSigner) SignResponse(pKey crypto.PrivateKey, pubKeyId string, r http.ResponseWriter, body []byte) error {
     75 	if body != nil {
     76 		err := addDigestResponse(r, m.dAlgo, body)
     77 		if err != nil {
     78 			return err
     79 		}
     80 	}
     81 	s, err := m.signatureStringResponse(r)
     82 	if err != nil {
     83 		return err
     84 	}
     85 	enc, err := m.signSignature(pKey, s)
     86 	if err != nil {
     87 		return err
     88 	}
     89 	setSignatureHeader(r.Header(), string(m.targetHeader), m.prefix, pubKeyId, m.m.String(), enc, m.headers, m.created, m.expires)
     90 	return nil
     91 }
     92 
     93 func (m *macSigner) signSignature(pKey crypto.PrivateKey, s string) (string, error) {
     94 	pKeyBytes, ok := pKey.([]byte)
     95 	if !ok {
     96 		return "", fmt.Errorf("private key for MAC signing must be of type []byte")
     97 	}
     98 	sig, err := m.m.Sign([]byte(s), pKeyBytes)
     99 	if err != nil {
    100 		return "", err
    101 	}
    102 	enc := base64.StdEncoding.EncodeToString(sig)
    103 	return enc, nil
    104 }
    105 
    106 func (m *macSigner) signatureString(r *http.Request) (string, error) {
    107 	return signatureString(r.Header, m.headers, addRequestTarget(r), m.created, m.expires)
    108 }
    109 
    110 func (m *macSigner) signatureStringResponse(r http.ResponseWriter) (string, error) {
    111 	return signatureString(r.Header(), m.headers, requestTargetNotPermitted, m.created, m.expires)
    112 }
    113 
    114 var _ Signer = &asymmSigner{}
    115 
    116 type asymmSigner struct {
    117 	s            signer
    118 	makeDigest   bool
    119 	dAlgo        DigestAlgorithm
    120 	headers      []string
    121 	targetHeader SignatureScheme
    122 	prefix       string
    123 	created      int64
    124 	expires      int64
    125 }
    126 
    127 func (a *asymmSigner) SignRequest(pKey crypto.PrivateKey, pubKeyId string, r *http.Request, body []byte) error {
    128 	if body != nil {
    129 		err := addDigest(r, a.dAlgo, body)
    130 		if err != nil {
    131 			return err
    132 		}
    133 	}
    134 	s, err := a.signatureString(r)
    135 	if err != nil {
    136 		return err
    137 	}
    138 	enc, err := a.signSignature(pKey, s)
    139 	if err != nil {
    140 		return err
    141 	}
    142 	setSignatureHeader(r.Header, string(a.targetHeader), a.prefix, pubKeyId, a.s.String(), enc, a.headers, a.created, a.expires)
    143 	return nil
    144 }
    145 
    146 func (a *asymmSigner) SignResponse(pKey crypto.PrivateKey, pubKeyId string, r http.ResponseWriter, body []byte) error {
    147 	if body != nil {
    148 		err := addDigestResponse(r, a.dAlgo, body)
    149 		if err != nil {
    150 			return err
    151 		}
    152 	}
    153 	s, err := a.signatureStringResponse(r)
    154 	if err != nil {
    155 		return err
    156 	}
    157 	enc, err := a.signSignature(pKey, s)
    158 	if err != nil {
    159 		return err
    160 	}
    161 	setSignatureHeader(r.Header(), string(a.targetHeader), a.prefix, pubKeyId, a.s.String(), enc, a.headers, a.created, a.expires)
    162 	return nil
    163 }
    164 
    165 func (a *asymmSigner) signSignature(pKey crypto.PrivateKey, s string) (string, error) {
    166 	sig, err := a.s.Sign(rand.Reader, pKey, []byte(s))
    167 	if err != nil {
    168 		return "", err
    169 	}
    170 	enc := base64.StdEncoding.EncodeToString(sig)
    171 	return enc, nil
    172 }
    173 
    174 func (a *asymmSigner) signatureString(r *http.Request) (string, error) {
    175 	return signatureString(r.Header, a.headers, addRequestTarget(r), a.created, a.expires)
    176 }
    177 
    178 func (a *asymmSigner) signatureStringResponse(r http.ResponseWriter) (string, error) {
    179 	return signatureString(r.Header(), a.headers, requestTargetNotPermitted, a.created, a.expires)
    180 }
    181 
    182 var _ SSHSigner = &asymmSSHSigner{}
    183 
    184 type asymmSSHSigner struct {
    185 	*asymmSigner
    186 }
    187 
    188 func (a *asymmSSHSigner) SignRequest(pubKeyId string, r *http.Request, body []byte) error {
    189 	return a.asymmSigner.SignRequest(nil, pubKeyId, r, body)
    190 }
    191 
    192 func (a *asymmSSHSigner) SignResponse(pubKeyId string, r http.ResponseWriter, body []byte) error {
    193 	return a.asymmSigner.SignResponse(nil, pubKeyId, r, body)
    194 }
    195 
    196 func setSignatureHeader(h http.Header, targetHeader, prefix, pubKeyId, algo, enc string, headers []string, created int64, expires int64) {
    197 	if len(headers) == 0 {
    198 		headers = defaultHeaders
    199 	}
    200 	var b bytes.Buffer
    201 	// KeyId
    202 	b.WriteString(prefix)
    203 	if len(prefix) > 0 {
    204 		b.WriteString(prefixSeparater)
    205 	}
    206 	b.WriteString(keyIdParameter)
    207 	b.WriteString(parameterKVSeparater)
    208 	b.WriteString(parameterValueDelimiter)
    209 	b.WriteString(pubKeyId)
    210 	b.WriteString(parameterValueDelimiter)
    211 	b.WriteString(parameterSeparater)
    212 	// Algorithm
    213 	b.WriteString(algorithmParameter)
    214 	b.WriteString(parameterKVSeparater)
    215 	b.WriteString(parameterValueDelimiter)
    216 	b.WriteString("hs2019") //real algorithm is hidden, see newest version of spec draft
    217 	b.WriteString(parameterValueDelimiter)
    218 	b.WriteString(parameterSeparater)
    219 
    220 	hasCreated := false
    221 	hasExpires := false
    222 	for _, h := range headers {
    223 		val := strings.ToLower(h)
    224 		if val == "("+createdKey+")" {
    225 			hasCreated = true
    226 		} else if val == "("+expiresKey+")" {
    227 			hasExpires = true
    228 		}
    229 	}
    230 
    231 	// Created
    232 	if hasCreated == true {
    233 		b.WriteString(createdKey)
    234 		b.WriteString(parameterKVSeparater)
    235 		b.WriteString(strconv.FormatInt(created, 10))
    236 		b.WriteString(parameterSeparater)
    237 	}
    238 
    239 	// Expires
    240 	if hasExpires == true {
    241 		b.WriteString(expiresKey)
    242 		b.WriteString(parameterKVSeparater)
    243 		b.WriteString(strconv.FormatInt(expires, 10))
    244 		b.WriteString(parameterSeparater)
    245 	}
    246 
    247 	// Headers
    248 	b.WriteString(headersParameter)
    249 	b.WriteString(parameterKVSeparater)
    250 	b.WriteString(parameterValueDelimiter)
    251 	for i, h := range headers {
    252 		b.WriteString(strings.ToLower(h))
    253 		if i != len(headers)-1 {
    254 			b.WriteString(headerParameterValueDelim)
    255 		}
    256 	}
    257 	b.WriteString(parameterValueDelimiter)
    258 	b.WriteString(parameterSeparater)
    259 	// Signature
    260 	b.WriteString(signatureParameter)
    261 	b.WriteString(parameterKVSeparater)
    262 	b.WriteString(parameterValueDelimiter)
    263 	b.WriteString(enc)
    264 	b.WriteString(parameterValueDelimiter)
    265 	h.Add(targetHeader, b.String())
    266 }
    267 
    268 func requestTargetNotPermitted(b *bytes.Buffer) error {
    269 	return fmt.Errorf("cannot sign with %q on anything other than an http request", RequestTarget)
    270 }
    271 
    272 func addRequestTarget(r *http.Request) func(b *bytes.Buffer) error {
    273 	return func(b *bytes.Buffer) error {
    274 		b.WriteString(RequestTarget)
    275 		b.WriteString(headerFieldDelimiter)
    276 		b.WriteString(strings.ToLower(r.Method))
    277 		b.WriteString(requestTargetSeparator)
    278 		b.WriteString(r.URL.Path)
    279 
    280 		if r.URL.RawQuery != "" {
    281 			b.WriteString("?")
    282 			b.WriteString(r.URL.RawQuery)
    283 		}
    284 
    285 		return nil
    286 	}
    287 }
    288 
    289 func signatureString(values http.Header, include []string, requestTargetFn func(b *bytes.Buffer) error, created int64, expires int64) (string, error) {
    290 	if len(include) == 0 {
    291 		include = defaultHeaders
    292 	}
    293 	var b bytes.Buffer
    294 	for n, i := range include {
    295 		i := strings.ToLower(i)
    296 		if i == RequestTarget {
    297 			err := requestTargetFn(&b)
    298 			if err != nil {
    299 				return "", err
    300 			}
    301 		} else if i == "("+expiresKey+")" {
    302 			if expires == 0 {
    303 				return "", fmt.Errorf("missing expires value")
    304 			}
    305 			b.WriteString(i)
    306 			b.WriteString(headerFieldDelimiter)
    307 			b.WriteString(strconv.FormatInt(expires, 10))
    308 		} else if i == "("+createdKey+")" {
    309 			if created == 0 {
    310 				return "", fmt.Errorf("missing created value")
    311 			}
    312 			b.WriteString(i)
    313 			b.WriteString(headerFieldDelimiter)
    314 			b.WriteString(strconv.FormatInt(created, 10))
    315 		} else {
    316 			hv, ok := values[textproto.CanonicalMIMEHeaderKey(i)]
    317 			if !ok {
    318 				return "", fmt.Errorf("missing header %q", i)
    319 			}
    320 			b.WriteString(i)
    321 			b.WriteString(headerFieldDelimiter)
    322 			for i, v := range hv {
    323 				b.WriteString(strings.TrimSpace(v))
    324 				if i < len(hv)-1 {
    325 					b.WriteString(headerValueDelimiter)
    326 				}
    327 			}
    328 		}
    329 		if n < len(include)-1 {
    330 			b.WriteString(headersDelimiter)
    331 		}
    332 	}
    333 	return b.String(), nil
    334 }