Compare commits

4 Commits

Author SHA1 Message Date
5e49efd7cb Merge pull request 'Fix LoRA loading and conditioning to match training notebook' (#23) from fix/lora-loading-and-conditioning into main
Reviewed-on: #23
2026-03-19 23:31:31 +00:00
fe830dea2e Fix LoRA loading and conditioning to match training notebook
card_generator_adapter.py had two mismatches with the training notebook:

1. LoRA loading: used pipe.load_lora_weights() (diffusers format) but the
   adapter was saved with PEFT's save_pretrained() — keys didn't match,
   so no LoRA weights were actually applied. Now uses
   PeftModel.from_pretrained() + merge_and_unload().

2. Conditioning: built a natural language prompt, but the LoRA was trained
   on json.dumps(meta) serialization. Now uses JSON serialization to match.
2026-03-20 00:28:05 +01:00
dbf4946875 Merge pull request 'Fix image detection after print-to-logging migration' (#22) from fix/app-image-detection into main
Reviewed-on: #22
2026-03-19 19:15:44 +00:00
e03daea1f3 Fix image detection after print-to-logging migration
app.py relied on parsing stdout for the save path printed by the
pipeline. After PR #21 replaced that print with logger.info(), the
stdout parsing returned None, causing "Aucune image générée détectée".

Now check the known save path directly (APP_DIR / save_path) and
only fall back to stdout parsing if the file isn't found.
2026-03-19 20:14:17 +01:00
2 changed files with 49 additions and 78 deletions

7
app.py
View File

@@ -37,6 +37,7 @@ def _extract_image_from_stdout(stdout: str) -> Path | None:
def run_prompt_pipeline(prompt_text: str) -> tuple[Path | None, str, list[str]]:
save_path = "generated_card.png"
cmd = [
sys.executable, "prompt_to_card_pipeline.py",
prompt_text,
@@ -46,7 +47,7 @@ def run_prompt_pipeline(prompt_text: str) -> tuple[Path | None, str, list[str]]:
"--template", "clean-text-to-keywords/json_template_example.json",
"--generator-module", "card_generator_adapter.py",
"--device", "cuda",
"--save-path", "generated_card.png",
"--save-path", save_path,
"--print-json",
]
@@ -63,7 +64,9 @@ def run_prompt_pipeline(prompt_text: str) -> tuple[Path | None, str, list[str]]:
if result.returncode != 0:
return None, full_output.strip() or "Erreur inconnue pendant le pipeline.", cmd
image_path = _extract_image_from_stdout(result.stdout or "")
image_path = APP_DIR / save_path
if not image_path.exists():
image_path = _extract_image_from_stdout(result.stdout or "")
return image_path, full_output.strip(), cmd
# ------------------------------------------------------------------ #

View File

@@ -1,24 +1,33 @@
"""Adapter to load the LoRA checkpoint and define conditioning logic.
Matches the training notebook (pokemon_card_training_2.ipynb):
- LoRA saved via PEFT's save_pretrained() → loaded via PeftModel.from_pretrained()
- Conditioning = JSON serialization of metadata (same format used during training)
Customize this file to match your model architecture, then use:
--generator-module card_generator_adapter.py
"""
from __future__ import annotations
import json
from typing import Any, Mapping
def build_pipeline(checkpoint_path: str, device: str):
"""Load LoRA adapter and return a callable pipeline.
"""Load LoRA adapter via PEFT and return a callable SD pipeline.
The pipeline must accept:
pipeline(prompt_or_conditioning, num_inference_steps=30, guidance_scale=7.5)
and return an object with .images attribute.
The LoRA was trained with peft.get_peft_model() on the UNet and saved
with unet.save_pretrained(). We reload it the same way, merge weights,
and plug the merged UNet into a StableDiffusionPipeline.
"""
from pathlib import Path
import torch
from diffusers import AutoencoderKL, StableDiffusionPipeline, UNet2DConditionModel
from peft import PeftModel
from transformers import CLIPTextModel, CLIPTokenizer
checkpoint_input = Path(checkpoint_path).expanduser().resolve()
if checkpoint_input.is_dir():
checkpoint_dir = checkpoint_input
@@ -27,81 +36,40 @@ def build_pipeline(checkpoint_path: str, device: str):
else:
raise FileNotFoundError(f"Checkpoint path not found: {checkpoint_input}")
# Load base Stable Diffusion model + LoRA adapter (PEFT)
try:
from diffusers import StableDiffusionPipeline
import torch
except ImportError as e:
raise RuntimeError(
f"diffusers and torch required. Install: pip install diffusers torch "
f"(error: {e})"
)
# Load base model
model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
)
pipe = pipe.to(device)
# Load LoRA weights from adapter_model.safetensors
adapter_path = checkpoint_dir / "adapter_model.safetensors"
if adapter_path.exists():
try:
pipe.load_lora_weights(str(checkpoint_dir))
except Exception as e:
message = str(e)
if "PEFT backend is required" in message:
raise RuntimeError(
"Failed to load LoRA: PEFT backend is missing. "
"Install required packages with: pip install peft transformers accelerate safetensors"
) from e
raise RuntimeError(
f"Failed to load LoRA from {checkpoint_dir}: {e}\n"
"Ensure adapter_config.json and adapter_model.safetensors are present."
) from e
else:
if not adapter_path.exists():
raise FileNotFoundError(
f"LoRA adapter not found at {adapter_path}. "
f"Expected: adapter_model.safetensors in {checkpoint_dir}"
)
model_id = "runwayml/stable-diffusion-v1-5"
dtype = torch.float16 if device == "cuda" else torch.float32
# Load base components
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=dtype)
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", torch_dtype=dtype)
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=dtype)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
# Load LoRA via PEFT, then merge into base weights
unet = PeftModel.from_pretrained(unet, str(checkpoint_dir))
unet = unet.merge_and_unload()
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
vae=vae,
unet=unet,
text_encoder=text_encoder,
tokenizer=tokenizer,
safety_checker=None,
torch_dtype=dtype,
)
pipe = pipe.to(device)
return pipe
def metadata_to_conditioning(meta: Mapping[str, Any]) -> str:
"""Convert metadata dict to a Stable Diffusion prompt.
LoRA is trained on Pokemon cards, so describe it as such.
"""
name = str(meta.get("name", "Unknown Pokemon"))
pokemon_type = str(meta.get("type", "normal")).capitalize()
secondary = meta.get("secondary_type")
hp = str(meta.get("hp", "60"))
attacks = meta.get("attacks") or []
attack_list = []
if isinstance(attacks, list):
for atk in attacks:
if isinstance(atk, dict):
attack_list.append(str(atk.get("name", "")).lower())
elif atk:
attack_list.append(str(atk).lower())
# Build a descriptive prompt for card generation
prompt = f"Pokemon trading card of {name}, {pokemon_type}-type Pokemon"
if secondary:
prompt += f"/{secondary.capitalize()}"
prompt += f", HP {hp}"
if attack_list:
prompt += f", with attacks: {', '.join(attack_list[:2])}"
description = meta.get("description", "").strip()
if description:
prompt += f". {description}"
prompt += ". High quality illustration, official Pokemon card style."
return prompt
"""Serialize metadata dict to JSON, matching the training conditioning format."""
return json.dumps(meta, sort_keys=True, ensure_ascii=False)