Quellcode durchsuchen

Move AudioKeyManager to tokio

Paul Lietar vor 8 Jahren
Ursprung
Commit
855a7e87a7
8 geänderte Dateien mit 164 neuen und 83 gelöschten Zeilen
  1. 2 2
      src/audio_decrypt.rs
  2. 45 69
      src/audio_key.rs
  3. 2 2
      src/cache/default_cache.rs
  4. 29 0
      src/component.rs
  5. 1 0
      src/lib.rs
  6. 2 1
      src/player.rs
  7. 30 9
      src/session.rs
  8. 53 0
      src/util/mod.rs

+ 2 - 2
src/audio_decrypt.rs

@@ -17,7 +17,7 @@ pub struct AudioDecrypt<T: io::Read> {
 
 impl<T: io::Read> AudioDecrypt<T> {
     pub fn new(key: AudioKey, reader: T) -> AudioDecrypt<T> {
-        let cipher = aes::ctr(aes::KeySize::KeySize128, &key, AUDIO_AESIV);
+        let cipher = aes::ctr(aes::KeySize::KeySize128, &key.0, AUDIO_AESIV);
         AudioDecrypt {
             cipher: cipher,
             key: key,
@@ -45,7 +45,7 @@ impl<T: io::Read + io::Seek> io::Seek for AudioDecrypt<T> {
         let iv = BigUint::from_bytes_be(AUDIO_AESIV)
                      .add(BigUint::from_u64(newpos / 16).unwrap())
                      .to_bytes_be();
-        self.cipher = aes::ctr(aes::KeySize::KeySize128, &self.key, &iv);
+        self.cipher = aes::ctr(aes::KeySize::KeySize128, &self.key.0, &iv);
 
         let buf = vec![0u8; skip as usize];
         let mut buf2 = vec![0u8; skip as usize];

+ 45 - 69
src/audio_key.rs

@@ -1,92 +1,68 @@
-use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
-use eventual;
+use byteorder::{BigEndian, ByteOrder, WriteBytesExt};
+use futures::sync::oneshot;
 use std::collections::HashMap;
-use std::io::{Cursor, Read, Write};
+use std::io::Write;
 
+use util::SeqGenerator;
 use util::{SpotifyId, FileId};
-use session::{Session, PacketHandler};
 
-pub type AudioKey = [u8; 16];
 #[derive(Debug,Hash,PartialEq,Eq,Copy,Clone)]
-pub struct AudioKeyError;
+pub struct AudioKey(pub [u8; 16]);
 
 #[derive(Debug,Hash,PartialEq,Eq,Copy,Clone)]
-struct AudioKeyId(SpotifyId, FileId);
+pub struct AudioKeyError;
 
-pub struct AudioKeyManager {
-    next_seq: u32,
-    pending: HashMap<u32, AudioKeyId>,
-    cache: HashMap<AudioKeyId, Vec<eventual::Complete<AudioKey, AudioKeyError>>>,
+type Result = ::std::result::Result<AudioKey, AudioKeyError>;
+
+component! {
+    AudioKeyManager : AudioKeyManagerInner {
+        sequence: SeqGenerator<u32> = SeqGenerator::new(0),
+        pending: HashMap<u32, oneshot::Sender<Result>> = HashMap::new(),
+    }
 }
 
 impl AudioKeyManager {
-    pub fn new() -> AudioKeyManager {
-        AudioKeyManager {
-            next_seq: 1,
-            pending: HashMap::new(),
-            cache: HashMap::new(),
+    pub fn dispatch(&self, cmd: u8, data: Vec<u8>) {
+        let seq = BigEndian::read_u32(&data[..4]);
+
+        let sender = self.lock(|inner| inner.pending.remove(&seq));
+
+        if let Some(sender) = sender {
+            match cmd {
+                0xd => {
+                    let mut key = [0u8; 16];
+                    key.copy_from_slice(&data[4..20]);
+                    sender.complete(Ok(AudioKey(key)));
+                }
+                0xe => {
+                    warn!("error audio key {:x} {:x}", data[4], data[5]);
+                    sender.complete(Err(AudioKeyError));
+                }
+                _ => (),
+            }
         }
     }
 
-    fn send_key_request(&mut self, session: &Session, track: SpotifyId, file: FileId) -> u32 {
-        let seq = self.next_seq;
-        self.next_seq += 1;
+    pub fn request<'a>(&self, track: SpotifyId, file: FileId) -> oneshot::Receiver<Result> {
+        let (tx, rx) = oneshot::channel();
+
+        let seq = self.lock(move |inner| {
+            let seq = inner.sequence.get();
+            inner.pending.insert(seq, tx);
+            seq
+        });
 
+        self.send_key_request(seq, track, file);
+        rx
+    }
+
+    fn send_key_request<'a>(&self, seq: u32, track: SpotifyId, file: FileId) {
         let mut data: Vec<u8> = Vec::new();
         data.write(&file.0).unwrap();
         data.write(&track.to_raw()).unwrap();
         data.write_u32::<BigEndian>(seq).unwrap();
         data.write_u16::<BigEndian>(0x0000).unwrap();
 
-        session.send_packet(0xc, data);
-
-        seq
-    }
-
-    pub fn request(&mut self,
-                   session: &Session,
-                   track: SpotifyId,
-                   file: FileId)
-                   -> eventual::Future<AudioKey, AudioKeyError> {
-
-        let id = AudioKeyId(track, file);
-        self.cache
-            .get_mut(&id)
-            .map(|ref mut requests| {
-                let (tx, rx) = eventual::Future::pair();
-                requests.push(tx);
-                rx
-            })
-            .unwrap_or_else(|| {
-                let seq = self.send_key_request(session, track, file);
-                self.pending.insert(seq, id.clone());
-
-                let (tx, rx) = eventual::Future::pair();
-                self.cache.insert(id, vec![tx]);
-                rx
-            })
-    }
-}
-
-impl PacketHandler for AudioKeyManager {
-    fn handle(&mut self, cmd: u8, data: Vec<u8>, _session: &Session) {
-        let mut data = Cursor::new(data);
-        let seq = data.read_u32::<BigEndian>().unwrap();
-
-        if let Some(callbacks) = self.pending.remove(&seq).and_then(|id| self.cache.remove(&id)) {
-            if cmd == 0xd {
-                let mut key = [0u8; 16];
-                data.read_exact(&mut key).unwrap();
-
-                for cb in callbacks {
-                    cb.complete(key);
-                }
-            } else if cmd == 0xe {
-                let error = AudioKeyError;
-                for cb in callbacks {
-                    cb.fail(error);
-                }
-            }
-        }
+        self.session().send_packet(0xc, data)
     }
 }

+ 2 - 2
src/cache/default_cache.rs

@@ -56,7 +56,7 @@ impl Cache for DefaultCache {
         value.and_then(|value| if value.len() == 16 {
             let mut result = [0u8; 16];
             result.clone_from_slice(&value);
-            Some(result)
+            Some(AudioKey(result))
         } else {
             None
         })
@@ -73,7 +73,7 @@ impl Cache for DefaultCache {
             key.extend_from_slice(&track.to_raw());
             key.extend_from_slice(&file.0);
 
-            db.set(&key, &audio_key.as_ref()).unwrap();
+            db.set(&key, &audio_key.0.as_ref()).unwrap();
         }
 
         xact.commit().unwrap();

+ 29 - 0
src/component.rs

@@ -0,0 +1,29 @@
+macro_rules! component {
+    ($name:ident : $inner:ident { $($key:ident : $ty:ty = $value:expr,)* }) => {
+        #[derive(Clone)]
+        pub struct $name(::std::sync::Arc<($crate::session::SessionWeak, ::std::sync::Mutex<$inner>)>);
+        impl $name {
+            #[allow(dead_code)]
+            pub fn new(session: $crate::session::SessionWeak) -> $name {
+                $name(::std::sync::Arc::new((session, ::std::sync::Mutex::new($inner {
+                    $($key : $value,)*
+                }))))
+            }
+
+            #[allow(dead_code)]
+            fn lock<F: FnOnce(&mut $inner) -> R, R>(&self, f: F) -> R {
+                let mut inner = (self.0).1.lock().expect("Mutex poisoned");
+                f(&mut inner)
+            }
+
+            #[allow(dead_code)]
+            fn session(&self) -> $crate::session::Session {
+                (self.0).0.upgrade()
+            }
+        }
+
+        struct $inner {
+            $($key : $ty,)*
+        }
+    }
+}

+ 1 - 0
src/lib.rs

@@ -49,6 +49,7 @@ extern crate portaudio;
 #[cfg(feature = "libpulse-sys")]
 extern crate libpulse_sys;
 
+#[macro_use] mod component;
 pub mod album_cover;
 pub mod audio_backend;
 pub mod audio_decrypt;

+ 2 - 1
src/player.rs

@@ -4,6 +4,7 @@ use std::sync::{mpsc, Mutex, Arc, MutexGuard};
 use std::thread;
 use std::io::{Read, Seek};
 use vorbis;
+use futures::Future;
 
 use audio_decrypt::AudioDecrypt;
 use audio_backend::Sink;
@@ -206,7 +207,7 @@ fn load_track(session: &Session, track_id: SpotifyId) -> Option<vorbis::Decoder<
         }
     };
 
-    let key = session.audio_key(track.id, file_id).await().unwrap();
+    let key = session.audio_key().request(track.id, file_id).wait().unwrap().unwrap();
 
     let audio_file = Subfile::new(AudioDecrypt::new(key, session.audio_file(file_id)), 0xa7);
     let decoder = vorbis::Decoder::new(audio_file).unwrap();

+ 30 - 9
src/session.rs

@@ -5,23 +5,23 @@ use eventual::Future;
 use eventual::Async;
 use std::io::{self, Read, Cursor};
 use std::result::Result;
-use std::sync::{Mutex, RwLock, Arc, mpsc};
+use std::sync::{Mutex, RwLock, Arc, mpsc, Weak};
 use std::str::FromStr;
 use futures::Future as Future_;
-use futures::Stream;
+use futures::{Stream, BoxFuture};
 use tokio_core::reactor::Handle;
 
 use album_cover::AlbumCover;
 use apresolve::apresolve_or_fallback;
 use audio_file::AudioFile;
-use audio_key::{AudioKeyManager, AudioKey, AudioKeyError};
+use audio_key::AudioKeyManager;
 use authentication::Credentials;
 use cache::Cache;
 use connection::{self, adaptor};
 use mercury::{MercuryManager, MercuryRequest, MercuryResponse};
 use metadata::{MetadataManager, MetadataRef, MetadataTrait};
 use stream::StreamManager;
-use util::{SpotifyId, FileId, ReadSeek};
+use util::{SpotifyId, FileId, ReadSeek, Lazy};
 
 use stream;
 
@@ -65,14 +65,18 @@ pub struct SessionInternal {
     mercury: Mutex<MercuryManager>,
     metadata: Mutex<MetadataManager>,
     stream: Mutex<StreamManager>,
-    audio_key: Mutex<AudioKeyManager>,
     rx_connection: Mutex<adaptor::StreamAdaptor<(u8, Vec<u8>), io::Error>>,
     tx_connection: Mutex<adaptor::SinkAdaptor<(u8, Vec<u8>)>>,
+
+    audio_key: Lazy<AudioKeyManager>,
 }
 
 #[derive(Clone)]
 pub struct Session(pub Arc<SessionInternal>);
 
+#[derive(Clone)]
+pub struct SessionWeak(pub Weak<SessionInternal>);
+
 pub fn device_id(name: &str) -> String {
     let mut h = Sha1::new();
     h.input_str(&name);
@@ -82,7 +86,7 @@ pub fn device_id(name: &str) -> String {
 impl Session {
     pub fn connect(config: Config, credentials: Credentials,
                    cache: Box<Cache + Send + Sync>, handle: Handle)
-        -> Box<Future_<Item=(Session, Box<Future_<Item=(), Error=io::Error>>), Error=io::Error>>
+        -> Box<Future_<Item=(Session, BoxFuture<(), io::Error>), Error=io::Error>>
     {
         let access_point = apresolve_or_fallback::<io::Error>(&handle);
 
@@ -108,7 +112,8 @@ impl Session {
     }
 
     fn create(transport: connection::Transport, config: Config,
-              cache: Box<Cache + Send + Sync>, username: String) -> (Session, Box<Future_<Item=(), Error=io::Error>>)
+              cache: Box<Cache + Send + Sync>, username: String)
+        -> (Session, BoxFuture<(), io::Error>)
     {
         let transport = transport.map(|(cmd, data)| (cmd, data.as_ref().to_owned()));
         let (tx, rx, task) = adaptor::adapt(transport);
@@ -127,12 +132,16 @@ impl Session {
             mercury: Mutex::new(MercuryManager::new()),
             metadata: Mutex::new(MetadataManager::new()),
             stream: Mutex::new(StreamManager::new()),
-            audio_key: Mutex::new(AudioKeyManager::new()),
+
+            audio_key: Lazy::new(),
         }));
 
         (session, task)
     }
 
+    pub fn audio_key(&self) -> &AudioKeyManager {
+        self.0.audio_key.get(|| AudioKeyManager::new(self.weak()))
+    }
 
     pub fn poll(&self) {
         let (cmd, data) = self.recv();
@@ -141,7 +150,7 @@ impl Session {
             0x4 => self.send_packet(0x49, data),
             0x4a => (),
             0x9 | 0xa => self.0.stream.lock().unwrap().handle(cmd, data, self),
-            0xd | 0xe => self.0.audio_key.lock().unwrap().handle(cmd, data, self),
+            0xd | 0xe => self.audio_key().dispatch(cmd, data),
             0x1b => {
                 self.0.data.write().unwrap().country = String::from_utf8(data).unwrap();
             }
@@ -158,6 +167,7 @@ impl Session {
         self.0.tx_connection.lock().unwrap().send((cmd, data))
     }
 
+    /*
     pub fn audio_key(&self, track: SpotifyId, file_id: FileId) -> Future<AudioKey, AudioKeyError> {
         self.0.cache
             .get_audio_key(track, file_id)
@@ -172,6 +182,7 @@ impl Session {
                     })
             })
     }
+    */
 
     pub fn audio_file(&self, file_id: FileId) -> Box<ReadSeek> {
         self.0.cache
@@ -241,6 +252,16 @@ impl Session {
     pub fn device_id(&self) -> &str {
         &self.config().device_id
     }
+
+    pub fn weak(&self) -> SessionWeak {
+        SessionWeak(Arc::downgrade(&self.0))
+    }
+}
+
+impl SessionWeak {
+    pub fn upgrade(&self) -> Session {
+        Session(self.0.upgrade().expect("Session died"))
+    }
 }
 
 pub trait PacketHandler {

+ 53 - 0
src/util/mod.rs

@@ -1,6 +1,7 @@
 use num::{BigUint, Integer, Zero, One};
 use rand::{Rng, Rand};
 use std::io;
+use std::mem;
 use std::ops::{Mul, Rem, Shr};
 use std::fs;
 use std::path::Path;
@@ -107,3 +108,55 @@ impl<'s> Iterator for StrChunks<'s> {
 pub trait ReadSeek : ::std::io::Read + ::std::io::Seek { }
 impl <T: ::std::io::Read + ::std::io::Seek> ReadSeek for T { }
 
+pub trait Seq {
+    fn next(&self) -> Self;
+}
+
+macro_rules! impl_seq {
+    ($($ty:ty)*) => { $(
+        impl Seq for $ty {
+            fn next(&self) -> Self { *self + 1 }
+        }
+    )* }
+}
+
+impl_seq!(u8 u16 u32 u64 usize);
+
+#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Default)]
+pub struct SeqGenerator<T: Seq>(T);
+
+impl <T: Seq> SeqGenerator<T> {
+    pub fn new(value: T) -> Self {
+        SeqGenerator(value)
+    }
+
+    pub fn get(&mut self) -> T {
+        let value = self.0.next();
+        mem::replace(&mut self.0, value)
+    }
+}
+
+use std::sync::Mutex;
+use std::cell::UnsafeCell;
+
+pub struct Lazy<T>(Mutex<bool>, UnsafeCell<Option<T>>);
+unsafe impl <T: Sync> Sync for Lazy<T> {}
+unsafe impl <T: Send> Send for Lazy<T> {}
+
+impl <T> Lazy<T> {
+    pub fn new() -> Lazy<T> {
+        Lazy(Mutex::new(false), UnsafeCell::new(None))
+    }
+
+    pub fn get<F: FnOnce() -> T>(&self, f: F) -> &T {
+        let mut inner = self.0.lock().unwrap();
+        if !*inner {
+            unsafe {
+                *self.1.get() = Some(f());
+            }
+            *inner = true;
+        }
+
+        unsafe { &*self.1.get() }.as_ref().unwrap()
+    }
+}