implement word list and image gen
This commit is contained in:
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
.idea
|
||||||
|
output
|
||||||
|
venv
|
||||||
1654
data/banned_words.txt
Normal file
1654
data/banned_words.txt
Normal file
File diff suppressed because it is too large
Load Diff
4345
data/imageable.txt
Normal file
4345
data/imageable.txt
Normal file
File diff suppressed because it is too large
Load Diff
6801
data/nounlist.csv
Normal file
6801
data/nounlist.csv
Normal file
File diff suppressed because it is too large
Load Diff
1
requirements.txt
Normal file
1
requirements.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
openai~=1.57.3
|
||||||
0
src/__init__.py
Normal file
0
src/__init__.py
Normal file
44
src/db.py
Normal file
44
src/db.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
import sqlite3
|
||||||
|
from pathlib import Path
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
|
||||||
|
def create_db(db_path: Path):
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
c = conn.cursor()
|
||||||
|
|
||||||
|
# Create the table if it doesn't exist
|
||||||
|
c.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS images (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
prompt TEXT,
|
||||||
|
model TEXT,
|
||||||
|
size TEXT,
|
||||||
|
quality TEXT,
|
||||||
|
filename TEXT
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Commit and close the connection
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
def insert_into_db(db_path: Path, image_id: uuid, prompt: str, model: str = "dall-e-3", size: str = "1024x1024", quality: str = "standard"):
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
c = conn.cursor()
|
||||||
|
# Insert the record into the database
|
||||||
|
c.execute(
|
||||||
|
"INSERT INTO images (id, prompt, model, size, quality, filename) VALUES (?, ?, ?, ?, ?, ?)",
|
||||||
|
(
|
||||||
|
image_id,
|
||||||
|
prompt,
|
||||||
|
model,
|
||||||
|
size,
|
||||||
|
quality,
|
||||||
|
f"{image_id}.png"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Commit and close the connection
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
47
src/image_gen.py
Normal file
47
src/image_gen.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
from openai import OpenAI
|
||||||
|
from pathlib import Path
|
||||||
|
import uuid
|
||||||
|
import requests
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
class Styles(Enum):
|
||||||
|
emoji = "emoji"
|
||||||
|
pixel_art = "pixel art"
|
||||||
|
svg = "svg"
|
||||||
|
cartoon = "cartoon"
|
||||||
|
|
||||||
|
client = OpenAI()
|
||||||
|
|
||||||
|
|
||||||
|
def image_style(base_prompt: str, style: Styles) -> str:
|
||||||
|
return f"create {style.name} style of {base_prompt}"
|
||||||
|
|
||||||
|
def icon_gen(prompt: str, quality: Literal["hd", "standard"], output: Path = "./output/"):
|
||||||
|
# Make sure output directory exists
|
||||||
|
output_path = Path(output) / quality
|
||||||
|
output_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Generate the image using the OpenAI client
|
||||||
|
response = client.images.generate(
|
||||||
|
model="dall-e-3",
|
||||||
|
prompt=prompt,
|
||||||
|
size="1024x1024",
|
||||||
|
quality=quality,
|
||||||
|
n=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract the image URL
|
||||||
|
image_url = response.data[0].url
|
||||||
|
|
||||||
|
# Generate a UUID for the filename
|
||||||
|
image_id = str(uuid.uuid4())
|
||||||
|
image_filename = f"{image_id}.png"
|
||||||
|
image_filepath = output_path / image_filename
|
||||||
|
|
||||||
|
# Download the image
|
||||||
|
image_response = requests.get(image_url)
|
||||||
|
image_response.raise_for_status()
|
||||||
|
with open(image_filepath, "wb") as f:
|
||||||
|
f.write(image_response.content)
|
||||||
|
print(image_id)
|
||||||
95
src/wordlist.py
Normal file
95
src/wordlist.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
import csv
|
||||||
|
import os
|
||||||
|
import requests
|
||||||
|
import zipfile
|
||||||
|
from io import BytesIO
|
||||||
|
from pathlib import Path
|
||||||
|
import nltk
|
||||||
|
from nltk.corpus import wordnet
|
||||||
|
|
||||||
|
|
||||||
|
def get_noun_list(output_dir: Path = Path("../data")) -> list[str]:
|
||||||
|
nounlist = output_dir / "nounlist.csv"
|
||||||
|
if not nounlist.exists():
|
||||||
|
# URL of the ZIP file
|
||||||
|
url = "https://www.kaggle.com/api/v1/datasets/download/leite0407/list-of-nouns"
|
||||||
|
|
||||||
|
# Create the data directory if it doesn't exist
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Download the ZIP file
|
||||||
|
print("Downloading the ZIP file...")
|
||||||
|
response = requests.get(url, stream=True)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
print("Download complete. Extracting files...")
|
||||||
|
with zipfile.ZipFile(BytesIO(response.content)) as zip_ref:
|
||||||
|
zip_ref.extractall(output_dir)
|
||||||
|
print(f"Files extracted to {output_dir}")
|
||||||
|
else:
|
||||||
|
raise Exception(f"Failed to download the file. Status code: {response.status_code}")
|
||||||
|
|
||||||
|
with open(nounlist, mode='r', encoding='utf-8') as csv_file:
|
||||||
|
reader = csv.reader(csv_file)
|
||||||
|
return [row[0] for row in reader if row] # Skip empty rows
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def get_banned_wordlist(output_dir: Path) -> list[str]:
|
||||||
|
banned_word_file = output_dir / "banned_words.txt"
|
||||||
|
|
||||||
|
if not banned_word_file.exists():
|
||||||
|
print("getting banned words")
|
||||||
|
# Sources of banned words
|
||||||
|
sources = [
|
||||||
|
"https://raw.githubusercontent.com/LDNOOBW/List-of-Dirty-Naughty-Obscene-and-Otherwise-Bad-Words/master/en",
|
||||||
|
"https://www.cs.cmu.edu/~biglou/resources/bad-words.txt"
|
||||||
|
]
|
||||||
|
|
||||||
|
banned_words = set()
|
||||||
|
|
||||||
|
# Download and combine
|
||||||
|
for url in sources:
|
||||||
|
response = requests.get(url)
|
||||||
|
if response.status_code == 200:
|
||||||
|
banned_words.update(response.text.splitlines())
|
||||||
|
|
||||||
|
# Save to banned_words.txt
|
||||||
|
with open(banned_word_file, "w") as file:
|
||||||
|
file.write("\n".join(sorted(banned_words)))
|
||||||
|
|
||||||
|
print("saved banned words")
|
||||||
|
|
||||||
|
with open(banned_word_file, "r") as file:
|
||||||
|
banned_words = file.read().splitlines()
|
||||||
|
if len(banned_words) <= 0:
|
||||||
|
raise Exception("no banned words")
|
||||||
|
return banned_words
|
||||||
|
|
||||||
|
|
||||||
|
def filter_banned_words_list(noun_list: list[str], banned_words: list[str]) -> list[str]:
|
||||||
|
filtered_list = set(noun_list) - set(banned_words)
|
||||||
|
print(f"Removed {len(noun_list) - len(filtered_list)} words")
|
||||||
|
return list(filtered_list)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_nonimageable_words(nouns: list[str]) -> list[str]:
|
||||||
|
return [noun for noun in nouns if is_imageable(noun)]
|
||||||
|
|
||||||
|
|
||||||
|
def is_imageable(noun):
|
||||||
|
# Get all the noun synsets for the given word
|
||||||
|
synsets = wordnet.synsets(noun, pos=wordnet.NOUN)
|
||||||
|
if not synsets:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# The canonical synset for "physical_entity"
|
||||||
|
physical_entity = wordnet.synset('physical_entity.n.01')
|
||||||
|
|
||||||
|
for syn in synsets:
|
||||||
|
# Traverse up the hypernym tree
|
||||||
|
for ancestor in syn.closure(lambda s: s.hypernyms()):
|
||||||
|
if ancestor == physical_entity:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
6
tests/image_gen_test.py
Normal file
6
tests/image_gen_test.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
from src.image_gen import image_style, Styles, icon_gen
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
dog_emoji = image_style("dog in a hat on a beach with a drink", style=Styles.cartoon)
|
||||||
|
icon_gen(dog_emoji, output=Path("../output"), quality="hd")
|
||||||
16
tests/wordlist_test.py
Normal file
16
tests/wordlist_test.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
from src.wordlist import get_noun_list, get_banned_wordlist, filter_banned_words_list, filter_nonimageable_words
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
nounlist = get_noun_list(Path("../data"))
|
||||||
|
banned_words = get_banned_wordlist(Path("../data"))
|
||||||
|
print(len(nounlist))
|
||||||
|
print(len( banned_words))
|
||||||
|
filtered_words = filter_banned_words_list(noun_list=nounlist, banned_words=banned_words)
|
||||||
|
print(len(filtered_words))
|
||||||
|
|
||||||
|
imageable = filter_nonimageable_words(filtered_words)
|
||||||
|
print(len(imageable))
|
||||||
|
print(f"imageable: {len(filtered_words) - len(imageable)}")
|
||||||
|
with open("../data/imageable.txt", "w") as fp:
|
||||||
|
fp.write("\n".join(imageable))
|
||||||
Reference in New Issue
Block a user