diff --git a/component.go b/component.go new file mode 100644 index 0000000..9a859e0 --- /dev/null +++ b/component.go @@ -0,0 +1,15 @@ +package ken + +import ( + "github.com/bwmarrin/discordgo" + "github.com/zekrotja/ken/util" +) + +type MessageComponent struct { + discordgo.MessageComponent +} + +func (t MessageComponent) GetValue() string { + val, _ := util.GetFieldValue(t.MessageComponent, "Value") + return val +} diff --git a/componentbuilder.go b/componentbuilder.go index 8f3135b..88fb4d4 100644 --- a/componentbuilder.go +++ b/componentbuilder.go @@ -1,9 +1,8 @@ package ken import ( - "reflect" - "github.com/bwmarrin/discordgo" + "github.com/zekrotja/ken/util" ) // ComponentAssembler helps to build the message @@ -228,15 +227,8 @@ func (t *ComponentBuilder) Build() (unreg func() error, err error) { } func getCustomId(component discordgo.MessageComponent) string { - componentValue := reflect.ValueOf(component) - customIdValue := componentValue.FieldByName("CustomID") - - var customId string - if customIdValue.IsValid() { - customId = customIdValue.String() - } - - return customId + val, _ := util.GetFieldValue(component, "CustomID") + return val } func removeComponentRecursive(components []discordgo.MessageComponent, customKey string) []discordgo.MessageComponent { diff --git a/components.go b/components.go index 0ddcde3..5ee73bf 100644 --- a/components.go +++ b/components.go @@ -16,6 +16,8 @@ import ( // the execution of the handler. type ComponentHandlerFunc func(ctx ComponentContext) bool +type ModalHandlerFunc func(ctx ModalContext) bool + // ComponentHandler keeps a registry of component handler // callbacks to be executed when a given component has // been interacted with. @@ -23,10 +25,12 @@ type ComponentHandler struct { ken *Ken unregisterFunc func() - mtx sync.RWMutex - handlers map[string]ComponentHandlerFunc + mtx sync.RWMutex + handlers map[string]ComponentHandlerFunc + modalHandlers map[string]ModalHandlerFunc - ctxPool sync.Pool + ctxPool sync.Pool + modalCtxPool sync.Pool } // NewComponentHandler returns a new instance of @@ -37,12 +41,18 @@ func NewComponentHandler(ken *Ken) *ComponentHandler { t.ken = ken t.handlers = make(map[string]ComponentHandlerFunc) + t.modalHandlers = make(map[string]ModalHandlerFunc) t.unregisterFunc = t.ken.s.AddHandler(t.handle) t.ctxPool = sync.Pool{ New: func() interface{} { return &ComponentCtx{} }, } + t.modalCtxPool = sync.Pool{ + New: func() interface{} { + return &ModalCtx{} + }, + } return &t } @@ -110,11 +120,43 @@ func (t *ComponentHandler) UnregisterDiscordHandler() { t.unregisterFunc() } -func (t *ComponentHandler) handle(_ *discordgo.Session, e *discordgo.InteractionCreate) { - if e.Type != discordgo.InteractionMessageComponent { +func (t *ComponentHandler) registerModalHandler(customId string, handler ModalHandlerFunc) func() { + t.mtx.Lock() + defer t.mtx.Unlock() + t.modalHandlers[customId] = func(ctx ModalContext) bool { + ok := handler(ctx) + if ok { + t.unregisterModalhandler(customId) + } + return ok + } + + return func() { + t.unregisterModalhandler(customId) + } +} + +func (t *ComponentHandler) unregisterModalhandler(customId ...string) { + if len(customId) == 0 { return } + t.mtx.Lock() + defer t.mtx.Unlock() + for _, id := range customId { + delete(t.modalHandlers, id) + } +} + +func (t *ComponentHandler) handle(_ *discordgo.Session, e *discordgo.InteractionCreate) { + switch e.Type { + case discordgo.InteractionMessageComponent: + t.handleMessageComponent(e) + case discordgo.InteractionModalSubmit: + t.handleModalSubmit(e) + } +} +func (t *ComponentHandler) handleMessageComponent(e *discordgo.InteractionCreate) { data := e.MessageComponentData() t.mtx.RLock() @@ -139,3 +181,29 @@ func (t *ComponentHandler) handle(_ *discordgo.Session, e *discordgo.Interaction handler(ctx) } + +func (t *ComponentHandler) handleModalSubmit(e *discordgo.InteractionCreate) { + data := e.ModalSubmitData() + + t.mtx.RLock() + handler, ok := t.modalHandlers[data.CustomID] + t.mtx.RUnlock() + + if !ok { + return + } + + ctx := t.modalCtxPool.Get().(*ModalCtx) + ctx.Data = data + ctx.Ephemeral = false + ctx.Event = e + ctx.Session = t.ken.s + ctx.Ken = t.ken + ctx.responded = false + + defer func() { + t.modalCtxPool.Put(ctx) + }() + + handler(ctx) +} diff --git a/context.go b/context.go index 539603e..46cc0c3 100644 --- a/context.go +++ b/context.go @@ -2,6 +2,7 @@ package ken import ( "github.com/bwmarrin/discordgo" + "github.com/rs/xid" ) // ContextResponder defines the implementation of an @@ -357,8 +358,12 @@ type ComponentContext interface { ContextResponder GetData() discordgo.MessageComponentInteractionData + OpenModal( + title string, + content string, + build func(b ComponentAssembler), + ) (<-chan ModalContext, error) } - type ComponentCtx struct { CtxResponder @@ -370,3 +375,76 @@ var _ ComponentContext = (*ComponentCtx)(nil) func (c *ComponentCtx) GetData() discordgo.MessageComponentInteractionData { return c.Data } + +func (c *ComponentCtx) OpenModal( + title string, + content string, + build func(b ComponentAssembler), +) (<-chan ModalContext, error) { + b := newComponentAssembler() + build(b) + + modalId := xid.New().String() + err := c.Respond(&discordgo.InteractionResponse{ + Type: discordgo.InteractionResponseModal, + Data: &discordgo.InteractionResponseData{ + CustomID: modalId, + Title: title, + Content: content, + Components: b.components, + }, + }) + if err != nil { + return nil, err + } + + cCtx := make(chan ModalContext, 1) + + c.Ken.componentHandler.registerModalHandler(modalId, func(ctx ModalContext) bool { + cCtx <- ctx + return true + }) + + return cCtx, nil +} + +type ModalContext interface { + ContextResponder + + GetData() discordgo.ModalSubmitInteractionData + GetComponentByID(customId string) MessageComponent +} + +type ModalCtx struct { + CtxResponder + + Data discordgo.ModalSubmitInteractionData +} + +var _ ModalContext = (*ModalCtx)(nil) + +func (c *ModalCtx) GetData() discordgo.ModalSubmitInteractionData { + return c.Data +} + +func (c *ModalCtx) GetComponentByID(customId string) MessageComponent { + return MessageComponent{getComponentByID(customId, c.GetData().Components)} +} + +func getComponentByID( + customId string, + comps []discordgo.MessageComponent, +) discordgo.MessageComponent { + for _, comp := range comps { + if row, ok := comp.(*discordgo.ActionsRow); ok { + found := getComponentByID(customId, row.Components) + if found != nil { + return found + } + } + if customId == getCustomId(comp) { + return comp + } + } + return nil +} diff --git a/examples/components/commands/modal.go b/examples/components/commands/modal.go new file mode 100644 index 0000000..43ede87 --- /dev/null +++ b/examples/components/commands/modal.go @@ -0,0 +1,89 @@ +package commands + +import ( + "fmt" + + "github.com/bwmarrin/discordgo" + "github.com/zekrotja/ken" +) + +type ModalCommand struct{} + +var ( + _ ken.SlashCommand = (*TestCommand)(nil) + _ ken.DmCapable = (*TestCommand)(nil) +) + +func (c *ModalCommand) Name() string { + return "modal" +} + +func (c *ModalCommand) Description() string { + return "Modal Test Command" +} + +func (c *ModalCommand) Version() string { + return "1.0.0" +} + +func (c *ModalCommand) Type() discordgo.ApplicationCommandType { + return discordgo.ChatApplicationCommand +} + +func (c *ModalCommand) Options() []*discordgo.ApplicationCommandOption { + return []*discordgo.ApplicationCommandOption{} +} + +func (c *ModalCommand) IsDmCapable() bool { + return false +} + +func (c *ModalCommand) Run(ctx *ken.Ctx) (err error) { + if err = ctx.Defer(); err != nil { + return + } + + fum := ctx.FollowUpEmbed(&discordgo.MessageEmbed{ + Description: "How are you?", + }) + if fum.HasError() { + return fum.Error + } + + _, err = fum.AddComponents(). + AddActionsRow(func(b ken.ComponentAssembler) { + b.Add(discordgo.Button{ + CustomID: "open-modal", + Label: "Write it!", + Style: discordgo.PrimaryButton, + }, func(ctx ken.ComponentContext) bool { + cCtx, err := ctx.OpenModal("Hello world", "Lorem ipsum ...", func(b ken.ComponentAssembler) { + b.AddActionsRow(func(b ken.ComponentAssembler) { + b.Add(discordgo.TextInput{ + CustomID: "text-input", + Label: "How are you?", + Style: discordgo.TextInputShort, + Required: true, + MaxLength: 1000, + }, nil) + }) + }) + + if err != nil { + fmt.Println("Error:", err) + return false + } + + embCtx := <-cCtx + + resp := embCtx.GetComponentByID("text-input").GetValue() + embCtx.RespondEmbed(&discordgo.MessageEmbed{ + Description: fmt.Sprintf(`"%s" - ok, thats cool`, resp), + }) + return true + }) + }, true). + Build() + + return err +} diff --git a/examples/components/main.go b/examples/components/main.go index 964a5d1..0262ad6 100644 --- a/examples/components/main.go +++ b/examples/components/main.go @@ -31,7 +31,10 @@ func main() { }) must(err) - must(k.RegisterCommands(new(commands.TestCommand))) + must(k.RegisterCommands( + new(commands.TestCommand), + new(commands.ModalCommand), + )) defer k.Unregister() diff --git a/go.mod b/go.mod index 108be9b..0d2926a 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,7 @@ require ( github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/go-redis/redis/v8 v8.11.5 // indirect + github.com/rs/xid v1.4.0 github.com/zekroTJA/timedmap v1.4.0 go.opentelemetry.io/otel v1.9.0 // indirect go.opentelemetry.io/otel/metric v0.31.0 // indirect diff --git a/go.sum b/go.sum index fe312dc..16a5824 100644 --- a/go.sum +++ b/go.sum @@ -53,6 +53,8 @@ github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7J github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= github.com/onsi/gomega v1.10.5/go.mod h1:gza4q3jKQJijlu05nKWRCW/GavJumGt8aNRxWg7mt48= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rs/xid v1.4.0 h1:qd7wPTDkN6KQx2VmMBLrpHkiyQwgFXRnkOLacUiaSNY= +github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= diff --git a/info.go b/info.go index f87cd76..add5376 100644 --- a/info.go +++ b/info.go @@ -37,7 +37,7 @@ func mustToJson(v interface{}) string { return string(d) } -type keyTransformerFunc func(string) string +type KeyTransformerFunc func(string) string // GetCommandInfo returns a list with information about all // registered commands. @@ -52,7 +52,7 @@ type keyTransformerFunc func(string) string // If you want to disable this behavior, you can set // Config.DisableCommandInfoCache to true on intializing // Ken. -func (k *Ken) GetCommandInfo(keyTransformer ...keyTransformerFunc) (cis CommandInfoList) { +func (k *Ken) GetCommandInfo(keyTransformer ...KeyTransformerFunc) (cis CommandInfoList) { kt := func(v string) string { return v } @@ -73,7 +73,7 @@ func (k *Ken) GetCommandInfo(keyTransformer ...keyTransformerFunc) (cis CommandI return } -func (k *Ken) collectCommandInfo(kt keyTransformerFunc) (cis CommandInfoList) { +func (k *Ken) collectCommandInfo(kt KeyTransformerFunc) (cis CommandInfoList) { cis = make(CommandInfoList, 0, len(k.cmds)) for _, cmd := range k.cmds { typ := reflect.TypeOf(cmd) diff --git a/util/reflect.go b/util/reflect.go new file mode 100644 index 0000000..7b8b569 --- /dev/null +++ b/util/reflect.go @@ -0,0 +1,18 @@ +package util + +import "reflect" + +func GetFieldValue(v interface{}, fieldName string) (value string, ok bool) { + val := reflect.ValueOf(v) + if val.Kind() == reflect.Ptr { + val = val.Elem() + } + fieldValue := val.FieldByName(fieldName) + + if fieldValue.IsValid() { + value = fieldValue.String() + ok = true + } + + return value, ok +}