implement tower shuffle

This commit is contained in:
2025-08-28 15:47:43 -05:00
parent 0b39710a78
commit 2ef80e8878
2 changed files with 82 additions and 0 deletions

74
src/tower_shuffle.py Normal file
View File

@@ -0,0 +1,74 @@
from dataclasses import dataclass
import numpy as np
@dataclass
class Tower:
floors: list[np.ndarray]
def split_tower(self) -> tuple[np.ndarray, np.ndarray]:
discard = np.array([], dtype=int)
keep = np.array([], dtype=int)
balance = self.balance()
for idx, floor in enumerate(self.floors):
div = len(floor)//2 + balance[idx]
floor_shuffle = np.random.permutation(len(floor))
keep = np.concatenate((keep, floor[floor_shuffle[:div]]))
discard = np.concatenate((discard, floor[floor_shuffle[div:]]))
diff = len(discard) - len(keep)
assert 0 <= diff <= 1
return keep, discard
def balance(self) -> list[int]:
odd_floors = np.array([idx for idx, el in enumerate(self.floors) if len(el) & 1])
balance = np.zeros(len(self.floors), dtype=int)
if len(odd_floors) == 0:
return balance.tolist()
shuffle = np.random.permutation(len(odd_floors))[:len(odd_floors) // 2]
odd_floors = odd_floors[shuffle]
balance[odd_floors] = 1
return balance.tolist()
def update_tower(self, keep: np.ndarray, other_discard: np.ndarray):
new_floors = []
for floor in self.floors:
new_floor = np.intersect1d(floor, keep)
if len(new_floor):
new_floors.append(new_floor)
self.floors = new_floors
self.floors.insert(0, other_discard)
def __str__(self):
strVal = ""
for idx, val in enumerate(reversed(self.floors)):
strVal += f"Floor {idx}: {val.tolist()}\n"
return strVal
@dataclass
class TowerShuffle:
total_positions: int
left_tower: Tower
right_tower: Tower
@classmethod
def new(cls, total_pos:int):
assert total_pos >= 4
rand_pos = np.random.permutation(total_pos)
return TowerShuffle(
total_positions=total_pos,
left_tower=Tower(floors=[rand_pos[:total_pos//2]]),
right_tower=Tower(floors=[rand_pos[total_pos//2:]]),
)
def shuffle(self):
left_keep, left_discard = self.left_tower.split_tower()
right_keep, right_discard = self.right_tower.split_tower()
self.left_tower.update_tower(left_keep, right_discard)
self.right_tower.update_tower(right_keep, left_discard)
def __str__(self):
return f"""Left Tower:
{self.left_tower}
Right Tower:
{self.right_tower}"""

View File

@@ -0,0 +1,8 @@
from src.tower_shuffle import TowerShuffle
def test_tower_shuffle():
tower = TowerShuffle.new(19)
print(tower)
for _ in range(100):
tower.shuffle()
print(tower)