diff --git a/day14/day14.cpp b/day14/day14.cpp index 6562929..b228d46 100644 --- a/day14/day14.cpp +++ b/day14/day14.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include constexpr char IGNORE_CHAR = 'X'; @@ -21,22 +22,24 @@ class InstructionBlock { } // We are dealing with a 36 bit numebr so we must use a long long to guarantee we can fit it - long long maskNumber(long long num) const { + long long maskValue(long long num) const { long long res = num; for (int i = 0; i < this->mask.size(); i++) { char maskChar = this->mask.at(this->mask.size() - i - 1); if (maskChar == IGNORE_CHAR) { continue; - } else if (maskChar == '0') { - res &= ~(1LL << i); - } else if (maskChar == '1') { - res |= (1LL << i); + } else { + res = this->setBitAt(res, i, maskChar); } } return res; } + std::vector maskMemoryAddress(long long address) const { + return this->recursivelyMaskAddresses(address); + } + const std::vector> &getStoreInstructions() const { return this->storeInstructions; } @@ -44,6 +47,40 @@ class InstructionBlock { private: std::string mask; std::vector> storeInstructions; + + long long setBitAt(long long value, int position, char bit) const { + auto res = value; + if (bit == '0') { + res &= ~(1LL << position); + } else if (bit == '1') { + res |= (1LL << position); + } else { + throw std::invalid_argument("bit must be zero or one"); + } + + return res; + } + + std::vector recursivelyMaskAddresses(long long address, int startIdx = 0) const { + std::vector results; + long long res = address; + for (int i = startIdx; i < this->mask.size(); i++) { + char maskChar = this->mask.at(this->mask.size() - i - 1); + if (maskChar == IGNORE_CHAR) { + auto masked0 = this->setBitAt(res, i, '0'); + auto masked1 = this->setBitAt(res, i, '1'); + auto results0 = this->recursivelyMaskAddresses(masked0, i + 1); + auto results1 = this->recursivelyMaskAddresses(masked1, i + 1); + results.insert(results.end(), results0.begin(), results0.end()); + results.insert(results.end(), results1.begin(), results1.end()); + } else if (maskChar == '1') { + res = this->setBitAt(res, i, maskChar); + } + } + + results.push_back(res); + return results; + } }; std::vector readInput(const std::string &filename) { @@ -91,8 +128,24 @@ long long part1(const std::vector &instructions) { std::unordered_map memory; for (const InstructionBlock &instruction : instructions) { for (const std::pair &storeInstruction : instruction.getStoreInstructions()) { - auto maskedStorage = instruction.maskNumber(storeInstruction.second); - memory[storeInstruction.first] = maskedStorage; + auto maskedValue = instruction.maskValue(storeInstruction.second); + memory[storeInstruction.first] = maskedValue; + } + } + + return std::accumulate( + memory.cbegin(), memory.cend(), 0LL, [](long long total, std::pair memoryItem) { + return total + memoryItem.second; + }); +} + +long long part2(const std::vector &instructions) { + std::unordered_map memory; + for (const InstructionBlock &instruction : instructions) { + for (const std::pair &storeInstruction : instruction.getStoreInstructions()) { + for (const auto maskedAddress : instruction.maskMemoryAddress(storeInstruction.first)) { + memory[maskedAddress] = storeInstruction.second; + } } } @@ -112,4 +165,5 @@ int main(int argc, char *argv[]) { auto parsedInput = parseInput(input); std::cout << part1(parsedInput) << std::endl; + std::cout << part2(parsedInput) << std::endl; }