mirror of
https://github.com/mainflux/mainflux.git
synced 2025-05-04 22:17:59 +08:00

Signed-off-by: dusanb94 <dusan.borovcanin@mainflux.com> Signed-off-by: dusanb94 <dusan.borovcanin@mainflux.com>
428 lines
10 KiB
Go
428 lines
10 KiB
Go
package tcp
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/plgd-dev/go-coap/v2/message"
|
|
"github.com/plgd-dev/go-coap/v2/message/codes"
|
|
coapNet "github.com/plgd-dev/go-coap/v2/net"
|
|
"github.com/plgd-dev/go-coap/v2/net/blockwise"
|
|
"github.com/plgd-dev/go-coap/v2/net/monitor/inactivity"
|
|
coapTCP "github.com/plgd-dev/go-coap/v2/tcp/message"
|
|
"github.com/plgd-dev/go-coap/v2/tcp/message/pool"
|
|
)
|
|
|
|
type EventFunc func()
|
|
|
|
type Session struct {
|
|
// This field needs to be the first in the struct to ensure proper word alignment on 32-bit platforms.
|
|
// See: https://golang.org/pkg/sync/atomic/#pkg-note-BUG
|
|
sequence uint64
|
|
|
|
onClose []EventFunc
|
|
|
|
ctx atomic.Value
|
|
|
|
inactivityMonitor inactivity.Monitor
|
|
|
|
errSendCSM error
|
|
|
|
cancel context.CancelFunc
|
|
|
|
done chan struct{}
|
|
|
|
goPool GoPoolFunc
|
|
errors ErrorFunc
|
|
blockWise *blockwise.BlockWise
|
|
|
|
connection *coapNet.Conn
|
|
|
|
handler HandlerFunc
|
|
|
|
midHandlerContainer *HandlerContainer
|
|
|
|
tokenHandlerContainer *HandlerContainer
|
|
|
|
messagePool *pool.Pool
|
|
|
|
mutex sync.Mutex
|
|
|
|
maxMessageSize uint32
|
|
peerBlockWiseTranferEnabled uint32
|
|
peerMaxMessageSize uint32
|
|
connectionCacheSize uint16
|
|
disableTCPSignalMessageCSM bool
|
|
disablePeerTCPSignalMessageCSMs bool
|
|
|
|
blockwiseSZX blockwise.SZX
|
|
closeSocket bool
|
|
}
|
|
|
|
func NewSession(
|
|
ctx context.Context,
|
|
connection *coapNet.Conn,
|
|
handler HandlerFunc,
|
|
maxMessageSize uint32,
|
|
goPool GoPoolFunc,
|
|
errors ErrorFunc,
|
|
blockwiseSZX blockwise.SZX,
|
|
blockWise *blockwise.BlockWise,
|
|
disablePeerTCPSignalMessageCSMs bool,
|
|
disableTCPSignalMessageCSM bool,
|
|
closeSocket bool,
|
|
inactivityMonitor inactivity.Monitor,
|
|
connectionCacheSize uint16,
|
|
messagePool *pool.Pool,
|
|
) *Session {
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
if errors == nil {
|
|
errors = func(error) {
|
|
// default no-op
|
|
}
|
|
}
|
|
if inactivityMonitor == nil {
|
|
inactivityMonitor = inactivity.NewNilMonitor()
|
|
}
|
|
|
|
s := &Session{
|
|
cancel: cancel,
|
|
connection: connection,
|
|
handler: handler,
|
|
maxMessageSize: maxMessageSize,
|
|
tokenHandlerContainer: NewHandlerContainer(),
|
|
midHandlerContainer: NewHandlerContainer(),
|
|
goPool: goPool,
|
|
errors: errors,
|
|
blockWise: blockWise,
|
|
blockwiseSZX: blockwiseSZX,
|
|
disablePeerTCPSignalMessageCSMs: disablePeerTCPSignalMessageCSMs,
|
|
disableTCPSignalMessageCSM: disableTCPSignalMessageCSM,
|
|
closeSocket: closeSocket,
|
|
inactivityMonitor: inactivityMonitor,
|
|
done: make(chan struct{}),
|
|
connectionCacheSize: connectionCacheSize,
|
|
messagePool: messagePool,
|
|
}
|
|
s.ctx.Store(&ctx)
|
|
|
|
if !disableTCPSignalMessageCSM {
|
|
err := s.sendCSM()
|
|
if err != nil {
|
|
s.errSendCSM = fmt.Errorf("cannot send CSM: %w", err)
|
|
}
|
|
}
|
|
|
|
return s
|
|
}
|
|
|
|
// SetContextValue stores the value associated with key to context of connection.
|
|
func (s *Session) SetContextValue(key interface{}, val interface{}) {
|
|
s.mutex.Lock()
|
|
defer s.mutex.Unlock()
|
|
ctx := context.WithValue(s.Context(), key, val)
|
|
s.ctx.Store(&ctx)
|
|
}
|
|
|
|
// Done signalizes that connection is not more processed.
|
|
func (s *Session) Done() <-chan struct{} {
|
|
return s.done
|
|
}
|
|
|
|
func (s *Session) AddOnClose(f EventFunc) {
|
|
s.mutex.Lock()
|
|
defer s.mutex.Unlock()
|
|
s.onClose = append(s.onClose, f)
|
|
}
|
|
|
|
func (s *Session) popOnClose() []EventFunc {
|
|
s.mutex.Lock()
|
|
defer s.mutex.Unlock()
|
|
tmp := s.onClose
|
|
s.onClose = nil
|
|
return tmp
|
|
}
|
|
|
|
func (s *Session) shutdown() {
|
|
defer close(s.done)
|
|
for _, f := range s.popOnClose() {
|
|
f()
|
|
}
|
|
}
|
|
|
|
func (s *Session) Close() error {
|
|
s.cancel()
|
|
if s.closeSocket {
|
|
return s.connection.Close()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Session) Sequence() uint64 {
|
|
return atomic.AddUint64(&s.sequence, 1)
|
|
}
|
|
|
|
func (s *Session) Context() context.Context {
|
|
return *s.ctx.Load().(*context.Context)
|
|
}
|
|
|
|
func (s *Session) PeerMaxMessageSize() uint32 {
|
|
return atomic.LoadUint32(&s.peerMaxMessageSize)
|
|
}
|
|
|
|
func (s *Session) PeerBlockWiseTransferEnabled() bool {
|
|
return atomic.LoadUint32(&s.peerBlockWiseTranferEnabled) == 1
|
|
}
|
|
|
|
func (s *Session) handleBlockwise(w *ResponseWriter, r *pool.Message) {
|
|
if s.blockWise != nil && s.PeerBlockWiseTransferEnabled() {
|
|
bwr := bwResponseWriter{
|
|
w: w,
|
|
}
|
|
s.blockWise.Handle(&bwr, r, s.blockwiseSZX, s.maxMessageSize, func(bw blockwise.ResponseWriter, br blockwise.Message) {
|
|
h, err := s.tokenHandlerContainer.Pop(r.Token())
|
|
rw := bw.(*bwResponseWriter).w
|
|
m := br.(*pool.Message)
|
|
if err == nil {
|
|
h(rw, m)
|
|
return
|
|
}
|
|
s.handler(rw, m)
|
|
})
|
|
return
|
|
}
|
|
h, err := s.tokenHandlerContainer.Pop(r.Token())
|
|
if err == nil {
|
|
h(w, r)
|
|
return
|
|
}
|
|
s.handler(w, r)
|
|
}
|
|
|
|
func (s *Session) handleSignals(r *pool.Message, cc *ClientConn) bool {
|
|
switch r.Code() {
|
|
case codes.CSM:
|
|
if s.disablePeerTCPSignalMessageCSMs {
|
|
return true
|
|
}
|
|
if size, err := r.GetOptionUint32(coapTCP.MaxMessageSize); err == nil {
|
|
atomic.StoreUint32(&s.peerMaxMessageSize, size)
|
|
}
|
|
if r.HasOption(coapTCP.BlockWiseTransfer) {
|
|
atomic.StoreUint32(&s.peerBlockWiseTranferEnabled, 1)
|
|
}
|
|
return true
|
|
case codes.Ping:
|
|
// if r.HasOption(coapTCP.Custody) {
|
|
//TODO
|
|
// }
|
|
if err := s.sendPong(r.Token()); err != nil && !coapNet.IsConnectionBrokenError(err) {
|
|
s.errors(fmt.Errorf("cannot handle ping signal: %w", err))
|
|
}
|
|
return true
|
|
case codes.Release:
|
|
// if r.HasOption(coapTCP.AlternativeAddress) {
|
|
//TODO
|
|
// }
|
|
return true
|
|
case codes.Abort:
|
|
// if r.HasOption(coapTCP.BadCSMOption) {
|
|
//TODO
|
|
// }
|
|
return true
|
|
case codes.Pong:
|
|
h, err := s.tokenHandlerContainer.Pop(r.Token())
|
|
if err == nil {
|
|
s.processReq(r, cc, h)
|
|
}
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
type bwResponseWriter struct {
|
|
w *ResponseWriter
|
|
}
|
|
|
|
func (b *bwResponseWriter) Message() blockwise.Message {
|
|
return b.w.response
|
|
}
|
|
|
|
func (b *bwResponseWriter) SetMessage(m blockwise.Message) {
|
|
b.w.cc.session.messagePool.ReleaseMessage(b.w.response)
|
|
b.w.response = m.(*pool.Message)
|
|
}
|
|
|
|
func (b *bwResponseWriter) RemoteAddr() net.Addr {
|
|
return b.w.cc.RemoteAddr()
|
|
}
|
|
|
|
func (s *Session) Handle(w *ResponseWriter, r *pool.Message) {
|
|
s.handleBlockwise(w, r)
|
|
}
|
|
|
|
func (s *Session) TokenHandler() *HandlerContainer {
|
|
return s.tokenHandlerContainer
|
|
}
|
|
|
|
func (s *Session) processReq(req *pool.Message, cc *ClientConn, handler func(w *ResponseWriter, r *pool.Message)) {
|
|
origResp := s.messagePool.AcquireMessage(s.Context())
|
|
origResp.SetToken(req.Token())
|
|
w := NewResponseWriter(origResp, cc, req.Options())
|
|
handler(w, req)
|
|
defer s.messagePool.ReleaseMessage(w.response)
|
|
if !req.IsHijacked() {
|
|
s.messagePool.ReleaseMessage(req)
|
|
}
|
|
if w.response.IsModified() {
|
|
err := s.WriteMessage(w.response)
|
|
if err != nil {
|
|
if errC := s.Close(); errC != nil {
|
|
s.errors(fmt.Errorf("cannot close connection: %w", errC))
|
|
}
|
|
s.errors(fmt.Errorf("cannot write response to %v: %w", s.connection.RemoteAddr(), err))
|
|
}
|
|
}
|
|
}
|
|
|
|
func seekBufferToNextMessage(buffer *bytes.Buffer, msgSize int) *bytes.Buffer {
|
|
if msgSize == buffer.Len() {
|
|
// buffer is empty so reset it
|
|
buffer.Reset()
|
|
return buffer
|
|
}
|
|
// rewind to next message
|
|
trimmed := 0
|
|
for trimmed != msgSize {
|
|
b := make([]byte, 4096)
|
|
max := 4096
|
|
if msgSize-trimmed < max {
|
|
max = msgSize - trimmed
|
|
}
|
|
v, _ := buffer.Read(b[:max])
|
|
trimmed += v
|
|
}
|
|
return buffer
|
|
}
|
|
|
|
func (s *Session) processBuffer(buffer *bytes.Buffer, cc *ClientConn) error {
|
|
for buffer.Len() > 0 {
|
|
var hdr coapTCP.MessageHeader
|
|
err := hdr.Unmarshal(buffer.Bytes())
|
|
if errors.Is(err, message.ErrShortRead) {
|
|
return nil
|
|
}
|
|
if hdr.TotalLen > s.maxMessageSize {
|
|
return fmt.Errorf("max message size(%v) was exceeded %v", s.maxMessageSize, hdr.TotalLen)
|
|
}
|
|
if uint32(buffer.Len()) < hdr.TotalLen {
|
|
return nil
|
|
}
|
|
req := s.messagePool.AcquireMessage(s.Context())
|
|
read, err := req.Unmarshal(buffer.Bytes()[:hdr.TotalLen])
|
|
if err != nil {
|
|
s.messagePool.ReleaseMessage(req)
|
|
return fmt.Errorf("cannot unmarshal with header: %w", err)
|
|
}
|
|
buffer = seekBufferToNextMessage(buffer, read)
|
|
req.SetSequence(s.Sequence())
|
|
s.inactivityMonitor.Notify()
|
|
if s.handleSignals(req, cc) {
|
|
continue
|
|
}
|
|
err = s.goPool(func() {
|
|
s.processReq(req, cc, s.Handle)
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("cannot spawn go routine: %w", err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Session) WriteMessage(req *pool.Message) error {
|
|
data, err := req.Marshal()
|
|
if err != nil {
|
|
return fmt.Errorf("cannot marshal: %w", err)
|
|
}
|
|
err = s.connection.WriteWithContext(req.Context(), data)
|
|
if err != nil {
|
|
return fmt.Errorf("cannot write to connection: %w", err)
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (s *Session) sendCSM() error {
|
|
token, err := message.GetToken()
|
|
if err != nil {
|
|
return fmt.Errorf("cannot get token: %w", err)
|
|
}
|
|
req := s.messagePool.AcquireMessage(s.Context())
|
|
defer s.messagePool.ReleaseMessage(req)
|
|
req.SetCode(codes.CSM)
|
|
req.SetToken(token)
|
|
return s.WriteMessage(req)
|
|
}
|
|
|
|
func (s *Session) sendPong(token message.Token) error {
|
|
req := s.messagePool.AcquireMessage(s.Context())
|
|
defer s.messagePool.ReleaseMessage(req)
|
|
req.SetCode(codes.Pong)
|
|
req.SetToken(token)
|
|
return s.WriteMessage(req)
|
|
}
|
|
|
|
func shrinkBufferIfNecessary(buffer *bytes.Buffer, maxCap uint16) *bytes.Buffer {
|
|
if buffer.Len() == 0 && buffer.Cap() > int(maxCap) {
|
|
buffer = bytes.NewBuffer(make([]byte, 0, maxCap))
|
|
}
|
|
return buffer
|
|
}
|
|
|
|
// Run reads and process requests from a connection, until the connection is not closed.
|
|
func (s *Session) Run(cc *ClientConn) (err error) {
|
|
defer func() {
|
|
err1 := s.Close()
|
|
if err == nil {
|
|
err = err1
|
|
}
|
|
s.shutdown()
|
|
}()
|
|
if s.errSendCSM != nil {
|
|
return s.errSendCSM
|
|
}
|
|
buffer := bytes.NewBuffer(make([]byte, 0, s.connectionCacheSize))
|
|
readBuf := make([]byte, s.connectionCacheSize)
|
|
for {
|
|
err = s.processBuffer(buffer, cc)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
buffer = shrinkBufferIfNecessary(buffer, s.connectionCacheSize)
|
|
readLen, err := s.connection.ReadWithContext(s.Context(), readBuf)
|
|
if err != nil {
|
|
if coapNet.IsConnectionBrokenError(err) { // other side closed the connection, ignore the error and return
|
|
return nil
|
|
}
|
|
return fmt.Errorf("cannot read from connection: %w", err)
|
|
}
|
|
if readLen > 0 {
|
|
buffer.Write(readBuf[:readLen])
|
|
}
|
|
}
|
|
}
|
|
|
|
// CheckExpirations checks and remove expired items from caches.
|
|
func (s *Session) CheckExpirations(now time.Time, cc *ClientConn) {
|
|
s.inactivityMonitor.CheckInactivity(now, cc)
|
|
if s.blockWise != nil {
|
|
s.blockWise.CheckExpirations(now)
|
|
}
|
|
}
|