summaryrefslogtreecommitdiff
path: root/lua_common.go
diff options
context:
space:
mode:
Diffstat (limited to 'lua_common.go')
-rw-r--r--lua_common.go95
1 files changed, 95 insertions, 0 deletions
diff --git a/lua_common.go b/lua_common.go
index 4e58d69..29fdab9 100644
--- a/lua_common.go
+++ b/lua_common.go
@@ -5,8 +5,82 @@ import (
"slices"
"fmt"
"runtime/cgo"
+ "time"
+ "errors"
)
+const LUA_MULTRET = -1
+
+// Require loads and executes the file pushing results onto the Lua stack.
+func (l *Lua) Require(file string) error {
+ err := l.LoadFile(file)
+ if err != nil {
+ return errors.New("could not open the file:\n" + err.Error())
+ }
+ err = l.PCall(0, LUA_MULTRET, 1)
+ if err != nil {
+ return errors.New("could not execute the file:\n" + err.Error())
+ }
+ return nil
+}
+
+// PushObject recursively pushes string->any Go table onto the stack.
+func (l *Lua) PushObject(table map[string]any) error {
+ l.CreateTable(len(table))
+ for k, v := range table {
+ err := l.PushAny(v)
+ if err != nil {
+ return err
+ }
+ l.SetTableItem(k)
+ }
+ return nil
+}
+
+// PushAny pushes value v onto the stack.
+func (l *Lua) PushAny(v any) error {
+ switch v.(type) {
+ case nil:
+ l.PushNil()
+ case string:
+ v, _ := v.(string)
+ l.PushString(v)
+ case func (l *Lua) int:
+ v, _ := v.(func (l *Lua) int)
+ l.PushGoFunction(v)
+ case int:
+ v, _ := v.(int)
+ l.PushNumber(v)
+ case int64:
+ v, _ := v.(int64)
+ l.PushNumber(int(v))
+ case float64:
+ v, _ := v.(float64)
+ l.PushFloatNumber(v)
+ case bool:
+ v, _ := v.(bool)
+ l.PushBoolean(v)
+ case map[string]any:
+ v, _ := v.(map[string]any)
+ err := l.PushObject(v)
+ if err != nil {
+ return fmt.Errorf("object push error: ", err)
+ }
+ case []any:
+ v, _ := v.([]any)
+ err := l.PushArray(v)
+ if err != nil {
+ return fmt.Errorf("array push error: ", err)
+ }
+ case time.Time:
+ v, _ := v.(time.Time)
+ l.PushString(v.Format(time.DateTime))
+ default:
+ return fmt.Errorf("unsupported value type: %T", v)
+ }
+ return nil
+}
+
// Scan scans values from the Lua stack into vars according to their types.
func (l *Lua) Scan(vars ...any) error {
slices.Reverse(vars)
@@ -137,3 +211,24 @@ func (l *Lua) Scan(vars ...any) error {
}
return nil
}
+
+// RestoreStackFunc remembers the Lua stack size and then restores it when a
+// returned function is called. It's a helper function to avoid stack leakage.
+func (l *Lua) RestoreStackFunc() func () {
+ before := l.StackLen()
+ return func () {
+ after := l.StackLen()
+ diff := after - before
+ if diff == 0 {
+ return
+ } else if diff < 0 {
+ msg := fmt.Sprintf(
+ "too many stack pops: len before: %d, after: %d\n",
+ before,
+ after,
+ )
+ panic(msg)
+ }
+ l.SetTop(before)
+ }
+}