implement tower shuffle
This commit is contained in:
File diff suppressed because one or more lines are too long
@@ -2,12 +2,16 @@ from dataclasses import dataclass
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from src.tower_shuffle import TowerShuffle
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Keypad:
|
class Keypad:
|
||||||
keypad: np.ndarray
|
keypad: np.ndarray
|
||||||
k: int # number of keys
|
k: int # number of keys
|
||||||
p: int # properties per key
|
p: int # properties per key
|
||||||
keypad_cache: list #
|
keypad_cache: list #
|
||||||
|
tower_shuffler: TowerShuffle
|
||||||
max_cache_size: int = 100
|
max_cache_size: int = 100
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -17,11 +21,16 @@ class Keypad:
|
|||||||
# Reshape into a 3x4 matrix
|
# Reshape into a 3x4 matrix
|
||||||
keypad = array.reshape(k, p)
|
keypad = array.reshape(k, p)
|
||||||
set_view = keypad.T
|
set_view = keypad.T
|
||||||
|
|
||||||
for set_idx in set_view:
|
for set_idx in set_view:
|
||||||
np.random.shuffle(set_idx)
|
np.random.shuffle(set_idx)
|
||||||
|
|
||||||
return Keypad(keypad=set_view.T, k=k, p=p, keypad_cache=[])
|
return Keypad(keypad=set_view.T, k=k, p=p, keypad_cache=[], tower_shuffler=TowerShuffle.new(p))
|
||||||
|
|
||||||
|
def tower_shuffle(self):
|
||||||
|
selected_positions = self.tower_shuffler.left_tower.tolist()
|
||||||
|
new_key_idxs = np.random.permutation(self.k)
|
||||||
|
self.keypad[:, selected_positions] = self.keypad[new_key_idxs, :][:, selected_positions]
|
||||||
|
self.tower_shuffler.shuffle()
|
||||||
|
|
||||||
def split_shuffle(self):
|
def split_shuffle(self):
|
||||||
"""
|
"""
|
||||||
@@ -49,9 +58,9 @@ class Keypad:
|
|||||||
def _shuffle(self) -> np.ndarray:
|
def _shuffle(self) -> np.ndarray:
|
||||||
column_permutation = np.random.permutation(self.p)
|
column_permutation = np.random.permutation(self.p)
|
||||||
column_subset = column_permutation[:self.p // 2]
|
column_subset = column_permutation[:self.p // 2]
|
||||||
perm_indices = np.random.permutation(self.k)
|
new_key_idxs = np.random.permutation(self.k)
|
||||||
shuffled_sets = self.keypad.copy()
|
shuffled_sets = self.keypad.copy()
|
||||||
shuffled_sets[:, column_subset] = shuffled_sets[perm_indices, :][:, column_subset]
|
shuffled_sets[:, column_subset] = shuffled_sets[new_key_idxs, :][:, column_subset]
|
||||||
return shuffled_sets
|
return shuffled_sets
|
||||||
|
|
||||||
def full_shuffle(self):
|
def full_shuffle(self):
|
||||||
|
|||||||
@@ -39,11 +39,17 @@ class Tower:
|
|||||||
self.floors.insert(0, other_discard)
|
self.floors.insert(0, other_discard)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
strVal = ""
|
str_val = ""
|
||||||
|
floor_numb = [i for i in reversed(range(len(self.floors)))]
|
||||||
for idx, val in enumerate(reversed(self.floors)):
|
for idx, val in enumerate(reversed(self.floors)):
|
||||||
strVal += f"Floor {idx}: {val.tolist()}\n"
|
str_val += f"Floor {floor_numb[idx]}: {val.tolist()}\n"
|
||||||
return strVal
|
return str_val
|
||||||
|
|
||||||
|
def tolist(self) -> list[int]:
|
||||||
|
tower = []
|
||||||
|
for floor in self.floors:
|
||||||
|
tower.extend(floor.tolist())
|
||||||
|
return tower
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TowerShuffle:
|
class TowerShuffle:
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ def total_shuffle_states(k: int, p: int) -> int:
|
|||||||
class ShuffleTypes(Enum):
|
class ShuffleTypes(Enum):
|
||||||
FULL_SHUFFLE = "FULL_SHUFFLE"
|
FULL_SHUFFLE = "FULL_SHUFFLE"
|
||||||
SPLIT_SHUFFLE = "SPLIT_SHUFFLE"
|
SPLIT_SHUFFLE = "SPLIT_SHUFFLE"
|
||||||
|
TOWER_SHUFFLE = "TOWER_SHUFFLE"
|
||||||
|
|
||||||
|
|
||||||
def observations(target_passcode: list[int], number_of_keys:int, properties_per_key: int, min_complexity: int, min_disparity: int, shuffle_type: ShuffleTypes, number_of_observations: int = 100):
|
def observations(target_passcode: list[int], number_of_keys:int, properties_per_key: int, min_complexity: int, min_disparity: int, shuffle_type: ShuffleTypes, number_of_observations: int = 100):
|
||||||
@@ -34,6 +35,8 @@ def observations(target_passcode: list[int], number_of_keys:int, properties_per_
|
|||||||
keypad.full_shuffle()
|
keypad.full_shuffle()
|
||||||
case ShuffleTypes.SPLIT_SHUFFLE:
|
case ShuffleTypes.SPLIT_SHUFFLE:
|
||||||
keypad.split_shuffle()
|
keypad.split_shuffle()
|
||||||
|
case ShuffleTypes.TOWER_SHUFFLE:
|
||||||
|
keypad.tower_shuffle()
|
||||||
case _:
|
case _:
|
||||||
raise Exception(f"no shuffle type {shuffle_type}")
|
raise Exception(f"no shuffle type {shuffle_type}")
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from typing import Iterable
|
|||||||
|
|
||||||
# Project root = parent of *this* file's directory
|
# Project root = parent of *this* file's directory
|
||||||
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||||||
OUTPUT_DIR = PROJECT_ROOT / "example" / "obs_json"
|
BASE_DIR = PROJECT_ROOT / "example"
|
||||||
PNG_DIR = PROJECT_ROOT / "example" / "obs_png"
|
PNG_DIR = PROJECT_ROOT / "example" / "obs_png"
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -23,6 +23,7 @@ def new_observation_sequence(
|
|||||||
complexity: int,
|
complexity: int,
|
||||||
disparity: int,
|
disparity: int,
|
||||||
numb_runs: int,
|
numb_runs: int,
|
||||||
|
shuffle_type: ShuffleTypes
|
||||||
) -> ObservationSequence:
|
) -> ObservationSequence:
|
||||||
passcode = passcode_generator(number_of_keys, properties_per_key, passcode_len, complexity, disparity)
|
passcode = passcode_generator(number_of_keys, properties_per_key, passcode_len, complexity, disparity)
|
||||||
obs_seq = ObservationSequence(target_passcode=passcode, observations=[])
|
obs_seq = ObservationSequence(target_passcode=passcode, observations=[])
|
||||||
@@ -32,7 +33,7 @@ def new_observation_sequence(
|
|||||||
properties_per_key=properties_per_key,
|
properties_per_key=properties_per_key,
|
||||||
min_complexity=complexity,
|
min_complexity=complexity,
|
||||||
min_disparity=disparity,
|
min_disparity=disparity,
|
||||||
shuffle_type=ShuffleTypes.SPLIT_SHUFFLE,
|
shuffle_type=shuffle_type,
|
||||||
number_of_observations=numb_runs,
|
number_of_observations=numb_runs,
|
||||||
)
|
)
|
||||||
for obs in obs_gen:
|
for obs in obs_gen:
|
||||||
@@ -50,14 +51,14 @@ def _next_json_filename(base_dir: Path) -> Path:
|
|||||||
return candidate
|
return candidate
|
||||||
counter += 1
|
counter += 1
|
||||||
|
|
||||||
def save_observation_sequence_to_json(seq: ObservationSequence, filename: Path | None = None) -> None:
|
def save_observation_sequence_to_json(seq: ObservationSequence, shuffle_type: ShuffleTypes, filename: Path | None = None) -> None:
|
||||||
"""
|
"""
|
||||||
Save ObservationSequence to JSON.
|
Save ObservationSequence to JSON.
|
||||||
- If filename is None, put it under PROJECT_ROOT/output/obs_json/ as observation_{n}.json
|
- If filename is None, put it under PROJECT_ROOT/output/obs_json/ as observation_{n}.json
|
||||||
- Creates directory if needed
|
- Creates directory if needed
|
||||||
"""
|
"""
|
||||||
if filename is None:
|
if filename is None:
|
||||||
base_dir = OUTPUT_DIR
|
base_dir = BASE_DIR / shuffle_type.name
|
||||||
base_dir.mkdir(parents=True, exist_ok=True)
|
base_dir.mkdir(parents=True, exist_ok=True)
|
||||||
filename = _next_json_filename(base_dir)
|
filename = _next_json_filename(base_dir)
|
||||||
else:
|
else:
|
||||||
@@ -223,12 +224,12 @@ def _next_run_dir(base_dir: Path) -> Path:
|
|||||||
return run_dir
|
return run_dir
|
||||||
counter += 1
|
counter += 1
|
||||||
|
|
||||||
def render_sequence_to_pngs(seq: ObservationSequence, out_dir: Path | None = None) -> None:
|
def render_sequence_to_pngs(seq: ObservationSequence, shuffle_type: ShuffleTypes, out_dir: Path | None = None) -> None:
|
||||||
"""
|
"""
|
||||||
Render each observation to its own PNG inside a fresh run directory.
|
Render each observation to its own PNG inside a fresh run directory.
|
||||||
Default: PROJECT_ROOT/output/obs_png/run_XXX/observation_001.png
|
Default: PROJECT_ROOT/output/obs_png/run_XXX/observation_001.png
|
||||||
"""
|
"""
|
||||||
base_dir = PNG_DIR if out_dir is None else out_dir
|
base_dir = BASE_DIR / shuffle_type.name / "obs_png" if out_dir is None else out_dir
|
||||||
base_dir.mkdir(parents=True, exist_ok=True)
|
base_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Create a fresh run dir
|
# Create a fresh run dir
|
||||||
@@ -238,6 +239,7 @@ def render_sequence_to_pngs(seq: ObservationSequence, out_dir: Path | None = Non
|
|||||||
filename = run_dir / f"observation_{i:03d}.png"
|
filename = run_dir / f"observation_{i:03d}.png"
|
||||||
render_observation_to_png(seq.target_passcode, obs, filename)
|
render_observation_to_png(seq.target_passcode, obs, filename)
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
obs_seq = new_observation_sequence(6, 9,4,0,0,numb_runs=50)
|
shuffle_type = ShuffleTypes.TOWER_SHUFFLE
|
||||||
save_observation_sequence_to_json(obs_seq)
|
obs_seq = new_observation_sequence(6, 9,4,0,0 ,numb_runs=50, shuffle_type=shuffle_type)
|
||||||
render_sequence_to_pngs(obs_seq)
|
save_observation_sequence_to_json(obs_seq, shuffle_type)
|
||||||
|
render_sequence_to_pngs(obs_seq, shuffle_type)
|
||||||
@@ -29,3 +29,11 @@ def test_full_shuffle():
|
|||||||
keypad.full_shuffle()
|
keypad.full_shuffle()
|
||||||
print(keypad.keypad)
|
print(keypad.keypad)
|
||||||
|
|
||||||
|
def test_tower_shuffle():
|
||||||
|
p = 4 # properties_per_key
|
||||||
|
k = 3 # number_of_keys
|
||||||
|
keypad = Keypad.new_keypad(k, p)
|
||||||
|
print()
|
||||||
|
for _ in range(10):
|
||||||
|
print(keypad.keypad)
|
||||||
|
keypad.tower_shuffle()
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from src.tower_shuffle import TowerShuffle
|
from src.tower_shuffle import TowerShuffle
|
||||||
|
|
||||||
def test_tower_shuffle():
|
def test_tower_shuffle():
|
||||||
tower = TowerShuffle.new(19)
|
tower = TowerShuffle.new(9)
|
||||||
print(tower)
|
print(tower)
|
||||||
for _ in range(100):
|
for _ in range(100):
|
||||||
tower.shuffle()
|
tower.shuffle()
|
||||||
|
|||||||
Reference in New Issue
Block a user