Fix image detection and LoRA loading to match training notebook

app.py: check the known save path directly instead of parsing stdout
(broken after PR #21 removed the print statement).

card_generator_adapter.py: two mismatches with the training notebook:
1. LoRA loading used pipe.load_lora_weights() (diffusers format) but the
   adapter was saved with PEFT's save_pretrained() — keys didn't match.
   Now uses PeftModel.from_pretrained() + merge_and_unload().
2. Conditioning built a natural language prompt, but the LoRA was trained
   on json.dumps(meta). Now uses JSON serialization to match.
This commit is contained in:
2026-03-19 20:14:17 +01:00
parent 984bbbec18
commit 7749d5ec35
2 changed files with 49 additions and 78 deletions

5
app.py
View File

@@ -37,6 +37,7 @@ def _extract_image_from_stdout(stdout: str) -> Path | None:
def run_prompt_pipeline(prompt_text: str) -> tuple[Path | None, str, list[str]]: def run_prompt_pipeline(prompt_text: str) -> tuple[Path | None, str, list[str]]:
save_path = "generated_card.png"
cmd = [ cmd = [
sys.executable, "prompt_to_card_pipeline.py", sys.executable, "prompt_to_card_pipeline.py",
prompt_text, prompt_text,
@@ -46,7 +47,7 @@ def run_prompt_pipeline(prompt_text: str) -> tuple[Path | None, str, list[str]]:
"--template", "clean-text-to-keywords/json_template_example.json", "--template", "clean-text-to-keywords/json_template_example.json",
"--generator-module", "card_generator_adapter.py", "--generator-module", "card_generator_adapter.py",
"--device", "cuda", "--device", "cuda",
"--save-path", "generated_card.png", "--save-path", save_path,
"--print-json", "--print-json",
] ]
@@ -63,6 +64,8 @@ def run_prompt_pipeline(prompt_text: str) -> tuple[Path | None, str, list[str]]:
if result.returncode != 0: if result.returncode != 0:
return None, full_output.strip() or "Erreur inconnue pendant le pipeline.", cmd return None, full_output.strip() or "Erreur inconnue pendant le pipeline.", cmd
image_path = APP_DIR / save_path
if not image_path.exists():
image_path = _extract_image_from_stdout(result.stdout or "") image_path = _extract_image_from_stdout(result.stdout or "")
return image_path, full_output.strip(), cmd return image_path, full_output.strip(), cmd

View File

@@ -1,24 +1,33 @@
"""Adapter to load the LoRA checkpoint and define conditioning logic. """Adapter to load the LoRA checkpoint and define conditioning logic.
Matches the training notebook (pokemon_card_training_2.ipynb):
- LoRA saved via PEFT's save_pretrained() → loaded via PeftModel.from_pretrained()
- Conditioning = JSON serialization of metadata (same format used during training)
Customize this file to match your model architecture, then use: Customize this file to match your model architecture, then use:
--generator-module card_generator_adapter.py --generator-module card_generator_adapter.py
""" """
from __future__ import annotations from __future__ import annotations
import json
from typing import Any, Mapping from typing import Any, Mapping
def build_pipeline(checkpoint_path: str, device: str): def build_pipeline(checkpoint_path: str, device: str):
"""Load LoRA adapter and return a callable pipeline. """Load LoRA adapter via PEFT and return a callable SD pipeline.
The pipeline must accept: The LoRA was trained with peft.get_peft_model() on the UNet and saved
pipeline(prompt_or_conditioning, num_inference_steps=30, guidance_scale=7.5) with unet.save_pretrained(). We reload it the same way, merge weights,
and plug the merged UNet into a StableDiffusionPipeline.
and return an object with .images attribute.
""" """
from pathlib import Path from pathlib import Path
import torch
from diffusers import AutoencoderKL, StableDiffusionPipeline, UNet2DConditionModel
from peft import PeftModel
from transformers import CLIPTextModel, CLIPTokenizer
checkpoint_input = Path(checkpoint_path).expanduser().resolve() checkpoint_input = Path(checkpoint_path).expanduser().resolve()
if checkpoint_input.is_dir(): if checkpoint_input.is_dir():
checkpoint_dir = checkpoint_input checkpoint_dir = checkpoint_input
@@ -27,81 +36,40 @@ def build_pipeline(checkpoint_path: str, device: str):
else: else:
raise FileNotFoundError(f"Checkpoint path not found: {checkpoint_input}") raise FileNotFoundError(f"Checkpoint path not found: {checkpoint_input}")
# Load base Stable Diffusion model + LoRA adapter (PEFT)
try:
from diffusers import StableDiffusionPipeline
import torch
except ImportError as e:
raise RuntimeError(
f"diffusers and torch required. Install: pip install diffusers torch "
f"(error: {e})"
)
# Load base model
model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
)
pipe = pipe.to(device)
# Load LoRA weights from adapter_model.safetensors
adapter_path = checkpoint_dir / "adapter_model.safetensors" adapter_path = checkpoint_dir / "adapter_model.safetensors"
if adapter_path.exists(): if not adapter_path.exists():
try:
pipe.load_lora_weights(str(checkpoint_dir))
except Exception as e:
message = str(e)
if "PEFT backend is required" in message:
raise RuntimeError(
"Failed to load LoRA: PEFT backend is missing. "
"Install required packages with: pip install peft transformers accelerate safetensors"
) from e
raise RuntimeError(
f"Failed to load LoRA from {checkpoint_dir}: {e}\n"
"Ensure adapter_config.json and adapter_model.safetensors are present."
) from e
else:
raise FileNotFoundError( raise FileNotFoundError(
f"LoRA adapter not found at {adapter_path}. " f"LoRA adapter not found at {adapter_path}. "
f"Expected: adapter_model.safetensors in {checkpoint_dir}" f"Expected: adapter_model.safetensors in {checkpoint_dir}"
) )
model_id = "runwayml/stable-diffusion-v1-5"
dtype = torch.float16 if device == "cuda" else torch.float32
# Load base components
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=dtype)
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", torch_dtype=dtype)
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=dtype)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
# Load LoRA via PEFT, then merge into base weights
unet = PeftModel.from_pretrained(unet, str(checkpoint_dir))
unet = unet.merge_and_unload()
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
vae=vae,
unet=unet,
text_encoder=text_encoder,
tokenizer=tokenizer,
safety_checker=None,
torch_dtype=dtype,
)
pipe = pipe.to(device)
return pipe return pipe
def metadata_to_conditioning(meta: Mapping[str, Any]) -> str: def metadata_to_conditioning(meta: Mapping[str, Any]) -> str:
"""Convert metadata dict to a Stable Diffusion prompt. """Serialize metadata dict to JSON, matching the training conditioning format."""
return json.dumps(meta, sort_keys=True, ensure_ascii=False)
LoRA is trained on Pokemon cards, so describe it as such.
"""
name = str(meta.get("name", "Unknown Pokemon"))
pokemon_type = str(meta.get("type", "normal")).capitalize()
secondary = meta.get("secondary_type")
hp = str(meta.get("hp", "60"))
attacks = meta.get("attacks") or []
attack_list = []
if isinstance(attacks, list):
for atk in attacks:
if isinstance(atk, dict):
attack_list.append(str(atk.get("name", "")).lower())
elif atk:
attack_list.append(str(atk).lower())
# Build a descriptive prompt for card generation
prompt = f"Pokemon trading card of {name}, {pokemon_type}-type Pokemon"
if secondary:
prompt += f"/{secondary.capitalize()}"
prompt += f", HP {hp}"
if attack_list:
prompt += f", with attacks: {', '.join(attack_list[:2])}"
description = meta.get("description", "").strip()
if description:
prompt += f". {description}"
prompt += ". High quality illustration, official Pokemon card style."
return prompt