mirror of
https://github.com/mainflux/mainflux.git
synced 2025-05-02 22:17:10 +08:00

Signed-off-by: dusanb94 <dusan.borovcanin@mainflux.com> Signed-off-by: dusanb94 <dusan.borovcanin@mainflux.com>
129 lines
2.9 KiB
Go
129 lines
2.9 KiB
Go
package net
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"sync"
|
|
|
|
"go.uber.org/atomic"
|
|
)
|
|
|
|
// Conn is a generic stream-oriented network connection that provides Read/Write with context.
|
|
//
|
|
// Multiple goroutines may invoke methods on a Conn simultaneously.
|
|
type Conn struct {
|
|
connection net.Conn
|
|
closed atomic.Bool
|
|
handshakeContext func(ctx context.Context) error
|
|
lock sync.Mutex
|
|
}
|
|
|
|
// NewConn creates connection over net.Conn.
|
|
func NewConn(c net.Conn) *Conn {
|
|
connection := Conn{
|
|
connection: c,
|
|
}
|
|
|
|
if v, ok := c.(interface {
|
|
HandshakeContext(ctx context.Context) error
|
|
}); ok {
|
|
connection.handshakeContext = v.HandshakeContext
|
|
}
|
|
|
|
return &connection
|
|
}
|
|
|
|
// LocalAddr returns the local network address. The Addr returned is shared by all invocations of LocalAddr, so do not modify it.
|
|
func (c *Conn) LocalAddr() net.Addr {
|
|
return c.connection.LocalAddr()
|
|
}
|
|
|
|
// Connection returns the network connection. The Conn returned is shared by all invocations of Connection, so do not modify it.
|
|
func (c *Conn) Connection() net.Conn {
|
|
return c.connection
|
|
}
|
|
|
|
// RemoteAddr returns the remote network address. The Addr returned is shared by all invocations of RemoteAddr, so do not modify it.
|
|
func (c *Conn) RemoteAddr() net.Addr {
|
|
return c.connection.RemoteAddr()
|
|
}
|
|
|
|
// Close closes the connection.
|
|
func (c *Conn) Close() error {
|
|
if !c.closed.CAS(false, true) {
|
|
return nil
|
|
}
|
|
return c.connection.Close()
|
|
}
|
|
|
|
func (c *Conn) handshake(ctx context.Context) error {
|
|
if c.handshakeContext != nil {
|
|
err := c.handshakeContext(ctx)
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
errC := c.Close()
|
|
if errC == nil {
|
|
return err
|
|
}
|
|
return fmt.Errorf("%v", []error{err, errC})
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// WriteWithContext writes data with context.
|
|
func (c *Conn) WriteWithContext(ctx context.Context, data []byte) error {
|
|
if err := c.handshake(ctx); err != nil {
|
|
return err
|
|
}
|
|
written := 0
|
|
c.lock.Lock()
|
|
defer c.lock.Unlock()
|
|
for written < len(data) {
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
default:
|
|
}
|
|
if c.closed.Load() {
|
|
return ErrConnectionIsClosed
|
|
}
|
|
n, err := c.connection.Write(data[written:])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
written += n
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ReadFullWithContext reads stream with context until whole buffer is satisfied.
|
|
func (c *Conn) ReadFullWithContext(ctx context.Context, buffer []byte) error {
|
|
offset := 0
|
|
for offset < len(buffer) {
|
|
n, err := c.ReadWithContext(ctx, buffer[offset:])
|
|
if err != nil {
|
|
return fmt.Errorf("cannot read full from connection: %w", err)
|
|
}
|
|
offset += n
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ReadWithContext reads stream with context.
|
|
func (c *Conn) ReadWithContext(ctx context.Context, buffer []byte) (int, error) {
|
|
select {
|
|
case <-ctx.Done():
|
|
return -1, ctx.Err()
|
|
default:
|
|
}
|
|
if c.closed.Load() {
|
|
return -1, ErrConnectionIsClosed
|
|
}
|
|
if err := c.handshake(ctx); err != nil {
|
|
return -1, err
|
|
}
|
|
return c.connection.Read(buffer)
|
|
}
|