channel.rs 6.5 KB


  1. use byteorder::{BigEndian, ByteOrder};
  2. use bytes::Bytes;
  3. use futures::sync::{mpsc, BiLock};
  4. use futures::{Async, Poll, Stream};
  5. use std::collections::HashMap;
  6. use std::time::Instant;
  7. use crate::util::SeqGenerator;
  8. component! {
  9. ChannelManager : ChannelManagerInner {
  10. sequence: SeqGenerator<u16> = SeqGenerator::new(0),
  11. channels: HashMap<u16, mpsc::UnboundedSender<(u8, Bytes)>> = HashMap::new(),
  12. download_rate_estimate: usize = 0,
  13. download_measurement_start: Option<Instant> = None,
  14. download_measurement_bytes: usize = 0,
  15. invalid: bool = false,
  16. }
  17. }
  18. #[derive(Debug, Hash, PartialEq, Eq, Copy, Clone)]
  19. pub struct ChannelError;
  20. pub struct Channel {
  21. receiver: mpsc::UnboundedReceiver<(u8, Bytes)>,
  22. state: ChannelState,
  23. }
  24. pub struct ChannelHeaders(BiLock<Channel>);
  25. pub struct ChannelData(BiLock<Channel>);
  26. pub enum ChannelEvent {
  27. Header(u8, Vec<u8>),
  28. Data(Bytes),
  29. }
  30. #[derive(Clone)]
  31. enum ChannelState {
  32. Header(Bytes),
  33. Data,
  34. Closed,
  35. }
  36. impl ChannelManager {
  37. pub fn allocate(&self) -> (u16, Channel) {
  38. let (tx, rx) = mpsc::unbounded();
  39. let seq = self.lock(|inner| {
  40. let seq = inner.sequence.get();
  41. if !inner.invalid {
  42. inner.channels.insert(seq, tx);
  43. }
  44. seq
  45. });
  46. let channel = Channel {
  47. receiver: rx,
  48. state: ChannelState::Header(Bytes::new()),
  49. };
  50. (seq, channel)
  51. }
  52. pub(crate) fn dispatch(&self, cmd: u8, mut data: Bytes) {
  53. use std::collections::hash_map::Entry;
  54. let id: u16 = BigEndian::read_u16(data.split_to(2).as_ref());
  55. self.lock(|inner| {
  56. let current_time = Instant::now();
  57. if let Some(download_measurement_start) = inner.download_measurement_start {
  58. if (current_time - download_measurement_start).as_millis() > 1000 {
  59. inner.download_rate_estimate = 1000 * inner.download_measurement_bytes
  60. / (current_time - download_measurement_start).as_millis() as usize;
  61. inner.download_measurement_start = Some(current_time);
  62. inner.download_measurement_bytes = 0;
  63. }
  64. } else {
  65. inner.download_measurement_start = Some(current_time);
  66. }
  67. inner.download_measurement_bytes += data.len();
  68. if let Entry::Occupied(entry) = inner.channels.entry(id) {
  69. let _ = entry.get().unbounded_send((cmd, data));
  70. }
  71. });
  72. }
  73. pub fn get_download_rate_estimate(&self) -> usize {
  74. return self.lock(|inner| inner.download_rate_estimate);
  75. }
  76. pub(crate) fn shutdown(&self) {
  77. self.lock(|inner| {
  78. inner.invalid = true;
  79. // destroy the sending halves of the channels to signal everyone who is waiting for something.
  80. inner.channels.clear();
  81. });
  82. }
  83. }
  84. impl Channel {
  85. fn recv_packet(&mut self) -> Poll<Bytes, ChannelError> {
  86. let (cmd, packet) = match self.receiver.poll() {
  87. Ok(Async::Ready(Some(t))) => t,
  88. Ok(Async::Ready(None)) => return Err(ChannelError), // The channel has been closed.
  89. Ok(Async::NotReady) => return Ok(Async::NotReady),
  90. Err(()) => unreachable!(),
  91. };
  92. if cmd == 0xa {
  93. let code = BigEndian::read_u16(&packet.as_ref()[..2]);
  94. error!("channel error: {} {}", packet.len(), code);
  95. self.state = ChannelState::Closed;
  96. Err(ChannelError)
  97. } else {
  98. Ok(Async::Ready(packet))
  99. }
  100. }
  101. pub fn split(self) -> (ChannelHeaders, ChannelData) {
  102. let (headers, data) = BiLock::new(self);
  103. (ChannelHeaders(headers), ChannelData(data))
  104. }
  105. }
  106. impl Stream for Channel {
  107. type Item = ChannelEvent;
  108. type Error = ChannelError;
  109. fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
  110. loop {
  111. match self.state.clone() {
  112. ChannelState::Closed => panic!("Polling already terminated channel"),
  113. ChannelState::Header(mut data) => {
  114. if data.len() == 0 {
  115. data = try_ready!(self.recv_packet());
  116. }
  117. let length = BigEndian::read_u16(data.split_to(2).as_ref()) as usize;
  118. if length == 0 {
  119. assert_eq!(data.len(), 0);
  120. self.state = ChannelState::Data;
  121. } else {
  122. let header_id = data.split_to(1).as_ref()[0];
  123. let header_data = data.split_to(length - 1).as_ref().to_owned();
  124. self.state = ChannelState::Header(data);
  125. let event = ChannelEvent::Header(header_id, header_data);
  126. return Ok(Async::Ready(Some(event)));
  127. }
  128. }
  129. ChannelState::Data => {
  130. let data = try_ready!(self.recv_packet());
  131. if data.len() == 0 {
  132. self.receiver.close();
  133. self.state = ChannelState::Closed;
  134. return Ok(Async::Ready(None));
  135. } else {
  136. let event = ChannelEvent::Data(data);
  137. return Ok(Async::Ready(Some(event)));
  138. }
  139. }
  140. }
  141. }
  142. }
  143. }
  144. impl Stream for ChannelData {
  145. type Item = Bytes;
  146. type Error = ChannelError;
  147. fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
  148. let mut channel = match self.0.poll_lock() {
  149. Async::Ready(c) => c,
  150. Async::NotReady => return Ok(Async::NotReady),
  151. };
  152. loop {
  153. match try_ready!(channel.poll()) {
  154. Some(ChannelEvent::Header(..)) => (),
  155. Some(ChannelEvent::Data(data)) => return Ok(Async::Ready(Some(data))),
  156. None => return Ok(Async::Ready(None)),
  157. }
  158. }
  159. }
  160. }
  161. impl Stream for ChannelHeaders {
  162. type Item = (u8, Vec<u8>);
  163. type Error = ChannelError;
  164. fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
  165. let mut channel = match self.0.poll_lock() {
  166. Async::Ready(c) => c,
  167. Async::NotReady => return Ok(Async::NotReady),
  168. };
  169. match try_ready!(channel.poll()) {
  170. Some(ChannelEvent::Header(id, data)) => Ok(Async::Ready(Some((id, data)))),
  171. Some(ChannelEvent::Data(..)) | None => Ok(Async::Ready(None)),
  172. }
  173. }
  174. }