Refactor UpgradableNumber to make CarriedNumber's type signature more bearable

old-bit-manip
Nick Krichevsky 2023-05-05 19:50:01 -04:00
parent 6dce7f0ead
commit bf9b18c2d6
2 changed files with 37 additions and 31 deletions

View File

@ -1,6 +1,6 @@
//! arithutil contains utilities to generalize the implementation of arithmetic instructions
use std::{marker::PhantomData, ops::Add};
use std::ops::Add;
use thiserror::Error;
@ -16,16 +16,15 @@ pub enum Error {
/// `CarriedNumber` is a number combined with its carry bit, intended to be used to perform carry-based arithmetic
/// operations such as ADC.
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct CarriedNumber<N, V> {
pub struct CarriedNumber<N> {
num: N,
carry_bit: u8,
_phantom: PhantomData<V>,
}
impl<N, V> CarriedNumber<N, V>
impl<N> CarriedNumber<N>
where
N: Copy + Clone + PartialEq + UpgradableNumber<V>,
V: From<N> + From<u8> + Add<Output = V>,
N: Copy + Clone + PartialEq + UpgradableNumber,
N::Output: From<u8> + Add<Output = N::Output>,
{
/// Create a new [`CarriedNumber`] with a raw value and its carry bit. Note that the carry bit must be 0 or 1,
/// or an [`Error::InvalidCarryBit`] will be returned
@ -34,17 +33,13 @@ where
return Err(Error::InvalidCarryBit(carry_bit));
}
Ok(Self {
num,
carry_bit,
_phantom: PhantomData,
})
Ok(Self { num, carry_bit })
}
/// Get the combination between the value and the carry bit. The return value is the "upgraded" version of `N`.
/// See [`upgrade::UpgradableNumber`] for more details
pub fn value(self) -> V {
self.num.upgrade() + V::from(self.carry_bit)
pub fn value(self) -> N::Output {
self.num.upgrade() + self.carry_bit.into()
}
/// Get the number and its carry bit directly
@ -116,8 +111,8 @@ impl CarryingAdd<i8, u16> for u16 {
}
}
impl CarryingAdd<CarriedNumber<u8, u16>, u8> for u8 {
fn add_with_carry(self, rhs: CarriedNumber<u8, u16>) -> (u8, bool, bool) {
impl CarryingAdd<CarriedNumber<u8>, u8> for u8 {
fn add_with_carry(self, rhs: CarriedNumber<u8>) -> (u8, bool, bool) {
let total = rhs.value() + u16::from(self);
// Given this is adding two u8s, the largest value we can produce is 0x01FF (both operands 0xFF, plus a carry bit)
// which can never wrap, so we just take the lower byte to convert back to u8
@ -140,8 +135,8 @@ impl CarryingSub<u8, u8> for u8 {
}
}
impl CarryingSub<CarriedNumber<u8, u16>, u8> for u8 {
fn sub_with_carry(self, rhs: CarriedNumber<u8, u16>) -> (u8, bool, bool) {
impl CarryingSub<CarriedNumber<u8>, u8> for u8 {
fn sub_with_carry(self, rhs: CarriedNumber<u8>) -> (u8, bool, bool) {
let total = self.wrapping_sub(rhs.num).wrapping_sub(rhs.carry_bit);
// let half_carry = did_8bit_sub_half_carry(self, rhs.value());
let half_carry = did_8bit_sub_half_carry_including_carry_bit(self, rhs.num, rhs.carry_bit);
@ -309,7 +304,7 @@ mod tests {
#[test_case(0xFF, CarriedNumber::new(0xFF, 1).unwrap(), (0xFF, true, true))]
fn test_add_carried_number_to_u8(
value: u8,
carried_number: CarriedNumber<u8, u16>,
carried_number: CarriedNumber<u8>,
expected: (u8, bool, bool),
) {
assert_eq!(expected, value.add_with_carry(carried_number));
@ -332,11 +327,7 @@ mod tests {
#[test_case(0xAB, CarriedNumber::new(0x5F, 1).unwrap(), (0x4B, true, false); "half borrow with carry bit and wrap")]
#[test_case(0x0F, CarriedNumber::new(0xF0, 0).unwrap(), (0x1F, false, true); "full borrow without carry bit")]
#[test_case(0x0F, CarriedNumber::new(0xEE, 1).unwrap(), (0x20, false, true); "full borrow with carry bit")]
fn test_sub_u8_from_u8_with_carry(
lhs: u8,
rhs: CarriedNumber<u8, u16>,
expected: (u8, bool, bool),
) {
fn test_sub_u8_from_u8_with_carry(lhs: u8, rhs: CarriedNumber<u8>, expected: (u8, bool, bool)) {
assert_eq!(expected, lhs.sub_with_carry(rhs));
}
}

View File

@ -1,22 +1,37 @@
/// `UpgradableNumber` represents an number that can be "upgraded" to its next largest type.
/// This is slightly different than `From`, in that it only allows you to upgrade "one step up".
/// You cannot, for instance, upgrade a `u8` to a `u64` directly.
pub trait UpgradableNumber<U: From<Self>>
pub trait UpgradableNumber
where
Self: Sized,
{
fn upgrade(self) -> U {
type Output: From<Self>;
fn upgrade(self) -> Self::Output {
self.into()
}
}
impl UpgradableNumber<u16> for u8 {}
impl UpgradableNumber<u32> for u16 {}
impl UpgradableNumber<u64> for u32 {}
impl UpgradableNumber for u8 {
type Output = u16;
}
impl UpgradableNumber for u16 {
type Output = u32;
}
impl UpgradableNumber for u32 {
type Output = u64;
}
impl UpgradableNumber<i16> for i8 {}
impl UpgradableNumber<i32> for i16 {}
impl UpgradableNumber<i64> for i32 {}
impl UpgradableNumber for i8 {
type Output = i16;
}
impl UpgradableNumber for i16 {
type Output = i32;
}
impl UpgradableNumber for i32 {
type Output = i64;
}
#[cfg(test)]
mod tests {