Refactor (#8)
Co-authored-by: Donovan <donovan.a.kelly@pm.me> Reviewed-on: https://git.infra.nkode.tech/dkelly/evilnkode/pulls/8
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
from src.evilkode import Evilkode
|
||||
from tqdm import tqdm
|
||||
from src.evilnkode import EvilNKode
|
||||
from dataclasses import dataclass
|
||||
from statistics import mean, variance
|
||||
from src.utils import observations, passcode_generator
|
||||
from src.keypad.keypad import BaseKeypad
|
||||
from pathlib import Path
|
||||
|
||||
from src.utils import ShuffleTypes, observations, passcode_generator
|
||||
import pickle
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -12,7 +13,7 @@ class Benchmark:
|
||||
iterations_to_replay: list[int]
|
||||
|
||||
|
||||
def shuffle_benchmark(
|
||||
def benchmark(
|
||||
number_of_keys: int,
|
||||
properties_per_key: int,
|
||||
passcode_len: int,
|
||||
@@ -20,86 +21,44 @@ def shuffle_benchmark(
|
||||
run_count: int,
|
||||
complexity: int,
|
||||
disparity: int,
|
||||
shuffle_type: ShuffleTypes,
|
||||
file_path: str = '../output',
|
||||
keypad: BaseKeypad,
|
||||
file_path: Path = '../output',
|
||||
overwrite: bool = False
|
||||
) -> Benchmark:
|
||||
# file_name_break = f"{shuffle_type.name.lower()}-{number_of_keys}-{properties_per_key}-{passcode_len}-{max_tries_before_lockout}-{complexity}-{disparity}-{run_count}.txt"
|
||||
# full_path_iter_break = Path(file_path) / "iterations_to_break" /file_name_break
|
||||
# if not overwrite and full_path_iter_break.exists():
|
||||
# print(f"file exists {file_path}")
|
||||
# with open(full_path_iter_break, "r") as fp:
|
||||
# iterations_to_break = fp.readline()
|
||||
# iterations_to_break = iterations_to_break.split(',')
|
||||
# iterations_to_break = [int(i) for i in iterations_to_break]
|
||||
# return Benchmark(
|
||||
# mean=mean(iterations_to_break),
|
||||
# variance=variance(iterations_to_break),
|
||||
# iterations_to_break=iterations_to_break
|
||||
# )
|
||||
shuffle_type = str(type(keypad)).lower().split('.')[-1].replace("'>", "")
|
||||
file_name = f"{shuffle_type}-{number_of_keys}-{properties_per_key}-{passcode_len}-{max_tries_before_lockout}-{complexity}-{disparity}-{run_count}.pkl"
|
||||
full_path = Path(file_path) / "benchmark" / file_name
|
||||
if not overwrite and full_path.exists():
|
||||
print(f"File exists: {full_path}")
|
||||
with open(full_path, "rb") as fp:
|
||||
return pickle.load(fp)
|
||||
|
||||
iterations_to_break = []
|
||||
iterations_to_replay = []
|
||||
for _ in range(run_count):
|
||||
for _ in tqdm(range(run_count)):
|
||||
passcode = passcode_generator(number_of_keys, properties_per_key, passcode_len, complexity, disparity)
|
||||
evilkode = Evilkode(
|
||||
evilnkode = EvilNKode(
|
||||
observations=observations(
|
||||
target_passcode=passcode,
|
||||
number_of_keys=number_of_keys,
|
||||
properties_per_key=properties_per_key,
|
||||
min_complexity=complexity,
|
||||
min_disparity=disparity,
|
||||
shuffle_type=shuffle_type,
|
||||
keypad=keypad,
|
||||
),
|
||||
number_of_keys=number_of_keys,
|
||||
properties_per_key=properties_per_key,
|
||||
passcode_len=passcode_len,
|
||||
max_tries_before_lockout=max_tries_before_lockout,
|
||||
)
|
||||
evilout = evilkode.run()
|
||||
evilout = evilnkode.run()
|
||||
iterations_to_break.append(evilout.iterations_to_break)
|
||||
iterations_to_replay.append(evilout.iterations_to_replay)
|
||||
|
||||
# full_path_iter_break.parent.mkdir(parents=True, exist_ok=True)
|
||||
# with open(full_path_iter_break, "w") as fp:
|
||||
# fp.write(",".join([str(i) for i in iterations_to_break])),
|
||||
|
||||
return Benchmark(
|
||||
benchmark_result = Benchmark(
|
||||
iterations_to_break=iterations_to_break,
|
||||
iterations_to_replay=iterations_to_replay
|
||||
)
|
||||
|
||||
if file_path:
|
||||
full_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(full_path, "wb") as fp:
|
||||
pickle.dump(benchmark_result, fp)
|
||||
|
||||
# def full_shuffle_benchmark(
|
||||
# number_of_keys: int,
|
||||
# properties_per_key: int,
|
||||
# passcode_len: int,
|
||||
# max_tries_before_lockout: int,
|
||||
# run_count: int,
|
||||
# complexity: int,
|
||||
# disparity: int,
|
||||
# ) -> Benchmark:
|
||||
# runs = []
|
||||
# for _ in range(run_count):
|
||||
# passcode = passcode_generator(number_of_keys, properties_per_key, passcode_len, complexity, disparity)
|
||||
# evilkode = Evilkode(
|
||||
# observations=observations(
|
||||
# target_passcode=passcode,
|
||||
# number_of_keys=number_of_keys,
|
||||
# properties_per_key=properties_per_key,
|
||||
# min_complexity=complexity,
|
||||
# min_disparity=disparity,
|
||||
# shuffle_type=ShuffleTypes.FULL_SHUFFLE,
|
||||
# ),
|
||||
# number_of_keys=number_of_keys,
|
||||
# properties_per_key=properties_per_key,
|
||||
# passcode_len=passcode_len,
|
||||
# max_tries_before_lockout=max_tries_before_lockout,
|
||||
# )
|
||||
# evilout = evilkode.run()
|
||||
# runs.append(evilout.iterations_to_break)
|
||||
#
|
||||
# return Benchmark(
|
||||
# mean=mean(runs),
|
||||
# variance=variance(runs),
|
||||
# iterations_to_break=runs
|
||||
# )
|
||||
return benchmark_result
|
||||
|
||||
@@ -19,17 +19,12 @@ class Observation:
|
||||
|
||||
@dataclass
|
||||
class EvilOutput:
|
||||
# possible_nkodes: list[list[int]]
|
||||
iterations_to_break: int
|
||||
iterations_to_replay: int
|
||||
|
||||
# @property
|
||||
# def number_of_possible_nkode(self):
|
||||
# return math.prod([len(el) for el in self.possible_nkodes])
|
||||
|
||||
|
||||
@dataclass
|
||||
class Evilkode:
|
||||
class EvilNKode:
|
||||
observations: Iterator[Observation]
|
||||
passcode_len: int
|
||||
number_of_keys: int
|
||||
@@ -37,7 +32,6 @@ class Evilkode:
|
||||
max_tries_before_lockout: int = 5
|
||||
possible_nkode = None
|
||||
|
||||
|
||||
def initialize(self):
|
||||
possible_values = set(range(self.number_of_keys * self.properties_per_key))
|
||||
self.possible_nkode = [possible_values.copy() for _ in range(self.passcode_len)]
|
||||
@@ -49,9 +43,9 @@ class Evilkode:
|
||||
if iterations_to_replay is None:
|
||||
replay_possibilities = self.replay_attack(obs)
|
||||
if replay_possibilities <= self.max_tries_before_lockout:
|
||||
iterations_to_replay = idx + 1
|
||||
iterations_to_replay = idx + 1
|
||||
if math.prod([len(el) for el in self.possible_nkode]) <= self.max_tries_before_lockout:
|
||||
assert iterations_to_replay <= idx +1
|
||||
assert iterations_to_replay <= idx + 1
|
||||
return EvilOutput(
|
||||
# possible_nkodes=[list(el) for el in self.possible_nkode],
|
||||
iterations_to_break=idx + 1,
|
||||
@@ -65,4 +59,4 @@ class Evilkode:
|
||||
possible_combos = 1
|
||||
for el in self.possible_nkode:
|
||||
possible_combos *= len({obs.flat_keypad.index(el2) // self.properties_per_key for el2 in el})
|
||||
return possible_combos
|
||||
return possible_combos
|
||||
@@ -1,92 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
|
||||
from src.tower_shuffle import TowerShuffle
|
||||
|
||||
|
||||
@dataclass
|
||||
class Keypad:
|
||||
keypad: np.ndarray
|
||||
k: int # number of keys
|
||||
p: int # properties per key
|
||||
keypad_cache: list #
|
||||
tower_shuffler: TowerShuffle
|
||||
max_cache_size: int = 100
|
||||
|
||||
@staticmethod
|
||||
def new_keypad(k: int, p: int):
|
||||
total_properties = k * p
|
||||
array = np.arange(total_properties)
|
||||
# Reshape into a 3x4 matrix
|
||||
keypad = array.reshape(k, p)
|
||||
set_view = keypad.T
|
||||
for set_idx in set_view:
|
||||
np.random.shuffle(set_idx)
|
||||
|
||||
return Keypad(keypad=set_view.T, k=k, p=p, keypad_cache=[], tower_shuffler=TowerShuffle.new(p))
|
||||
|
||||
def tower_shuffle(self):
|
||||
selected_positions = self.tower_shuffler.left_tower.tolist()
|
||||
new_key_idxs = np.random.permutation(self.k)
|
||||
self.keypad[:, selected_positions] = self.keypad[new_key_idxs, :][:, selected_positions]
|
||||
self.tower_shuffler.shuffle()
|
||||
|
||||
def split_shuffle(self):
|
||||
"""
|
||||
This is a modified split shuffle.
|
||||
It doesn't shuffle the keys only the properties in the keys.
|
||||
Shuffling the keys makes it hard for people to guess an nKode not a machine.
|
||||
This split shuffle includes a cache to prevent the same configuration from being used.
|
||||
This cache is not in any other implementation.
|
||||
Testing suggests it's not necessary.
|
||||
Getting the same keypad twice over 100 shuffles is very unlikely.
|
||||
"""
|
||||
shuffled_sets = self._shuffle()
|
||||
# Sort the shuffled sets by the first column
|
||||
sorted_set = shuffled_sets[np.argsort(shuffled_sets[:, 0])]
|
||||
while str(sorted_set) in self.keypad_cache:
|
||||
# continue shuffling until we get a unique configuration
|
||||
shuffled_sets = self._shuffle()
|
||||
sorted_set = shuffled_sets[np.argsort(shuffled_sets[:, 0])]
|
||||
|
||||
self.keypad_cache.append(str(sorted_set))
|
||||
self.keypad_cache = self.keypad_cache[:self.max_cache_size]
|
||||
self.keypad = shuffled_sets
|
||||
|
||||
|
||||
def _shuffle(self) -> np.ndarray:
|
||||
column_permutation = np.random.permutation(self.p)
|
||||
column_subset = column_permutation[:self.p // 2]
|
||||
new_key_idxs = np.random.permutation(self.k)
|
||||
shuffled_sets = self.keypad.copy()
|
||||
shuffled_sets[:, column_subset] = shuffled_sets[new_key_idxs, :][:, column_subset]
|
||||
return shuffled_sets
|
||||
|
||||
def full_shuffle(self):
|
||||
shuffled_matrix = np.array([np.random.permutation(row) for row in self.keypad.T])
|
||||
self.keypad = shuffled_matrix.T
|
||||
|
||||
def key_entry(self, target_passcode: list[int]) -> list[int]:
|
||||
"""
|
||||
Given target_values, return the row indices they are in.
|
||||
Assert that each element is >= 0 and < self.k * self.p.
|
||||
"""
|
||||
# Convert the list to a NumPy array for vectorized checks
|
||||
vals = np.array(target_passcode)
|
||||
|
||||
# Validate that each value is within the valid range
|
||||
if np.any((vals < 0) | (vals >= self.k * self.p)):
|
||||
raise ValueError("One or more values are out of the valid range.")
|
||||
|
||||
# Flatten the keypad to a 1D array
|
||||
flat = self.keypad.flatten()
|
||||
|
||||
# Create an inverse mapping from value -> row index
|
||||
inv_index = np.empty(self.k * self.p, dtype=int)
|
||||
# Each value v is at position i in 'flat', so row = i // p
|
||||
for i, v in enumerate(flat):
|
||||
inv_index[v] = i // self.p
|
||||
|
||||
# Use the inverse mapping to get row indices for all target values
|
||||
return inv_index[vals].tolist()
|
||||
0
src/keypad/__init__.py
Normal file
0
src/keypad/__init__.py
Normal file
123
src/keypad/keypad.py
Normal file
123
src/keypad/keypad.py
Normal file
@@ -0,0 +1,123 @@
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
from src.keypad.tower_shuffle import TowerShuffle
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Self
|
||||
|
||||
|
||||
class BaseKeypad(ABC):
|
||||
keypad: np.ndarray
|
||||
k: int # number of keys
|
||||
p: int # properties per key
|
||||
|
||||
@classmethod
|
||||
def _build_keypad(cls, k: int, p: int) -> np.ndarray:
|
||||
rng = np.random.default_rng()
|
||||
total = k * p
|
||||
array = np.arange(total)
|
||||
keypad = array.reshape(k, p)
|
||||
set_view = keypad.T.copy()
|
||||
for set_idx in set_view:
|
||||
rng.shuffle(set_idx)
|
||||
return set_view.T
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def new_keypad(cls, k: int, p: int) -> Self:
|
||||
raise NotImplementedError
|
||||
|
||||
def key_entry(self, target_passcode: list[int]) -> list[int]:
|
||||
vals = np.array(target_passcode)
|
||||
if np.any((vals < 0) | (vals >= self.k * self.p)):
|
||||
raise ValueError("One or more values are out of the valid range.")
|
||||
flat = self.keypad.flatten()
|
||||
inv_index = np.empty(self.k * self.p, dtype=int)
|
||||
for i, v in enumerate(flat):
|
||||
inv_index[v] = i // self.p
|
||||
return inv_index[vals].tolist()
|
||||
|
||||
@abstractmethod
|
||||
def shuffle(self):
|
||||
pass
|
||||
|
||||
def keypad_mat(self) -> list[list[int]]:
|
||||
return [el.tolist() for el in self.keypad]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SlidingTowerShuffleKeypad(BaseKeypad):
|
||||
keypad: np.ndarray
|
||||
k: int # number of keys
|
||||
p: int # properties per key
|
||||
tower_shuffle: TowerShuffle
|
||||
|
||||
@classmethod
|
||||
def new_keypad(cls, k: int, p: int) -> Self:
|
||||
kp = cls._build_keypad(k, p)
|
||||
return cls(keypad=kp, k=k, p=p, tower_shuffle=TowerShuffle.new(p))
|
||||
|
||||
def shuffle(self):
|
||||
selected_positions = self.tower_shuffle.left_tower.tolist()
|
||||
shift = np.random.randint(1, self.k) # random int in [1, k-1]
|
||||
new_key_idxs = np.roll(np.arange(self.k), shift)
|
||||
shuffled_sets = self.keypad.copy()
|
||||
shuffled_sets[:, selected_positions] = shuffled_sets[new_key_idxs, :][:, selected_positions]
|
||||
self.keypad = shuffled_sets
|
||||
self.tower_shuffle.shuffle()
|
||||
|
||||
|
||||
@dataclass
|
||||
class RandomShuffleKeypad(BaseKeypad):
|
||||
keypad: np.ndarray
|
||||
k: int # number of keys
|
||||
p: int # properties per key
|
||||
|
||||
@classmethod
|
||||
def new_keypad(cls, k: int, p: int) -> Self:
|
||||
kp = cls._build_keypad(k, p)
|
||||
return cls(keypad=kp, k=k, p=p)
|
||||
|
||||
def shuffle(self):
|
||||
shuffled_matrix = np.array([np.random.permutation(row) for row in self.keypad.T])
|
||||
self.keypad = shuffled_matrix.T
|
||||
|
||||
|
||||
@dataclass
|
||||
class RandomSplitShuffleKeypad(BaseKeypad):
|
||||
keypad: np.ndarray
|
||||
k: int # number of keys
|
||||
p: int # properties per key
|
||||
|
||||
@classmethod
|
||||
def new_keypad(cls, k: int, p: int) -> Self:
|
||||
kp = cls._build_keypad(k, p)
|
||||
return cls(keypad=kp, k=k, p=p)
|
||||
|
||||
def shuffle(self):
|
||||
column_permutation = np.random.permutation(self.p)
|
||||
column_subset = column_permutation[:self.p // 2]
|
||||
new_key_idxs = np.random.permutation(self.k)
|
||||
shuffled_sets = self.keypad.copy()
|
||||
shuffled_sets[:, column_subset] = shuffled_sets[new_key_idxs, :][:, column_subset]
|
||||
self.keypad = shuffled_sets
|
||||
|
||||
|
||||
@dataclass
|
||||
class SlidingSplitShuffleKeypad(BaseKeypad):
|
||||
keypad: np.ndarray
|
||||
k: int # number of keys
|
||||
p: int # properties per key
|
||||
|
||||
@classmethod
|
||||
def new_keypad(cls, k: int, p: int) -> Self:
|
||||
kp = cls._build_keypad(k, p)
|
||||
return cls(keypad=kp, k=k, p=p)
|
||||
|
||||
def shuffle(self):
|
||||
selected_positions = np.random.permutation(self.p)
|
||||
column_subset = selected_positions[:self.p // 2]
|
||||
shift = np.random.randint(1, self.k)
|
||||
new_key_idxs = np.roll(np.arange(self.k), shift)
|
||||
shuffled_sets = self.keypad.copy()
|
||||
shuffled_sets[:, column_subset] = shuffled_sets[new_key_idxs, :][:, column_subset]
|
||||
self.keypad = shuffled_sets
|
||||
50
src/utils.py
50
src/utils.py
@@ -1,65 +1,49 @@
|
||||
import random
|
||||
from enum import Enum
|
||||
from math import factorial, comb
|
||||
|
||||
from src.evilkode import Observation
|
||||
from src.keypad import Keypad
|
||||
from src.evilnkode import Observation
|
||||
from src.keypad.keypad import BaseKeypad
|
||||
from typing import Iterator
|
||||
|
||||
|
||||
def total_valid_nkode_states(k: int, p: int) -> int:
|
||||
return factorial(k) ** (p-1)
|
||||
return factorial(k) ** (p - 1)
|
||||
|
||||
|
||||
def total_shuffle_states(k: int, p: int) -> int:
|
||||
return comb((p-1), (p-1) // 2) * factorial(k)
|
||||
return comb((p - 1), (p - 1) // 2) * factorial(k)
|
||||
|
||||
|
||||
class ShuffleTypes(Enum):
|
||||
FULL_SHUFFLE = "FULL_SHUFFLE"
|
||||
SPLIT_SHUFFLE = "SPLIT_SHUFFLE"
|
||||
TOWER_SHUFFLE = "TOWER_SHUFFLE"
|
||||
|
||||
|
||||
def observations(target_passcode: list[int], number_of_keys:int, properties_per_key: int, min_complexity: int, min_disparity: int, shuffle_type: ShuffleTypes, number_of_observations: int = 100):
|
||||
k = number_of_keys
|
||||
p = properties_per_key
|
||||
keypad = Keypad.new_keypad(k, p)
|
||||
|
||||
def obs_gen():
|
||||
def observations(target_passcode: list[int], keypad: BaseKeypad, number_of_observations: int = 100) -> Iterator[
|
||||
Observation]:
|
||||
def obs():
|
||||
for _ in range(number_of_observations):
|
||||
yield Observation(
|
||||
keypad=keypad.keypad.copy(),
|
||||
keypad=keypad.keypad_mat(),
|
||||
key_selection=keypad.key_entry(target_passcode=target_passcode)
|
||||
)
|
||||
match shuffle_type:
|
||||
case ShuffleTypes.FULL_SHUFFLE:
|
||||
keypad.full_shuffle()
|
||||
case ShuffleTypes.SPLIT_SHUFFLE:
|
||||
keypad.split_shuffle()
|
||||
case ShuffleTypes.TOWER_SHUFFLE:
|
||||
keypad.tower_shuffle()
|
||||
case _:
|
||||
raise Exception(f"no shuffle type {shuffle_type}")
|
||||
keypad.shuffle()
|
||||
|
||||
return obs_gen()
|
||||
return obs()
|
||||
|
||||
|
||||
def passcode_generator(k: int, p: int, n: int, c: int, d: int) -> list[int]:
|
||||
assert n >= c
|
||||
assert p*k >= c
|
||||
assert p * k >= c
|
||||
|
||||
assert n >= d
|
||||
assert p >= d
|
||||
passcode_prop = []
|
||||
passcode_set = []
|
||||
valid_choices = {i for i in range(k*p)}
|
||||
repeat_set = n-d
|
||||
repeat_prop = n-c
|
||||
valid_choices = {i for i in range(k * p)}
|
||||
repeat_set = n - d
|
||||
repeat_prop = n - c
|
||||
prop_added = set()
|
||||
set_added = set()
|
||||
|
||||
for _ in range(n):
|
||||
prop = random.choice(list(valid_choices))
|
||||
prop_set = prop//p
|
||||
prop_set = prop // p
|
||||
passcode_prop.append(prop)
|
||||
passcode_set.append(prop_set)
|
||||
|
||||
|
||||
@@ -1,245 +0,0 @@
|
||||
import json
|
||||
from dataclasses import dataclass, asdict
|
||||
from evilkode import Observation
|
||||
from utils import observations, passcode_generator, ShuffleTypes
|
||||
from pathlib import Path
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from typing import Iterable
|
||||
|
||||
# Project root = parent of *this* file's directory
|
||||
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||||
BASE_DIR = PROJECT_ROOT / "example"
|
||||
PNG_DIR = PROJECT_ROOT / "example" / "obs_png"
|
||||
|
||||
@dataclass
|
||||
class ObservationSequence:
|
||||
target_passcode: list[int]
|
||||
observations: list[Observation]
|
||||
|
||||
def new_observation_sequence(
|
||||
number_of_keys: int,
|
||||
properties_per_key: int,
|
||||
passcode_len: int,
|
||||
complexity: int,
|
||||
disparity: int,
|
||||
numb_runs: int,
|
||||
shuffle_type: ShuffleTypes
|
||||
) -> ObservationSequence:
|
||||
passcode = passcode_generator(number_of_keys, properties_per_key, passcode_len, complexity, disparity)
|
||||
obs_seq = ObservationSequence(target_passcode=passcode, observations=[])
|
||||
obs_gen = observations(
|
||||
target_passcode=passcode,
|
||||
number_of_keys=number_of_keys,
|
||||
properties_per_key=properties_per_key,
|
||||
min_complexity=complexity,
|
||||
min_disparity=disparity,
|
||||
shuffle_type=shuffle_type,
|
||||
number_of_observations=numb_runs,
|
||||
)
|
||||
for obs in obs_gen:
|
||||
obs.keypad = obs.keypad.tolist()
|
||||
obs_seq.observations.append(obs)
|
||||
|
||||
return obs_seq
|
||||
|
||||
def _next_json_filename(base_dir: Path) -> Path:
|
||||
"""Find the next available observation_X.json file in base_dir."""
|
||||
counter = 1
|
||||
while True:
|
||||
candidate = base_dir / f"observation_{counter}.json"
|
||||
if not candidate.exists():
|
||||
return candidate
|
||||
counter += 1
|
||||
|
||||
def save_observation_sequence_to_json(seq: ObservationSequence, shuffle_type: ShuffleTypes, filename: Path | None = None) -> None:
|
||||
"""
|
||||
Save ObservationSequence to JSON.
|
||||
- If filename is None, put it under PROJECT_ROOT/output/obs_json/ as observation_{n}.json
|
||||
- Creates directory if needed
|
||||
"""
|
||||
if filename is None:
|
||||
base_dir = BASE_DIR / shuffle_type.name / "obs_json"
|
||||
base_dir.mkdir(parents=True, exist_ok=True)
|
||||
filename = _next_json_filename(base_dir)
|
||||
else:
|
||||
filename.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with filename.open("w", encoding="utf-8") as f:
|
||||
json.dump(asdict(seq), f, indent=4)
|
||||
|
||||
# ---------- Helpers ----------
|
||||
def _load_font(preferred: str, size: int) -> ImageFont.FreeTypeFont | ImageFont.ImageFont:
|
||||
"""Try a preferred TTF, fall back to common monospace, then PIL default."""
|
||||
candidates = [
|
||||
preferred,
|
||||
"DejaVuSansMono.ttf", # common on Linux
|
||||
"Consolas.ttf", # Windows
|
||||
"Menlo.ttc", "Menlo.ttf", # macOS
|
||||
"Courier New.ttf",
|
||||
]
|
||||
for c in candidates:
|
||||
try:
|
||||
return ImageFont.truetype(c, size)
|
||||
except Exception:
|
||||
continue
|
||||
return ImageFont.load_default()
|
||||
|
||||
def _text_size(draw: ImageDraw.ImageDraw, text: str, font: ImageFont.ImageFont) -> tuple[int, int]:
|
||||
"""Get (w, h) using font bbox for accurate layout."""
|
||||
left, top, right, bottom = draw.textbbox((0, 0), text, font=font)
|
||||
return right - left, bottom - top
|
||||
|
||||
def _join_nums(nums: Iterable[int]) -> str:
|
||||
return " ".join(str(n) for n in nums)
|
||||
|
||||
def _next_available_path(path: Path) -> Path:
|
||||
"""If path exists, append _1, _2, ..."""
|
||||
if not path.exists():
|
||||
return path
|
||||
base, suffix = path.stem, path.suffix or ".png"
|
||||
i = 1
|
||||
while True:
|
||||
candidate = path.with_name(f"{base}_{i}{suffix}")
|
||||
if not candidate.exists():
|
||||
return candidate
|
||||
i += 1
|
||||
|
||||
# ---------- Core rendering ----------
|
||||
def render_observation_to_png(
|
||||
target_passcode: list[int],
|
||||
obs: Observation,
|
||||
out_path: Path,
|
||||
*,
|
||||
header_font_name: str = "DejaVuSans.ttf",
|
||||
body_font_name: str = "DejaVuSans.ttf",
|
||||
header_size: int = 28,
|
||||
body_size: int = 24,
|
||||
margin: int = 32,
|
||||
row_padding_xy: tuple[int, int] = (16, 12), # (x, y) padding inside row box
|
||||
row_spacing: int = 14,
|
||||
header_spacing: int = 10,
|
||||
section_spacing: int = 18,
|
||||
bg_color: str = "white",
|
||||
fg_color: str = "black",
|
||||
row_fill: str = "#f7f7f7",
|
||||
row_outline: str = "#222222",
|
||||
):
|
||||
"""
|
||||
Render a single observation:
|
||||
- Top lines:
|
||||
Target Passcode: {target}
|
||||
Selected Keys: {selected keys}
|
||||
- Then a stack of row boxes representing the keypad rows.
|
||||
"""
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
out_path = _next_available_path(out_path)
|
||||
|
||||
# Fonts
|
||||
header_font = _load_font(header_font_name, header_size)
|
||||
body_font = _load_font(body_font_name, body_size)
|
||||
|
||||
# Prepare strings
|
||||
header1 = f"Target Passcode: {_join_nums(target_passcode)}"
|
||||
header2 = f"Selected Keys: {_join_nums(obs.key_selection)}"
|
||||
row_texts = [_join_nums(row) for row in obs.keypad]
|
||||
|
||||
# Measure to compute canvas size
|
||||
# Provisional image for measurement
|
||||
temp_img = Image.new("RGB", (1, 1), bg_color)
|
||||
d = ImageDraw.Draw(temp_img)
|
||||
|
||||
h1_w, h1_h = _text_size(d, header1, header_font)
|
||||
h2_w, h2_h = _text_size(d, header2, header_font)
|
||||
|
||||
row_text_sizes = [_text_size(d, t, body_font) for t in row_texts]
|
||||
row_box_widths = [tw + 2 * row_padding_xy[0] for (tw, th) in row_text_sizes]
|
||||
row_box_heights = [th + 2 * row_padding_xy[1] for (tw, th) in row_text_sizes]
|
||||
|
||||
content_width = max([h1_w, h2_w] + (row_box_widths or [0]))
|
||||
total_rows_height = sum(row_box_heights) + row_spacing * max(0, len(row_box_heights) - 1)
|
||||
|
||||
width = content_width + 2 * margin
|
||||
height = (
|
||||
margin
|
||||
+ h1_h
|
||||
+ header_spacing
|
||||
+ h2_h
|
||||
+ section_spacing
|
||||
+ total_rows_height
|
||||
+ margin
|
||||
)
|
||||
|
||||
# Create final image
|
||||
img = Image.new("RGB", (max(width, 300), max(height, 200)), bg_color)
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
# Draw headers
|
||||
x = margin
|
||||
y = margin
|
||||
draw.text((x, y), header1, font=header_font, fill=fg_color)
|
||||
y += h1_h + header_spacing
|
||||
draw.text((x, y), header2, font=header_font, fill=fg_color)
|
||||
y += h2_h + section_spacing
|
||||
|
||||
# Draw row boxes with evenly spaced numbers
|
||||
max_box_width = max(row_box_widths) if row_box_widths else 0
|
||||
for row, box_h in zip(obs.keypad, row_box_heights):
|
||||
box_left = x
|
||||
box_top = y
|
||||
box_right = x + max_box_width
|
||||
box_bottom = y + box_h
|
||||
|
||||
# draw row rectangle
|
||||
draw.rectangle(
|
||||
[box_left, box_top, box_right, box_bottom],
|
||||
fill=row_fill,
|
||||
outline=row_outline,
|
||||
width=2
|
||||
)
|
||||
|
||||
# evenly spaced numbers
|
||||
n = len(row)
|
||||
if n > 0:
|
||||
available_width = max_box_width - 2 * row_padding_xy[0]
|
||||
spacing = available_width / (n + 1)
|
||||
|
||||
for idx, num in enumerate(row, start=1):
|
||||
num_text = str(num)
|
||||
num_w, num_h = _text_size(draw, num_text, body_font)
|
||||
num_x = box_left + row_padding_xy[0] + spacing * idx - num_w / 2
|
||||
num_y = box_top + (box_h - num_h) // 2
|
||||
draw.text((num_x, num_y), num_text, font=body_font, fill=fg_color)
|
||||
|
||||
y = box_bottom + row_spacing
|
||||
|
||||
img.save(out_path, format="PNG")
|
||||
|
||||
def _next_run_dir(base_dir: Path) -> Path:
|
||||
"""Find the next available run directory under base_dir (run_001, run_002, ...)."""
|
||||
counter = 1
|
||||
while True:
|
||||
run_dir = base_dir / f"run_{counter:03d}"
|
||||
if not run_dir.exists():
|
||||
run_dir.mkdir(parents=True)
|
||||
return run_dir
|
||||
counter += 1
|
||||
|
||||
def render_sequence_to_pngs(seq: ObservationSequence, shuffle_type: ShuffleTypes, out_dir: Path | None = None) -> None:
|
||||
"""
|
||||
Render each observation to its own PNG inside a fresh run directory.
|
||||
Default: PROJECT_ROOT/output/obs_png/run_XXX/observation_001.png
|
||||
"""
|
||||
base_dir = BASE_DIR / shuffle_type.name / "obs_png" if out_dir is None else out_dir
|
||||
base_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create a fresh run dir
|
||||
run_dir = _next_run_dir(base_dir)
|
||||
|
||||
for i, obs in enumerate(seq.observations, start=1):
|
||||
filename = run_dir / f"observation_{i:03d}.png"
|
||||
render_observation_to_png(seq.target_passcode, obs, filename)
|
||||
if __name__ == "__main__":
|
||||
shuffle_type = ShuffleTypes.TOWER_SHUFFLE
|
||||
obs_seq = new_observation_sequence(6, 9,4,0,0 ,numb_runs=50, shuffle_type=shuffle_type)
|
||||
save_observation_sequence_to_json(obs_seq, shuffle_type)
|
||||
render_sequence_to_pngs(obs_seq, shuffle_type)
|
||||
Reference in New Issue
Block a user