Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make Gateway.Open wait until ready event is received #321

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions gateway/gateway_events.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 43,12 @@ type EventReady struct {
func (EventReady) messageData() {}
func (EventReady) eventData() {}

// EventResumed is the event sent by discord when you successfully resume
type EventResumed struct{}

func (EventResumed) messageData() {}
func (EventResumed) eventData() {}

type EventApplicationCommandPermissionsUpdate struct {
discord.ApplicationCommandPermissions
}
Expand Down
65 changes: 56 additions & 9 deletions gateway/gateway_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 92,8 @@ func (g *gatewayImpl) open(ctx context.Context) error {
g.config.Logger.Debug(g.formatLogs("opening gateway connection"))

g.connMu.Lock()
defer g.connMu.Unlock()
if g.conn != nil {
g.connMu.Unlock()
return discord.ErrGatewayAlreadyConnected
}
g.status = StatusConnecting
Expand All @@ -120,6 120,7 @@ func (g *gatewayImpl) open(ctx context.Context) error {
}

g.config.Logger.Error(g.formatLogsf("error connecting to the gateway. url: %s, error: %s, body: %s", gatewayURL, err, body))
g.connMu.Unlock()
return err
}

Expand All @@ -128,13 129,30 @@ func (g *gatewayImpl) open(ctx context.Context) error {
})

g.conn = conn
g.connMu.Unlock()

// reset rate limiter when connecting
g.config.RateLimiter.Reset()

g.status = StatusWaitingForHello

go g.listen(conn)
readyChan := make(chan error)
go g.listen(conn, readyChan)

select {
case <-ctx.Done():
closeCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
g.Close(closeCtx)
return ctx.Err()
case err = <-readyChan:
if err != nil {
closeCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
g.Close(closeCtx)
return fmt.Errorf("failed to open gateway connection: %w", err)
}
}

return nil
}
Expand Down Expand Up @@ -226,6 244,13 @@ func (g *gatewayImpl) reconnectTry(ctx context.Context, try int) error {
}

if err := g.open(ctx); err != nil {
var closeError *websocket.CloseError
if errors.As(err, &closeError) {
closeCode := CloseEventCodeByCode(closeError.Code)
if !closeCode.Reconnect {
return err
}
}
if errors.Is(err, discord.ErrGatewayAlreadyConnected) {
return err
}
Expand Down Expand Up @@ -279,7 304,7 @@ func (g *gatewayImpl) sendHeartbeat() {
g.lastHeartbeatSent = time.Now().UTC()
}

func (g *gatewayImpl) identify() {
func (g *gatewayImpl) identify() error {
g.status = StatusIdentifying
g.config.Logger.Debug(g.formatLogs("sending Identify command..."))

Expand All @@ -298,12 323,13 @@ func (g *gatewayImpl) identify() {
}

if err := g.Send(context.TODO(), OpcodeIdentify, identify); err != nil {
g.config.Logger.Error(g.formatLogs("error sending Identify command err: ", err))
return err
}
g.status = StatusWaitingForReady
return nil
}

func (g *gatewayImpl) resume() {
func (g *gatewayImpl) resume() error {
g.status = StatusResuming
resume := MessageDataResume{
Token: g.token,
Expand All @@ -313,16 339,22 @@ func (g *gatewayImpl) resume() {

g.config.Logger.Debug(g.formatLogs("sending Resume command..."))
if err := g.Send(context.TODO(), OpcodeResume, resume); err != nil {
g.config.Logger.Error(g.formatLogs("error sending resume command err: ", err))
return err
}
return nil
}

func (g *gatewayImpl) listen(conn *websocket.Conn) {
func (g *gatewayImpl) listen(conn *websocket.Conn, readyChan chan<- error) {
defer g.config.Logger.Debug(g.formatLogs("exiting listen goroutine..."))
loop:
for {
mt, data, err := conn.ReadMessage()
if err != nil {
if g.status != StatusReady {
readyChan <- err
close(readyChan)
break loop
}
g.connMu.Lock()
sameConnection := g.conn == conn
g.connMu.Unlock()
Expand Down Expand Up @@ -382,9 414,14 @@ loop:
go g.heartbeat()

if g.config.LastSequenceReceived == nil || g.config.SessionID == nil {
g.identify()
err = g.identify()
} else {
g.resume()
err = g.resume()
}
if err != nil {
readyChan <- err
close(readyChan)
return
}

case OpcodeDispatch:
Expand Down Expand Up @@ -418,6 455,16 @@ loop:
})
}
g.eventHandlerFunc(message.T, message.S, g.config.ShardID, eventData)
if _, ok = eventData.(EventReady); ok {
g.config.Logger.Debug(g.formatLogs("ready successful"))
readyChan <- nil
close(readyChan)
} else if _, ok = eventData.(EventResumed); ok {
g.config.Logger.Debug(g.formatLogs("resume successful"))
g.status = StatusReady
readyChan <- nil
close(readyChan)
}

case OpcodeHeartbeat:
g.sendHeartbeat()
Expand Down
2 changes: 1 addition & 1 deletion gateway/gateway_messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 112,7 @@ func UnmarshalEventData(data []byte, eventType EventType) (EventData, error) {
eventData = d

case EventTypeResumed:
// no data
eventData = EventResumed{}

case EventTypeApplicationCommandPermissionsUpdate:
var d EventApplicationCommandPermissionsUpdate
Expand Down