mirror of
https://github.com/eventials/goevents.git
synced 2025-04-24 13:48:53 +08:00

- Add approximateNumberOfMessages method to the consumer type - This function utilizes AWS SQS's GetQueueAttributes API to fetch an estimated count of the messages currently in the queue
508 lines
13 KiB
Go
508 lines
13 KiB
Go
package sns
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"strconv"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/aws/aws-sdk-go/aws"
|
|
"github.com/aws/aws-sdk-go/aws/credentials"
|
|
"github.com/aws/aws-sdk-go/aws/session"
|
|
"github.com/aws/aws-sdk-go/service/sqs"
|
|
"github.com/eventials/goevents/messaging"
|
|
"github.com/sirupsen/logrus"
|
|
)
|
|
|
|
var (
|
|
ErrEmptyConfig = errors.New("empty config")
|
|
ErrEmptyAccessKey = errors.New("empty access key")
|
|
ErrEmptySecretKey = errors.New("empty secret key")
|
|
)
|
|
|
|
type snsMessagePayload struct {
|
|
Message string `json:"Message"`
|
|
MessageID string `json:"MessageId"`
|
|
Signature string `json:"Signature"`
|
|
SignatureVersion string `json:"SignatureVersion"`
|
|
SigningCertURL string `json:"SigningCertURL"`
|
|
Subject string `json:"Subject"`
|
|
Timestamp time.Time `json:"Timestamp"`
|
|
TopicArn string `json:"TopicArn"`
|
|
Type string `json:"Type"`
|
|
UnsubscribeURL string `json:"UnsubscribeURL"`
|
|
}
|
|
|
|
type ConsumerConfig struct {
|
|
AccessKey string
|
|
SecretKey string
|
|
Region string
|
|
VisibilityTimeout int64
|
|
WaitTimeSeconds int64
|
|
MaxNumberOfMessages int64
|
|
QueueUrl string
|
|
}
|
|
|
|
func (c *ConsumerConfig) setDefaults() {
|
|
if c.VisibilityTimeout == 0 {
|
|
c.VisibilityTimeout = 45
|
|
}
|
|
|
|
if c.WaitTimeSeconds == 0 {
|
|
c.WaitTimeSeconds = 20
|
|
}
|
|
|
|
if c.MaxNumberOfMessages == 0 {
|
|
c.MaxNumberOfMessages = 5
|
|
}
|
|
}
|
|
|
|
func (c *ConsumerConfig) isValid() error {
|
|
if c == nil {
|
|
return ErrEmptyConfig
|
|
}
|
|
|
|
if c.AccessKey == "" {
|
|
return ErrEmptyAccessKey
|
|
}
|
|
|
|
if c.SecretKey == "" {
|
|
return ErrEmptySecretKey
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type handler struct {
|
|
action string
|
|
fn messaging.EventHandler
|
|
options *messaging.SubscribeOptions
|
|
}
|
|
|
|
type consumer struct {
|
|
sqs *sqs.SQS
|
|
stop chan bool
|
|
qos chan bool
|
|
config *ConsumerConfig
|
|
receiveMessageInput *sqs.ReceiveMessageInput
|
|
m sync.RWMutex
|
|
wg sync.WaitGroup
|
|
handlers map[string]handler
|
|
processingMessages map[string]bool
|
|
mProcessingMessages sync.RWMutex
|
|
closeOnce sync.Once
|
|
stopped bool
|
|
}
|
|
|
|
func NewConsumer(config *ConsumerConfig) (messaging.Consumer, error) {
|
|
if err := config.isValid(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
config.setDefaults()
|
|
|
|
creds := credentials.NewStaticCredentials(config.AccessKey, config.SecretKey, "")
|
|
|
|
sess, err := session.NewSession(&aws.Config{
|
|
Region: aws.String(config.Region),
|
|
Credentials: creds,
|
|
})
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
c := &consumer{
|
|
sqs: sqs.New(sess),
|
|
config: config,
|
|
stop: make(chan bool),
|
|
qos: make(chan bool, config.MaxNumberOfMessages),
|
|
handlers: make(map[string]handler),
|
|
processingMessages: make(map[string]bool),
|
|
}
|
|
|
|
c.receiveMessageInput = &sqs.ReceiveMessageInput{
|
|
AttributeNames: []*string{
|
|
aws.String(sqs.MessageSystemAttributeNameSentTimestamp),
|
|
},
|
|
MessageAttributeNames: []*string{
|
|
aws.String(sqs.QueueAttributeNameAll),
|
|
},
|
|
QueueUrl: aws.String(c.config.QueueUrl),
|
|
MaxNumberOfMessages: aws.Int64(c.config.MaxNumberOfMessages),
|
|
VisibilityTimeout: aws.Int64(c.config.VisibilityTimeout),
|
|
WaitTimeSeconds: aws.Int64(c.config.WaitTimeSeconds),
|
|
}
|
|
|
|
return c, nil
|
|
}
|
|
|
|
func MustNewConsumer(config *ConsumerConfig) messaging.Consumer {
|
|
consumer, err := NewConsumer(config)
|
|
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
return consumer
|
|
}
|
|
|
|
func (c *consumer) getHandler(action string) (handler, bool) {
|
|
c.m.RLock()
|
|
defer c.m.RUnlock()
|
|
|
|
hnd, ok := c.handlers[action]
|
|
|
|
return hnd, ok
|
|
}
|
|
|
|
// Subscribe subscribes an handler to a action. Action must be ARN URL.
|
|
func (c *consumer) Subscribe(action string, handlerFn messaging.EventHandler, options *messaging.SubscribeOptions) error {
|
|
c.m.Lock()
|
|
defer c.m.Unlock()
|
|
|
|
c.handlers[action] = handler{
|
|
action: action,
|
|
options: options,
|
|
fn: handlerFn,
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *consumer) Unsubscribe(action string) error {
|
|
c.m.Lock()
|
|
defer c.m.Unlock()
|
|
|
|
delete(c.handlers, action)
|
|
|
|
return nil
|
|
}
|
|
|
|
// TODO: needs to bind SNS Topic (actions) to SQS Queue in AWS.
|
|
func (c *consumer) BindActions(actions ...string) error {
|
|
return nil
|
|
}
|
|
|
|
// TODO: reverts anything done by BindActions
|
|
func (c *consumer) UnbindActions(actions ...string) error {
|
|
c.m.Lock()
|
|
defer c.m.Unlock()
|
|
|
|
for _, action := range actions {
|
|
delete(c.handlers, action)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func stringToTime(t string) time.Time {
|
|
if t == "" {
|
|
return time.Time{}
|
|
}
|
|
|
|
i, err := strconv.ParseInt(t, 10, 64)
|
|
|
|
if err != nil {
|
|
return time.Time{}
|
|
}
|
|
|
|
return time.Unix(0, i*1000000)
|
|
}
|
|
|
|
func (c *consumer) callAndHandlePanic(event messaging.Event, fn messaging.EventHandler) (err error) {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
switch x := r.(type) {
|
|
case string:
|
|
err = errors.New(x)
|
|
case error:
|
|
err = x
|
|
default:
|
|
err = errors.New("Unknown panic")
|
|
}
|
|
}
|
|
}()
|
|
|
|
err = fn(event)
|
|
|
|
return
|
|
}
|
|
|
|
func (c *consumer) isMessageProcessing(id string) bool {
|
|
c.mProcessingMessages.RLock()
|
|
defer c.mProcessingMessages.RUnlock()
|
|
|
|
_, ok := c.processingMessages[id]
|
|
return ok
|
|
}
|
|
|
|
func (c *consumer) addMessageProcessing(id string) {
|
|
c.mProcessingMessages.Lock()
|
|
defer c.mProcessingMessages.Unlock()
|
|
|
|
c.processingMessages[id] = true
|
|
}
|
|
|
|
func (c *consumer) deleteMessageProcessing(id string) {
|
|
c.mProcessingMessages.Lock()
|
|
defer c.mProcessingMessages.Unlock()
|
|
|
|
delete(c.processingMessages, id)
|
|
}
|
|
|
|
func (c *consumer) handleMessage(message *sqs.Message) {
|
|
sns := &snsMessagePayload{}
|
|
err := json.Unmarshal([]byte(*message.Body), sns)
|
|
|
|
if err != nil {
|
|
logrus.WithFields(logrus.Fields{
|
|
"error": err,
|
|
"body": []byte(*message.Body),
|
|
}).Error("Failed to unmarshall sns message.")
|
|
return
|
|
}
|
|
|
|
id := *message.MessageId
|
|
receiptHandle := *message.ReceiptHandle
|
|
|
|
log := logrus.WithFields(logrus.Fields{
|
|
"action": sns.TopicArn,
|
|
"message_id": id,
|
|
"body": sns.Message,
|
|
})
|
|
|
|
handler, ok := c.getHandler(sns.TopicArn)
|
|
|
|
if !ok {
|
|
log.Error("Action not found.")
|
|
return
|
|
}
|
|
|
|
// Check if message is already processing in goroutine.
|
|
// This will occur if consumer is slower than VisibilityTimeout.
|
|
if c.isMessageProcessing(id) {
|
|
log.Debug("Message is already processing.")
|
|
return
|
|
}
|
|
|
|
// QOS: do not consume while prior messages are in goroutines.
|
|
c.qos <- true
|
|
|
|
c.addMessageProcessing(id)
|
|
|
|
c.wg.Add(1)
|
|
go func(event messaging.Event, fn messaging.EventHandler, receiptHandle string) {
|
|
defer func() {
|
|
c.wg.Done()
|
|
|
|
c.deleteMessageProcessing(event.Id)
|
|
|
|
<-c.qos
|
|
}()
|
|
|
|
err := c.callAndHandlePanic(event, fn)
|
|
|
|
if err != nil {
|
|
log.WithError(err).Debug("Failed to process event.")
|
|
return
|
|
}
|
|
|
|
log.Debug("Deleting message.")
|
|
|
|
_, err = c.sqs.DeleteMessage(&sqs.DeleteMessageInput{
|
|
QueueUrl: aws.String(c.config.QueueUrl),
|
|
ReceiptHandle: aws.String(receiptHandle),
|
|
})
|
|
|
|
if err != nil {
|
|
log.WithError(err).Error("Failed to delete message.")
|
|
return
|
|
}
|
|
|
|
log.Debug("Message handled successfully.")
|
|
}(messaging.Event{
|
|
Id: id,
|
|
Action: handler.action,
|
|
Body: []byte(sns.Message),
|
|
Timestamp: stringToTime(*message.Attributes["SentTimestamp"]),
|
|
}, handler.fn, receiptHandle)
|
|
}
|
|
|
|
func (c *consumer) doConsume() {
|
|
var nextMessages int64 = c.config.MaxNumberOfMessages - int64(len(c.qos))
|
|
|
|
if nextMessages == 0 {
|
|
logrus.Debugf("QOS full with %d.", c.config.MaxNumberOfMessages)
|
|
time.Sleep(5 * time.Second)
|
|
return
|
|
}
|
|
|
|
c.receiveMessageInput.MaxNumberOfMessages = aws.Int64(nextMessages)
|
|
|
|
result, err := c.sqs.ReceiveMessage(c.receiveMessageInput)
|
|
|
|
if err != nil {
|
|
logrus.WithError(err).Error("Failed to get messages.")
|
|
return
|
|
}
|
|
|
|
for _, message := range result.Messages {
|
|
c.handleMessage(message)
|
|
}
|
|
}
|
|
|
|
// Consume consumes with long-poll, messages from AWS SQS and dispatch to handlers.
|
|
// Polling time is configured by WaitTimeSeconds.
|
|
// Messages successfully handled will be deleted from SQS.
|
|
// Messages who failed to delete from SQS will be received again, and application needs to handle
|
|
// by using MessageId.
|
|
// Receiving duplicate messages may happen using more than one consumer if not processing in VisibilityTimeout.
|
|
func (c *consumer) Consume() {
|
|
logrus.Info("Registered handlers:")
|
|
|
|
for _, handler := range c.handlers {
|
|
logrus.Infof(" %s", handler.action)
|
|
}
|
|
|
|
logrus.WithFields(logrus.Fields{
|
|
"queue": c.config.QueueUrl,
|
|
}).Info("Consuming messages...")
|
|
|
|
for {
|
|
select {
|
|
case <-c.stop:
|
|
return
|
|
default:
|
|
c.doConsume()
|
|
}
|
|
}
|
|
}
|
|
|
|
// PriorityConsume function takes a list of consumer objects and starts consuming them in a loop.
|
|
// The consumers are processed according to their priority. The process is stopped if the stop signal is received.
|
|
// The function exits when all consumers are stopped.
|
|
func PriorityConsume(consumers []messaging.Consumer) {
|
|
logrus.Info("Registered handlers:")
|
|
|
|
// consumersQueue holds the actual *consumer types cast from the messaging.Consumer interface
|
|
consumersQueue := make([]*consumer, len(consumers))
|
|
|
|
// Logging the registered handlers for each consumer
|
|
for priority, c := range consumers {
|
|
actualConsumer, ok := c.(*consumer) // Casting to *consumer
|
|
if !ok {
|
|
logrus.Error("Failed to cast consumer to consumer type")
|
|
return
|
|
}
|
|
for _, handler := range actualConsumer.handlers {
|
|
logrus.Infof(" %s (priority %d)", handler.action, priority)
|
|
}
|
|
consumersQueue[priority] = actualConsumer
|
|
}
|
|
|
|
// Counter for tracking the number of stopped consumers
|
|
var consumersStopped int
|
|
|
|
// Main loop to consume messages
|
|
for {
|
|
// Iterate over each consumer
|
|
for _, consumer := range consumersQueue {
|
|
// Determine if a consumer should consume messages
|
|
shouldConsume := !consumer.stopped && checkPriorityMessages(consumersQueue, consumer)
|
|
|
|
if !shouldConsume {
|
|
continue
|
|
}
|
|
|
|
// If the consumer is still active, try to consume a message
|
|
select {
|
|
case <-consumer.stop:
|
|
// If a stop signal is received, stop the consumer and increase the count of stopped consumers
|
|
consumer.stopped = true
|
|
consumersStopped++
|
|
default:
|
|
// If no stop signal, consume the message
|
|
consumer.doConsume()
|
|
}
|
|
}
|
|
|
|
// If all consumers are stopped, exit the function
|
|
if consumersStopped == len(consumers) {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// checkPriorityMessages checks whether the given consumer can consume a message.
|
|
// It checks all other consumers and allows a consumer to consume only if there are no higher priority consumers with messages to consume.
|
|
func checkPriorityMessages(consumers []*consumer, currentConsumer *consumer) bool {
|
|
// Iterate over each consumer
|
|
for _, consumer := range consumers {
|
|
|
|
// Skip stopped consumers
|
|
if consumer.stopped {
|
|
continue
|
|
}
|
|
|
|
// If we reached the currentConsumer in the iteration, then all higher priority consumers were checked and they don't have any messages to consume.
|
|
// So, the currentConsumer is allowed to consume.
|
|
if consumer == currentConsumer {
|
|
return true
|
|
}
|
|
|
|
// Get the approximate number of messages in the queue for the consumer
|
|
qtMessages, err := consumer.approximateNumberOfMessages()
|
|
if err != nil {
|
|
logrus.WithError(err).Errorf("Failed to get approximate number of messages for consumer %s", consumer.config.QueueUrl)
|
|
continue
|
|
}
|
|
|
|
// If any higher priority consumer has messages to consume, then the currentConsumer should not consume.
|
|
if len(consumer.qos)+qtMessages > 0 {
|
|
logrus.Debugf("Higher priority consumer %s has messages to consume, skipping current consumer %s", consumer.config.QueueUrl, currentConsumer.config.QueueUrl)
|
|
return false
|
|
}
|
|
}
|
|
|
|
// If no consumers were found to have messages, the currentConsumer can consume.
|
|
return true
|
|
}
|
|
|
|
// approximateNumberOfMessages is a method on the consumer type that gets the approximate number of messages in the queue.
|
|
// It uses AWS SQS's GetQueueAttributes API to fetch the ApproximateNumberOfMessages attribute,
|
|
// which gives an estimate of the number of visible messages in the queue.
|
|
func (c *consumer) approximateNumberOfMessages() (int, error) {
|
|
// Request to get queue attributes from AWS SQS
|
|
result, err := c.sqs.GetQueueAttributes(&sqs.GetQueueAttributesInput{
|
|
QueueUrl: aws.String(c.config.QueueUrl), // Specify the queue URL from consumer config
|
|
AttributeNames: []*string{aws.String(sqs.QueueAttributeNameApproximateNumberOfMessages)}, // Specify the attribute we want to fetch
|
|
})
|
|
|
|
// If there's an error in fetching the queue attributes, return 0 and the error
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
// If the fetch was successful, convert the returned attribute (which is a string) to an integer
|
|
return strconv.Atoi(*result.Attributes[sqs.QueueAttributeNameApproximateNumberOfMessages])
|
|
}
|
|
|
|
func (c *consumer) doClose() {
|
|
logrus.Info("Closing SNS consumer...")
|
|
|
|
c.stop <- true
|
|
|
|
logrus.Info("SNS consumer closed. Waiting remaining handlers...")
|
|
c.wg.Wait()
|
|
|
|
logrus.Info("SNS consumer closed.")
|
|
|
|
close(c.stop)
|
|
}
|
|
|
|
func (c *consumer) Close() {
|
|
c.closeOnce.Do(c.doClose)
|
|
}
|