diff --git a/README.md b/README.md new file mode 100644 index 0000000..542e11e --- /dev/null +++ b/README.md @@ -0,0 +1,7 @@ +# Icon Generator for nKode + +## Install + +- download [wordnet](https://github.com/nltk/nltk_data/blob/gh-pages/packages/corpora/wordnet.zip) +- `mkdir ~/nltk_data/corpora` +- unzip wordnet and copy the folder into the new directory above \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 7b8f4b6..bb1e04f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,3 @@ openai~=1.57.3 +requests~=2.32.3 +nltk~=3.9.1 \ No newline at end of file diff --git a/src/db.py b/src/db.py index f7b9731..75923f6 100644 --- a/src/db.py +++ b/src/db.py @@ -1,44 +1,90 @@ -import sqlite3 -from pathlib import Path +from sqlalchemy import create_engine, Column, String, Enum, func +from sqlalchemy.orm import sessionmaker, declarative_base import uuid +import hashlib +from typing import Literal + +# Define the base for ORM mapping +Base = declarative_base() + +# Define the Images table as a Python class +class Image(Base): + __tablename__ = 'images' + + id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4())) + prompt = Column(String, nullable=False) + model = Column(Enum("dall-e-3", "dall-e-2", name="model_enum"), nullable=False) + size = Column(String, nullable=False, default="1024x1024") + quality = Column(Enum("standard", "hd", name="quality_enum"), nullable=False, default="standard") + filename = Column(String, nullable=False) + hash = Column(String, nullable=True) + +# Create a database engine and session factory +def create_db(db_path: str): + engine = create_engine(f'sqlite:///{db_path}') + Base.metadata.create_all(engine) + return engine + +def hash_image_row( + prompt: str, + model: str, + size: str, + quality: str, +) -> str: + # Create a unique hash from the concatenated parameters + concat_string = f"{prompt}:{model}:{size}:{quality}" + unique_hash = hashlib.sha256(concat_string.encode()).hexdigest() + return unique_hash -def create_db(db_path: Path): - conn = sqlite3.connect(db_path) - c = conn.cursor() - - # Create the table if it doesn't exist - c.execute(""" - CREATE TABLE IF NOT EXISTS images ( - id TEXT PRIMARY KEY, - prompt TEXT, - model TEXT, - size TEXT, - quality TEXT, - filename TEXT +def insert_image_row( + session, + image_id: str, + prompt: str, + hash: str, + model: Literal["dall-e-3", "dall-e-2"] = "dall-e-3", + size: Literal["1024x1024"] = "1024x1024", + quality: Literal["standard", "hd"] = "standard", +): + image = Image( + id=image_id, + prompt=prompt, + model=model, + size=size, + quality=quality, + filename=f"{image_id}.png", + hash=hash ) - """) - - # Commit and close the connection - conn.commit() - conn.close() + session.add(image) + session.commit() -def insert_into_db(db_path: Path, image_id: uuid, prompt: str, model: str = "dall-e-3", size: str = "1024x1024", quality: str = "standard"): - conn = sqlite3.connect(db_path) - c = conn.cursor() - # Insert the record into the database - c.execute( - "INSERT INTO images (id, prompt, model, size, quality, filename) VALUES (?, ?, ?, ?, ?, ?)", - ( - image_id, - prompt, - model, - size, - quality, - f"{image_id}.png" - ) +def count_hash(session, hash_value: str): + return session.query(func.count(Image.hash)).filter(Image.hash == hash_value).scalar() + + +if __name__ == "__main__": + # Example usage + db_path = "../output/images.db" + engine = create_db(db_path) + Session = sessionmaker(bind=engine) + session = Session() + + prompt = "A futuristic cityscape" + model = "dall-e-3" + size = "1024x1024" + quality = "standard" + hash = hash_image_row(prompt, model, size, quality) + # Insert an example record + insert_image_row( + session, + image_id=str(uuid.uuid4()), + prompt=prompt, + model=model, + size=size, + quality="standard", + hash=hash ) - # Commit and close the connection - conn.commit() - conn.close() + + session.close() + diff --git a/src/image_gen.py b/src/image_gen.py index d1bf1aa..6d8fc46 100644 --- a/src/image_gen.py +++ b/src/image_gen.py @@ -45,3 +45,8 @@ def icon_gen(prompt: str, quality: Literal["hd", "standard"], output: Path = "./ with open(image_filepath, "wb") as f: f.write(image_response.content) print(image_id) + + +if __name__ == "__main__": + dog_emoji = image_style("dog in a hat on a beach with a drink", style=Styles.cartoon) + icon_gen(dog_emoji, output=Path("../output"), quality="hd") diff --git a/src/wordlist.py b/src/wordlist.py index db06200..2ec2ff6 100644 --- a/src/wordlist.py +++ b/src/wordlist.py @@ -93,3 +93,18 @@ def is_imageable(noun): return True return False + + +if __name__ == "__main__": + nounlist = get_noun_list(Path("../data")) + banned_words = get_banned_wordlist(Path("../data")) + print(len(nounlist)) + print(len( banned_words)) + filtered_words = filter_banned_words_list(noun_list=nounlist, banned_words=banned_words) + print(len(filtered_words)) + + imageable = filter_nonimageable_words(filtered_words) + print(len(imageable)) + print(f"imageable: {len(filtered_words) - len(imageable)}") + with open("../data/imageable.txt", "w") as fp: + fp.write("\n".join(imageable)) \ No newline at end of file diff --git a/tests/image_gen_test.py b/tests/image_gen_test.py deleted file mode 100644 index 78e5f6a..0000000 --- a/tests/image_gen_test.py +++ /dev/null @@ -1,6 +0,0 @@ -from src.image_gen import image_style, Styles, icon_gen -from pathlib import Path - -if __name__ == "__main__": - dog_emoji = image_style("dog in a hat on a beach with a drink", style=Styles.cartoon) - icon_gen(dog_emoji, output=Path("../output"), quality="hd") diff --git a/tests/wordlist_test.py b/tests/wordlist_test.py deleted file mode 100644 index e82a47d..0000000 --- a/tests/wordlist_test.py +++ /dev/null @@ -1,16 +0,0 @@ -from src.wordlist import get_noun_list, get_banned_wordlist, filter_banned_words_list, filter_nonimageable_words -from pathlib import Path - -if __name__ == "__main__": - nounlist = get_noun_list(Path("../data")) - banned_words = get_banned_wordlist(Path("../data")) - print(len(nounlist)) - print(len( banned_words)) - filtered_words = filter_banned_words_list(noun_list=nounlist, banned_words=banned_words) - print(len(filtered_words)) - - imageable = filter_nonimageable_words(filtered_words) - print(len(imageable)) - print(f"imageable: {len(filtered_words) - len(imageable)}") - with open("../data/imageable.txt", "w") as fp: - fp.write("\n".join(imageable)) \ No newline at end of file