1
0
mirror of https://github.com/mainflux/mainflux.git synced 2025-04-26 13:48:53 +08:00

Pulled latest master branch version of paho mqtt with glide

Signed-off-by: Nikola Marčetić <nikola.marcetic@nsystems.rs>
This commit is contained in:
Nikola Marčetić 2016-10-12 17:45:05 +02:00
parent 63a97748b6
commit b1fab11c32
47 changed files with 304 additions and 108 deletions

8
glide.lock generated
View File

@ -1,10 +1,10 @@
hash: 66d7291cd74288a935de41fb4a151c983c52feee556cae1d6e2230c3c70a31d8
updated: 2016-10-12T15:23:23.717254004+02:00
hash: e45f1b36f1f69712061260e9ecf2d9d6a2cff1d1d6fcf9c4246e9678a4796f72
updated: 2016-10-12T17:44:20.535180376+02:00
imports:
- name: github.com/BurntSushi/toml
version: bbd5bb678321a0d6e58f1099321dfa73391c1b6f
- name: github.com/eclipse/paho.mqtt.golang
version: 45f9b18f4864c81d49c3ed01e5faec9eeb05de31
version: 13afcbe8e41508479762a90e9242577210c2ca8d
subpackages:
- packets
- name: github.com/fatih/color
@ -26,7 +26,7 @@ imports:
- name: github.com/xeipuuv/gojsonschema
version: 00f9fafb54d2244d291b86ab63d12c38bd5c3886
- name: golang.org/x/net
version: cf4effbb9db1f3ef07f7e1891402991b6afbb276
version: 6dba816f1056709e29a1c442883cab1336d3c083
subpackages:
- websocket
- name: golang.org/x/sys

View File

@ -78,8 +78,6 @@ type client struct {
stop chan struct{}
persist Store
options ClientOptions
pingTimer *time.Timer
pingRespTimer *time.Timer
pingResp chan struct{}
status connStatus
workers sync.WaitGroup
@ -224,9 +222,6 @@ func (c *client) Connect() Token {
c.ibound = make(chan packets.ControlPacket)
c.errors = make(chan error, 1)
c.stop = make(chan struct{})
c.pingTimer = time.NewTimer(c.options.KeepAlive)
c.pingRespTimer = time.NewTimer(time.Duration(10) * time.Second)
c.pingRespTimer.Stop()
c.pingResp = make(chan struct{}, 1)
c.incomingPubChan = make(chan *packets.PublishPacket, c.options.MessageChannelDepth)
@ -268,12 +263,14 @@ func (c *client) Connect() Token {
// internal function used to reconnect the client when it loses its connection
func (c *client) reconnect() {
DEBUG.Println(CLI, "enter reconnect")
c.setConnected(reconnecting)
var rc byte = 1
var sleep uint = 1
var err error
var (
err error
for rc != 0 {
rc = byte(1)
sleep = time.Duration(1 * time.Second)
)
for rc != 0 && c.status != disconnected {
cm := newConnectMsgFromOptions(&c.options)
for _, broker := range c.options.Servers {
@ -318,15 +315,23 @@ func (c *client) reconnect() {
}
}
if rc != 0 {
DEBUG.Println(CLI, "Reconnect failed, sleeping for", sleep, "seconds")
time.Sleep(time.Duration(sleep) * time.Second)
if sleep <= uint(c.options.MaxReconnectInterval.Seconds()) {
DEBUG.Println(CLI, "Reconnect failed, sleeping for", int(sleep.Seconds()), "seconds")
time.Sleep(sleep)
if sleep < c.options.MaxReconnectInterval {
sleep *= 2
}
if sleep > c.options.MaxReconnectInterval {
sleep = c.options.MaxReconnectInterval
}
}
}
// Disconnect() must have been called while we were trying to reconnect.
if c.status == disconnected {
DEBUG.Println(CLI, "Client moved to disconnected state while reconnecting, abandoning reconnect")
return
}
c.pingTimer.Reset(c.options.KeepAlive)
c.stop = make(chan struct{})
c.workers.Add(1)
@ -378,19 +383,21 @@ func (c *client) connect() byte {
// the specified number of milliseconds to wait for existing work to be
// completed.
func (c *client) Disconnect(quiesce uint) {
if !c.IsConnected() {
WARN.Println(CLI, "already disconnected")
return
if c.status == connected {
DEBUG.Println(CLI, "disconnecting")
c.setConnected(disconnected)
dm := packets.NewControlPacket(packets.Disconnect).(*packets.DisconnectPacket)
dt := newToken(packets.Disconnect)
c.oboundP <- &PacketAndToken{p: dm, t: dt}
// wait for work to finish, or quiesce time consumed
dt.WaitTimeout(time.Duration(quiesce) * time.Millisecond)
} else {
WARN.Println(CLI, "Disconnect() called but not connected (disconnected/reconnecting)")
c.setConnected(disconnected)
}
DEBUG.Println(CLI, "disconnecting")
c.setConnected(disconnected)
dm := packets.NewControlPacket(packets.Disconnect).(*packets.DisconnectPacket)
dt := newToken(packets.Disconnect)
c.oboundP <- &PacketAndToken{p: dm, t: dt}
// wait for work to finish, or quiesce time consumed
dt.WaitTimeout(time.Duration(quiesce) * time.Millisecond)
c.disconnect()
}
@ -414,14 +421,15 @@ func (c *client) internalConnLost(err error) {
c.closeStop()
c.conn.Close()
c.workers.Wait()
if c.options.OnConnectionLost != nil {
go c.options.OnConnectionLost(c, err)
}
if c.options.AutoReconnect {
c.setConnected(reconnecting)
go c.reconnect()
} else {
c.setConnected(disconnected)
}
if c.options.OnConnectionLost != nil {
go c.options.OnConnectionLost(c, err)
}
}
}
@ -436,9 +444,17 @@ func (c *client) closeStop() {
}
}
func (c *client) closeConn() {
c.Lock()
defer c.Unlock()
if c.conn != nil {
c.conn.Close()
}
}
func (c *client) disconnect() {
c.closeStop()
c.conn.Close()
c.closeConn()
c.workers.Wait()
close(c.stopRouter)
DEBUG.Println(CLI, "disconnected")

11
vendor/github.com/eclipse/paho.mqtt.golang/cmd/build.sh generated vendored Executable file
View File

@ -0,0 +1,11 @@
#!/bin/sh
for dir in `ls -d */ | cut -f1 -d'/'`
do
echo "Compiling $dir ...\c"
cd $dir
go clean
go build
cd ..
echo " done."
done

View File

@ -75,7 +75,7 @@ func (store *FileStore) Close() {
store.Lock()
defer store.Unlock()
store.opened = false
WARN.Println(STR, "store is not open")
DEBUG.Println(STR, "store is closed")
}
// Put will put a message into the store, associated with the provided
@ -83,10 +83,15 @@ func (store *FileStore) Close() {
func (store *FileStore) Put(key string, m packets.ControlPacket) {
store.Lock()
defer store.Unlock()
chkcond(store.opened)
if !store.opened {
ERROR.Println(STR, "Trying to use file store, but not open")
return
}
full := fullpath(store.directory, key)
write(store.directory, key, m)
chkcond(exists(full))
if !exists(full) {
ERROR.Println(STR, "file not created:", full)
}
}
// Get will retrieve a message from the store, the one associated with
@ -94,7 +99,10 @@ func (store *FileStore) Put(key string, m packets.ControlPacket) {
func (store *FileStore) Get(key string) packets.ControlPacket {
store.RLock()
defer store.RUnlock()
chkcond(store.opened)
if !store.opened {
ERROR.Println(STR, "Trying to use file store, but not open")
return nil
}
filepath := fullpath(store.directory, key)
if !exists(filepath) {
return nil
@ -142,7 +150,10 @@ func (store *FileStore) Reset() {
// lockless
func (store *FileStore) all() []string {
chkcond(store.opened)
if !store.opened {
ERROR.Println(STR, "Trying to use file store, but not open")
return nil
}
keys := []string{}
files, rderr := ioutil.ReadDir(store.directory)
chkerr(rderr)
@ -161,7 +172,10 @@ func (store *FileStore) all() []string {
// lockless
func (store *FileStore) del(key string) {
chkcond(store.opened)
if !store.opened {
ERROR.Println(STR, "Trying to use file store, but not open")
return
}
DEBUG.Println(STR, "store del filepath:", store.directory)
DEBUG.Println(STR, "store delete key:", key)
filepath := fullpath(store.directory, key)
@ -173,7 +187,9 @@ func (store *FileStore) del(key string) {
rerr := os.Remove(filepath)
chkerr(rerr)
DEBUG.Println(STR, "del msg:", key)
chkcond(!exists(filepath))
if exists(filepath) {
ERROR.Println(STR, "file not deleted:", filepath)
}
}
func fullpath(store string, key string) string {

View File

@ -53,7 +53,10 @@ func (store *MemoryStore) Open() {
func (store *MemoryStore) Put(key string, message packets.ControlPacket) {
store.Lock()
defer store.Unlock()
chkcond(store.opened)
if !store.opened {
ERROR.Println(STR, "Trying to use memory store, but not open")
return
}
store.messages[key] = message
}
@ -62,7 +65,10 @@ func (store *MemoryStore) Put(key string, message packets.ControlPacket) {
func (store *MemoryStore) Get(key string) packets.ControlPacket {
store.RLock()
defer store.RUnlock()
chkcond(store.opened)
if !store.opened {
ERROR.Println(STR, "Trying to use memory store, but not open")
return nil
}
mid := mIDFromKey(key)
m := store.messages[key]
if m == nil {
@ -78,7 +84,10 @@ func (store *MemoryStore) Get(key string) packets.ControlPacket {
func (store *MemoryStore) All() []string {
store.RLock()
defer store.RUnlock()
chkcond(store.opened)
if !store.opened {
ERROR.Println(STR, "Trying to use memory store, but not open")
return nil
}
keys := []string{}
for k := range store.messages {
keys = append(keys, k)
@ -91,6 +100,10 @@ func (store *MemoryStore) All() []string {
func (store *MemoryStore) Del(key string) {
store.Lock()
defer store.Unlock()
if !store.opened {
ERROR.Println(STR, "Trying to use memory store, but not open")
return
}
mid := mIDFromKey(key)
m := store.messages[key]
if m == nil {
@ -105,7 +118,10 @@ func (store *MemoryStore) Del(key string) {
func (store *MemoryStore) Close() {
store.Lock()
defer store.Unlock()
chkcond(store.opened)
if !store.opened {
ERROR.Println(STR, "Trying to close memory store, but not open")
return
}
store.opened = false
DEBUG.Println(STR, "memorystore closed")
}
@ -114,7 +130,9 @@ func (store *MemoryStore) Close() {
func (store *MemoryStore) Reset() {
store.Lock()
defer store.Unlock()
chkcond(store.opened)
if !store.opened {
ERROR.Println(STR, "Trying to reset memory store, but not open")
}
store.messages = make(map[string]packets.ControlPacket)
WARN.Println(STR, "memorystore wiped")
}

View File

@ -98,7 +98,7 @@ func newConnectMsgFromOptions(options *ClientOptions) *packets.ConnectPacket {
}
}
m.KeepaliveTimer = uint16(options.KeepAlive.Seconds())
m.Keepalive = uint16(options.KeepAlive.Seconds())
return m
}

View File

@ -159,7 +159,7 @@ func outgoing(c *client) {
}
}
// Reset ping timer after sending control packet.
c.pingTimer.Reset(c.options.KeepAlive)
c.pingResp <- struct{}{}
}
}
@ -186,8 +186,8 @@ func alllogic(c *client) {
sa := msg.(*packets.SubackPacket)
DEBUG.Println(NET, "received suback, id:", sa.MessageID)
token := c.getToken(sa.MessageID).(*SubscribeToken)
DEBUG.Println(NET, "granted qoss", sa.GrantedQoss)
for i, qos := range sa.GrantedQoss {
DEBUG.Println(NET, "granted qoss", sa.ReturnCodes)
for i, qos := range sa.ReturnCodes {
token.subResult[token.subs[i]] = qos
}
token.flowComplete()

View File

@ -19,9 +19,3 @@ func chkerr(e error) {
panic(e)
}
}
func chkcond(b bool) {
if !b {
panic("oops")
}
}

View File

@ -10,13 +10,13 @@ import (
//Connack MQTT packet
type ConnackPacket struct {
FixedHeader
TopicNameCompression byte
ReturnCode byte
SessionPresent bool
ReturnCode byte
}
func (ca *ConnackPacket) String() string {
str := fmt.Sprintf("%s\n", ca.FixedHeader)
str += fmt.Sprintf("returncode: %d", ca.ReturnCode)
str += fmt.Sprintf("sessionpresent: %t returncode: %d", ca.SessionPresent, ca.ReturnCode)
return str
}
@ -24,7 +24,7 @@ func (ca *ConnackPacket) Write(w io.Writer) error {
var body bytes.Buffer
var err error
body.WriteByte(ca.TopicNameCompression)
body.WriteByte(boolToByte(ca.SessionPresent))
body.WriteByte(ca.ReturnCode)
ca.FixedHeader.RemainingLength = 2
packet := ca.FixedHeader.pack()
@ -36,9 +36,11 @@ func (ca *ConnackPacket) Write(w io.Writer) error {
//Unpack decodes the details of a ControlPacket after the fixed
//header has been read
func (ca *ConnackPacket) Unpack(b io.Reader) {
ca.TopicNameCompression = decodeByte(b)
func (ca *ConnackPacket) Unpack(b io.Reader) error {
ca.SessionPresent = 1&decodeByte(b) > 0
ca.ReturnCode = decodeByte(b)
return nil
}
//Details returns a Details struct containing the Qos and

View File

@ -19,7 +19,7 @@ type ConnectPacket struct {
UsernameFlag bool
PasswordFlag bool
ReservedBit byte
KeepaliveTimer uint16
Keepalive uint16
ClientIdentifier string
WillTopic string
@ -30,7 +30,7 @@ type ConnectPacket struct {
func (c *ConnectPacket) String() string {
str := fmt.Sprintf("%s\n", c.FixedHeader)
str += fmt.Sprintf("protocolversion: %d protocolname: %s cleansession: %t willflag: %t WillQos: %d WillRetain: %t Usernameflag: %t Passwordflag: %t keepalivetimer: %d\nclientId: %s\nwilltopic: %s\nwillmessage: %s\nUsername: %s\nPassword: %s\n", c.ProtocolVersion, c.ProtocolName, c.CleanSession, c.WillFlag, c.WillQos, c.WillRetain, c.UsernameFlag, c.PasswordFlag, c.KeepaliveTimer, c.ClientIdentifier, c.WillTopic, c.WillMessage, c.Username, c.Password)
str += fmt.Sprintf("protocolversion: %d protocolname: %s cleansession: %t willflag: %t WillQos: %d WillRetain: %t Usernameflag: %t Passwordflag: %t keepalive: %d\nclientId: %s\nwilltopic: %s\nwillmessage: %s\nUsername: %s\nPassword: %s\n", c.ProtocolVersion, c.ProtocolName, c.CleanSession, c.WillFlag, c.WillQos, c.WillRetain, c.UsernameFlag, c.PasswordFlag, c.Keepalive, c.ClientIdentifier, c.WillTopic, c.WillMessage, c.Username, c.Password)
return str
}
@ -41,7 +41,7 @@ func (c *ConnectPacket) Write(w io.Writer) error {
body.Write(encodeString(c.ProtocolName))
body.WriteByte(c.ProtocolVersion)
body.WriteByte(boolToByte(c.CleanSession)<<1 | boolToByte(c.WillFlag)<<2 | c.WillQos<<3 | boolToByte(c.WillRetain)<<5 | boolToByte(c.PasswordFlag)<<6 | boolToByte(c.UsernameFlag)<<7)
body.Write(encodeUint16(c.KeepaliveTimer))
body.Write(encodeUint16(c.Keepalive))
body.Write(encodeString(c.ClientIdentifier))
if c.WillFlag {
body.Write(encodeString(c.WillTopic))
@ -63,7 +63,7 @@ func (c *ConnectPacket) Write(w io.Writer) error {
//Unpack decodes the details of a ControlPacket after the fixed
//header has been read
func (c *ConnectPacket) Unpack(b io.Reader) {
func (c *ConnectPacket) Unpack(b io.Reader) error {
c.ProtocolName = decodeString(b)
c.ProtocolVersion = decodeByte(b)
options := decodeByte(b)
@ -74,7 +74,7 @@ func (c *ConnectPacket) Unpack(b io.Reader) {
c.WillRetain = 1&(options>>5) > 0
c.PasswordFlag = 1&(options>>6) > 0
c.UsernameFlag = 1&(options>>7) > 0
c.KeepaliveTimer = decodeUint16(b)
c.Keepalive = decodeUint16(b)
c.ClientIdentifier = decodeString(b)
if c.WillFlag {
c.WillTopic = decodeString(b)
@ -86,6 +86,8 @@ func (c *ConnectPacket) Unpack(b io.Reader) {
if c.PasswordFlag {
c.Password = decodeBytes(b)
}
return nil
}
//Validate performs validation of the fields of a Connect packet

View File

@ -25,7 +25,8 @@ func (d *DisconnectPacket) Write(w io.Writer) error {
//Unpack decodes the details of a ControlPacket after the fixed
//header has been read
func (d *DisconnectPacket) Unpack(b io.Reader) {
func (d *DisconnectPacket) Unpack(b io.Reader) error {
return nil
}
//Details returns a Details struct containing the Qos and

View File

@ -13,7 +13,7 @@ import (
//written
type ControlPacket interface {
Write(io.Writer) error
Unpack(io.Reader)
Unpack(io.Reader) error
String() string
Details() Details
}
@ -116,8 +116,8 @@ func ReadPacket(r io.Reader) (cp ControlPacket, err error) {
if err != nil {
return nil, err
}
cp.Unpack(bytes.NewBuffer(packetBytes))
return cp, nil
err = cp.Unpack(bytes.NewBuffer(packetBytes))
return cp, err
}
//NewControlPacket is used to create a new ControlPacket of the type specified

View File

@ -25,7 +25,8 @@ func (pr *PingreqPacket) Write(w io.Writer) error {
//Unpack decodes the details of a ControlPacket after the fixed
//header has been read
func (pr *PingreqPacket) Unpack(b io.Reader) {
func (pr *PingreqPacket) Unpack(b io.Reader) error {
return nil
}
//Details returns a Details struct containing the Qos and

View File

@ -25,7 +25,8 @@ func (pr *PingrespPacket) Write(w io.Writer) error {
//Unpack decodes the details of a ControlPacket after the fixed
//header has been read
func (pr *PingrespPacket) Unpack(b io.Reader) {
func (pr *PingrespPacket) Unpack(b io.Reader) error {
return nil
}
//Details returns a Details struct containing the Qos and

View File

@ -30,8 +30,10 @@ func (pa *PubackPacket) Write(w io.Writer) error {
//Unpack decodes the details of a ControlPacket after the fixed
//header has been read
func (pa *PubackPacket) Unpack(b io.Reader) {
func (pa *PubackPacket) Unpack(b io.Reader) error {
pa.MessageID = decodeUint16(b)
return nil
}
//Details returns a Details struct containing the Qos and

View File

@ -30,8 +30,10 @@ func (pc *PubcompPacket) Write(w io.Writer) error {
//Unpack decodes the details of a ControlPacket after the fixed
//header has been read
func (pc *PubcompPacket) Unpack(b io.Reader) {
func (pc *PubcompPacket) Unpack(b io.Reader) error {
pc.MessageID = decodeUint16(b)
return nil
}
//Details returns a Details struct containing the Qos and

View File

@ -41,7 +41,7 @@ func (p *PublishPacket) Write(w io.Writer) error {
//Unpack decodes the details of a ControlPacket after the fixed
//header has been read
func (p *PublishPacket) Unpack(b io.Reader) {
func (p *PublishPacket) Unpack(b io.Reader) error {
var payloadLength = p.FixedHeader.RemainingLength
p.TopicName = decodeString(b)
if p.Qos > 0 {
@ -50,8 +50,13 @@ func (p *PublishPacket) Unpack(b io.Reader) {
} else {
payloadLength -= len(p.TopicName) + 2
}
if payloadLength < 0 {
return fmt.Errorf("Error upacking publish, payload length < 0")
}
p.Payload = make([]byte, payloadLength)
b.Read(p.Payload)
_, err := b.Read(p.Payload)
return err
}
//Copy creates a new PublishPacket with the same topic and payload

View File

@ -30,8 +30,10 @@ func (pr *PubrecPacket) Write(w io.Writer) error {
//Unpack decodes the details of a ControlPacket after the fixed
//header has been read
func (pr *PubrecPacket) Unpack(b io.Reader) {
func (pr *PubrecPacket) Unpack(b io.Reader) error {
pr.MessageID = decodeUint16(b)
return nil
}
//Details returns a Details struct containing the Qos and

View File

@ -30,8 +30,10 @@ func (pr *PubrelPacket) Write(w io.Writer) error {
//Unpack decodes the details of a ControlPacket after the fixed
//header has been read
func (pr *PubrelPacket) Unpack(b io.Reader) {
func (pr *PubrelPacket) Unpack(b io.Reader) error {
pr.MessageID = decodeUint16(b)
return nil
}
//Details returns a Details struct containing the Qos and

View File

@ -11,7 +11,7 @@ import (
type SubackPacket struct {
FixedHeader
MessageID uint16
GrantedQoss []byte
ReturnCodes []byte
}
func (sa *SubackPacket) String() string {
@ -24,7 +24,7 @@ func (sa *SubackPacket) Write(w io.Writer) error {
var body bytes.Buffer
var err error
body.Write(encodeUint16(sa.MessageID))
body.Write(sa.GrantedQoss)
body.Write(sa.ReturnCodes)
sa.FixedHeader.RemainingLength = body.Len()
packet := sa.FixedHeader.pack()
packet.Write(body.Bytes())
@ -35,11 +35,13 @@ func (sa *SubackPacket) Write(w io.Writer) error {
//Unpack decodes the details of a ControlPacket after the fixed
//header has been read
func (sa *SubackPacket) Unpack(b io.Reader) {
func (sa *SubackPacket) Unpack(b io.Reader) error {
var qosBuffer bytes.Buffer
sa.MessageID = decodeUint16(b)
qosBuffer.ReadFrom(b)
sa.GrantedQoss = qosBuffer.Bytes()
sa.ReturnCodes = qosBuffer.Bytes()
return nil
}
//Details returns a Details struct containing the Qos and

View File

@ -40,7 +40,7 @@ func (s *SubscribePacket) Write(w io.Writer) error {
//Unpack decodes the details of a ControlPacket after the fixed
//header has been read
func (s *SubscribePacket) Unpack(b io.Reader) {
func (s *SubscribePacket) Unpack(b io.Reader) error {
s.MessageID = decodeUint16(b)
payloadLength := s.FixedHeader.RemainingLength - 2
for payloadLength > 0 {
@ -50,6 +50,8 @@ func (s *SubscribePacket) Unpack(b io.Reader) {
s.Qoss = append(s.Qoss, qos)
payloadLength -= 2 + len(topic) + 1 //2 bytes of string length, plus string, plus 1 byte for Qos
}
return nil
}
//Details returns a Details struct containing the Qos and

View File

@ -30,8 +30,10 @@ func (ua *UnsubackPacket) Write(w io.Writer) error {
//Unpack decodes the details of a ControlPacket after the fixed
//header has been read
func (ua *UnsubackPacket) Unpack(b io.Reader) {
func (ua *UnsubackPacket) Unpack(b io.Reader) error {
ua.MessageID = decodeUint16(b)
return nil
}
//Details returns a Details struct containing the Qos and

View File

@ -37,12 +37,14 @@ func (u *UnsubscribePacket) Write(w io.Writer) error {
//Unpack decodes the details of a ControlPacket after the fixed
//header has been read
func (u *UnsubscribePacket) Unpack(b io.Reader) {
func (u *UnsubscribePacket) Unpack(b io.Reader) error {
u.MessageID = decodeUint16(b)
var topic string
for topic = decodeString(b); topic != ""; topic = decodeString(b) {
u.Topics = append(u.Topics, topic)
}
return nil
}
//Details returns a Details struct containing the Qos and

View File

@ -16,6 +16,7 @@ package mqtt
import (
"errors"
"time"
"github.com/eclipse/paho.mqtt.golang/packets"
)
@ -23,29 +24,62 @@ import (
func keepalive(c *client) {
DEBUG.Println(PNG, "keepalive starting")
pingTimer := timer{Timer: time.NewTimer(c.options.KeepAlive)}
pingRespTimer := timer{Timer: time.NewTimer(c.options.PingTimeout)}
pingRespTimer.Stop()
for {
select {
case <-c.stop:
DEBUG.Println(PNG, "keepalive stopped")
c.workers.Done()
return
case <-c.pingTimer.C:
case <-pingTimer.C:
pingTimer.SetRead(true)
DEBUG.Println(PNG, "keepalive sending ping")
ping := packets.NewControlPacket(packets.Pingreq).(*packets.PingreqPacket)
//We don't want to wait behind large messages being sent, the Write call
//will block until it it able to send the packet.
ping.Write(c.conn)
c.pingRespTimer.Reset(c.options.PingTimeout)
pingRespTimer.Reset(c.options.PingTimeout)
case <-c.pingResp:
DEBUG.Println(NET, "resetting ping timers")
c.pingRespTimer.Stop()
c.pingTimer.Reset(c.options.KeepAlive)
case <-c.pingRespTimer.C:
pingRespTimer.Stop()
pingTimer.Reset(c.options.KeepAlive)
case <-pingRespTimer.C:
pingRespTimer.SetRead(true)
CRITICAL.Println(PNG, "pingresp not received, disconnecting")
c.workers.Done()
c.internalConnLost(errors.New("pingresp not received, disconnecting"))
c.pingTimer.Stop()
pingTimer.Stop()
return
}
}
}
type timer struct {
*time.Timer
readFrom bool
}
func (t *timer) SetRead(v bool) {
t.readFrom = v
}
func (t *timer) Stop() bool {
defer t.SetRead(true)
if !t.Timer.Stop() && !t.readFrom {
<-t.C
return false
}
return true
}
func (t *timer) Reset(d time.Duration) bool {
defer t.SetRead(false)
t.Stop()
return t.Timer.Reset(d)
}

View File

@ -1,10 +0,0 @@
#!/bin/sh
go clean
for file in *.go
do
echo -n "Compiling $file ..."
go build "$file"
echo " done."
done

View File

@ -77,7 +77,7 @@ func persistOutbound(s Store, m packets.ControlPacket) {
// until puback received
s.Put(outboundKeyFromMID(m.Details().MessageID), m)
default:
chkcond(false)
ERROR.Println(STR, "Asked to persist an invalid message type")
}
case 2:
switch m.(type) {
@ -86,7 +86,7 @@ func persistOutbound(s Store, m packets.ControlPacket) {
// until pubrel received
s.Put(outboundKeyFromMID(m.Details().MessageID), m)
default:
chkcond(false)
ERROR.Println(STR, "Asked to persist an invalid message type")
}
}
}
@ -102,7 +102,7 @@ func persistInbound(s Store, m packets.ControlPacket) {
s.Del(outboundKeyFromMID(m.Details().MessageID))
case *packets.PublishPacket, *packets.PubrecPacket, *packets.PingrespPacket, *packets.ConnackPacket:
default:
chkcond(false)
ERROR.Println(STR, "Asked to persist an invalid messages type")
}
case 1:
switch m.(type) {
@ -111,7 +111,7 @@ func persistInbound(s Store, m packets.ControlPacket) {
// until puback sent
s.Put(inboundKeyFromMID(m.Details().MessageID), m)
default:
chkcond(false)
ERROR.Println(STR, "Asked to persist an invalid messages type")
}
case 2:
switch m.(type) {
@ -120,7 +120,7 @@ func persistInbound(s Store, m packets.ControlPacket) {
// until pubrel received
s.Put(inboundKeyFromMID(m.Details().MessageID), m)
default:
chkcond(false)
ERROR.Println(STR, "Asked to persist an invalid messages type")
}
}
}

View File

@ -15,6 +15,7 @@
package mqtt
import (
"fmt"
"io/ioutil"
"testing"
@ -44,7 +45,9 @@ func Test_exists_no(t *testing.T) {
}
func isemptydir(dir string) bool {
chkcond(exists(dir))
if !exists(dir) {
panic(fmt.Errorf("Directory %s does not exist", dir))
}
files, err := ioutil.ReadDir(dir)
chkerr(err)
return len(files) == 0

View File

@ -32,6 +32,8 @@ const (
PingFrame = 9
PongFrame = 10
UnknownFrame = 255
DefaultMaxPayloadBytes = 32 << 20 // 32MB
)
// ProtocolError represents WebSocket protocol errors.
@ -58,6 +60,10 @@ var (
ErrNotSupported = &ProtocolError{"not supported"}
)
// ErrFrameTooLarge is returned by Codec's Receive method if payload size
// exceeds limit set by Conn.MaxPayloadBytes
var ErrFrameTooLarge = errors.New("websocket: frame payload size exceeds limit")
// Addr is an implementation of net.Addr for WebSocket.
type Addr struct {
*url.URL
@ -166,6 +172,10 @@ type Conn struct {
frameHandler
PayloadType byte
defaultCloseStatus int
// MaxPayloadBytes limits the size of frame payload received over Conn
// by Codec's Receive method. If zero, DefaultMaxPayloadBytes is used.
MaxPayloadBytes int
}
// Read implements the io.Reader interface:
@ -302,7 +312,12 @@ func (cd Codec) Send(ws *Conn, v interface{}) (err error) {
return err
}
// Receive receives single frame from ws, unmarshaled by cd.Unmarshal and stores in v.
// Receive receives single frame from ws, unmarshaled by cd.Unmarshal and stores
// in v. The whole frame payload is read to an in-memory buffer; max size of
// payload is defined by ws.MaxPayloadBytes. If frame payload size exceeds
// limit, ErrFrameTooLarge is returned; in this case frame is not read off wire
// completely. The next call to Receive would read and discard leftover data of
// previous oversized frame before processing next frame.
func (cd Codec) Receive(ws *Conn, v interface{}) (err error) {
ws.rio.Lock()
defer ws.rio.Unlock()
@ -325,6 +340,19 @@ again:
if frame == nil {
goto again
}
maxPayloadBytes := ws.MaxPayloadBytes
if maxPayloadBytes == 0 {
maxPayloadBytes = DefaultMaxPayloadBytes
}
if hf, ok := frame.(*hybiFrameReader); ok && hf.header.Length > int64(maxPayloadBytes) {
// payload size exceeds limit, no need to call Unmarshal
//
// set frameReader to current oversized frame so that
// the next call to this function can drain leftover
// data before processing the next frame
ws.frameReader = frame
return ErrFrameTooLarge
}
payloadType := frame.PayloadType()
data, err := ioutil.ReadAll(frame)
if err != nil {

View File

@ -9,6 +9,7 @@ import (
"fmt"
"io"
"log"
"math/rand"
"net"
"net/http"
"net/http/httptest"
@ -605,3 +606,60 @@ func TestCtrlAndData(t *testing.T) {
}
}
}
func TestCodec_ReceiveLimited(t *testing.T) {
const limit = 2048
var payloads [][]byte
for _, size := range []int{
1024,
2048,
4096, // receive of this message would be interrupted due to limit
2048, // this one is to make sure next receive recovers discarding leftovers
} {
b := make([]byte, size)
rand.Read(b)
payloads = append(payloads, b)
}
handlerDone := make(chan struct{})
limitedHandler := func(ws *Conn) {
defer close(handlerDone)
ws.MaxPayloadBytes = limit
defer ws.Close()
for i, p := range payloads {
t.Logf("payload #%d (size %d, exceeds limit: %v)", i, len(p), len(p) > limit)
var recv []byte
err := Message.Receive(ws, &recv)
switch err {
case nil:
case ErrFrameTooLarge:
if len(p) <= limit {
t.Fatalf("unexpected frame size limit: expected %d bytes of payload having limit at %d", len(p), limit)
}
continue
default:
t.Fatalf("unexpected error: %v (want either nil or ErrFrameTooLarge)", err)
}
if len(recv) > limit {
t.Fatalf("received %d bytes of payload having limit at %d", len(recv), limit)
}
if !bytes.Equal(p, recv) {
t.Fatalf("received payload differs:\ngot:\t%v\nwant:\t%v", recv, p)
}
}
}
server := httptest.NewServer(Handler(limitedHandler))
defer server.CloseClientConnections()
defer server.Close()
addr := server.Listener.Addr().String()
ws, err := Dial("ws://"+addr+"/", "", "http://localhost/")
if err != nil {
t.Fatal(err)
}
defer ws.Close()
for i, p := range payloads {
if err := Message.Send(ws, p); err != nil {
t.Fatalf("payload #%d (size %d): %v", i, len(p), err)
}
}
<-handlerDone
}