From cc8d3cf4a7cd296d8c409fef5db9df138c3b238d Mon Sep 17 00:00:00 2001 From: unwox Date: Sat, 14 Sep 2024 15:54:27 +0600 Subject: add luna.db.* module --- go.mod | 2 + go.sum | 2 + lua.go | 138 +++++++++++++++++++++++++++++++--------- main.go | 213 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++- worker.go | 17 ++--- 5 files changed, 332 insertions(+), 40 deletions(-) create mode 100644 go.sum diff --git a/go.mod b/go.mod index 0579859..6d81a5a 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module luna go 1.22.2 + +require github.com/mattn/go-sqlite3 v1.14.23 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..32531fa --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/mattn/go-sqlite3 v1.14.23 h1:gbShiuAP1W5j9UOksQ06aiiqPMxYecovVGwmTxWtuw0= +github.com/mattn/go-sqlite3 v1.14.23/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= diff --git a/lua.go b/lua.go index 8401d37..51af7d1 100644 --- a/lua.go +++ b/lua.go @@ -25,6 +25,7 @@ import ( "reflect" "runtime/cgo" "slices" + "time" "unsafe" ) @@ -66,7 +67,7 @@ func (l *Lua) Require(file string) error { return nil } -// PCall calls a function with the stack index (-nargs - 1) expecting nresults +// PCall calls a function with the stack index (-nargs-1) expecting nresults // number of results. func (l *Lua) PCall(nargs int, nresults int) error { res := C.lua_pcallk(l.l, C.int(nargs), C.int(nresults), 0, 0, nil) @@ -122,6 +123,11 @@ func (l *Lua) PushNumber(num int) { C.lua_pushnumber(l.l, C.double(num)) } +// PushNumber pushes the number onto the Lua stack. +func (l *Lua) PushFloatNumber(num float64) { + C.lua_pushnumber(l.l, C.double(num)) +} + // PushString pushes the string onto the Lua stack. func (l *Lua) PushString(str string) { cstr := C.CString(str) @@ -153,36 +159,64 @@ func (l *Lua) SetTableItem(key string) { C.lua_setfield(l.l, -2, C.CString(key)) } -// PushTable pushes string->any table onto the stack. +// PushAny pushes value v onto the stack. +func (l *Lua) PushAny(v any) error { + switch v.(type) { + case string: + v, _ := v.(string) + l.PushString(v) + case func (l *Lua) int: + v, _ := v.(func (l *Lua) int) + l.PushGoFunction(v) + case int: + v, _ := v.(int) + l.PushNumber(v) + case int64: + v, _ := v.(int64) + l.PushNumber(int(v)) + case float64: + v, _ := v.(float64) + l.PushFloatNumber(v) + case map[string]any: + v, _ := v.(map[string]any) + l.PushObject(v) + case []any: + v, _ := v.([]any) + l.PushArray(v) + case time.Time: + v, _ := v.(time.Time) + l.PushString(v.Format(time.DateTime)) + default: + return fmt.Errorf("unsupported value type: %T", v) + } + return nil +} + +// PushObject recursively pushes string->any Go table onto the stack. func (l *Lua) PushObject(table map[string]any) error { - var pushTable func(t map[string]any) error - pushTable = func (t map[string]any) error { - l.CreateTable(len(t)) - for k, v := range t { - switch v.(type) { - case string: - v, _ := v.(string) - l.PushString(v) - case func (l *Lua) int: - v, _ := v.(func (l *Lua) int) - l.PushGoFunction(v) - case int: - v, _ := v.(int) - l.PushNumber(v) - case map[string]any: - v, _ := v.(map[string]any) - pushTable(v) - default: - return fmt.Errorf( - "unsupported value type: %T", - v, - ) - } - l.SetTableItem(k) + l.CreateTable(len(table)) + for k, v := range table { + err := l.PushAny(v) + if err != nil { + return err } - return nil + l.SetTableItem(k) } - return pushTable(table) + return nil +} + +// PushArray recursively pushes an array of Go values onto the stack. +func (l *Lua) PushArray(array []any) error { + l.CreateTable(len(array)) + for k, v := range array { + l.PushNumber(k + 1) + err := l.PushAny(v) + if err != nil { + return err + } + C.lua_settable(l.l, C.int(-3)) + } + return nil } // PushFromRef pushes a value from registry ref onto the stack. @@ -190,6 +224,16 @@ func (l *Lua) PushFromRef(ref LuaRef) { C.lua_rawgeti(l.l, C.LUA_REGISTRYINDEX, C.longlong(ref)); } +// Type returns type of the value sitting at n index on the stack. +func (l *Lua) Type(n int) int { + return int(C.lua_type(l.l, C.int(n))) +} + +// IsNil checks if the stack contains nil under the given index. +func (l *Lua) IsNil(index int) bool { + return C.lua_type(l.l, C.int(index)) == C.LUA_TNIL +} + // IsString checks if the stack contains a string under the given index. func (l *Lua) IsString(index int) bool { return C.lua_isstring(l.l, C.int(index)) == 1 @@ -251,6 +295,15 @@ func (l *Lua) Scan(vars ...any) error { ) } *v.(*LuaRef) = l.PopToRef() + } else if t.String() == "cgo.Handle" { + if l.IsNumber(-1) == false { + return fmt.Errorf( + "passed arg #%d must be Go handler", + len(vars)-i, + ) + } + *v.(*cgo.Handle) = cgo.Handle(uintptr(l.ToInt(-1))) + l.Pop(1) } else if tk == reflect.String { if l.IsString(-1) == false { return fmt.Errorf( @@ -287,13 +340,38 @@ func (l *Lua) Scan(vars ...any) error { len(vars)-i, ) } - iv := l.ToString(-1) + v := l.ToString(-1) l.Pop(1) // We must not pop the item key from the stack // because otherwise C.lua_next won't work // properly. k := l.ToString(-1) - (*vm)[k] = iv + (*vm)[k] = v + } + l.Pop(1) + } else if tk == reflect.Slice && + t.Elem().Kind() == reflect.Interface { + if !l.IsTable(-1) { + return fmt.Errorf( + "passed arg #%d must be a table", + len(vars)-i, + ) + } + va, _ := v.(*[]any) + l.PushNil() + for l.Next() { + var v any = nil + if l.IsString(-1) { + v = l.ToString(-1) + } else if l.IsNumber(-1) { + v = l.ToInt(-1) + } else if l.IsNil(-1) { + v = nil + } else { + return fmt.Errorf("unknown value in array") + } + l.Pop(1) + (*va) = append((*va), v) } l.Pop(1) } else { diff --git a/main.go b/main.go index bde394e..eaa3ec6 100644 --- a/main.go +++ b/main.go @@ -2,6 +2,7 @@ package main import ( "bufio" + "database/sql" "errors" "flag" "fmt" @@ -10,8 +11,11 @@ import ( "net/http" "os" "runtime" + "runtime/cgo" "strings" "time" + + _ "github.com/mattn/go-sqlite3" ) func main() { @@ -28,7 +32,7 @@ func main() { httpClient := &http.Client{} // queue for http messages for workers to handle - msgs := make(chan interface{}, 4096) + msgs := make(chan any, 4096) mux := http.NewServeMux() wrks := []*Worker{} // track routes for mux to avoid registering the same route twice @@ -128,9 +132,214 @@ func main() { return 3 } + // define luna.db module + dbModule := make(map[string]any) + dbModule["open"] = func (l *Lua) int { + var file string + err := l.Scan(&file) + if err != nil { + // FIXME: handle. + return 0 + } + r, err := sql.Open("sqlite3", file) + if err != nil { + // FIXME: handle. + return 0 + } + h := cgo.NewHandle(r) + l.PushNumber(int(h)) + return 1 + } + dbModule["begin"] = func (l *Lua) int { + var handle cgo.Handle + err := l.Scan(&handle) + if err != nil { + // FIXME: handle. + return 0 + } + db := handle.Value().(*sql.DB) + tx, err := db.Begin() + if err != nil { + // FIXME: handle. + return 0 + } + txh := cgo.NewHandle(tx) + l.PushNumber(int(txh)) + return 1 + } + dbModule["commit"] = func (l *Lua) int { + var handle cgo.Handle + err := l.Scan(&handle) + if err != nil { + // FIXME: handle. + return 0 + } + tx := handle.Value().(*sql.Tx) + err = tx.Commit() + if err != nil { + // FIXME: handle. + return 0 + } + handle.Delete() + return 0 + } + dbModule["rollback"] = func (l *Lua) int { + var handle cgo.Handle + err := l.Scan(&handle) + if err != nil { + // FIXME: handle. + return 0 + } + tx := handle.Value().(*sql.Tx) + err = tx.Rollback() + if err != nil { + // FIXME: handle. + return 0 + } + handle.Delete() + return 0 + } + dbModule["exec-tx"] = func (l *Lua) int { + var handle cgo.Handle + var query string + var params []any + err := l.Scan(&handle, &query, ¶ms) + if err != nil { + fmt.Println(err) + // FIXME: handle. + return 0 + } + tx := handle.Value().(*sql.Tx) + _, err = tx.Exec(query, params...) + if err != nil { + fmt.Println(err) + // FIXME: handle. + return 0 + } + return 0 + } + dbModule["exec"] = func (l *Lua) int { + var handle cgo.Handle + var query string + var params []any + err := l.Scan(&handle, &query, ¶ms) + if err != nil { + fmt.Println(err) + // FIXME: handle. + return 0 + } + db := handle.Value().(*sql.DB) + _, err = db.Exec(query, params...) + if err != nil { + fmt.Println(err) + // FIXME: handle. + return 0 + } + return 0 + } + dbModule["query"] = func (l *Lua) int { + var handle cgo.Handle + var query string + var params []any + err := l.Scan(&handle, &query, ¶ms) + if err != nil { + fmt.Println(err) + // FIXME: handle. + return 0 + } + db := handle.Value().(*sql.DB) + rows, err := db.Query(query, params...) + if err != nil { + // FIXME: handle. + return 0 + } + var res [][]any + cols, _ := rows.Columns() + for rows.Next() { + scans := make([]any, len(cols)) + for i, _ := range scans { + scans[i] = &scans[i] + } + var row []any + rows.Scan(scans...) + for _, v := range scans { + row = append(row, v) + } + res = append(res, row) + } + var ares []any + for _, v := range res { + ares = append(ares, v) + } + err = l.PushArray(ares) + if err != nil { + fmt.Println(err) + // FIXME: handle + return 0 + } + return 1 + } + dbModule["query*"] = func (l *Lua) int { + var handle cgo.Handle + var query string + var params []any + err := l.Scan(&handle, &query, ¶ms) + if err != nil { + // FIXME: handle. + return 0 + } + db := handle.Value().(*sql.DB) + rows, err := db.Query(query, params...) + if err != nil { + // FIXME: handle. + return 0 + } + var res []map[string]any + cols, _ := rows.Columns() + for rows.Next() { + scans := make([]any, len(cols)) + for i, _ := range scans { + scans[i] = &scans[i] + } + row := make(map[string]any) + rows.Scan(scans...) + for i, v := range scans { + row[cols[i]] = v + } + res = append(res, row) + } + var ares []any + for _, v := range res { + ares = append(ares, v) + } + err = l.PushArray(ares) + if err != nil { + // FIXME: handle + return 0 + } + return 1 + } + dbModule["close"] = func (l *Lua) int { + var handle cgo.Handle + err := l.Scan(&handle) + if err != nil { + // FIXME: handle. + return 0 + } + db := handle.Value().(*sql.DB) + err = db.Close() + if err != nil { + // FIXME: handle. + return 0 + } + handle.Delete() + return 0 + } + module := make(map[string]any) module["router"] = routeModule module["http"] = httpModule + module["db"] = dbModule // start workers for i := 0; i < *wrksNum; i++ { @@ -177,7 +386,7 @@ func printUsage() { func mustExist(file string) { if _, err := os.Stat(file); errors.Is(err, os.ErrNotExist) { - fmt.Printf(`file "%s" does not exist\n`, file) + fmt.Printf("file \"%s\" does not exist\n", file) os.Exit(1) } } diff --git a/worker.go b/worker.go index b4cae12..2df98a0 100644 --- a/worker.go +++ b/worker.go @@ -64,8 +64,15 @@ func (w *Worker) Start(filename string, module map[string]any) error { } w.lua.Start() defer w.lua.RestoreStackFunc()() - w.initLunaModule(module) - err := w.lua.Require(filename) + + // registers the module in the Lua context + err := w.lua.PushObject(module) + if err != nil { + return err + } + w.lua.SetGlobal("luna") + + err = w.lua.Require(filename) if err != nil { return err } @@ -206,9 +213,3 @@ func (w *Worker) Stop() { func (w *Worker) HasSameLua(l *Lua) bool { return w.lua == l } - -// initLunaModule registers the module in the Lua context. -func (w *Worker) initLunaModule(module map[string]any) { - w.lua.PushObject(module) - w.lua.SetGlobal("luna") -} -- cgit v1.2.3