summaryrefslogtreecommitdiff
path: root/worker.go
diff options
context:
space:
mode:
authorunwox <me@unwox.com>2024-06-13 14:07:23 +0600
committerunwox <me@unwox.com>2024-06-13 14:07:23 +0600
commitcae471d4bced5f7490cc18e86b50e51df64ddb7b (patch)
tree5d289199438060fdeefa2228b5a8ced05c9e34f8 /worker.go
parent4374d8c5d98fbc9e2589074cd0f05266db6a10e1 (diff)
make evaluation of custom lua code thread-safe
Diffstat (limited to 'worker.go')
-rw-r--r--worker.go164
1 files changed, 100 insertions, 64 deletions
diff --git a/worker.go b/worker.go
index 01b55e3..2b563ef 100644
--- a/worker.go
+++ b/worker.go
@@ -7,27 +7,33 @@ import (
"net/http"
)
-type WorkerRequest struct {
- Request *http.Request
- Route string
- result chan *WorkerResponse
+type HTTPRequest struct {
+ request *http.Request
+ route string
+ debug bool
+ result chan *HTTPResponse
}
-type WorkerResponse struct {
+type EvalRequest struct {
+ code string
+ result chan error
+}
+
+type HTTPResponse struct {
Code int
Headers map[string]string
Body string
}
type Worker struct {
- read chan *WorkerRequest
+ read chan interface{}
lua *Lua
api LuaRef
routes map[string]LuaRef
started bool
}
-func NewWorker(read chan *WorkerRequest) *Worker {
+func NewWorker(read chan interface{}) *Worker {
return &Worker {
read: read,
routes: make(map[string]LuaRef),
@@ -35,7 +41,7 @@ func NewWorker(read chan *WorkerRequest) *Worker {
}
}
-func (w *Worker) Start (filename string) error {
+func (w *Worker) Start(filename string) error {
if w.started {
return errors.New("already started")
}
@@ -54,74 +60,107 @@ func (w *Worker) Start (filename string) error {
return nil
}
-func (w *Worker) Listen () {
+func (w *Worker) Listen() {
for {
- r := <- w.read
resStack := w.lua.RestoreStackFunc()
- handlerRef := w.routes[r.Route]
- w.lua.PushFromRef(handlerRef)
- w.lua.PushString(r.Request.Method)
- w.lua.PushString(r.Request.URL.Path)
-
- fh := make(map[string]string)
- for k := range r.Request.Header {
- fh[k] = r.Request.Header.Get(k)
- }
- w.lua.PushTable(fh)
-
- body, err := io.ReadAll(r.Request.Body)
- if err != nil {
- resStack()
- r.result <- &WorkerResponse {
- Code: 500,
- Headers: make(map[string]string),
- Body: "server error",
+ r := <- w.read
+
+ switch r.(type) {
+ case *HTTPRequest:
+ r := r.(*HTTPRequest)
+ // If in debug mode always use handlers from api table
+ // instead of cached references. Makes it much easier
+ // to hot-replace route handlers.
+ if r.debug {
+ w.lua.PushFromRef(w.api)
+ w.lua.PushTableItem("routes")
+ w.lua.PushTableItem(r.route)
+ } else {
+ w.lua.PushFromRef(w.routes[r.route])
}
- log.Println("could not read a request body")
- continue
- }
- w.lua.PushString(string(body))
-
- w.lua.PCall(4, 3)
- code := w.lua.ToInt(-3)
- rbody := w.lua.ToString(-1)
-
- // parse headers
- headers := make(map[string]string)
- w.lua.Pop(1)
- w.lua.PushNil()
- for w.lua.Next() {
- if !w.lua.IsString(-2) || !w.lua.IsString(-2) {
- w.lua.Pop(1)
+ w.lua.PushString(r.request.Method)
+ w.lua.PushString(r.request.URL.Path)
+
+ fh := make(map[string]string)
+ for k := range r.request.Header {
+ fh[k] = r.request.Header.Get(k)
+ }
+ w.lua.PushTable(fh)
+
+ body, err := io.ReadAll(r.request.Body)
+ if err != nil {
+ r.result <- &HTTPResponse {
+ Code: 500,
+ Headers: make(map[string]string),
+ Body: "server error",
+ }
+ log.Println("could not read a request body")
+ resStack()
continue
}
- v := w.lua.ToString(-1)
+ w.lua.PushString(string(body))
+
+ w.lua.PCall(4, 3)
+ code := w.lua.ToInt(-3)
+ rbody := w.lua.ToString(-1)
+
+ // Parse headers.
+ headers := make(map[string]string)
w.lua.Pop(1)
- // we must not pop the item key from the stack because
- // otherwise C.lua_next won't work properly
- k := w.lua.ToString(-1)
- headers[k] = v
- }
- resStack()
+ w.lua.PushNil()
+ for w.lua.Next() {
+ if !w.lua.IsString(-2) || !w.lua.IsString(-2) {
+ w.lua.Pop(1)
+ continue
+ }
+ v := w.lua.ToString(-1)
+ w.lua.Pop(1)
+ // We must not pop the item key from the stack
+ // because otherwise C.lua_next won't work
+ // properly.
+ k := w.lua.ToString(-1)
+ headers[k] = v
+ }
+ r.result <- &HTTPResponse {
+ Code: int(code),
+ Headers: headers,
+ Body: rbody,
+ }
+
+ case *EvalRequest:
+ r := r.(*EvalRequest)
+ err := w.lua.LoadString(r.code)
+ r.result <- err
- r.result <- &WorkerResponse {
- Code: int(code),
- Headers: headers,
- Body: rbody,
+ default:
+ log.Fatal("unknown request")
}
+
+ resStack()
}
}
-func (w *Worker) Request (route string, r *http.Request) chan *WorkerResponse {
- res := make(chan *WorkerResponse)
- w.read <- &WorkerRequest{
- Request: r,
- Route: route,
+func (w *Worker) Request(
+ route string,
+ r *http.Request,
+ debug bool,
+) chan *HTTPResponse {
+ res := make(chan *HTTPResponse)
+ w.read <- &HTTPRequest{
+ request: r,
+ route: route,
+ debug: debug,
result: res,
}
return res
}
+func (w *Worker) Eval(code string) chan error {
+ res := make(chan error)
+ w.read <- &EvalRequest{code: code, result: res}
+ return res
+}
+
func (w *Worker) ListRoutes() []string {
res := []string{}
for route, _ := range w.routes {
@@ -130,15 +169,12 @@ func (w *Worker) ListRoutes() []string {
return res
}
-func (w *Worker) Eval(code string) error {
- return w.lua.LoadString(code)
-}
-
func (w *Worker) Stop() {
w.lua.Close()
}
func (w *Worker) initRoutes() error {
+ defer w.lua.RestoreStackFunc()()
w.lua.PushFromRef(w.api)
w.lua.PushTableItem("routes")
if !w.lua.IsTable(-1) {