Refactor UpgradableNumber to make CarriedNumber's type signature more bearable

This commit is contained in:
Nick Krichevsky 2023-05-05 19:50:01 -04:00
parent 6dce7f0ead
commit bf9b18c2d6
2 changed files with 37 additions and 31 deletions

View file

@ -1,6 +1,6 @@
//! arithutil contains utilities to generalize the implementation of arithmetic instructions //! arithutil contains utilities to generalize the implementation of arithmetic instructions
use std::{marker::PhantomData, ops::Add}; use std::ops::Add;
use thiserror::Error; use thiserror::Error;
@ -16,16 +16,15 @@ pub enum Error {
/// `CarriedNumber` is a number combined with its carry bit, intended to be used to perform carry-based arithmetic /// `CarriedNumber` is a number combined with its carry bit, intended to be used to perform carry-based arithmetic
/// operations such as ADC. /// operations such as ADC.
#[derive(Debug, Clone, Copy, PartialEq)] #[derive(Debug, Clone, Copy, PartialEq)]
pub struct CarriedNumber<N, V> { pub struct CarriedNumber<N> {
num: N, num: N,
carry_bit: u8, carry_bit: u8,
_phantom: PhantomData<V>,
} }
impl<N, V> CarriedNumber<N, V> impl<N> CarriedNumber<N>
where where
N: Copy + Clone + PartialEq + UpgradableNumber<V>, N: Copy + Clone + PartialEq + UpgradableNumber,
V: From<N> + From<u8> + Add<Output = V>, N::Output: From<u8> + Add<Output = N::Output>,
{ {
/// Create a new [`CarriedNumber`] with a raw value and its carry bit. Note that the carry bit must be 0 or 1, /// Create a new [`CarriedNumber`] with a raw value and its carry bit. Note that the carry bit must be 0 or 1,
/// or an [`Error::InvalidCarryBit`] will be returned /// or an [`Error::InvalidCarryBit`] will be returned
@ -34,17 +33,13 @@ where
return Err(Error::InvalidCarryBit(carry_bit)); return Err(Error::InvalidCarryBit(carry_bit));
} }
Ok(Self { Ok(Self { num, carry_bit })
num,
carry_bit,
_phantom: PhantomData,
})
} }
/// Get the combination between the value and the carry bit. The return value is the "upgraded" version of `N`. /// Get the combination between the value and the carry bit. The return value is the "upgraded" version of `N`.
/// See [`upgrade::UpgradableNumber`] for more details /// See [`upgrade::UpgradableNumber`] for more details
pub fn value(self) -> V { pub fn value(self) -> N::Output {
self.num.upgrade() + V::from(self.carry_bit) self.num.upgrade() + self.carry_bit.into()
} }
/// Get the number and its carry bit directly /// Get the number and its carry bit directly
@ -116,8 +111,8 @@ impl CarryingAdd<i8, u16> for u16 {
} }
} }
impl CarryingAdd<CarriedNumber<u8, u16>, u8> for u8 { impl CarryingAdd<CarriedNumber<u8>, u8> for u8 {
fn add_with_carry(self, rhs: CarriedNumber<u8, u16>) -> (u8, bool, bool) { fn add_with_carry(self, rhs: CarriedNumber<u8>) -> (u8, bool, bool) {
let total = rhs.value() + u16::from(self); let total = rhs.value() + u16::from(self);
// Given this is adding two u8s, the largest value we can produce is 0x01FF (both operands 0xFF, plus a carry bit) // Given this is adding two u8s, the largest value we can produce is 0x01FF (both operands 0xFF, plus a carry bit)
// which can never wrap, so we just take the lower byte to convert back to u8 // which can never wrap, so we just take the lower byte to convert back to u8
@ -140,8 +135,8 @@ impl CarryingSub<u8, u8> for u8 {
} }
} }
impl CarryingSub<CarriedNumber<u8, u16>, u8> for u8 { impl CarryingSub<CarriedNumber<u8>, u8> for u8 {
fn sub_with_carry(self, rhs: CarriedNumber<u8, u16>) -> (u8, bool, bool) { fn sub_with_carry(self, rhs: CarriedNumber<u8>) -> (u8, bool, bool) {
let total = self.wrapping_sub(rhs.num).wrapping_sub(rhs.carry_bit); let total = self.wrapping_sub(rhs.num).wrapping_sub(rhs.carry_bit);
// let half_carry = did_8bit_sub_half_carry(self, rhs.value()); // let half_carry = did_8bit_sub_half_carry(self, rhs.value());
let half_carry = did_8bit_sub_half_carry_including_carry_bit(self, rhs.num, rhs.carry_bit); let half_carry = did_8bit_sub_half_carry_including_carry_bit(self, rhs.num, rhs.carry_bit);
@ -309,7 +304,7 @@ mod tests {
#[test_case(0xFF, CarriedNumber::new(0xFF, 1).unwrap(), (0xFF, true, true))] #[test_case(0xFF, CarriedNumber::new(0xFF, 1).unwrap(), (0xFF, true, true))]
fn test_add_carried_number_to_u8( fn test_add_carried_number_to_u8(
value: u8, value: u8,
carried_number: CarriedNumber<u8, u16>, carried_number: CarriedNumber<u8>,
expected: (u8, bool, bool), expected: (u8, bool, bool),
) { ) {
assert_eq!(expected, value.add_with_carry(carried_number)); assert_eq!(expected, value.add_with_carry(carried_number));
@ -332,11 +327,7 @@ mod tests {
#[test_case(0xAB, CarriedNumber::new(0x5F, 1).unwrap(), (0x4B, true, false); "half borrow with carry bit and wrap")] #[test_case(0xAB, CarriedNumber::new(0x5F, 1).unwrap(), (0x4B, true, false); "half borrow with carry bit and wrap")]
#[test_case(0x0F, CarriedNumber::new(0xF0, 0).unwrap(), (0x1F, false, true); "full borrow without carry bit")] #[test_case(0x0F, CarriedNumber::new(0xF0, 0).unwrap(), (0x1F, false, true); "full borrow without carry bit")]
#[test_case(0x0F, CarriedNumber::new(0xEE, 1).unwrap(), (0x20, false, true); "full borrow with carry bit")] #[test_case(0x0F, CarriedNumber::new(0xEE, 1).unwrap(), (0x20, false, true); "full borrow with carry bit")]
fn test_sub_u8_from_u8_with_carry( fn test_sub_u8_from_u8_with_carry(lhs: u8, rhs: CarriedNumber<u8>, expected: (u8, bool, bool)) {
lhs: u8,
rhs: CarriedNumber<u8, u16>,
expected: (u8, bool, bool),
) {
assert_eq!(expected, lhs.sub_with_carry(rhs)); assert_eq!(expected, lhs.sub_with_carry(rhs));
} }
} }

View file

@ -1,22 +1,37 @@
/// `UpgradableNumber` represents an number that can be "upgraded" to its next largest type. /// `UpgradableNumber` represents an number that can be "upgraded" to its next largest type.
/// This is slightly different than `From`, in that it only allows you to upgrade "one step up". /// This is slightly different than `From`, in that it only allows you to upgrade "one step up".
/// You cannot, for instance, upgrade a `u8` to a `u64` directly. /// You cannot, for instance, upgrade a `u8` to a `u64` directly.
pub trait UpgradableNumber<U: From<Self>> pub trait UpgradableNumber
where where
Self: Sized, Self: Sized,
{ {
fn upgrade(self) -> U { type Output: From<Self>;
fn upgrade(self) -> Self::Output {
self.into() self.into()
} }
} }
impl UpgradableNumber<u16> for u8 {} impl UpgradableNumber for u8 {
impl UpgradableNumber<u32> for u16 {} type Output = u16;
impl UpgradableNumber<u64> for u32 {} }
impl UpgradableNumber for u16 {
type Output = u32;
}
impl UpgradableNumber for u32 {
type Output = u64;
}
impl UpgradableNumber<i16> for i8 {} impl UpgradableNumber for i8 {
impl UpgradableNumber<i32> for i16 {} type Output = i16;
impl UpgradableNumber<i64> for i32 {} }
impl UpgradableNumber for i16 {
type Output = i32;
}
impl UpgradableNumber for i32 {
type Output = i64;
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {