From 2ef80e8878b0bc0b4dad0f51e453d50c64f56797 Mon Sep 17 00:00:00 2001 From: Donovan Date: Thu, 28 Aug 2025 15:47:43 -0500 Subject: [PATCH] implement tower shuffle --- src/tower_shuffle.py | 74 +++++++++++++++++++++++++++++++++++++ tests/test_tower_shuffle.py | 8 ++++ 2 files changed, 82 insertions(+) create mode 100644 src/tower_shuffle.py create mode 100644 tests/test_tower_shuffle.py diff --git a/src/tower_shuffle.py b/src/tower_shuffle.py new file mode 100644 index 0000000..6b2fc74 --- /dev/null +++ b/src/tower_shuffle.py @@ -0,0 +1,74 @@ +from dataclasses import dataclass +import numpy as np + +@dataclass +class Tower: + floors: list[np.ndarray] + + def split_tower(self) -> tuple[np.ndarray, np.ndarray]: + discard = np.array([], dtype=int) + keep = np.array([], dtype=int) + balance = self.balance() + for idx, floor in enumerate(self.floors): + div = len(floor)//2 + balance[idx] + floor_shuffle = np.random.permutation(len(floor)) + keep = np.concatenate((keep, floor[floor_shuffle[:div]])) + discard = np.concatenate((discard, floor[floor_shuffle[div:]])) + diff = len(discard) - len(keep) + assert 0 <= diff <= 1 + return keep, discard + + def balance(self) -> list[int]: + odd_floors = np.array([idx for idx, el in enumerate(self.floors) if len(el) & 1]) + balance = np.zeros(len(self.floors), dtype=int) + if len(odd_floors) == 0: + return balance.tolist() + shuffle = np.random.permutation(len(odd_floors))[:len(odd_floors) // 2] + odd_floors = odd_floors[shuffle] + balance[odd_floors] = 1 + return balance.tolist() + + + def update_tower(self, keep: np.ndarray, other_discard: np.ndarray): + new_floors = [] + for floor in self.floors: + new_floor = np.intersect1d(floor, keep) + if len(new_floor): + new_floors.append(new_floor) + self.floors = new_floors + self.floors.insert(0, other_discard) + + def __str__(self): + strVal = "" + for idx, val in enumerate(reversed(self.floors)): + strVal += f"Floor {idx}: {val.tolist()}\n" + return strVal + + +@dataclass +class TowerShuffle: + total_positions: int + left_tower: Tower + right_tower: Tower + + @classmethod + def new(cls, total_pos:int): + assert total_pos >= 4 + rand_pos = np.random.permutation(total_pos) + return TowerShuffle( + total_positions=total_pos, + left_tower=Tower(floors=[rand_pos[:total_pos//2]]), + right_tower=Tower(floors=[rand_pos[total_pos//2:]]), + ) + + def shuffle(self): + left_keep, left_discard = self.left_tower.split_tower() + right_keep, right_discard = self.right_tower.split_tower() + self.left_tower.update_tower(left_keep, right_discard) + self.right_tower.update_tower(right_keep, left_discard) + + def __str__(self): + return f"""Left Tower: +{self.left_tower} +Right Tower: +{self.right_tower}""" diff --git a/tests/test_tower_shuffle.py b/tests/test_tower_shuffle.py new file mode 100644 index 0000000..a19c21b --- /dev/null +++ b/tests/test_tower_shuffle.py @@ -0,0 +1,8 @@ +from src.tower_shuffle import TowerShuffle + +def test_tower_shuffle(): + tower = TowerShuffle.new(19) + print(tower) + for _ in range(100): + tower.shuffle() + print(tower)