implement tower shuffle
This commit is contained in:
File diff suppressed because one or more lines are too long
@@ -2,12 +2,16 @@ from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
|
||||
from src.tower_shuffle import TowerShuffle
|
||||
|
||||
|
||||
@dataclass
|
||||
class Keypad:
|
||||
keypad: np.ndarray
|
||||
k: int # number of keys
|
||||
p: int # properties per key
|
||||
keypad_cache: list #
|
||||
tower_shuffler: TowerShuffle
|
||||
max_cache_size: int = 100
|
||||
|
||||
@staticmethod
|
||||
@@ -17,11 +21,16 @@ class Keypad:
|
||||
# Reshape into a 3x4 matrix
|
||||
keypad = array.reshape(k, p)
|
||||
set_view = keypad.T
|
||||
|
||||
for set_idx in set_view:
|
||||
np.random.shuffle(set_idx)
|
||||
|
||||
return Keypad(keypad=set_view.T, k=k, p=p, keypad_cache=[])
|
||||
return Keypad(keypad=set_view.T, k=k, p=p, keypad_cache=[], tower_shuffler=TowerShuffle.new(p))
|
||||
|
||||
def tower_shuffle(self):
|
||||
selected_positions = self.tower_shuffler.left_tower.tolist()
|
||||
new_key_idxs = np.random.permutation(self.k)
|
||||
self.keypad[:, selected_positions] = self.keypad[new_key_idxs, :][:, selected_positions]
|
||||
self.tower_shuffler.shuffle()
|
||||
|
||||
def split_shuffle(self):
|
||||
"""
|
||||
@@ -49,9 +58,9 @@ class Keypad:
|
||||
def _shuffle(self) -> np.ndarray:
|
||||
column_permutation = np.random.permutation(self.p)
|
||||
column_subset = column_permutation[:self.p // 2]
|
||||
perm_indices = np.random.permutation(self.k)
|
||||
new_key_idxs = np.random.permutation(self.k)
|
||||
shuffled_sets = self.keypad.copy()
|
||||
shuffled_sets[:, column_subset] = shuffled_sets[perm_indices, :][:, column_subset]
|
||||
shuffled_sets[:, column_subset] = shuffled_sets[new_key_idxs, :][:, column_subset]
|
||||
return shuffled_sets
|
||||
|
||||
def full_shuffle(self):
|
||||
|
||||
@@ -39,11 +39,17 @@ class Tower:
|
||||
self.floors.insert(0, other_discard)
|
||||
|
||||
def __str__(self):
|
||||
strVal = ""
|
||||
str_val = ""
|
||||
floor_numb = [i for i in reversed(range(len(self.floors)))]
|
||||
for idx, val in enumerate(reversed(self.floors)):
|
||||
strVal += f"Floor {idx}: {val.tolist()}\n"
|
||||
return strVal
|
||||
str_val += f"Floor {floor_numb[idx]}: {val.tolist()}\n"
|
||||
return str_val
|
||||
|
||||
def tolist(self) -> list[int]:
|
||||
tower = []
|
||||
for floor in self.floors:
|
||||
tower.extend(floor.tolist())
|
||||
return tower
|
||||
|
||||
@dataclass
|
||||
class TowerShuffle:
|
||||
|
||||
@@ -16,6 +16,7 @@ def total_shuffle_states(k: int, p: int) -> int:
|
||||
class ShuffleTypes(Enum):
|
||||
FULL_SHUFFLE = "FULL_SHUFFLE"
|
||||
SPLIT_SHUFFLE = "SPLIT_SHUFFLE"
|
||||
TOWER_SHUFFLE = "TOWER_SHUFFLE"
|
||||
|
||||
|
||||
def observations(target_passcode: list[int], number_of_keys:int, properties_per_key: int, min_complexity: int, min_disparity: int, shuffle_type: ShuffleTypes, number_of_observations: int = 100):
|
||||
@@ -34,6 +35,8 @@ def observations(target_passcode: list[int], number_of_keys:int, properties_per_
|
||||
keypad.full_shuffle()
|
||||
case ShuffleTypes.SPLIT_SHUFFLE:
|
||||
keypad.split_shuffle()
|
||||
case ShuffleTypes.TOWER_SHUFFLE:
|
||||
keypad.tower_shuffle()
|
||||
case _:
|
||||
raise Exception(f"no shuffle type {shuffle_type}")
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import Iterable
|
||||
|
||||
# Project root = parent of *this* file's directory
|
||||
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||||
OUTPUT_DIR = PROJECT_ROOT / "example" / "obs_json"
|
||||
BASE_DIR = PROJECT_ROOT / "example"
|
||||
PNG_DIR = PROJECT_ROOT / "example" / "obs_png"
|
||||
|
||||
@dataclass
|
||||
@@ -23,6 +23,7 @@ def new_observation_sequence(
|
||||
complexity: int,
|
||||
disparity: int,
|
||||
numb_runs: int,
|
||||
shuffle_type: ShuffleTypes
|
||||
) -> ObservationSequence:
|
||||
passcode = passcode_generator(number_of_keys, properties_per_key, passcode_len, complexity, disparity)
|
||||
obs_seq = ObservationSequence(target_passcode=passcode, observations=[])
|
||||
@@ -32,7 +33,7 @@ def new_observation_sequence(
|
||||
properties_per_key=properties_per_key,
|
||||
min_complexity=complexity,
|
||||
min_disparity=disparity,
|
||||
shuffle_type=ShuffleTypes.SPLIT_SHUFFLE,
|
||||
shuffle_type=shuffle_type,
|
||||
number_of_observations=numb_runs,
|
||||
)
|
||||
for obs in obs_gen:
|
||||
@@ -50,14 +51,14 @@ def _next_json_filename(base_dir: Path) -> Path:
|
||||
return candidate
|
||||
counter += 1
|
||||
|
||||
def save_observation_sequence_to_json(seq: ObservationSequence, filename: Path | None = None) -> None:
|
||||
def save_observation_sequence_to_json(seq: ObservationSequence, shuffle_type: ShuffleTypes, filename: Path | None = None) -> None:
|
||||
"""
|
||||
Save ObservationSequence to JSON.
|
||||
- If filename is None, put it under PROJECT_ROOT/output/obs_json/ as observation_{n}.json
|
||||
- Creates directory if needed
|
||||
"""
|
||||
if filename is None:
|
||||
base_dir = OUTPUT_DIR
|
||||
base_dir = BASE_DIR / shuffle_type.name
|
||||
base_dir.mkdir(parents=True, exist_ok=True)
|
||||
filename = _next_json_filename(base_dir)
|
||||
else:
|
||||
@@ -223,12 +224,12 @@ def _next_run_dir(base_dir: Path) -> Path:
|
||||
return run_dir
|
||||
counter += 1
|
||||
|
||||
def render_sequence_to_pngs(seq: ObservationSequence, out_dir: Path | None = None) -> None:
|
||||
def render_sequence_to_pngs(seq: ObservationSequence, shuffle_type: ShuffleTypes, out_dir: Path | None = None) -> None:
|
||||
"""
|
||||
Render each observation to its own PNG inside a fresh run directory.
|
||||
Default: PROJECT_ROOT/output/obs_png/run_XXX/observation_001.png
|
||||
"""
|
||||
base_dir = PNG_DIR if out_dir is None else out_dir
|
||||
base_dir = BASE_DIR / shuffle_type.name / "obs_png" if out_dir is None else out_dir
|
||||
base_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create a fresh run dir
|
||||
@@ -238,6 +239,7 @@ def render_sequence_to_pngs(seq: ObservationSequence, out_dir: Path | None = Non
|
||||
filename = run_dir / f"observation_{i:03d}.png"
|
||||
render_observation_to_png(seq.target_passcode, obs, filename)
|
||||
if __name__ == "__main__":
|
||||
obs_seq = new_observation_sequence(6, 9,4,0,0,numb_runs=50)
|
||||
save_observation_sequence_to_json(obs_seq)
|
||||
render_sequence_to_pngs(obs_seq)
|
||||
shuffle_type = ShuffleTypes.TOWER_SHUFFLE
|
||||
obs_seq = new_observation_sequence(6, 9,4,0,0 ,numb_runs=50, shuffle_type=shuffle_type)
|
||||
save_observation_sequence_to_json(obs_seq, shuffle_type)
|
||||
render_sequence_to_pngs(obs_seq, shuffle_type)
|
||||
@@ -29,3 +29,11 @@ def test_full_shuffle():
|
||||
keypad.full_shuffle()
|
||||
print(keypad.keypad)
|
||||
|
||||
def test_tower_shuffle():
|
||||
p = 4 # properties_per_key
|
||||
k = 3 # number_of_keys
|
||||
keypad = Keypad.new_keypad(k, p)
|
||||
print()
|
||||
for _ in range(10):
|
||||
print(keypad.keypad)
|
||||
keypad.tower_shuffle()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from src.tower_shuffle import TowerShuffle
|
||||
|
||||
def test_tower_shuffle():
|
||||
tower = TowerShuffle.new(19)
|
||||
tower = TowerShuffle.new(9)
|
||||
print(tower)
|
||||
for _ in range(100):
|
||||
tower.shuffle()
|
||||
|
||||
Reference in New Issue
Block a user