int128.rs 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. use std;
  2. #[derive(Debug,Copy,Clone,PartialEq,Eq,Hash)]
  3. #[allow(non_camel_case_types)]
  4. pub struct u128 {
  5. high: u64,
  6. low: u64
  7. }
  8. impl u128 {
  9. pub fn from_parts(high: u64, low: u64) -> u128 {
  10. u128 { high: high, low: low }
  11. }
  12. pub fn parts(&self) -> (u64, u64) {
  13. (self.high, self.low)
  14. }
  15. }
  16. impl std::num::Zero for u128 {
  17. fn zero() -> u128 {
  18. u128::from_parts(0, 0)
  19. }
  20. }
  21. impl std::ops::Add<u128> for u128 {
  22. type Output = u128;
  23. fn add(self, rhs: u128) -> u128 {
  24. let low = self.low + rhs.low;
  25. let high = self.high + rhs.high +
  26. if low < self.low { 1 } else { 0 };
  27. u128::from_parts(high, low)
  28. }
  29. }
  30. impl <'a> std::ops::Add<&'a u128> for u128 {
  31. type Output = u128;
  32. fn add(self, rhs: &'a u128) -> u128 {
  33. let low = self.low + rhs.low;
  34. let high = self.high + rhs.high +
  35. if low < self.low { 1 } else { 0 };
  36. u128::from_parts(high, low)
  37. }
  38. }
  39. impl std::convert::From<u8> for u128 {
  40. fn from(n: u8) -> u128 {
  41. u128::from_parts(0, n as u64)
  42. }
  43. }
  44. impl std::ops::Mul<u128> for u128 {
  45. type Output = u128;
  46. fn mul(self, rhs: u128) -> u128 {
  47. let top: [u64; 4] =
  48. [self.high >> 32, self.high & 0xFFFFFFFF,
  49. self.low >> 32, self.low & 0xFFFFFFFF];
  50. let bottom : [u64; 4] =
  51. [rhs.high >> 32, rhs.high & 0xFFFFFFFF,
  52. rhs.low >> 32, rhs.low & 0xFFFFFFFF];
  53. let mut rows = [std::num::Zero::zero(); 16];
  54. for i in 0..4 {
  55. for j in 0..4 {
  56. let shift = i + j;
  57. let product = top[3-i] * bottom[3-j];
  58. let (high, low) = match shift {
  59. 0 => (0, product),
  60. 1 => (product >> 32, product << 32),
  61. 2 => (product, 0),
  62. 3 => (product << 32, 0),
  63. _ => {
  64. if product != 0 {
  65. panic!("Overflow on mul {:?} {:?} ({} {})",
  66. self, rhs, i, j)
  67. } else {
  68. (0, 0)
  69. }
  70. }
  71. };
  72. rows[j * 4 + i] = u128::from_parts(high, low);
  73. }
  74. }
  75. rows.iter().sum::<u128>()
  76. }
  77. }