2 Commits

Author SHA1 Message Date
86ccd0fe31 update evilnkode.ipynb tests 2025-12-03 10:18:15 -06:00
dkelly
e24fe3b512 Refactor (#8)
Co-authored-by: Donovan <donovan.a.kelly@pm.me>
Reviewed-on: https://git.infra.nkode.tech/dkelly/evilnkode/pulls/8
2025-09-09 18:22:22 +00:00
3 changed files with 167 additions and 29 deletions

File diff suppressed because one or more lines are too long

View File

@@ -1,6 +1,7 @@
from dataclasses import dataclass from dataclasses import dataclass
import numpy as np import numpy as np
@dataclass @dataclass
class Tower: class Tower:
floors: list[np.ndarray] floors: list[np.ndarray]
@@ -10,7 +11,7 @@ class Tower:
keep = np.array([], dtype=int) keep = np.array([], dtype=int)
balance = self.balance() balance = self.balance()
for idx, floor in enumerate(self.floors): for idx, floor in enumerate(self.floors):
div = len(floor)//2 + balance[idx] div = len(floor) // 2 + balance[idx]
floor_shuffle = np.random.permutation(len(floor)) floor_shuffle = np.random.permutation(len(floor))
keep = np.concatenate((keep, floor[floor_shuffle[:div]])) keep = np.concatenate((keep, floor[floor_shuffle[:div]]))
discard = np.concatenate((discard, floor[floor_shuffle[div:]])) discard = np.concatenate((discard, floor[floor_shuffle[div:]]))
@@ -28,7 +29,6 @@ class Tower:
balance[odd_floors] = 1 balance[odd_floors] = 1
return balance.tolist() return balance.tolist()
def update_tower(self, keep: np.ndarray, other_discard: np.ndarray): def update_tower(self, keep: np.ndarray, other_discard: np.ndarray):
new_floors = [] new_floors = []
for floor in self.floors: for floor in self.floors:
@@ -51,20 +51,22 @@ class Tower:
tower.extend(floor.tolist()) tower.extend(floor.tolist())
return tower return tower
@dataclass @dataclass
class TowerShuffle: class TowerShuffle:
# TODO: I don't think total_positions is used anywhere
total_positions: int total_positions: int
left_tower: Tower left_tower: Tower
right_tower: Tower right_tower: Tower
@classmethod @classmethod
def new(cls, total_pos:int): def new(cls, total_pos: int):
assert total_pos >= 3 assert total_pos >= 3
rand_pos = np.random.permutation(total_pos) rand_pos = np.random.permutation(total_pos)
return TowerShuffle( return TowerShuffle(
total_positions=total_pos, total_positions=total_pos,
left_tower=Tower(floors=[rand_pos[:total_pos//2]]), left_tower=Tower(floors=[rand_pos[:total_pos // 2]]),
right_tower=Tower(floors=[rand_pos[total_pos//2:]]), right_tower=Tower(floors=[rand_pos[total_pos // 2:]]),
) )
def shuffle(self): def shuffle(self):

View File

@@ -1,7 +1,8 @@
from src.keypad.tower_shuffle import TowerShuffle from src.keypad.tower_shuffle import TowerShuffle
def test_tower_shuffle(): def test_tower_shuffle():
tower = TowerShuffle.new(9) tower = TowerShuffle.new(13)
print(tower) print(tower)
for _ in range(100): for _ in range(100):
tower.shuffle() tower.shuffle()