int128.rs 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. use std;
  2. #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
  3. #[allow(non_camel_case_types)]
  4. pub struct u128 {
  5. high: u64,
  6. low: u64,
  7. }
  8. impl u128 {
  9. pub fn zero() -> u128 {
  10. u128::from_parts(0, 0)
  11. }
  12. pub fn from_parts(high: u64, low: u64) -> u128 {
  13. u128 {
  14. high: high,
  15. low: low,
  16. }
  17. }
  18. pub fn parts(&self) -> (u64, u64) {
  19. (self.high, self.low)
  20. }
  21. }
  22. impl std::ops::Add<u128> for u128 {
  23. type Output = u128;
  24. fn add(self, rhs: u128) -> u128 {
  25. let low = self.low + rhs.low;
  26. let high = self.high + rhs.high + if low < self.low { 1 } else { 0 };
  27. u128::from_parts(high, low)
  28. }
  29. }
  30. impl<'a> std::ops::Add<&'a u128> for u128 {
  31. type Output = u128;
  32. fn add(self, rhs: &'a u128) -> u128 {
  33. let low = self.low + rhs.low;
  34. let high = self.high + rhs.high + if low < self.low { 1 } else { 0 };
  35. u128::from_parts(high, low)
  36. }
  37. }
  38. impl std::convert::From<u8> for u128 {
  39. fn from(n: u8) -> u128 {
  40. u128::from_parts(0, n as u64)
  41. }
  42. }
  43. impl std::ops::Mul<u128> for u128 {
  44. type Output = u128;
  45. fn mul(self, rhs: u128) -> u128 {
  46. let top: [u64; 4] = [
  47. self.high >> 32,
  48. self.high & 0xFFFFFFFF,
  49. self.low >> 32,
  50. self.low & 0xFFFFFFFF,
  51. ];
  52. let bottom: [u64; 4] = [
  53. rhs.high >> 32,
  54. rhs.high & 0xFFFFFFFF,
  55. rhs.low >> 32,
  56. rhs.low & 0xFFFFFFFF,
  57. ];
  58. let mut rows = [u128::zero(); 16];
  59. for i in 0..4 {
  60. for j in 0..4 {
  61. let shift = i + j;
  62. let product = top[3 - i] * bottom[3 - j];
  63. let (high, low) = match shift {
  64. 0 => (0, product),
  65. 1 => (product >> 32, product << 32),
  66. 2 => (product, 0),
  67. 3 => (product << 32, 0),
  68. _ => {
  69. if product == 0 {
  70. (0, 0)
  71. } else {
  72. panic!("Overflow on mul {:?} {:?} ({} {})", self, rhs, i, j)
  73. }
  74. }
  75. };
  76. rows[j * 4 + i] = u128::from_parts(high, low);
  77. }
  78. }
  79. rows.iter().fold(u128::zero(), std::ops::Add::add)
  80. }
  81. }