update evilnkode.ipynb tests
This commit is contained in:
File diff suppressed because one or more lines are too long
@@ -1,6 +1,7 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Tower:
|
class Tower:
|
||||||
floors: list[np.ndarray]
|
floors: list[np.ndarray]
|
||||||
@@ -10,7 +11,7 @@ class Tower:
|
|||||||
keep = np.array([], dtype=int)
|
keep = np.array([], dtype=int)
|
||||||
balance = self.balance()
|
balance = self.balance()
|
||||||
for idx, floor in enumerate(self.floors):
|
for idx, floor in enumerate(self.floors):
|
||||||
div = len(floor)//2 + balance[idx]
|
div = len(floor) // 2 + balance[idx]
|
||||||
floor_shuffle = np.random.permutation(len(floor))
|
floor_shuffle = np.random.permutation(len(floor))
|
||||||
keep = np.concatenate((keep, floor[floor_shuffle[:div]]))
|
keep = np.concatenate((keep, floor[floor_shuffle[:div]]))
|
||||||
discard = np.concatenate((discard, floor[floor_shuffle[div:]]))
|
discard = np.concatenate((discard, floor[floor_shuffle[div:]]))
|
||||||
@@ -28,7 +29,6 @@ class Tower:
|
|||||||
balance[odd_floors] = 1
|
balance[odd_floors] = 1
|
||||||
return balance.tolist()
|
return balance.tolist()
|
||||||
|
|
||||||
|
|
||||||
def update_tower(self, keep: np.ndarray, other_discard: np.ndarray):
|
def update_tower(self, keep: np.ndarray, other_discard: np.ndarray):
|
||||||
new_floors = []
|
new_floors = []
|
||||||
for floor in self.floors:
|
for floor in self.floors:
|
||||||
@@ -51,20 +51,22 @@ class Tower:
|
|||||||
tower.extend(floor.tolist())
|
tower.extend(floor.tolist())
|
||||||
return tower
|
return tower
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TowerShuffle:
|
class TowerShuffle:
|
||||||
|
# TODO: I don't think total_positions is used anywhere
|
||||||
total_positions: int
|
total_positions: int
|
||||||
left_tower: Tower
|
left_tower: Tower
|
||||||
right_tower: Tower
|
right_tower: Tower
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def new(cls, total_pos:int):
|
def new(cls, total_pos: int):
|
||||||
assert total_pos >= 3
|
assert total_pos >= 3
|
||||||
rand_pos = np.random.permutation(total_pos)
|
rand_pos = np.random.permutation(total_pos)
|
||||||
return TowerShuffle(
|
return TowerShuffle(
|
||||||
total_positions=total_pos,
|
total_positions=total_pos,
|
||||||
left_tower=Tower(floors=[rand_pos[:total_pos//2]]),
|
left_tower=Tower(floors=[rand_pos[:total_pos // 2]]),
|
||||||
right_tower=Tower(floors=[rand_pos[total_pos//2:]]),
|
right_tower=Tower(floors=[rand_pos[total_pos // 2:]]),
|
||||||
)
|
)
|
||||||
|
|
||||||
def shuffle(self):
|
def shuffle(self):
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
from src.keypad.tower_shuffle import TowerShuffle
|
from src.keypad.tower_shuffle import TowerShuffle
|
||||||
|
|
||||||
|
|
||||||
def test_tower_shuffle():
|
def test_tower_shuffle():
|
||||||
tower = TowerShuffle.new(9)
|
tower = TowerShuffle.new(13)
|
||||||
print(tower)
|
print(tower)
|
||||||
for _ in range(100):
|
for _ in range(100):
|
||||||
tower.shuffle()
|
tower.shuffle()
|
||||||
|
|||||||
Reference in New Issue
Block a user