rumqttc/v5/mqttbytes/v5/
connect.rs

1use super::*;
2use bytes::{Buf, Bytes};
3
4/// Connection packet initiated by the client
5#[derive(Debug, Clone, PartialEq, Eq)]
6pub struct Connect {
7    /// Mqtt keep alive time
8    pub keep_alive: u16,
9    /// Client Id
10    pub client_id: String,
11    /// Clean session. Asks the broker to clear previous state
12    pub clean_start: bool,
13    pub properties: Option<ConnectProperties>,
14}
15
16impl Connect {
17    #[allow(clippy::type_complexity)]
18    pub fn read(
19        fixed_header: FixedHeader,
20        mut bytes: Bytes,
21    ) -> Result<(Connect, Option<LastWill>, Option<Login>), Error> {
22        let variable_header_index = fixed_header.fixed_header_len;
23        bytes.advance(variable_header_index);
24
25        // Variable header
26        let protocol_name = read_mqtt_string(&mut bytes)?;
27        let protocol_level = read_u8(&mut bytes)?;
28        if protocol_name != "MQTT" {
29            return Err(Error::InvalidProtocol);
30        }
31
32        if protocol_level != 5 {
33            return Err(Error::InvalidProtocolLevel(protocol_level));
34        }
35
36        let connect_flags = read_u8(&mut bytes)?;
37        let clean_start = (connect_flags & 0b10) != 0;
38        let keep_alive = read_u16(&mut bytes)?;
39
40        let properties = ConnectProperties::read(&mut bytes)?;
41
42        let client_id = read_mqtt_string(&mut bytes)?;
43        let will = LastWill::read(connect_flags, &mut bytes)?;
44        let login = Login::read(connect_flags, &mut bytes)?;
45
46        let connect = Connect {
47            keep_alive,
48            client_id,
49            clean_start,
50            properties,
51        };
52
53        Ok((connect, will, login))
54    }
55
56    fn len(&self, will: &Option<LastWill>, l: &Option<Login>) -> usize {
57        let mut len = 2 + "MQTT".len() // protocol name
58                        + 1            // protocol version
59                        + 1            // connect flags
60                        + 2; // keep alive
61
62        if let Some(p) = &self.properties {
63            let properties_len = p.len();
64            let properties_len_len = len_len(properties_len);
65            len += properties_len_len + properties_len;
66        } else {
67            // just 1 byte representing 0 len
68            len += 1;
69        }
70
71        len += 2 + self.client_id.len();
72
73        // last will len
74        if let Some(w) = will {
75            len += w.len();
76        }
77
78        // username and password len
79        if let Some(l) = l {
80            len += l.len();
81        }
82
83        len
84    }
85
86    pub fn write(
87        &self,
88        will: &Option<LastWill>,
89        l: &Option<Login>,
90        buffer: &mut BytesMut,
91    ) -> Result<usize, Error> {
92        let len = self.len(will, l);
93
94        buffer.put_u8(0b0001_0000);
95        let count = write_remaining_length(buffer, len)?;
96        write_mqtt_string(buffer, "MQTT");
97
98        buffer.put_u8(0x05);
99        let flags_index = 1 + count + 2 + 4 + 1;
100
101        let mut connect_flags = 0;
102        if self.clean_start {
103            connect_flags |= 0x02;
104        }
105
106        buffer.put_u8(connect_flags);
107        buffer.put_u16(self.keep_alive);
108
109        match &self.properties {
110            Some(p) => p.write(buffer)?,
111            None => {
112                write_remaining_length(buffer, 0)?;
113            }
114        };
115
116        write_mqtt_string(buffer, &self.client_id);
117
118        if let Some(w) = will {
119            connect_flags |= w.write(buffer)?;
120        }
121
122        if let Some(l) = l {
123            connect_flags |= l.write(buffer);
124        }
125
126        // update connect flags
127        buffer[flags_index] = connect_flags;
128        Ok(1 + count + len)
129    }
130}
131
132#[derive(Debug, Clone, PartialEq, Eq)]
133pub struct ConnectProperties {
134    /// Expiry interval property after loosing connection
135    pub session_expiry_interval: Option<u32>,
136    /// Maximum simultaneous packets
137    pub receive_maximum: Option<u16>,
138    /// Maximum packet size
139    pub max_packet_size: Option<u32>,
140    /// Maximum mapping integer for a topic
141    pub topic_alias_max: Option<u16>,
142    pub request_response_info: Option<u8>,
143    pub request_problem_info: Option<u8>,
144    /// List of user properties
145    pub user_properties: Vec<(String, String)>,
146    /// Method of authentication
147    pub authentication_method: Option<String>,
148    /// Authentication data
149    pub authentication_data: Option<Bytes>,
150}
151
152impl ConnectProperties {
153    pub fn new() -> ConnectProperties {
154        ConnectProperties {
155            session_expiry_interval: None,
156            receive_maximum: None,
157            max_packet_size: None,
158            topic_alias_max: None,
159            request_response_info: None,
160            request_problem_info: None,
161            user_properties: Vec::new(),
162            authentication_method: None,
163            authentication_data: None,
164        }
165    }
166
167    pub fn read(bytes: &mut Bytes) -> Result<Option<ConnectProperties>, Error> {
168        let mut session_expiry_interval = None;
169        let mut receive_maximum = None;
170        let mut max_packet_size = None;
171        let mut topic_alias_max = None;
172        let mut request_response_info = None;
173        let mut request_problem_info = None;
174        let mut user_properties = Vec::new();
175        let mut authentication_method = None;
176        let mut authentication_data = None;
177
178        let (properties_len_len, properties_len) = length(bytes.iter())?;
179        bytes.advance(properties_len_len);
180        if properties_len == 0 {
181            return Ok(None);
182        }
183
184        let mut cursor = 0;
185        // read until cursor reaches property length. properties_len = 0 will skip this loop
186        while cursor < properties_len {
187            let prop = read_u8(bytes)?;
188            cursor += 1;
189            match property(prop)? {
190                PropertyType::SessionExpiryInterval => {
191                    session_expiry_interval = Some(read_u32(bytes)?);
192                    cursor += 4;
193                }
194                PropertyType::ReceiveMaximum => {
195                    receive_maximum = Some(read_u16(bytes)?);
196                    cursor += 2;
197                }
198                PropertyType::MaximumPacketSize => {
199                    max_packet_size = Some(read_u32(bytes)?);
200                    cursor += 4;
201                }
202                PropertyType::TopicAliasMaximum => {
203                    topic_alias_max = Some(read_u16(bytes)?);
204                    cursor += 2;
205                }
206                PropertyType::RequestResponseInformation => {
207                    request_response_info = Some(read_u8(bytes)?);
208                    cursor += 1;
209                }
210                PropertyType::RequestProblemInformation => {
211                    request_problem_info = Some(read_u8(bytes)?);
212                    cursor += 1;
213                }
214                PropertyType::UserProperty => {
215                    let key = read_mqtt_string(bytes)?;
216                    let value = read_mqtt_string(bytes)?;
217                    cursor += 2 + key.len() + 2 + value.len();
218                    user_properties.push((key, value));
219                }
220                PropertyType::AuthenticationMethod => {
221                    let method = read_mqtt_string(bytes)?;
222                    cursor += 2 + method.len();
223                    authentication_method = Some(method);
224                }
225                PropertyType::AuthenticationData => {
226                    let data = read_mqtt_bytes(bytes)?;
227                    cursor += 2 + data.len();
228                    authentication_data = Some(data);
229                }
230                _ => return Err(Error::InvalidPropertyType(prop)),
231            }
232        }
233
234        Ok(Some(ConnectProperties {
235            session_expiry_interval,
236            receive_maximum,
237            max_packet_size,
238            topic_alias_max,
239            request_response_info,
240            request_problem_info,
241            user_properties,
242            authentication_method,
243            authentication_data,
244        }))
245    }
246
247    fn len(&self) -> usize {
248        let mut len = 0;
249
250        if self.session_expiry_interval.is_some() {
251            len += 1 + 4;
252        }
253
254        if self.receive_maximum.is_some() {
255            len += 1 + 2;
256        }
257
258        if self.max_packet_size.is_some() {
259            len += 1 + 4;
260        }
261
262        if self.topic_alias_max.is_some() {
263            len += 1 + 2;
264        }
265
266        if self.request_response_info.is_some() {
267            len += 1 + 1;
268        }
269
270        if self.request_problem_info.is_some() {
271            len += 1 + 1;
272        }
273
274        for (key, value) in self.user_properties.iter() {
275            len += 1 + 2 + key.len() + 2 + value.len();
276        }
277
278        if let Some(authentication_method) = &self.authentication_method {
279            len += 1 + 2 + authentication_method.len();
280        }
281
282        if let Some(authentication_data) = &self.authentication_data {
283            len += 1 + 2 + authentication_data.len();
284        }
285
286        len
287    }
288
289    pub fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> {
290        let len = self.len();
291        write_remaining_length(buffer, len)?;
292
293        if let Some(session_expiry_interval) = self.session_expiry_interval {
294            buffer.put_u8(PropertyType::SessionExpiryInterval as u8);
295            buffer.put_u32(session_expiry_interval);
296        }
297
298        if let Some(receive_maximum) = self.receive_maximum {
299            buffer.put_u8(PropertyType::ReceiveMaximum as u8);
300            buffer.put_u16(receive_maximum);
301        }
302
303        if let Some(max_packet_size) = self.max_packet_size {
304            buffer.put_u8(PropertyType::MaximumPacketSize as u8);
305            buffer.put_u32(max_packet_size);
306        }
307
308        if let Some(topic_alias_max) = self.topic_alias_max {
309            buffer.put_u8(PropertyType::TopicAliasMaximum as u8);
310            buffer.put_u16(topic_alias_max);
311        }
312
313        if let Some(request_response_info) = self.request_response_info {
314            buffer.put_u8(PropertyType::RequestResponseInformation as u8);
315            buffer.put_u8(request_response_info);
316        }
317
318        if let Some(request_problem_info) = self.request_problem_info {
319            buffer.put_u8(PropertyType::RequestProblemInformation as u8);
320            buffer.put_u8(request_problem_info);
321        }
322
323        for (key, value) in self.user_properties.iter() {
324            buffer.put_u8(PropertyType::UserProperty as u8);
325            write_mqtt_string(buffer, key);
326            write_mqtt_string(buffer, value);
327        }
328
329        if let Some(authentication_method) = &self.authentication_method {
330            buffer.put_u8(PropertyType::AuthenticationMethod as u8);
331            write_mqtt_string(buffer, authentication_method);
332        }
333
334        if let Some(authentication_data) = &self.authentication_data {
335            buffer.put_u8(PropertyType::AuthenticationData as u8);
336            write_mqtt_bytes(buffer, authentication_data);
337        }
338
339        Ok(())
340    }
341}
342
343impl Default for ConnectProperties {
344    fn default() -> Self {
345        Self::new()
346    }
347}
348
349/// LastWill that broker forwards on behalf of the client
350#[derive(Debug, Clone, PartialEq, Eq)]
351pub struct LastWill {
352    pub topic: Bytes,
353    pub message: Bytes,
354    pub qos: QoS,
355    pub retain: bool,
356    pub properties: Option<LastWillProperties>,
357}
358
359impl LastWill {
360    pub fn new(
361        topic: impl Into<String>,
362        payload: impl Into<Vec<u8>>,
363        qos: QoS,
364        retain: bool,
365        properties: Option<LastWillProperties>,
366    ) -> LastWill {
367        let topic = Bytes::copy_from_slice(topic.into().as_bytes());
368        LastWill {
369            topic,
370            message: Bytes::from(payload.into()),
371            qos,
372            retain,
373            properties,
374        }
375    }
376
377    fn len(&self) -> usize {
378        let mut len = 0;
379
380        if let Some(p) = &self.properties {
381            let properties_len = p.len();
382            let properties_len_len = len_len(properties_len);
383            len += properties_len_len + properties_len;
384        } else {
385            // just 1 byte representing 0 len
386            len += 1;
387        }
388
389        len += 2 + self.topic.len() + 2 + self.message.len();
390        len
391    }
392
393    pub fn read(connect_flags: u8, bytes: &mut Bytes) -> Result<Option<LastWill>, Error> {
394        let o = match connect_flags & 0b100 {
395            0 if (connect_flags & 0b0011_1000) != 0 => {
396                return Err(Error::IncorrectPacketFormat);
397            }
398            0 => None,
399            _ => {
400                // Properties in variable header
401                let properties = LastWillProperties::read(bytes)?;
402
403                let will_topic = read_mqtt_bytes(bytes)?;
404                let will_message = read_mqtt_bytes(bytes)?;
405                let qos_num = (connect_flags & 0b11000) >> 3;
406                let will_qos = qos(qos_num).ok_or(Error::InvalidQoS(qos_num))?;
407                Some(LastWill {
408                    topic: will_topic,
409                    message: will_message,
410                    qos: will_qos,
411                    retain: (connect_flags & 0b0010_0000) != 0,
412                    properties,
413                })
414            }
415        };
416
417        Ok(o)
418    }
419
420    pub fn write(&self, buffer: &mut BytesMut) -> Result<u8, Error> {
421        let mut connect_flags = 0;
422
423        connect_flags |= 0x04 | (self.qos as u8) << 3;
424        if self.retain {
425            connect_flags |= 0x20;
426        }
427
428        if let Some(p) = &self.properties {
429            p.write(buffer)?;
430        } else {
431            write_remaining_length(buffer, 0)?;
432        }
433
434        write_mqtt_bytes(buffer, &self.topic);
435        write_mqtt_bytes(buffer, &self.message);
436        Ok(connect_flags)
437    }
438}
439
440#[derive(Debug, Clone, PartialEq, Eq)]
441pub struct LastWillProperties {
442    pub delay_interval: Option<u32>,
443    pub payload_format_indicator: Option<u8>,
444    pub message_expiry_interval: Option<u32>,
445    pub content_type: Option<String>,
446    pub response_topic: Option<String>,
447    pub correlation_data: Option<Bytes>,
448    pub user_properties: Vec<(String, String)>,
449}
450
451impl LastWillProperties {
452    fn len(&self) -> usize {
453        let mut len = 0;
454
455        if self.delay_interval.is_some() {
456            len += 1 + 4;
457        }
458
459        if self.payload_format_indicator.is_some() {
460            len += 1 + 1;
461        }
462
463        if self.message_expiry_interval.is_some() {
464            len += 1 + 4;
465        }
466
467        if let Some(typ) = &self.content_type {
468            len += 1 + 2 + typ.len()
469        }
470
471        if let Some(topic) = &self.response_topic {
472            len += 1 + 2 + topic.len()
473        }
474
475        if let Some(data) = &self.correlation_data {
476            len += 1 + 2 + data.len()
477        }
478
479        for (key, value) in self.user_properties.iter() {
480            len += 1 + 2 + key.len() + 2 + value.len();
481        }
482
483        len
484    }
485
486    pub fn read(bytes: &mut Bytes) -> Result<Option<LastWillProperties>, Error> {
487        let mut delay_interval = None;
488        let mut payload_format_indicator = None;
489        let mut message_expiry_interval = None;
490        let mut content_type = None;
491        let mut response_topic = None;
492        let mut correlation_data = None;
493        let mut user_properties = Vec::new();
494
495        let (properties_len_len, properties_len) = length(bytes.iter())?;
496        bytes.advance(properties_len_len);
497        if properties_len == 0 {
498            return Ok(None);
499        }
500
501        let mut cursor = 0;
502        // read until cursor reaches property length. properties_len = 0 will skip this loop
503        while cursor < properties_len {
504            let prop = read_u8(bytes)?;
505            cursor += 1;
506
507            match property(prop)? {
508                PropertyType::WillDelayInterval => {
509                    delay_interval = Some(read_u32(bytes)?);
510                    cursor += 4;
511                }
512                PropertyType::PayloadFormatIndicator => {
513                    payload_format_indicator = Some(read_u8(bytes)?);
514                    cursor += 1;
515                }
516                PropertyType::MessageExpiryInterval => {
517                    message_expiry_interval = Some(read_u32(bytes)?);
518                    cursor += 4;
519                }
520                PropertyType::ContentType => {
521                    let typ = read_mqtt_string(bytes)?;
522                    cursor += 2 + typ.len();
523                    content_type = Some(typ);
524                }
525                PropertyType::ResponseTopic => {
526                    let topic = read_mqtt_string(bytes)?;
527                    cursor += 2 + topic.len();
528                    response_topic = Some(topic);
529                }
530                PropertyType::CorrelationData => {
531                    let data = read_mqtt_bytes(bytes)?;
532                    cursor += 2 + data.len();
533                    correlation_data = Some(data);
534                }
535                PropertyType::UserProperty => {
536                    let key = read_mqtt_string(bytes)?;
537                    let value = read_mqtt_string(bytes)?;
538                    cursor += 2 + key.len() + 2 + value.len();
539                    user_properties.push((key, value));
540                }
541                _ => return Err(Error::InvalidPropertyType(prop)),
542            }
543        }
544
545        Ok(Some(LastWillProperties {
546            delay_interval,
547            payload_format_indicator,
548            message_expiry_interval,
549            content_type,
550            response_topic,
551            correlation_data,
552            user_properties,
553        }))
554    }
555
556    pub fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> {
557        let len = self.len();
558        write_remaining_length(buffer, len)?;
559
560        if let Some(delay_interval) = self.delay_interval {
561            buffer.put_u8(PropertyType::WillDelayInterval as u8);
562            buffer.put_u32(delay_interval);
563        }
564
565        if let Some(payload_format_indicator) = self.payload_format_indicator {
566            buffer.put_u8(PropertyType::PayloadFormatIndicator as u8);
567            buffer.put_u8(payload_format_indicator);
568        }
569
570        if let Some(message_expiry_interval) = self.message_expiry_interval {
571            buffer.put_u8(PropertyType::MessageExpiryInterval as u8);
572            buffer.put_u32(message_expiry_interval);
573        }
574
575        if let Some(typ) = &self.content_type {
576            buffer.put_u8(PropertyType::ContentType as u8);
577            write_mqtt_string(buffer, typ);
578        }
579
580        if let Some(topic) = &self.response_topic {
581            buffer.put_u8(PropertyType::ResponseTopic as u8);
582            write_mqtt_string(buffer, topic);
583        }
584
585        if let Some(data) = &self.correlation_data {
586            buffer.put_u8(PropertyType::CorrelationData as u8);
587            write_mqtt_bytes(buffer, data);
588        }
589
590        for (key, value) in self.user_properties.iter() {
591            buffer.put_u8(PropertyType::UserProperty as u8);
592            write_mqtt_string(buffer, key);
593            write_mqtt_string(buffer, value);
594        }
595
596        Ok(())
597    }
598}
599#[derive(Debug, Clone, PartialEq, Eq)]
600pub struct Login {
601    pub username: String,
602    pub password: String,
603}
604
605impl Login {
606    pub fn new<U: Into<String>, P: Into<String>>(u: U, p: P) -> Login {
607        Login {
608            username: u.into(),
609            password: p.into(),
610        }
611    }
612
613    pub fn read(connect_flags: u8, bytes: &mut Bytes) -> Result<Option<Login>, Error> {
614        let username = match connect_flags & 0b1000_0000 {
615            0 => String::new(),
616            _ => read_mqtt_string(bytes)?,
617        };
618
619        let password = match connect_flags & 0b0100_0000 {
620            0 => String::new(),
621            _ => read_mqtt_string(bytes)?,
622        };
623
624        if username.is_empty() && password.is_empty() {
625            Ok(None)
626        } else {
627            Ok(Some(Login { username, password }))
628        }
629    }
630
631    fn len(&self) -> usize {
632        let mut len = 0;
633
634        if !self.username.is_empty() {
635            len += 2 + self.username.len();
636        }
637
638        if !self.password.is_empty() {
639            len += 2 + self.password.len();
640        }
641
642        len
643    }
644
645    pub fn write(&self, buffer: &mut BytesMut) -> u8 {
646        let mut connect_flags = 0;
647        if !self.username.is_empty() {
648            connect_flags |= 0x80;
649            write_mqtt_string(buffer, &self.username);
650        }
651
652        if !self.password.is_empty() {
653            connect_flags |= 0x40;
654            write_mqtt_string(buffer, &self.password);
655        }
656
657        connect_flags
658    }
659}
660
661#[cfg(test)]
662mod test {
663    use super::super::test::{USER_PROP_KEY, USER_PROP_VAL};
664    use super::*;
665    use bytes::BytesMut;
666    use pretty_assertions::assert_eq;
667
668    #[test]
669    fn length_calculation() {
670        let mut dummy_bytes = BytesMut::new();
671        let mut connect_props = ConnectProperties::new();
672        // Use user_properties to pad the size to exceed ~128 bytes to make the
673        // remaining_length field in the packet be 2 bytes long.
674        connect_props.user_properties = vec![(USER_PROP_KEY.into(), USER_PROP_VAL.into())];
675        let connect_pkt = Connect {
676            keep_alive: 5,
677            client_id: "client".into(),
678            clean_start: true,
679            properties: Some(connect_props),
680        };
681
682        let reported_size = connect_pkt.write(&None, &None, &mut dummy_bytes).unwrap();
683        let size_from_bytes = dummy_bytes.len();
684
685        assert_eq!(reported_size, size_from_bytes);
686    }
687}