util.go (5097B)
1 // GoToSocial 2 // Copyright (C) GoToSocial Authors admin@gotosocial.org 3 // SPDX-License-Identifier: AGPL-3.0-or-later 4 // 5 // This program is free software: you can redistribute it and/or modify 6 // it under the terms of the GNU Affero General Public License as published by 7 // the Free Software Foundation, either version 3 of the License, or 8 // (at your option) any later version. 9 // 10 // This program is distributed in the hope that it will be useful, 11 // but WITHOUT ANY WARRANTY; without even the implied warranty of 12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 // GNU Affero General Public License for more details. 14 // 15 // You should have received a copy of the GNU Affero General Public License 16 // along with this program. If not, see <http://www.gnu.org/licenses/>. 17 18 package testrig 19 20 import ( 21 "bytes" 22 "context" 23 "fmt" 24 "io" 25 "mime/multipart" 26 "net/url" 27 "os" 28 "time" 29 30 "github.com/superseriousbusiness/gotosocial/internal/messages" 31 tlprocessor "github.com/superseriousbusiness/gotosocial/internal/processing/timeline" 32 "github.com/superseriousbusiness/gotosocial/internal/state" 33 "github.com/superseriousbusiness/gotosocial/internal/timeline" 34 "github.com/superseriousbusiness/gotosocial/internal/typeutils" 35 "github.com/superseriousbusiness/gotosocial/internal/visibility" 36 ) 37 38 func StartWorkers(state *state.State) { 39 state.Workers.EnqueueClientAPI = func(context.Context, ...messages.FromClientAPI) {} 40 state.Workers.EnqueueFederator = func(context.Context, ...messages.FromFederator) {} 41 42 _ = state.Workers.Scheduler.Start(nil) 43 _ = state.Workers.ClientAPI.Start(1, 10) 44 _ = state.Workers.Federator.Start(1, 10) 45 _ = state.Workers.Media.Start(1, 10) 46 } 47 48 func StopWorkers(state *state.State) { 49 _ = state.Workers.Scheduler.Stop() 50 _ = state.Workers.ClientAPI.Stop() 51 _ = state.Workers.Federator.Stop() 52 _ = state.Workers.Media.Stop() 53 } 54 55 func StartTimelines(state *state.State, filter *visibility.Filter, typeConverter typeutils.TypeConverter) { 56 state.Timelines.Home = timeline.NewManager( 57 tlprocessor.HomeTimelineGrab(state), 58 tlprocessor.HomeTimelineFilter(state, filter), 59 tlprocessor.HomeTimelineStatusPrepare(state, typeConverter), 60 tlprocessor.SkipInsert(), 61 ) 62 if err := state.Timelines.Home.Start(); err != nil { 63 panic(fmt.Sprintf("error starting home timeline: %s", err)) 64 } 65 66 state.Timelines.List = timeline.NewManager( 67 tlprocessor.ListTimelineGrab(state), 68 tlprocessor.ListTimelineFilter(state, filter), 69 tlprocessor.ListTimelineStatusPrepare(state, typeConverter), 70 tlprocessor.SkipInsert(), 71 ) 72 if err := state.Timelines.List.Start(); err != nil { 73 panic(fmt.Sprintf("error starting list timeline: %s", err)) 74 } 75 } 76 77 // CreateMultipartFormData is a handy function for taking a fieldname and a filename, and creating a multipart form bytes buffer 78 // with the file contents set in the given fieldname. The extraFields param can be used to add extra FormFields to the request, as necessary. 79 // The returned bytes.Buffer b can be used like so: 80 // 81 // httptest.NewRequest(http.MethodPost, "https://example.org/whateverpath", bytes.NewReader(b.Bytes())) 82 // 83 // The returned *multipart.Writer w can be used to set the content type of the request, like so: 84 // 85 // req.Header.Set("Content-Type", w.FormDataContentType()) 86 func CreateMultipartFormData(fieldName string, fileName string, extraFields map[string]string) (bytes.Buffer, *multipart.Writer, error) { 87 var b bytes.Buffer 88 89 w := multipart.NewWriter(&b) 90 var fw io.Writer 91 92 if fileName != "" { 93 file, err := os.Open(fileName) 94 if err != nil { 95 return b, nil, err 96 } 97 if fw, err = w.CreateFormFile(fieldName, file.Name()); err != nil { 98 return b, nil, err 99 } 100 if _, err = io.Copy(fw, file); err != nil { 101 return b, nil, err 102 } 103 } 104 105 for k, v := range extraFields { 106 f, err := w.CreateFormField(k) 107 if err != nil { 108 return b, nil, err 109 } 110 if _, err := io.Copy(f, bytes.NewBufferString(v)); err != nil { 111 return b, nil, err 112 } 113 } 114 115 if err := w.Close(); err != nil { 116 return b, nil, err 117 } 118 return b, w, nil 119 } 120 121 // URLMustParse tries to parse the given URL and panics if it can't. 122 // Should only be used in tests. 123 func URLMustParse(stringURL string) *url.URL { 124 u, err := url.Parse(stringURL) 125 if err != nil { 126 panic(err) 127 } 128 return u 129 } 130 131 // TimeMustParse tries to parse the given time as RFC3339, and panics if it can't. 132 // Should only be used in tests. 133 func TimeMustParse(timeString string) time.Time { 134 t, err := time.Parse(time.RFC3339, timeString) 135 if err != nil { 136 panic(err) 137 } 138 return t 139 } 140 141 // WaitFor calls condition every 200ms, returning true 142 // when condition() returns true, or false after 5s. 143 // 144 // It's useful for when you're waiting for something to 145 // happen, but you don't know exactly how long it will take, 146 // and you want to fail if the thing doesn't happen within 5s. 147 func WaitFor(condition func() bool) bool { 148 tick := time.NewTicker(200 * time.Millisecond) 149 defer tick.Stop() 150 151 timeout := time.NewTimer(5 * time.Second) 152 defer timeout.Stop() 153 154 for { 155 select { 156 case <-tick.C: 157 if condition() { 158 return true 159 } 160 case <-timeout.C: 161 return false 162 } 163 } 164 }