summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lua.go28
-rw-r--r--luajit.go25
-rw-r--r--main.go90
-rw-r--r--worker.go241
4 files changed, 233 insertions, 151 deletions
diff --git a/lua.go b/lua.go
index 50c6fc5..43cd3f3 100644
--- a/lua.go
+++ b/lua.go
@@ -32,10 +32,13 @@ import (
"unsafe"
)
-type LuaRef C.longlong
+type LuaRef C.int
// Lua is a wrapper around C Lua state with several conveniences.
type Lua struct {
l *C.lua_State
+ running **Lua
+ yield func()
+ resume func() bool
}
//export luna_run_go_func
@@ -48,6 +51,7 @@ func luna_run_go_func(f C.uintptr_t) C.int {
// Start opens the Lua context with all built-in libraries.
func (l *Lua) Start() {
l.l = C.luaL_newstate()
+ l.running = &l
C.luaL_openlibs(l.l)
}
@@ -72,6 +76,7 @@ func (l *Lua) Require(file string) error {
// PCall calls a function with the stack index (-nargs-1) expecting nresults
// number of results.
func (l *Lua) PCall(nargs int, nresults int) error {
+ *l.running = l
res := C.lua_pcallk(l.l, C.int(nargs), C.int(nresults), 0, 0, nil)
if res != C.LUA_OK {
errMsg := l.ToString(-1)
@@ -132,6 +137,8 @@ func (l *Lua) PushNil() {
// PushBoolean pushes a boolean onto the Lua stack.
func (l *Lua) PushBoolean(b bool) {
+ // For some reason this is needed here to not mess up threads.
+ *l.running = l
if b {
C.lua_pushboolean(l.l, 1)
} else {
@@ -161,7 +168,7 @@ func (l *Lua) PushString(str string) {
// method later.
func (l *Lua) PushGoFunction(f func (l *Lua) int) cgo.Handle {
h := cgo.NewHandle(func () int {
- return f(l)
+ return f(*l.running)
})
C.luna_push_function(l.l, C.uintptr_t(h))
return h
@@ -200,6 +207,9 @@ func (l *Lua) PushAny(v any) error {
case float64:
v, _ := v.(float64)
l.PushFloatNumber(v)
+ case bool:
+ v, _ := v.(bool)
+ l.PushBoolean(v)
case map[string]any:
v, _ := v.(map[string]any)
err := l.PushObject(v)
@@ -453,7 +463,15 @@ func (l *Lua) GetGlobal(name string) {
C.lua_getglobal(l.l, cstr)
}
-func (l *Lua) Error(msg string) {
- l.PushString(msg)
- C.lua_error(l.l)
+func (l *Lua) NewThread(yield func(), resume func() bool) *Lua {
+ return &Lua{
+ l: C.lua_newthread(l.l),
+ running: l.running,
+ resume: resume,
+ yield: yield,
+ }
+}
+
+func (l *Lua) Unref(ref LuaRef) {
+ C.luaL_unref(l.l, C.LUA_REGISTRYINDEX, C.int(ref))
}
diff --git a/luajit.go b/luajit.go
index 1a8b47d..9414787 100644
--- a/luajit.go
+++ b/luajit.go
@@ -32,10 +32,13 @@ import (
"unsafe"
)
-type LuaRef C.longlong
+type LuaRef C.int
// Lua is a wrapper around C Lua state with several conveniences.
type Lua struct {
l *C.lua_State
+ running **Lua
+ yield func()
+ resume func() bool
}
//export luna_run_go_func
@@ -48,6 +51,7 @@ func luna_run_go_func(f C.uintptr_t) C.int {
// Start opens the Lua context with all built-in libraries.
func (l *Lua) Start() {
l.l = C.luaL_newstate()
+ l.running = &l
C.luaL_openlibs(l.l)
}
@@ -73,6 +77,7 @@ func (l *Lua) Require(file string) error {
// PCall calls a function with the stack index (-nargs-1) expecting nresults
// number of results.
func (l *Lua) PCall(nargs int, nresults int) error {
+ *l.running = l
res := C.lua_pcall(l.l, C.int(nargs), C.int(nresults), 0)
if res != C.LUA_OK {
errMsg := l.ToString(-1)
@@ -128,6 +133,8 @@ func (l *Lua) PushNil() {
// PushBoolean pushes a boolean onto the Lua stack.
func (l *Lua) PushBoolean(b bool) {
+ // For some reason this is needed here to not mess up threads.
+ *l.running = l
if b {
C.lua_pushboolean(l.l, 1)
} else {
@@ -157,7 +164,7 @@ func (l *Lua) PushString(str string) {
// method later.
func (l *Lua) PushGoFunction(f func (l *Lua) int) cgo.Handle {
h := cgo.NewHandle(func () int {
- return f(l)
+ return f(*l.running)
})
C.luna_push_function(l.l, C.uintptr_t(h))
return h
@@ -448,7 +455,15 @@ func (l *Lua) GetGlobal(name string) {
C.lua_getfield(l.l, C.LUA_GLOBALSINDEX, cstr)
}
-func (l *Lua) Error(msg string) {
- l.PushString(msg)
- C.lua_error(l.l)
+func (l *Lua) NewThread(yield func(), resume func() bool) *Lua {
+ return &Lua{
+ l: C.lua_newthread(l.l),
+ running: l.running,
+ resume: resume,
+ yield: yield,
+ }
+}
+
+func (l *Lua) Unref(ref LuaRef) {
+ C.luaL_unref(l.l, C.LUA_REGISTRYINDEX, C.int(ref))
}
diff --git a/main.go b/main.go
index 93e7dad..3a031e2 100644
--- a/main.go
+++ b/main.go
@@ -34,14 +34,14 @@ func main() {
httpClient := &http.Client{}
// queue for http messages for workers to handle
- msgs := make(chan any, 4096)
+ msgs := make(chan *HTTPRequest, 4096)
mux := http.NewServeMux()
wrks := []*Worker{}
// track routes for mux to avoid registering the same route twice
routes := make(map[string]bool)
// track open dbs to close them on exit
dbs := []*sql.DB{}
- dbmu := sync.Mutex{}
+ mu := sync.Mutex{}
defer func() {
for _, db := range dbs {
if db == nil {
@@ -77,6 +77,7 @@ func main() {
if ok {
return luaOk(l, nil)
}
+ mu.Lock()
routes[route] = true
mux.HandleFunc(
route,
@@ -98,6 +99,7 @@ func main() {
)
},
)
+ mu.Unlock()
return luaOk(l, nil)
}
routeModule["static"] = func (l *Lua) int {
@@ -170,9 +172,9 @@ func main() {
if err != nil {
return luaErr(l, err)
}
- dbmu.Lock()
+ mu.Lock()
dbs = append(dbs, db)
- dbmu.Unlock()
+ mu.Unlock()
h := cgo.NewHandle(db)
return luaOk(l, int(h))
}
@@ -256,33 +258,33 @@ func main() {
if err != nil {
return luaErr(l, err)
}
+ ares := []any{}
db := handle.Value().(*sql.DB)
rows, err := db.Query(query, params...)
if err != nil {
return luaErr(l, err)
}
- var res [][]any
- cols, _ := rows.Columns()
- for rows.Next() {
- scans := make([]any, len(cols))
- for i, _ := range scans {
- scans[i] = &scans[i]
+ go func() {
+ var res [][]any
+ cols, _ := rows.Columns()
+ for rows.Next() {
+ scans := make([]any, len(cols))
+ for i, _ := range scans {
+ scans[i] = &scans[i]
+ }
+ var row []any
+ rows.Scan(scans...)
+ for _, v := range scans {
+ row = append(row, v)
+ }
+ res = append(res, row)
}
- var row []any
- rows.Scan(scans...)
- for _, v := range scans {
- row = append(row, v)
+ for _, v := range res {
+ ares = append(ares, v)
}
- res = append(res, row)
- }
- ares := []any{}
- for _, v := range res {
- ares = append(ares, v)
- }
- err = l.PushArray(ares)
- if err != nil {
- return luaErr(l, err)
- }
+ l.resume()
+ }()
+ l.yield()
return luaOk(l, ares)
}
dbModule["query*"] = func (l *Lua) int {
@@ -298,24 +300,28 @@ func main() {
if err != nil {
return luaErr(l, err)
}
- var res []map[string]any
- cols, _ := rows.Columns()
- for rows.Next() {
- scans := make([]any, len(cols))
- for i, _ := range scans {
- scans[i] = &scans[i]
+ ares := []any{}
+ go func() {
+ cols, _ := rows.Columns()
+ var res []map[string]any
+ for rows.Next() {
+ scans := make([]any, len(cols))
+ for i, _ := range scans {
+ scans[i] = &scans[i]
+ }
+ row := make(map[string]any)
+ rows.Scan(scans...)
+ for i, v := range scans {
+ row[cols[i]] = v
+ }
+ res = append(res, row)
}
- row := make(map[string]any)
- rows.Scan(scans...)
- for i, v := range scans {
- row[cols[i]] = v
+ for _, v := range res {
+ ares = append(ares, v)
}
- res = append(res, row)
- }
- ares := []any{}
- for _, v := range res {
- ares = append(ares, v)
- }
+ l.resume()
+ }()
+ l.yield()
return luaOk(l, ares)
}
dbModule["close"] = func (l *Lua) int {
@@ -329,14 +335,14 @@ func main() {
if err != nil {
return luaErr(l, err)
}
- dbmu.Lock()
+ mu.Lock()
for k, _db := range dbs {
if db == _db {
dbs[k] = nil
break
}
}
- dbmu.Unlock()
+ mu.Unlock()
handle.Delete()
return luaOk(l, nil)
}
diff --git a/worker.go b/worker.go
index 54a3757..aa18b62 100644
--- a/worker.go
+++ b/worker.go
@@ -21,15 +21,8 @@ type HTTPResponse struct {
Body string
}
-type Worker struct {
- lua *Lua
- routes map[string]LuaRef
- started bool
- mu sync.Mutex
-}
-
func HandleHTTPRequest(
- queue chan any,
+ queue chan *HTTPRequest,
route string,
req *http.Request,
) chan *HTTPResponse {
@@ -42,6 +35,13 @@ func HandleHTTPRequest(
return res
}
+type Worker struct {
+ lua *Lua
+ routes map[string]LuaRef
+ started bool
+ mu sync.Mutex
+}
+
// NewWorker creates a new instance of Worker type.
func NewWorker() *Worker {
return &Worker {
@@ -97,8 +97,8 @@ func (w *Worker) Start(argv []string, module map[string]any) error {
return nil
}
-// Listen starts a goroutine listening/handling HTTP requests from the queue.
-func (w *Worker) Listen(queue chan any) {
+// Listen starts handling HTTP requests from the queue.
+func (w *Worker) Listen(queue chan *HTTPRequest) {
stringListToAny := func(slice []string) []any {
res := []any{}
for _, v := range slice {
@@ -106,109 +106,127 @@ func (w *Worker) Listen(queue chan any) {
}
return res
}
- handle := func() {
- defer w.lua.RestoreStackFunc()()
- r := <- queue
-
- switch r.(type) {
- case *HTTPRequest:
- r := r.(*HTTPRequest)
- if _, ok := w.routes[r.route]; !ok {
- r.result <- &HTTPResponse {
- Code: 404,
- Headers: make(map[string]string),
- Body: "not found",
- }
- log.Println("no corresponding route")
- return
- }
+ handle := func(r *HTTPRequest, yield func(), resume func() bool) {
+ l := w.lua.NewThread(yield, resume)
+ // Save a thread to a reference so it's not garbage collected
+ // before we are done with it.
+ ref := w.lua.PopToRef()
+ defer w.lua.Unref(ref)
- w.lua.PushFromRef(w.routes[r.route])
- res := make(map[string]any)
- res["method"] = r.request.Method
- res["path"] = r.request.URL.Path
-
- fh := make(map[string]any)
- for k := range r.request.Header {
- fh[k] = r.request.Header.Get(k)
+ if _, ok := w.routes[r.route]; !ok {
+ r.result <- &HTTPResponse {
+ Code: 404,
+ Headers: make(map[string]string),
+ Body: "not found",
}
- res["headers"] = fh
+ log.Println("no corresponding route")
+ return
+ }
- flatQr := make(map[string]any)
- qr := r.request.URL.Query()
- for k := range qr {
- flatQr[k] = stringListToAny(qr[k])
- }
- res["query"] = flatQr
-
- body, err := io.ReadAll(r.request.Body)
- if err != nil {
- r.result <- &HTTPResponse {
- Code: 500,
- Headers: make(map[string]string),
- Body: "server error",
- }
- log.Println("could not read a request body:", err)
- return
- }
- res["body"] = string(body)
-
- err = w.lua.PushObject(res)
- if err != nil {
- r.result <- &HTTPResponse {
- Code: 500,
- Headers: make(map[string]string),
- Body: "server error",
- }
- log.Println("could not form a request to lua:", err)
- return
- }
+ l.PushFromRef(w.routes[r.route])
+ res := make(map[string]any)
+ res["method"] = r.request.Method
+ res["path"] = r.request.URL.Path
- err = w.lua.PCall(1, 3)
+ fh := make(map[string]any)
+ for k := range r.request.Header {
+ fh[k] = r.request.Header.Get(k)
+ }
+ res["headers"] = fh
- if err != nil {
- r.result <- &HTTPResponse {
- Code: 500,
- Headers: make(map[string]string),
- Body: "server error",
- }
- log.Println("could not read a request body:", err)
- return
+ flatQr := make(map[string]any)
+ qr := r.request.URL.Query()
+ for k := range qr {
+ flatQr[k] = stringListToAny(qr[k])
+ }
+ res["query"] = flatQr
+
+ body, err := io.ReadAll(r.request.Body)
+ if err != nil {
+ r.result <- &HTTPResponse{
+ Code: 500,
+ Headers: make(map[string]string),
+ Body: "server error",
}
+ log.Println("could not read request body:", err)
+ return
+ }
+ res["body"] = string(body)
- code := w.lua.ToInt(-3)
- rbody := w.lua.ToString(-1)
-
- // Parse headers.
- headers := make(map[string]string)
- w.lua.Pop(1)
- w.lua.PushNil()
- for w.lua.Next() {
- if !w.lua.IsString(-2) || !w.lua.IsString(-2) {
- w.lua.Pop(1)
- continue
- }
- v := w.lua.ToString(-1)
- w.lua.Pop(1)
- // We must not pop the item key from the stack
- // because otherwise C.lua_next won't work
- // properly.
- k := w.lua.ToString(-1)
- headers[k] = v
+ err = l.PushObject(res)
+ if err != nil {
+ r.result <- &HTTPResponse{
+ Code: 500,
+ Headers: make(map[string]string),
+ Body: "server error",
}
- r.result <- &HTTPResponse {
- Code: int(code),
- Headers: headers,
- Body: rbody,
+ log.Println("could not form request to lua:", err)
+ return
+ }
+
+ err = l.PCall(1, 3)
+ if err != nil {
+ r.result <- &HTTPResponse{
+ Code: 500,
+ Headers: make(map[string]string),
+ Body: "server error",
}
+ log.Println("could not process request:", err)
+ // TODO: print lua stack as well?
+ return
+ }
- default:
- log.Fatal("unknown request")
+ // TODO: probably it would be better to just use l.Scan()
+ // here but i'm not really sure if we want to have that
+ // overhead here.
+ code := l.ToInt(-3)
+ rbody := l.ToString(-1)
+ // Parse headers.
+ headers := make(map[string]string)
+ l.Pop(1)
+ l.PushNil()
+ for l.Next() {
+ if !l.IsString(-2) || !l.IsString(-2) {
+ l.Pop(1)
+ continue
+ }
+ v := l.ToString(-1)
+ l.Pop(1)
+ // We must not pop the item key from the stack
+ // because otherwise C.lua_next won't work
+ // properly.
+ k := l.ToString(-1)
+ headers[k] = v
+ }
+ r.result <- &HTTPResponse{
+ Code: int(code),
+ Headers: headers,
+ Body: rbody,
}
}
+ resCh := make(chan func() bool, 4096)
+outer:
for {
- handle()
+ select {
+ case r, ok := <- queue:
+ // accept new requests
+ if !ok {
+ break outer
+ }
+ resCh <- NewCoroutine(func(yield func(), resume func() bool) {
+ handle(r, yield, func () bool {
+ resCh <- resume
+ return true
+ })
+ })
+ case resume, ok := <-resCh:
+ // coroutine executor
+ if !ok {
+ break outer
+ }
+ resume()
+ }
}
}
@@ -237,3 +255,28 @@ func (w *Worker) Stop() {
func (w *Worker) HasSameLua(l *Lua) bool {
return w.lua == l
}
+
+func NewCoroutine(f func (yield func(), resume func() bool)) (resume func() bool) {
+ cin := make(chan bool)
+ cout := make(chan bool)
+ running := true
+ resume = func() bool {
+ if !running {
+ return false
+ }
+ cin <- true
+ <-cout
+ return true
+ }
+ yield := func() {
+ cout <- true
+ <-cin
+ }
+ go func() {
+ <-cin
+ f(yield, resume)
+ running = false
+ cout <- true
+ }()
+ return resume
+}