From 7f62c1fd66ffeb7e46127e5972d814abb9299848 Mon Sep 17 00:00:00 2001 From: unwox Date: Fri, 25 Oct 2024 17:33:56 +0600 Subject: implement non-blocking IO operations --- lua.go | 28 ++++++-- luajit.go | 25 +++++-- main.go | 90 ++++++++++++----------- worker.go | 241 ++++++++++++++++++++++++++++++++++++-------------------------- 4 files changed, 233 insertions(+), 151 deletions(-) diff --git a/lua.go b/lua.go index 50c6fc5..43cd3f3 100644 --- a/lua.go +++ b/lua.go @@ -32,10 +32,13 @@ import ( "unsafe" ) -type LuaRef C.longlong +type LuaRef C.int // Lua is a wrapper around C Lua state with several conveniences. type Lua struct { l *C.lua_State + running **Lua + yield func() + resume func() bool } //export luna_run_go_func @@ -48,6 +51,7 @@ func luna_run_go_func(f C.uintptr_t) C.int { // Start opens the Lua context with all built-in libraries. func (l *Lua) Start() { l.l = C.luaL_newstate() + l.running = &l C.luaL_openlibs(l.l) } @@ -72,6 +76,7 @@ func (l *Lua) Require(file string) error { // PCall calls a function with the stack index (-nargs-1) expecting nresults // number of results. func (l *Lua) PCall(nargs int, nresults int) error { + *l.running = l res := C.lua_pcallk(l.l, C.int(nargs), C.int(nresults), 0, 0, nil) if res != C.LUA_OK { errMsg := l.ToString(-1) @@ -132,6 +137,8 @@ func (l *Lua) PushNil() { // PushBoolean pushes a boolean onto the Lua stack. func (l *Lua) PushBoolean(b bool) { + // For some reason this is needed here to not mess up threads. + *l.running = l if b { C.lua_pushboolean(l.l, 1) } else { @@ -161,7 +168,7 @@ func (l *Lua) PushString(str string) { // method later. func (l *Lua) PushGoFunction(f func (l *Lua) int) cgo.Handle { h := cgo.NewHandle(func () int { - return f(l) + return f(*l.running) }) C.luna_push_function(l.l, C.uintptr_t(h)) return h @@ -200,6 +207,9 @@ func (l *Lua) PushAny(v any) error { case float64: v, _ := v.(float64) l.PushFloatNumber(v) + case bool: + v, _ := v.(bool) + l.PushBoolean(v) case map[string]any: v, _ := v.(map[string]any) err := l.PushObject(v) @@ -453,7 +463,15 @@ func (l *Lua) GetGlobal(name string) { C.lua_getglobal(l.l, cstr) } -func (l *Lua) Error(msg string) { - l.PushString(msg) - C.lua_error(l.l) +func (l *Lua) NewThread(yield func(), resume func() bool) *Lua { + return &Lua{ + l: C.lua_newthread(l.l), + running: l.running, + resume: resume, + yield: yield, + } +} + +func (l *Lua) Unref(ref LuaRef) { + C.luaL_unref(l.l, C.LUA_REGISTRYINDEX, C.int(ref)) } diff --git a/luajit.go b/luajit.go index 1a8b47d..9414787 100644 --- a/luajit.go +++ b/luajit.go @@ -32,10 +32,13 @@ import ( "unsafe" ) -type LuaRef C.longlong +type LuaRef C.int // Lua is a wrapper around C Lua state with several conveniences. type Lua struct { l *C.lua_State + running **Lua + yield func() + resume func() bool } //export luna_run_go_func @@ -48,6 +51,7 @@ func luna_run_go_func(f C.uintptr_t) C.int { // Start opens the Lua context with all built-in libraries. func (l *Lua) Start() { l.l = C.luaL_newstate() + l.running = &l C.luaL_openlibs(l.l) } @@ -73,6 +77,7 @@ func (l *Lua) Require(file string) error { // PCall calls a function with the stack index (-nargs-1) expecting nresults // number of results. func (l *Lua) PCall(nargs int, nresults int) error { + *l.running = l res := C.lua_pcall(l.l, C.int(nargs), C.int(nresults), 0) if res != C.LUA_OK { errMsg := l.ToString(-1) @@ -128,6 +133,8 @@ func (l *Lua) PushNil() { // PushBoolean pushes a boolean onto the Lua stack. func (l *Lua) PushBoolean(b bool) { + // For some reason this is needed here to not mess up threads. + *l.running = l if b { C.lua_pushboolean(l.l, 1) } else { @@ -157,7 +164,7 @@ func (l *Lua) PushString(str string) { // method later. func (l *Lua) PushGoFunction(f func (l *Lua) int) cgo.Handle { h := cgo.NewHandle(func () int { - return f(l) + return f(*l.running) }) C.luna_push_function(l.l, C.uintptr_t(h)) return h @@ -448,7 +455,15 @@ func (l *Lua) GetGlobal(name string) { C.lua_getfield(l.l, C.LUA_GLOBALSINDEX, cstr) } -func (l *Lua) Error(msg string) { - l.PushString(msg) - C.lua_error(l.l) +func (l *Lua) NewThread(yield func(), resume func() bool) *Lua { + return &Lua{ + l: C.lua_newthread(l.l), + running: l.running, + resume: resume, + yield: yield, + } +} + +func (l *Lua) Unref(ref LuaRef) { + C.luaL_unref(l.l, C.LUA_REGISTRYINDEX, C.int(ref)) } diff --git a/main.go b/main.go index 93e7dad..3a031e2 100644 --- a/main.go +++ b/main.go @@ -34,14 +34,14 @@ func main() { httpClient := &http.Client{} // queue for http messages for workers to handle - msgs := make(chan any, 4096) + msgs := make(chan *HTTPRequest, 4096) mux := http.NewServeMux() wrks := []*Worker{} // track routes for mux to avoid registering the same route twice routes := make(map[string]bool) // track open dbs to close them on exit dbs := []*sql.DB{} - dbmu := sync.Mutex{} + mu := sync.Mutex{} defer func() { for _, db := range dbs { if db == nil { @@ -77,6 +77,7 @@ func main() { if ok { return luaOk(l, nil) } + mu.Lock() routes[route] = true mux.HandleFunc( route, @@ -98,6 +99,7 @@ func main() { ) }, ) + mu.Unlock() return luaOk(l, nil) } routeModule["static"] = func (l *Lua) int { @@ -170,9 +172,9 @@ func main() { if err != nil { return luaErr(l, err) } - dbmu.Lock() + mu.Lock() dbs = append(dbs, db) - dbmu.Unlock() + mu.Unlock() h := cgo.NewHandle(db) return luaOk(l, int(h)) } @@ -256,33 +258,33 @@ func main() { if err != nil { return luaErr(l, err) } + ares := []any{} db := handle.Value().(*sql.DB) rows, err := db.Query(query, params...) if err != nil { return luaErr(l, err) } - var res [][]any - cols, _ := rows.Columns() - for rows.Next() { - scans := make([]any, len(cols)) - for i, _ := range scans { - scans[i] = &scans[i] + go func() { + var res [][]any + cols, _ := rows.Columns() + for rows.Next() { + scans := make([]any, len(cols)) + for i, _ := range scans { + scans[i] = &scans[i] + } + var row []any + rows.Scan(scans...) + for _, v := range scans { + row = append(row, v) + } + res = append(res, row) } - var row []any - rows.Scan(scans...) - for _, v := range scans { - row = append(row, v) + for _, v := range res { + ares = append(ares, v) } - res = append(res, row) - } - ares := []any{} - for _, v := range res { - ares = append(ares, v) - } - err = l.PushArray(ares) - if err != nil { - return luaErr(l, err) - } + l.resume() + }() + l.yield() return luaOk(l, ares) } dbModule["query*"] = func (l *Lua) int { @@ -298,24 +300,28 @@ func main() { if err != nil { return luaErr(l, err) } - var res []map[string]any - cols, _ := rows.Columns() - for rows.Next() { - scans := make([]any, len(cols)) - for i, _ := range scans { - scans[i] = &scans[i] + ares := []any{} + go func() { + cols, _ := rows.Columns() + var res []map[string]any + for rows.Next() { + scans := make([]any, len(cols)) + for i, _ := range scans { + scans[i] = &scans[i] + } + row := make(map[string]any) + rows.Scan(scans...) + for i, v := range scans { + row[cols[i]] = v + } + res = append(res, row) } - row := make(map[string]any) - rows.Scan(scans...) - for i, v := range scans { - row[cols[i]] = v + for _, v := range res { + ares = append(ares, v) } - res = append(res, row) - } - ares := []any{} - for _, v := range res { - ares = append(ares, v) - } + l.resume() + }() + l.yield() return luaOk(l, ares) } dbModule["close"] = func (l *Lua) int { @@ -329,14 +335,14 @@ func main() { if err != nil { return luaErr(l, err) } - dbmu.Lock() + mu.Lock() for k, _db := range dbs { if db == _db { dbs[k] = nil break } } - dbmu.Unlock() + mu.Unlock() handle.Delete() return luaOk(l, nil) } diff --git a/worker.go b/worker.go index 54a3757..aa18b62 100644 --- a/worker.go +++ b/worker.go @@ -21,15 +21,8 @@ type HTTPResponse struct { Body string } -type Worker struct { - lua *Lua - routes map[string]LuaRef - started bool - mu sync.Mutex -} - func HandleHTTPRequest( - queue chan any, + queue chan *HTTPRequest, route string, req *http.Request, ) chan *HTTPResponse { @@ -42,6 +35,13 @@ func HandleHTTPRequest( return res } +type Worker struct { + lua *Lua + routes map[string]LuaRef + started bool + mu sync.Mutex +} + // NewWorker creates a new instance of Worker type. func NewWorker() *Worker { return &Worker { @@ -97,8 +97,8 @@ func (w *Worker) Start(argv []string, module map[string]any) error { return nil } -// Listen starts a goroutine listening/handling HTTP requests from the queue. -func (w *Worker) Listen(queue chan any) { +// Listen starts handling HTTP requests from the queue. +func (w *Worker) Listen(queue chan *HTTPRequest) { stringListToAny := func(slice []string) []any { res := []any{} for _, v := range slice { @@ -106,109 +106,127 @@ func (w *Worker) Listen(queue chan any) { } return res } - handle := func() { - defer w.lua.RestoreStackFunc()() - r := <- queue - - switch r.(type) { - case *HTTPRequest: - r := r.(*HTTPRequest) - if _, ok := w.routes[r.route]; !ok { - r.result <- &HTTPResponse { - Code: 404, - Headers: make(map[string]string), - Body: "not found", - } - log.Println("no corresponding route") - return - } + handle := func(r *HTTPRequest, yield func(), resume func() bool) { + l := w.lua.NewThread(yield, resume) + // Save a thread to a reference so it's not garbage collected + // before we are done with it. + ref := w.lua.PopToRef() + defer w.lua.Unref(ref) - w.lua.PushFromRef(w.routes[r.route]) - res := make(map[string]any) - res["method"] = r.request.Method - res["path"] = r.request.URL.Path - - fh := make(map[string]any) - for k := range r.request.Header { - fh[k] = r.request.Header.Get(k) + if _, ok := w.routes[r.route]; !ok { + r.result <- &HTTPResponse { + Code: 404, + Headers: make(map[string]string), + Body: "not found", } - res["headers"] = fh + log.Println("no corresponding route") + return + } - flatQr := make(map[string]any) - qr := r.request.URL.Query() - for k := range qr { - flatQr[k] = stringListToAny(qr[k]) - } - res["query"] = flatQr - - body, err := io.ReadAll(r.request.Body) - if err != nil { - r.result <- &HTTPResponse { - Code: 500, - Headers: make(map[string]string), - Body: "server error", - } - log.Println("could not read a request body:", err) - return - } - res["body"] = string(body) - - err = w.lua.PushObject(res) - if err != nil { - r.result <- &HTTPResponse { - Code: 500, - Headers: make(map[string]string), - Body: "server error", - } - log.Println("could not form a request to lua:", err) - return - } + l.PushFromRef(w.routes[r.route]) + res := make(map[string]any) + res["method"] = r.request.Method + res["path"] = r.request.URL.Path - err = w.lua.PCall(1, 3) + fh := make(map[string]any) + for k := range r.request.Header { + fh[k] = r.request.Header.Get(k) + } + res["headers"] = fh - if err != nil { - r.result <- &HTTPResponse { - Code: 500, - Headers: make(map[string]string), - Body: "server error", - } - log.Println("could not read a request body:", err) - return + flatQr := make(map[string]any) + qr := r.request.URL.Query() + for k := range qr { + flatQr[k] = stringListToAny(qr[k]) + } + res["query"] = flatQr + + body, err := io.ReadAll(r.request.Body) + if err != nil { + r.result <- &HTTPResponse{ + Code: 500, + Headers: make(map[string]string), + Body: "server error", } + log.Println("could not read request body:", err) + return + } + res["body"] = string(body) - code := w.lua.ToInt(-3) - rbody := w.lua.ToString(-1) - - // Parse headers. - headers := make(map[string]string) - w.lua.Pop(1) - w.lua.PushNil() - for w.lua.Next() { - if !w.lua.IsString(-2) || !w.lua.IsString(-2) { - w.lua.Pop(1) - continue - } - v := w.lua.ToString(-1) - w.lua.Pop(1) - // We must not pop the item key from the stack - // because otherwise C.lua_next won't work - // properly. - k := w.lua.ToString(-1) - headers[k] = v + err = l.PushObject(res) + if err != nil { + r.result <- &HTTPResponse{ + Code: 500, + Headers: make(map[string]string), + Body: "server error", } - r.result <- &HTTPResponse { - Code: int(code), - Headers: headers, - Body: rbody, + log.Println("could not form request to lua:", err) + return + } + + err = l.PCall(1, 3) + if err != nil { + r.result <- &HTTPResponse{ + Code: 500, + Headers: make(map[string]string), + Body: "server error", } + log.Println("could not process request:", err) + // TODO: print lua stack as well? + return + } - default: - log.Fatal("unknown request") + // TODO: probably it would be better to just use l.Scan() + // here but i'm not really sure if we want to have that + // overhead here. + code := l.ToInt(-3) + rbody := l.ToString(-1) + // Parse headers. + headers := make(map[string]string) + l.Pop(1) + l.PushNil() + for l.Next() { + if !l.IsString(-2) || !l.IsString(-2) { + l.Pop(1) + continue + } + v := l.ToString(-1) + l.Pop(1) + // We must not pop the item key from the stack + // because otherwise C.lua_next won't work + // properly. + k := l.ToString(-1) + headers[k] = v + } + r.result <- &HTTPResponse{ + Code: int(code), + Headers: headers, + Body: rbody, } } + resCh := make(chan func() bool, 4096) +outer: for { - handle() + select { + case r, ok := <- queue: + // accept new requests + if !ok { + break outer + } + resCh <- NewCoroutine(func(yield func(), resume func() bool) { + handle(r, yield, func () bool { + resCh <- resume + return true + }) + }) + case resume, ok := <-resCh: + // coroutine executor + if !ok { + break outer + } + resume() + } } } @@ -237,3 +255,28 @@ func (w *Worker) Stop() { func (w *Worker) HasSameLua(l *Lua) bool { return w.lua == l } + +func NewCoroutine(f func (yield func(), resume func() bool)) (resume func() bool) { + cin := make(chan bool) + cout := make(chan bool) + running := true + resume = func() bool { + if !running { + return false + } + cin <- true + <-cout + return true + } + yield := func() { + cout <- true + <-cin + } + go func() { + <-cin + f(yield, resume) + running = false + cout <- true + }() + return resume +} -- cgit v1.2.3