1use super::*;
2use bytes::{Buf, Bytes};
3
4#[derive(Debug, Clone, PartialEq, Eq)]
6pub struct Connect {
7 pub keep_alive: u16,
9 pub client_id: String,
11 pub clean_start: bool,
13 pub properties: Option<ConnectProperties>,
14}
15
16impl Connect {
17 #[allow(clippy::type_complexity)]
18 pub fn read(
19 fixed_header: FixedHeader,
20 mut bytes: Bytes,
21 ) -> Result<(Connect, Option<LastWill>, Option<Login>), Error> {
22 let variable_header_index = fixed_header.fixed_header_len;
23 bytes.advance(variable_header_index);
24
25 let protocol_name = read_mqtt_string(&mut bytes)?;
27 let protocol_level = read_u8(&mut bytes)?;
28 if protocol_name != "MQTT" {
29 return Err(Error::InvalidProtocol);
30 }
31
32 if protocol_level != 5 {
33 return Err(Error::InvalidProtocolLevel(protocol_level));
34 }
35
36 let connect_flags = read_u8(&mut bytes)?;
37 let clean_start = (connect_flags & 0b10) != 0;
38 let keep_alive = read_u16(&mut bytes)?;
39
40 let properties = ConnectProperties::read(&mut bytes)?;
41
42 let client_id = read_mqtt_string(&mut bytes)?;
43 let will = LastWill::read(connect_flags, &mut bytes)?;
44 let login = Login::read(connect_flags, &mut bytes)?;
45
46 let connect = Connect {
47 keep_alive,
48 client_id,
49 clean_start,
50 properties,
51 };
52
53 Ok((connect, will, login))
54 }
55
56 fn len(&self, will: &Option<LastWill>, l: &Option<Login>) -> usize {
57 let mut len = 2 + "MQTT".len() + 1 + 1 + 2; if let Some(p) = &self.properties {
63 let properties_len = p.len();
64 let properties_len_len = len_len(properties_len);
65 len += properties_len_len + properties_len;
66 } else {
67 len += 1;
69 }
70
71 len += 2 + self.client_id.len();
72
73 if let Some(w) = will {
75 len += w.len();
76 }
77
78 if let Some(l) = l {
80 len += l.len();
81 }
82
83 len
84 }
85
86 pub fn write(
87 &self,
88 will: &Option<LastWill>,
89 l: &Option<Login>,
90 buffer: &mut BytesMut,
91 ) -> Result<usize, Error> {
92 let len = self.len(will, l);
93
94 buffer.put_u8(0b0001_0000);
95 let count = write_remaining_length(buffer, len)?;
96 write_mqtt_string(buffer, "MQTT");
97
98 buffer.put_u8(0x05);
99 let flags_index = 1 + count + 2 + 4 + 1;
100
101 let mut connect_flags = 0;
102 if self.clean_start {
103 connect_flags |= 0x02;
104 }
105
106 buffer.put_u8(connect_flags);
107 buffer.put_u16(self.keep_alive);
108
109 match &self.properties {
110 Some(p) => p.write(buffer)?,
111 None => {
112 write_remaining_length(buffer, 0)?;
113 }
114 };
115
116 write_mqtt_string(buffer, &self.client_id);
117
118 if let Some(w) = will {
119 connect_flags |= w.write(buffer)?;
120 }
121
122 if let Some(l) = l {
123 connect_flags |= l.write(buffer);
124 }
125
126 buffer[flags_index] = connect_flags;
128 Ok(1 + count + len)
129 }
130}
131
132#[derive(Debug, Clone, PartialEq, Eq)]
133pub struct ConnectProperties {
134 pub session_expiry_interval: Option<u32>,
136 pub receive_maximum: Option<u16>,
138 pub max_packet_size: Option<u32>,
140 pub topic_alias_max: Option<u16>,
142 pub request_response_info: Option<u8>,
143 pub request_problem_info: Option<u8>,
144 pub user_properties: Vec<(String, String)>,
146 pub authentication_method: Option<String>,
148 pub authentication_data: Option<Bytes>,
150}
151
152impl ConnectProperties {
153 pub fn new() -> ConnectProperties {
154 ConnectProperties {
155 session_expiry_interval: None,
156 receive_maximum: None,
157 max_packet_size: None,
158 topic_alias_max: None,
159 request_response_info: None,
160 request_problem_info: None,
161 user_properties: Vec::new(),
162 authentication_method: None,
163 authentication_data: None,
164 }
165 }
166
167 pub fn read(bytes: &mut Bytes) -> Result<Option<ConnectProperties>, Error> {
168 let mut session_expiry_interval = None;
169 let mut receive_maximum = None;
170 let mut max_packet_size = None;
171 let mut topic_alias_max = None;
172 let mut request_response_info = None;
173 let mut request_problem_info = None;
174 let mut user_properties = Vec::new();
175 let mut authentication_method = None;
176 let mut authentication_data = None;
177
178 let (properties_len_len, properties_len) = length(bytes.iter())?;
179 bytes.advance(properties_len_len);
180 if properties_len == 0 {
181 return Ok(None);
182 }
183
184 let mut cursor = 0;
185 while cursor < properties_len {
187 let prop = read_u8(bytes)?;
188 cursor += 1;
189 match property(prop)? {
190 PropertyType::SessionExpiryInterval => {
191 session_expiry_interval = Some(read_u32(bytes)?);
192 cursor += 4;
193 }
194 PropertyType::ReceiveMaximum => {
195 receive_maximum = Some(read_u16(bytes)?);
196 cursor += 2;
197 }
198 PropertyType::MaximumPacketSize => {
199 max_packet_size = Some(read_u32(bytes)?);
200 cursor += 4;
201 }
202 PropertyType::TopicAliasMaximum => {
203 topic_alias_max = Some(read_u16(bytes)?);
204 cursor += 2;
205 }
206 PropertyType::RequestResponseInformation => {
207 request_response_info = Some(read_u8(bytes)?);
208 cursor += 1;
209 }
210 PropertyType::RequestProblemInformation => {
211 request_problem_info = Some(read_u8(bytes)?);
212 cursor += 1;
213 }
214 PropertyType::UserProperty => {
215 let key = read_mqtt_string(bytes)?;
216 let value = read_mqtt_string(bytes)?;
217 cursor += 2 + key.len() + 2 + value.len();
218 user_properties.push((key, value));
219 }
220 PropertyType::AuthenticationMethod => {
221 let method = read_mqtt_string(bytes)?;
222 cursor += 2 + method.len();
223 authentication_method = Some(method);
224 }
225 PropertyType::AuthenticationData => {
226 let data = read_mqtt_bytes(bytes)?;
227 cursor += 2 + data.len();
228 authentication_data = Some(data);
229 }
230 _ => return Err(Error::InvalidPropertyType(prop)),
231 }
232 }
233
234 Ok(Some(ConnectProperties {
235 session_expiry_interval,
236 receive_maximum,
237 max_packet_size,
238 topic_alias_max,
239 request_response_info,
240 request_problem_info,
241 user_properties,
242 authentication_method,
243 authentication_data,
244 }))
245 }
246
247 fn len(&self) -> usize {
248 let mut len = 0;
249
250 if self.session_expiry_interval.is_some() {
251 len += 1 + 4;
252 }
253
254 if self.receive_maximum.is_some() {
255 len += 1 + 2;
256 }
257
258 if self.max_packet_size.is_some() {
259 len += 1 + 4;
260 }
261
262 if self.topic_alias_max.is_some() {
263 len += 1 + 2;
264 }
265
266 if self.request_response_info.is_some() {
267 len += 1 + 1;
268 }
269
270 if self.request_problem_info.is_some() {
271 len += 1 + 1;
272 }
273
274 for (key, value) in self.user_properties.iter() {
275 len += 1 + 2 + key.len() + 2 + value.len();
276 }
277
278 if let Some(authentication_method) = &self.authentication_method {
279 len += 1 + 2 + authentication_method.len();
280 }
281
282 if let Some(authentication_data) = &self.authentication_data {
283 len += 1 + 2 + authentication_data.len();
284 }
285
286 len
287 }
288
289 pub fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> {
290 let len = self.len();
291 write_remaining_length(buffer, len)?;
292
293 if let Some(session_expiry_interval) = self.session_expiry_interval {
294 buffer.put_u8(PropertyType::SessionExpiryInterval as u8);
295 buffer.put_u32(session_expiry_interval);
296 }
297
298 if let Some(receive_maximum) = self.receive_maximum {
299 buffer.put_u8(PropertyType::ReceiveMaximum as u8);
300 buffer.put_u16(receive_maximum);
301 }
302
303 if let Some(max_packet_size) = self.max_packet_size {
304 buffer.put_u8(PropertyType::MaximumPacketSize as u8);
305 buffer.put_u32(max_packet_size);
306 }
307
308 if let Some(topic_alias_max) = self.topic_alias_max {
309 buffer.put_u8(PropertyType::TopicAliasMaximum as u8);
310 buffer.put_u16(topic_alias_max);
311 }
312
313 if let Some(request_response_info) = self.request_response_info {
314 buffer.put_u8(PropertyType::RequestResponseInformation as u8);
315 buffer.put_u8(request_response_info);
316 }
317
318 if let Some(request_problem_info) = self.request_problem_info {
319 buffer.put_u8(PropertyType::RequestProblemInformation as u8);
320 buffer.put_u8(request_problem_info);
321 }
322
323 for (key, value) in self.user_properties.iter() {
324 buffer.put_u8(PropertyType::UserProperty as u8);
325 write_mqtt_string(buffer, key);
326 write_mqtt_string(buffer, value);
327 }
328
329 if let Some(authentication_method) = &self.authentication_method {
330 buffer.put_u8(PropertyType::AuthenticationMethod as u8);
331 write_mqtt_string(buffer, authentication_method);
332 }
333
334 if let Some(authentication_data) = &self.authentication_data {
335 buffer.put_u8(PropertyType::AuthenticationData as u8);
336 write_mqtt_bytes(buffer, authentication_data);
337 }
338
339 Ok(())
340 }
341}
342
343impl Default for ConnectProperties {
344 fn default() -> Self {
345 Self::new()
346 }
347}
348
349#[derive(Debug, Clone, PartialEq, Eq)]
351pub struct LastWill {
352 pub topic: Bytes,
353 pub message: Bytes,
354 pub qos: QoS,
355 pub retain: bool,
356 pub properties: Option<LastWillProperties>,
357}
358
359impl LastWill {
360 pub fn new(
361 topic: impl Into<String>,
362 payload: impl Into<Vec<u8>>,
363 qos: QoS,
364 retain: bool,
365 properties: Option<LastWillProperties>,
366 ) -> LastWill {
367 let topic = Bytes::copy_from_slice(topic.into().as_bytes());
368 LastWill {
369 topic,
370 message: Bytes::from(payload.into()),
371 qos,
372 retain,
373 properties,
374 }
375 }
376
377 fn len(&self) -> usize {
378 let mut len = 0;
379
380 if let Some(p) = &self.properties {
381 let properties_len = p.len();
382 let properties_len_len = len_len(properties_len);
383 len += properties_len_len + properties_len;
384 } else {
385 len += 1;
387 }
388
389 len += 2 + self.topic.len() + 2 + self.message.len();
390 len
391 }
392
393 pub fn read(connect_flags: u8, bytes: &mut Bytes) -> Result<Option<LastWill>, Error> {
394 let o = match connect_flags & 0b100 {
395 0 if (connect_flags & 0b0011_1000) != 0 => {
396 return Err(Error::IncorrectPacketFormat);
397 }
398 0 => None,
399 _ => {
400 let properties = LastWillProperties::read(bytes)?;
402
403 let will_topic = read_mqtt_bytes(bytes)?;
404 let will_message = read_mqtt_bytes(bytes)?;
405 let qos_num = (connect_flags & 0b11000) >> 3;
406 let will_qos = qos(qos_num).ok_or(Error::InvalidQoS(qos_num))?;
407 Some(LastWill {
408 topic: will_topic,
409 message: will_message,
410 qos: will_qos,
411 retain: (connect_flags & 0b0010_0000) != 0,
412 properties,
413 })
414 }
415 };
416
417 Ok(o)
418 }
419
420 pub fn write(&self, buffer: &mut BytesMut) -> Result<u8, Error> {
421 let mut connect_flags = 0;
422
423 connect_flags |= 0x04 | (self.qos as u8) << 3;
424 if self.retain {
425 connect_flags |= 0x20;
426 }
427
428 if let Some(p) = &self.properties {
429 p.write(buffer)?;
430 } else {
431 write_remaining_length(buffer, 0)?;
432 }
433
434 write_mqtt_bytes(buffer, &self.topic);
435 write_mqtt_bytes(buffer, &self.message);
436 Ok(connect_flags)
437 }
438}
439
440#[derive(Debug, Clone, PartialEq, Eq)]
441pub struct LastWillProperties {
442 pub delay_interval: Option<u32>,
443 pub payload_format_indicator: Option<u8>,
444 pub message_expiry_interval: Option<u32>,
445 pub content_type: Option<String>,
446 pub response_topic: Option<String>,
447 pub correlation_data: Option<Bytes>,
448 pub user_properties: Vec<(String, String)>,
449}
450
451impl LastWillProperties {
452 fn len(&self) -> usize {
453 let mut len = 0;
454
455 if self.delay_interval.is_some() {
456 len += 1 + 4;
457 }
458
459 if self.payload_format_indicator.is_some() {
460 len += 1 + 1;
461 }
462
463 if self.message_expiry_interval.is_some() {
464 len += 1 + 4;
465 }
466
467 if let Some(typ) = &self.content_type {
468 len += 1 + 2 + typ.len()
469 }
470
471 if let Some(topic) = &self.response_topic {
472 len += 1 + 2 + topic.len()
473 }
474
475 if let Some(data) = &self.correlation_data {
476 len += 1 + 2 + data.len()
477 }
478
479 for (key, value) in self.user_properties.iter() {
480 len += 1 + 2 + key.len() + 2 + value.len();
481 }
482
483 len
484 }
485
486 pub fn read(bytes: &mut Bytes) -> Result<Option<LastWillProperties>, Error> {
487 let mut delay_interval = None;
488 let mut payload_format_indicator = None;
489 let mut message_expiry_interval = None;
490 let mut content_type = None;
491 let mut response_topic = None;
492 let mut correlation_data = None;
493 let mut user_properties = Vec::new();
494
495 let (properties_len_len, properties_len) = length(bytes.iter())?;
496 bytes.advance(properties_len_len);
497 if properties_len == 0 {
498 return Ok(None);
499 }
500
501 let mut cursor = 0;
502 while cursor < properties_len {
504 let prop = read_u8(bytes)?;
505 cursor += 1;
506
507 match property(prop)? {
508 PropertyType::WillDelayInterval => {
509 delay_interval = Some(read_u32(bytes)?);
510 cursor += 4;
511 }
512 PropertyType::PayloadFormatIndicator => {
513 payload_format_indicator = Some(read_u8(bytes)?);
514 cursor += 1;
515 }
516 PropertyType::MessageExpiryInterval => {
517 message_expiry_interval = Some(read_u32(bytes)?);
518 cursor += 4;
519 }
520 PropertyType::ContentType => {
521 let typ = read_mqtt_string(bytes)?;
522 cursor += 2 + typ.len();
523 content_type = Some(typ);
524 }
525 PropertyType::ResponseTopic => {
526 let topic = read_mqtt_string(bytes)?;
527 cursor += 2 + topic.len();
528 response_topic = Some(topic);
529 }
530 PropertyType::CorrelationData => {
531 let data = read_mqtt_bytes(bytes)?;
532 cursor += 2 + data.len();
533 correlation_data = Some(data);
534 }
535 PropertyType::UserProperty => {
536 let key = read_mqtt_string(bytes)?;
537 let value = read_mqtt_string(bytes)?;
538 cursor += 2 + key.len() + 2 + value.len();
539 user_properties.push((key, value));
540 }
541 _ => return Err(Error::InvalidPropertyType(prop)),
542 }
543 }
544
545 Ok(Some(LastWillProperties {
546 delay_interval,
547 payload_format_indicator,
548 message_expiry_interval,
549 content_type,
550 response_topic,
551 correlation_data,
552 user_properties,
553 }))
554 }
555
556 pub fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> {
557 let len = self.len();
558 write_remaining_length(buffer, len)?;
559
560 if let Some(delay_interval) = self.delay_interval {
561 buffer.put_u8(PropertyType::WillDelayInterval as u8);
562 buffer.put_u32(delay_interval);
563 }
564
565 if let Some(payload_format_indicator) = self.payload_format_indicator {
566 buffer.put_u8(PropertyType::PayloadFormatIndicator as u8);
567 buffer.put_u8(payload_format_indicator);
568 }
569
570 if let Some(message_expiry_interval) = self.message_expiry_interval {
571 buffer.put_u8(PropertyType::MessageExpiryInterval as u8);
572 buffer.put_u32(message_expiry_interval);
573 }
574
575 if let Some(typ) = &self.content_type {
576 buffer.put_u8(PropertyType::ContentType as u8);
577 write_mqtt_string(buffer, typ);
578 }
579
580 if let Some(topic) = &self.response_topic {
581 buffer.put_u8(PropertyType::ResponseTopic as u8);
582 write_mqtt_string(buffer, topic);
583 }
584
585 if let Some(data) = &self.correlation_data {
586 buffer.put_u8(PropertyType::CorrelationData as u8);
587 write_mqtt_bytes(buffer, data);
588 }
589
590 for (key, value) in self.user_properties.iter() {
591 buffer.put_u8(PropertyType::UserProperty as u8);
592 write_mqtt_string(buffer, key);
593 write_mqtt_string(buffer, value);
594 }
595
596 Ok(())
597 }
598}
599#[derive(Debug, Clone, PartialEq, Eq)]
600pub struct Login {
601 pub username: String,
602 pub password: String,
603}
604
605impl Login {
606 pub fn new<U: Into<String>, P: Into<String>>(u: U, p: P) -> Login {
607 Login {
608 username: u.into(),
609 password: p.into(),
610 }
611 }
612
613 pub fn read(connect_flags: u8, bytes: &mut Bytes) -> Result<Option<Login>, Error> {
614 let username = match connect_flags & 0b1000_0000 {
615 0 => String::new(),
616 _ => read_mqtt_string(bytes)?,
617 };
618
619 let password = match connect_flags & 0b0100_0000 {
620 0 => String::new(),
621 _ => read_mqtt_string(bytes)?,
622 };
623
624 if username.is_empty() && password.is_empty() {
625 Ok(None)
626 } else {
627 Ok(Some(Login { username, password }))
628 }
629 }
630
631 fn len(&self) -> usize {
632 let mut len = 0;
633
634 if !self.username.is_empty() {
635 len += 2 + self.username.len();
636 }
637
638 if !self.password.is_empty() {
639 len += 2 + self.password.len();
640 }
641
642 len
643 }
644
645 pub fn write(&self, buffer: &mut BytesMut) -> u8 {
646 let mut connect_flags = 0;
647 if !self.username.is_empty() {
648 connect_flags |= 0x80;
649 write_mqtt_string(buffer, &self.username);
650 }
651
652 if !self.password.is_empty() {
653 connect_flags |= 0x40;
654 write_mqtt_string(buffer, &self.password);
655 }
656
657 connect_flags
658 }
659}
660
661#[cfg(test)]
662mod test {
663 use super::super::test::{USER_PROP_KEY, USER_PROP_VAL};
664 use super::*;
665 use bytes::BytesMut;
666 use pretty_assertions::assert_eq;
667
668 #[test]
669 fn length_calculation() {
670 let mut dummy_bytes = BytesMut::new();
671 let mut connect_props = ConnectProperties::new();
672 connect_props.user_properties = vec![(USER_PROP_KEY.into(), USER_PROP_VAL.into())];
675 let connect_pkt = Connect {
676 keep_alive: 5,
677 client_id: "client".into(),
678 clean_start: true,
679 properties: Some(connect_props),
680 };
681
682 let reported_size = connect_pkt.write(&None, &None, &mut dummy_bytes).unwrap();
683 let size_from_bytes = dummy_bytes.len();
684
685 assert_eq!(reported_size, size_from_bytes);
686 }
687}