diff options
| author | unwox <me@unwox.com> | 2024-06-13 14:07:23 +0600 |
|---|---|---|
| committer | unwox <me@unwox.com> | 2024-06-13 14:07:23 +0600 |
| commit | cae471d4bced5f7490cc18e86b50e51df64ddb7b (patch) | |
| tree | 5d289199438060fdeefa2228b5a8ced05c9e34f8 /worker.go | |
| parent | 4374d8c5d98fbc9e2589074cd0f05266db6a10e1 (diff) | |
make evaluation of custom lua code thread-safe
Diffstat (limited to 'worker.go')
| -rw-r--r-- | worker.go | 164 |
1 files changed, 100 insertions, 64 deletions
@@ -7,27 +7,33 @@ import ( "net/http" ) -type WorkerRequest struct { - Request *http.Request - Route string - result chan *WorkerResponse +type HTTPRequest struct { + request *http.Request + route string + debug bool + result chan *HTTPResponse } -type WorkerResponse struct { +type EvalRequest struct { + code string + result chan error +} + +type HTTPResponse struct { Code int Headers map[string]string Body string } type Worker struct { - read chan *WorkerRequest + read chan interface{} lua *Lua api LuaRef routes map[string]LuaRef started bool } -func NewWorker(read chan *WorkerRequest) *Worker { +func NewWorker(read chan interface{}) *Worker { return &Worker { read: read, routes: make(map[string]LuaRef), @@ -35,7 +41,7 @@ func NewWorker(read chan *WorkerRequest) *Worker { } } -func (w *Worker) Start (filename string) error { +func (w *Worker) Start(filename string) error { if w.started { return errors.New("already started") } @@ -54,74 +60,107 @@ func (w *Worker) Start (filename string) error { return nil } -func (w *Worker) Listen () { +func (w *Worker) Listen() { for { - r := <- w.read resStack := w.lua.RestoreStackFunc() - handlerRef := w.routes[r.Route] - w.lua.PushFromRef(handlerRef) - w.lua.PushString(r.Request.Method) - w.lua.PushString(r.Request.URL.Path) - - fh := make(map[string]string) - for k := range r.Request.Header { - fh[k] = r.Request.Header.Get(k) - } - w.lua.PushTable(fh) - - body, err := io.ReadAll(r.Request.Body) - if err != nil { - resStack() - r.result <- &WorkerResponse { - Code: 500, - Headers: make(map[string]string), - Body: "server error", + r := <- w.read + + switch r.(type) { + case *HTTPRequest: + r := r.(*HTTPRequest) + // If in debug mode always use handlers from api table + // instead of cached references. Makes it much easier + // to hot-replace route handlers. + if r.debug { + w.lua.PushFromRef(w.api) + w.lua.PushTableItem("routes") + w.lua.PushTableItem(r.route) + } else { + w.lua.PushFromRef(w.routes[r.route]) } - log.Println("could not read a request body") - continue - } - w.lua.PushString(string(body)) - - w.lua.PCall(4, 3) - code := w.lua.ToInt(-3) - rbody := w.lua.ToString(-1) - - // parse headers - headers := make(map[string]string) - w.lua.Pop(1) - w.lua.PushNil() - for w.lua.Next() { - if !w.lua.IsString(-2) || !w.lua.IsString(-2) { - w.lua.Pop(1) + w.lua.PushString(r.request.Method) + w.lua.PushString(r.request.URL.Path) + + fh := make(map[string]string) + for k := range r.request.Header { + fh[k] = r.request.Header.Get(k) + } + w.lua.PushTable(fh) + + body, err := io.ReadAll(r.request.Body) + if err != nil { + r.result <- &HTTPResponse { + Code: 500, + Headers: make(map[string]string), + Body: "server error", + } + log.Println("could not read a request body") + resStack() continue } - v := w.lua.ToString(-1) + w.lua.PushString(string(body)) + + w.lua.PCall(4, 3) + code := w.lua.ToInt(-3) + rbody := w.lua.ToString(-1) + + // Parse headers. + headers := make(map[string]string) w.lua.Pop(1) - // we must not pop the item key from the stack because - // otherwise C.lua_next won't work properly - k := w.lua.ToString(-1) - headers[k] = v - } - resStack() + w.lua.PushNil() + for w.lua.Next() { + if !w.lua.IsString(-2) || !w.lua.IsString(-2) { + w.lua.Pop(1) + continue + } + v := w.lua.ToString(-1) + w.lua.Pop(1) + // We must not pop the item key from the stack + // because otherwise C.lua_next won't work + // properly. + k := w.lua.ToString(-1) + headers[k] = v + } + r.result <- &HTTPResponse { + Code: int(code), + Headers: headers, + Body: rbody, + } + + case *EvalRequest: + r := r.(*EvalRequest) + err := w.lua.LoadString(r.code) + r.result <- err - r.result <- &WorkerResponse { - Code: int(code), - Headers: headers, - Body: rbody, + default: + log.Fatal("unknown request") } + + resStack() } } -func (w *Worker) Request (route string, r *http.Request) chan *WorkerResponse { - res := make(chan *WorkerResponse) - w.read <- &WorkerRequest{ - Request: r, - Route: route, +func (w *Worker) Request( + route string, + r *http.Request, + debug bool, +) chan *HTTPResponse { + res := make(chan *HTTPResponse) + w.read <- &HTTPRequest{ + request: r, + route: route, + debug: debug, result: res, } return res } +func (w *Worker) Eval(code string) chan error { + res := make(chan error) + w.read <- &EvalRequest{code: code, result: res} + return res +} + func (w *Worker) ListRoutes() []string { res := []string{} for route, _ := range w.routes { @@ -130,15 +169,12 @@ func (w *Worker) ListRoutes() []string { return res } -func (w *Worker) Eval(code string) error { - return w.lua.LoadString(code) -} - func (w *Worker) Stop() { w.lua.Close() } func (w *Worker) initRoutes() error { + defer w.lua.RestoreStackFunc()() w.lua.PushFromRef(w.api) w.lua.PushTableItem("routes") if !w.lua.IsTable(-1) { |
