rumqttc/
framed.rs

1use bytes::BytesMut;
2use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
3
4use crate::mqttbytes::{self, v4::*};
5use crate::{Incoming, MqttState, StateError};
6use std::io;
7
8/// Network transforms packets <-> frames efficiently. It takes
9/// advantage of pre-allocation, buffering and vectorization when
10/// appropriate to achieve performance
11pub struct Network {
12    /// Socket for IO
13    socket: Box<dyn N>,
14    /// Buffered reads
15    read: BytesMut,
16    /// Maximum packet size
17    max_incoming_size: usize,
18    /// Maximum readv count
19    max_readb_count: usize,
20}
21
22impl Network {
23    pub fn new(socket: impl N + 'static, max_incoming_size: usize) -> Network {
24        let socket = Box::new(socket) as Box<dyn N>;
25        Network {
26            socket,
27            read: BytesMut::with_capacity(10 * 1024),
28            max_incoming_size,
29            max_readb_count: 10,
30        }
31    }
32
33    /// Reads more than 'required' bytes to frame a packet into self.read buffer
34    async fn read_bytes(&mut self, required: usize) -> io::Result<usize> {
35        let mut total_read = 0;
36        loop {
37            let read = self.socket.read_buf(&mut self.read).await?;
38            if 0 == read {
39                return if self.read.is_empty() {
40                    Err(io::Error::new(
41                        io::ErrorKind::ConnectionAborted,
42                        "connection closed by peer",
43                    ))
44                } else {
45                    Err(io::Error::new(
46                        io::ErrorKind::ConnectionReset,
47                        "connection reset by peer",
48                    ))
49                };
50            }
51
52            total_read += read;
53            if total_read >= required {
54                return Ok(total_read);
55            }
56        }
57    }
58
59    pub async fn read(&mut self) -> io::Result<Incoming> {
60        loop {
61            let required = match read(&mut self.read, self.max_incoming_size) {
62                Ok(packet) => return Ok(packet),
63                Err(mqttbytes::Error::InsufficientBytes(required)) => required,
64                Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())),
65            };
66
67            // read more packets until a frame can be created. This function
68            // blocks until a frame can be created. Use this in a select! branch
69            self.read_bytes(required).await?;
70        }
71    }
72
73    /// Read packets in bulk. This allow replies to be in bulk. This method is used
74    /// after the connection is established to read a bunch of incoming packets
75    pub async fn readb(&mut self, state: &mut MqttState) -> Result<(), StateError> {
76        let mut count = 0;
77        loop {
78            match read(&mut self.read, self.max_incoming_size) {
79                Ok(packet) => {
80                    state.handle_incoming_packet(packet)?;
81
82                    count += 1;
83                    if count >= self.max_readb_count {
84                        return Ok(());
85                    }
86                }
87                // If some packets are already framed, return those
88                Err(mqttbytes::Error::InsufficientBytes(_)) if count > 0 => return Ok(()),
89                // Wait for more bytes until a frame can be created
90                Err(mqttbytes::Error::InsufficientBytes(required)) => {
91                    self.read_bytes(required).await?;
92                }
93                Err(e) => return Err(StateError::Deserialization(e)),
94            };
95        }
96    }
97
98    pub async fn connect(&mut self, connect: Connect) -> io::Result<usize> {
99        let mut write = BytesMut::new();
100        let len = match connect.write(&mut write) {
101            Ok(size) => size,
102            Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())),
103        };
104
105        self.socket.write_all(&write[..]).await?;
106        Ok(len)
107    }
108
109    pub async fn flush(&mut self, write: &mut BytesMut) -> io::Result<()> {
110        if write.is_empty() {
111            return Ok(());
112        }
113
114        self.socket.write_all(&write[..]).await?;
115        write.clear();
116        Ok(())
117    }
118}
119
120pub trait N: AsyncRead + AsyncWrite + Send + Unpin {}
121impl<T> N for T where T: AsyncRead + AsyncWrite + Send + Unpin {}