1use bytes::BytesMut;
2use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
3
4use super::mqttbytes;
5use super::mqttbytes::v5::{Connect, Login, Packet};
6use super::{Incoming, MqttOptions, MqttState, StateError};
7use std::io;
8
9pub struct Network {
13 socket: Box<dyn N>,
15 read: BytesMut,
17 max_incoming_size: Option<usize>,
19 max_readb_count: usize,
21}
22
23impl Network {
24 pub fn new(socket: impl N + 'static, max_incoming_size: Option<usize>) -> Network {
25 let socket = Box::new(socket) as Box<dyn N>;
26 Network {
27 socket,
28 read: BytesMut::with_capacity(10 * 1024),
29 max_incoming_size,
30 max_readb_count: 10,
31 }
32 }
33
34 async fn read_bytes(&mut self, required: usize) -> io::Result<usize> {
36 let mut total_read = 0;
37 loop {
38 let read = self.socket.read_buf(&mut self.read).await?;
39 if 0 == read {
40 return if self.read.is_empty() {
41 Err(io::Error::new(
42 io::ErrorKind::ConnectionAborted,
43 "connection closed by peer",
44 ))
45 } else {
46 Err(io::Error::new(
47 io::ErrorKind::ConnectionReset,
48 "connection reset by peer",
49 ))
50 };
51 }
52
53 total_read += read;
54 if total_read >= required {
55 return Ok(total_read);
56 }
57 }
58 }
59
60 pub async fn read(&mut self) -> io::Result<Incoming> {
61 loop {
62 let required = match Packet::read(&mut self.read, self.max_incoming_size) {
63 Ok(packet) => return Ok(packet),
64 Err(mqttbytes::Error::InsufficientBytes(required)) => required,
65 Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())),
66 };
67
68 self.read_bytes(required).await?;
71 }
72 }
73
74 pub async fn readb(&mut self, state: &mut MqttState) -> Result<(), StateError> {
77 let mut count = 0;
78 loop {
79 match Packet::read(&mut self.read, self.max_incoming_size) {
80 Ok(packet) => {
81 state.handle_incoming_packet(packet)?;
82
83 count += 1;
84 if count >= self.max_readb_count {
85 return Ok(());
86 }
87 }
88 Err(mqttbytes::Error::InsufficientBytes(_)) if count > 0 => return Ok(()),
90 Err(mqttbytes::Error::InsufficientBytes(required)) => {
92 self.read_bytes(required).await?;
93 }
94 Err(mqttbytes::Error::PayloadSizeLimitExceeded { pkt_size, max }) => {
95 state.handle_protocol_error()?;
96 return Err(StateError::IncomingPacketTooLarge { pkt_size, max });
97 }
98 Err(e) => return Err(StateError::Deserialization(e)),
99 };
100 }
101 }
102
103 pub async fn connect(&mut self, connect: Connect, options: &MqttOptions) -> io::Result<usize> {
104 let mut write = BytesMut::new();
105 let last_will = options.last_will();
106 let login = options.credentials().map(|l| Login {
107 username: l.0,
108 password: l.1,
109 });
110
111 let len = match Packet::Connect(connect, last_will, login).write(&mut write) {
112 Ok(size) => size,
113 Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())),
114 };
115
116 self.socket.write_all(&write[..]).await?;
117 Ok(len)
118 }
119
120 pub async fn flush(&mut self, write: &mut BytesMut) -> io::Result<()> {
121 if write.is_empty() {
122 return Ok(());
123 }
124
125 self.socket.write_all(&write[..]).await?;
126 write.clear();
127 Ok(())
128 }
129}
130
131pub trait N: AsyncRead + AsyncWrite + Send + Unpin {}
132impl<T> N for T where T: AsyncRead + AsyncWrite + Send + Unpin {}