From 9fdf79842dd8d3ced7127ba51a4df2e28550ee8b Mon Sep 17 00:00:00 2001 From: Donovan Date: Sun, 9 Mar 2025 07:39:29 -0500 Subject: [PATCH] refactor asserts --- docs/render_markdown.py | 2 +- src/customer.py | 12 ++++++--- src/customer_attributes.py | 20 +++++++------- src/nkode_api.py | 50 +++++++++++++++++++---------------- src/user.py | 2 +- src/user_cipher_keys.py | 10 ++++--- src/user_interface.py | 23 +++++++++------- src/user_signup_session.py | 15 ++++++----- src/utils.py | 6 +++-- test/test_nkode_interface.py | 2 +- test/test_user_cipher_keys.py | 4 +-- test/test_user_interface.py | 2 +- 12 files changed, 84 insertions(+), 64 deletions(-) diff --git a/docs/render_markdown.py b/docs/render_markdown.py index 5b61aaa..ddd1439 100644 --- a/docs/render_markdown.py +++ b/docs/render_markdown.py @@ -181,7 +181,7 @@ if __name__ == "__main__": """ REFRESH USER KEYS """ - user.user_keys = UserCipherKeys.new( + user.user_keys = UserCipherKeys.create( customer.attributes.keypad_size, customer.attributes.set_vals, user.user_keys.max_nkode_len diff --git a/src/customer.py b/src/customer.py index bc278bf..8b25694 100644 --- a/src/customer.py +++ b/src/customer.py @@ -14,13 +14,19 @@ class Customer(BaseModel): # TODO: validate policy and keypad size don't conflict - def add_new_user(self, user: User): + if user.username in self.users: + raise ValueError(f"User with username '{user.username}' already exists") self.users[user.username] = user def valid_key_entry(self, username, selected_keys) -> bool: - assert (username in self.users.keys()) - assert (all(0 <= key_idx < self.attributes.keypad_size.numb_of_keys for key_idx in selected_keys)) + if username not in self.users: + raise ValueError(f"User '{username}' does not exist") + + keypad_size = self.attributes.keypad_size.numb_of_keys + if not all(0 <= key_idx < keypad_size for key_idx in selected_keys): + raise ValueError(f"Invalid key indices. Must be between 0 and {keypad_size - 1}") + passcode_len = len(selected_keys) user = self.users[username] diff --git a/src/customer_attributes.py b/src/customer_attributes.py index 768af45..10929ce 100644 --- a/src/customer_attributes.py +++ b/src/customer_attributes.py @@ -1,5 +1,5 @@ +from typing import ClassVar from pydantic import BaseModel, model_validator - from src.models import KeypadSize from src.utils import generate_random_nonrepeating_list @@ -8,20 +8,21 @@ class CustomerAttributes(BaseModel): attr_vals: list[int] set_vals: list[int] keypad_size: KeypadSize + MAX_KEYS: ClassVar[int] = 256 + MAX_ATTRS_PER_KEY: ClassVar[int] = 256 @model_validator(mode='after') - def check_keys_vs_attrs(self): + def check_keys_vs_attrs(self) -> 'CustomerAttributes': if self.keypad_size.is_dispersable: raise ValueError("number of keys must be less than the number of " "attributes per key to be dispersion resistant") return self - @staticmethod - def new(keypad_size: KeypadSize): - assert (keypad_size.numb_of_keys <= 256) - assert (keypad_size.attrs_per_key <= 256) - - return CustomerAttributes( + @classmethod + def create(cls, keypad_size: KeypadSize) -> 'CustomerAttributes': + if keypad_size.numb_of_keys > cls.MAX_KEYS or keypad_size.attrs_per_key > cls.MAX_ATTRS_PER_KEY: + raise ValueError(f"Keys and attributes per key must not exceed {cls.MAX_KEYS}") + return cls( attr_vals=generate_random_nonrepeating_list(keypad_size.numb_of_attrs), set_vals=generate_random_nonrepeating_list(keypad_size.attrs_per_key), keypad_size=keypad_size, @@ -38,5 +39,6 @@ class CustomerAttributes(BaseModel): return self.set_vals[set_idx] def get_set_index(self, set_val: int) -> int: - assert (set_val in self.set_vals) + if set_val not in self.set_vals: + raise ValueError(f"Set value {set_val} not found in set values") return self.set_vals.index(set_val) diff --git a/src/nkode_api.py b/src/nkode_api.py index 5d9cd4e..d425ab0 100644 --- a/src/nkode_api.py +++ b/src/nkode_api.py @@ -16,18 +16,18 @@ class NKodeAPI(BaseModel): def create_new_customer(self, keypad_size: KeypadSize, nkode_policy: NKodePolicy) -> UUID: new_customer = Customer( customer_id=uuid4(), - attributes=CustomerAttributes.new(keypad_size), + attributes=CustomerAttributes.create(keypad_size), users={}, nkode_policy=nkode_policy ) self.customers[new_customer.customer_id] = new_customer - return new_customer.customer_id def generate_signup_interface(self, customer_id: UUID) -> tuple[UUID, list[int]]: - assert (customer_id in self.customers.keys()) + if customer_id not in self.customers.keys(): + raise ValueError(f"Customer with ID '{customer_id}' does not exist") customer = self.customers[customer_id] - login_interface = UserInterface.new(customer.attributes.keypad_size) + login_interface = UserInterface.create(customer.attributes.keypad_size) set_interface = login_interface.sign_up_interface() new_session = UserSignupSession( session_id=uuid4(), @@ -46,10 +46,13 @@ class NKodeAPI(BaseModel): key_selection: list[int], session_id: UUID ) -> list[int]: - assert (customer_id in self.customers.keys()) + if customer_id not in self.customers.keys(): + raise ValueError(f"Customer ID {customer_id} not found") customer = self.customers[customer_id] - assert (username not in customer.users.keys()) - assert (session_id in self.signup_sessions.keys()) + if username in customer.users.keys(): + raise ValueError(f"Username '{username}' already exists for this customer") + if session_id not in self.signup_sessions.keys(): + raise ValueError(f"Session ID {session_id} not found") self.signup_sessions[session_id].set_user_nkode(username, key_selection) return self.signup_sessions[session_id].confirm_interface @@ -60,20 +63,20 @@ class NKodeAPI(BaseModel): confirm_key_entry: list[int], session_id: UUID ) -> bool: - assert ( - session_id in self.signup_sessions.keys() and - customer_id == self.signup_sessions[session_id].customer_id and - username == self.signup_sessions[session_id].username - ) + if session_id not in self.signup_sessions.keys(): + raise AssertionError(f"Session ID {session_id} not found in signup sessions") + session = self.signup_sessions[session_id] + if customer_id != session.customer_id: + raise AssertionError(f"Customer ID mismatch: {customer_id} vs {session.customer_id}") + if username != session.username: + raise AssertionError(f"Username mismatch: {username} vs {session.username}") customer = self.customers[customer_id] - passcode = self.signup_sessions[session_id].deduce_passcode(confirm_key_entry) - new_user_keys = UserCipherKeys.new( + new_user_keys = UserCipherKeys.create( customer.attributes.keypad_size, customer.attributes.set_vals, customer.nkode_policy.max_nkode_len ) - enciphered_passcode = new_user_keys.encipher_nkode(passcode, customer.attributes) new_user = User( username=username, @@ -86,21 +89,22 @@ class NKodeAPI(BaseModel): return True def get_login_interface(self, username: str, customer_id: UUID) -> list[int]: - """ - TODO: how do we prevent a targeted denial-of-service attack? - """ - assert (customer_id in self.customers.keys()) + if customer_id not in self.customers.keys(): + raise ValueError("Customer ID not found") customer = self.customers[customer_id] - assert (username in customer.users.keys()) + if username not in customer.users.keys(): + raise ValueError("Username not found") user = customer.users[username] user.user_interface.partial_interface_shuffle() return user.user_interface.interface def login(self, customer_id: UUID, username: str, key_selection: list[int]) -> bool: - assert (customer_id in self.customers.keys()) + if customer_id not in self.customers.keys(): + raise ValueError("Customer ID not found") customer = self.customers[customer_id] return customer.valid_key_entry(username, key_selection) def renew_attributes(self, customer_id: UUID) -> bool: - assert (customer_id in self.customers.keys()) - return self.customers[customer_id].renew_keys() + if customer_id not in self.customers.keys(): + raise ValueError("Customer ID not found") + return self.customers[customer_id].renew_keys() \ No newline at end of file diff --git a/src/user.py b/src/user.py index e9c7b75..9e7a864 100644 --- a/src/user.py +++ b/src/user.py @@ -20,7 +20,7 @@ class User(BaseModel): self.user_keys.alpha_key = xor_lists(self.user_keys.alpha_key, attrs_xor) def refresh_passcode(self, passcode_attr_idx: list[int], customer_interface: CustomerAttributes): - self.user_keys = UserCipherKeys.new( + self.user_keys = UserCipherKeys.create( customer_interface.keypad_size, customer_interface.set_vals, self.user_keys.max_nkode_len diff --git a/src/user_cipher_keys.py b/src/user_cipher_keys.py index 391bffc..8f1bbe8 100644 --- a/src/user_cipher_keys.py +++ b/src/user_cipher_keys.py @@ -17,9 +17,10 @@ class UserCipherKeys(BaseModel): salt: bytes max_nkode_len: int - @staticmethod - def new(keypad_size: KeypadSize, set_values: list[int], max_nkode_len: int): - assert len(set_values) == keypad_size.attrs_per_key + @classmethod + def create(cls, keypad_size: KeypadSize, set_values: list[int], max_nkode_len: int) -> 'UserCipherKeys': + if len(set_values) != keypad_size.attrs_per_key: + raise ValueError("Invalid set values") set_key = generate_random_nonrepeating_list(keypad_size.attrs_per_key) set_key = xor_lists(set_key, set_values) @@ -34,7 +35,8 @@ class UserCipherKeys(BaseModel): ) def pad_user_mask(self, user_mask: list[int], set_vals: list[int]) -> list[int]: - assert (len(user_mask) <= self.max_nkode_len) + if len(user_mask) >= self.max_nkode_len: + raise ValueError("User mask is too long") padded_user_mask = user_mask.copy() for _ in range(self.max_nkode_len - len(user_mask)): padded_user_mask.append(choice(set_vals)) diff --git a/src/user_interface.py b/src/user_interface.py index 673272e..bb23e39 100644 --- a/src/user_interface.py +++ b/src/user_interface.py @@ -9,9 +9,8 @@ class UserInterface(BaseModel): interface: list[int] keypad_size: KeypadSize - @staticmethod - def new(keypad_size: KeypadSize): - # Todo: this a hack do a proper random interface + @classmethod + def create(cls, keypad_size: KeypadSize) -> 'UserInterface': interface = UserInterface( interface=list(range(keypad_size.numb_of_attrs)), keypad_size=keypad_size @@ -20,7 +19,8 @@ class UserInterface(BaseModel): return interface def sign_up_interface(self): - assert (not self.keypad_size.is_dispersable) + if self.keypad_size.is_dispersable: + raise ValueError("Keypad size is dispersable") self.random_interface_shuffle() interface_matrix = self.interface_keypad_matrix() attr_set_view = matrix_transpose(interface_matrix) @@ -47,7 +47,8 @@ class UserInterface(BaseModel): self.interface = matrix_to_list(keypad_view) def disperse_interface(self): - assert (self.keypad_size.is_dispersable) + if not self.keypad_size.is_dispersable: + raise ValueError("Keypad size is not dispersable") user_interface_matrix = list_to_matrix(self.interface, self.keypad_size.attrs_per_key) shuffled_keys = secure_fisher_yates_shuffle(user_interface_matrix) @@ -60,6 +61,7 @@ class UserInterface(BaseModel): self.interface = matrix_to_list(dispersed_interface) def partial_interface_shuffle(self): + # TODO: this should be split shuffle numb_of_selected_sets = self.keypad_size.attrs_per_key // 2 # randomly shuffle half the sets. if attrs_per_key is odd, randomly add one 50% of the time numb_of_selected_sets += choice([0, 1]) if (self.keypad_size.attrs_per_key & 1) == 1 else 0 @@ -80,7 +82,8 @@ class UserInterface(BaseModel): attr_rotation: list[int] ) -> list[list[int]]: transposed_user_interface = matrix_transpose(user_interface) - assert (len(attr_rotation) == len(transposed_user_interface)) + if len(attr_rotation) != len(transposed_user_interface): + raise ValueError("attr_rotation must be the same length as the transposed user interface") for idx, attr_set in enumerate(transposed_user_interface): rotation = attr_rotation[idx] transposed_user_interface[idx] = attr_set[rotation:] + attr_set[:rotation] @@ -93,12 +96,12 @@ class UserInterface(BaseModel): for attr in key: graph[attr] = set(key) graph[attr].remove(attr) - return graph def get_attr_idx_by_keynumb_setidx(self, key_numb: int, set_idx: int) -> int: - assert (0 <= key_numb < self.keypad_size.numb_of_keys) - assert (0 <= set_idx < self.keypad_size.attrs_per_key) + if not (0 <= key_numb < self.keypad_size.numb_of_keys): + raise ValueError(f"key_numb must be between 0 and {self.keypad_size.numb_of_keys - 1}") + if not (0 <= set_idx < self.keypad_size.attrs_per_key): + raise ValueError(f"set_idx must be between 0 and {self.keypad_size.attrs_per_key - 1}") keypad_attr_idx = self.interface_keypad_matrix() - return keypad_attr_idx[key_numb][set_idx] diff --git a/src/user_signup_session.py b/src/user_signup_session.py index 2bedf59..539f5fd 100644 --- a/src/user_signup_session.py +++ b/src/user_signup_session.py @@ -10,7 +10,6 @@ class UserSignupSession(BaseModel): session_id: UUID customer_id: UUID login_interface: UserInterface - keypad_size: KeypadSize set_interface: list[int] | None = None confirm_interface: list[int] | None = None @@ -18,28 +17,30 @@ class UserSignupSession(BaseModel): username: str | None = None def deduce_passcode(self, confirm_key_entry: list[int]) -> list[int]: - assert (all(0 <= key <= self.keypad_size.numb_of_keys for key in confirm_key_entry)) + if not all(0 <= key <= self.keypad_size.numb_of_keys for key in confirm_key_entry): + raise ValueError("Key values must be within valid range") attrs_per_key = self.keypad_size.attrs_per_key - set_key_entry = self.set_key_entry - assert (len(set_key_entry) == len(confirm_key_entry)) + if len(set_key_entry) != len(confirm_key_entry): + raise ValueError("Key entry lengths must match") set_interface = self.set_interface confirm_interface = self.confirm_interface set_key_vals = [set_interface[key * attrs_per_key:(key + 1) * attrs_per_key] for key in set_key_entry] confirm_key_vals = [confirm_interface[key * attrs_per_key:(key + 1) * attrs_per_key] for key in confirm_key_entry] - passcode = [] for idx in range(len(set_key_entry)): set_key = set(set_key_vals[idx]) confirm_key = set(confirm_key_vals[idx]) intersection = list(set_key.intersection(confirm_key)) - assert (len(intersection) == 1) + if len(intersection) != 1: + raise ValueError("Intersection must contain exactly one element") passcode.append(intersection[0]) return passcode def set_user_nkode(self, username: str, key_selection: list[int]): - assert (all(0 <= key <= self.keypad_size.numb_of_keys for key in key_selection)) + if not all(0 <= key <= self.keypad_size.numb_of_keys for key in key_selection): + raise ValueError("Key values must be within valid range") set_interface = UserInterface( interface=self.set_interface, keypad_size=self.keypad_size diff --git a/src/utils.py b/src/utils.py index d2c6692..8a28da2 100644 --- a/src/utils.py +++ b/src/utils.py @@ -10,12 +10,14 @@ def secure_fisher_yates_shuffle(arr: list) -> list: def generate_random_nonrepeating_list(list_len: int, min_val: int = 0, max_val: int = 2 ** 16) -> list[int]: - assert (max_val - min_val >= list_len) + if max_val - min_val < list_len: + raise ValueError("Range of values is less than the list length requested") return secure_fisher_yates_shuffle(list(range(min_val, max_val)))[:list_len] def xor_lists(l1: list[int], l2: list[int]): - assert len(l1) == len(l2) + if len(l1) != len(l2): + raise ValueError("Lists must be of equal length") return [l2[i] ^ l1[i] for i in range(len(l1))] diff --git a/test/test_nkode_interface.py b/test/test_nkode_interface.py index 05f048d..da3d4d4 100644 --- a/test/test_nkode_interface.py +++ b/test/test_nkode_interface.py @@ -8,7 +8,7 @@ from src.models import KeypadSize [KeypadSize(numb_of_keys=10, attrs_per_key=11)] ) def test_attr_set_idx(keypad_size): - user_interface = UserInterface.new(keypad_size) + user_interface = UserInterface.create(keypad_size) for attr_idx in range(keypad_size.numb_of_attrs): user_interface_idx = user_interface.interface[attr_idx] diff --git a/test/test_user_cipher_keys.py b/test/test_user_cipher_keys.py index 528551d..1b7e11f 100644 --- a/test/test_user_cipher_keys.py +++ b/test/test_user_cipher_keys.py @@ -27,13 +27,13 @@ def test_encode_decode_base64(passcode_len): (KeypadSize(numb_of_keys=8, attrs_per_key=11), 12), ]) def test_decode_mask(keypad_size, max_nkode_len): - customer = CustomerAttributes.new(keypad_size) + customer = CustomerAttributes.create(keypad_size) passcode_entry = generate_random_nonrepeating_list( keypad_size.numb_of_attrs, max_val=keypad_size.numb_of_attrs)[:4] passcode_values = [customer.attr_vals[idx] for idx in passcode_entry] set_vals = customer.set_vals - user_keys = UserCipherKeys.new(keypad_size, set_vals, max_nkode_len) + user_keys = UserCipherKeys.create(keypad_size, set_vals, max_nkode_len) passcode = user_keys.encipher_nkode(passcode_entry, customer) orig_passcode_set_vals = [customer.get_attr_set_val(attr) for attr in passcode_values] diff --git a/test/test_user_interface.py b/test/test_user_interface.py index d3fd2f3..563d4ec 100644 --- a/test/test_user_interface.py +++ b/test/test_user_interface.py @@ -5,7 +5,7 @@ from src.models import KeypadSize @pytest.fixture() def user_interface(): - return UserInterface.new(keypad_size=KeypadSize(attrs_per_key=7, numb_of_keys=10)) + return UserInterface.create(keypad_size=KeypadSize(attrs_per_key=7, numb_of_keys=10)) def test_dispersion(user_interface):