Fix LoRA loading and conditioning to match training notebook #23

Merged
llabeyrie merged 1 commits from fix/lora-loading-and-conditioning into main 2026-03-19 23:31:31 +00:00
Showing only changes of commit fe830dea2e - Show all commits

View File

@@ -1,24 +1,33 @@
"""Adapter to load the LoRA checkpoint and define conditioning logic. """Adapter to load the LoRA checkpoint and define conditioning logic.
Matches the training notebook (pokemon_card_training_2.ipynb):
- LoRA saved via PEFT's save_pretrained() → loaded via PeftModel.from_pretrained()
- Conditioning = JSON serialization of metadata (same format used during training)
Customize this file to match your model architecture, then use: Customize this file to match your model architecture, then use:
--generator-module card_generator_adapter.py --generator-module card_generator_adapter.py
""" """
from __future__ import annotations from __future__ import annotations
import json
from typing import Any, Mapping from typing import Any, Mapping
def build_pipeline(checkpoint_path: str, device: str): def build_pipeline(checkpoint_path: str, device: str):
"""Load LoRA adapter and return a callable pipeline. """Load LoRA adapter via PEFT and return a callable SD pipeline.
The pipeline must accept: The LoRA was trained with peft.get_peft_model() on the UNet and saved
pipeline(prompt_or_conditioning, num_inference_steps=30, guidance_scale=7.5) with unet.save_pretrained(). We reload it the same way, merge weights,
and plug the merged UNet into a StableDiffusionPipeline.
and return an object with .images attribute.
""" """
from pathlib import Path from pathlib import Path
import torch
from diffusers import AutoencoderKL, StableDiffusionPipeline, UNet2DConditionModel
from peft import PeftModel
from transformers import CLIPTextModel, CLIPTokenizer
checkpoint_input = Path(checkpoint_path).expanduser().resolve() checkpoint_input = Path(checkpoint_path).expanduser().resolve()
if checkpoint_input.is_dir(): if checkpoint_input.is_dir():
checkpoint_dir = checkpoint_input checkpoint_dir = checkpoint_input
@@ -27,81 +36,40 @@ def build_pipeline(checkpoint_path: str, device: str):
else: else:
raise FileNotFoundError(f"Checkpoint path not found: {checkpoint_input}") raise FileNotFoundError(f"Checkpoint path not found: {checkpoint_input}")
# Load base Stable Diffusion model + LoRA adapter (PEFT)
try:
from diffusers import StableDiffusionPipeline
import torch
except ImportError as e:
raise RuntimeError(
f"diffusers and torch required. Install: pip install diffusers torch "
f"(error: {e})"
)
# Load base model
model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
)
pipe = pipe.to(device)
# Load LoRA weights from adapter_model.safetensors
adapter_path = checkpoint_dir / "adapter_model.safetensors" adapter_path = checkpoint_dir / "adapter_model.safetensors"
if adapter_path.exists(): if not adapter_path.exists():
try:
pipe.load_lora_weights(str(checkpoint_dir))
except Exception as e:
message = str(e)
if "PEFT backend is required" in message:
raise RuntimeError(
"Failed to load LoRA: PEFT backend is missing. "
"Install required packages with: pip install peft transformers accelerate safetensors"
) from e
raise RuntimeError(
f"Failed to load LoRA from {checkpoint_dir}: {e}\n"
"Ensure adapter_config.json and adapter_model.safetensors are present."
) from e
else:
raise FileNotFoundError( raise FileNotFoundError(
f"LoRA adapter not found at {adapter_path}. " f"LoRA adapter not found at {adapter_path}. "
f"Expected: adapter_model.safetensors in {checkpoint_dir}" f"Expected: adapter_model.safetensors in {checkpoint_dir}"
) )
model_id = "runwayml/stable-diffusion-v1-5"
dtype = torch.float16 if device == "cuda" else torch.float32
# Load base components
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=dtype)
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", torch_dtype=dtype)
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=dtype)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
# Load LoRA via PEFT, then merge into base weights
unet = PeftModel.from_pretrained(unet, str(checkpoint_dir))
unet = unet.merge_and_unload()
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
vae=vae,
unet=unet,
text_encoder=text_encoder,
tokenizer=tokenizer,
safety_checker=None,
torch_dtype=dtype,
)
pipe = pipe.to(device)
return pipe return pipe
def metadata_to_conditioning(meta: Mapping[str, Any]) -> str: def metadata_to_conditioning(meta: Mapping[str, Any]) -> str:
"""Convert metadata dict to a Stable Diffusion prompt. """Serialize metadata dict to JSON, matching the training conditioning format."""
return json.dumps(meta, sort_keys=True, ensure_ascii=False)
LoRA is trained on Pokemon cards, so describe it as such.
"""
name = str(meta.get("name", "Unknown Pokemon"))
pokemon_type = str(meta.get("type", "normal")).capitalize()
secondary = meta.get("secondary_type")
hp = str(meta.get("hp", "60"))
attacks = meta.get("attacks") or []
attack_list = []
if isinstance(attacks, list):
for atk in attacks:
if isinstance(atk, dict):
attack_list.append(str(atk.get("name", "")).lower())
elif atk:
attack_list.append(str(atk).lower())
# Build a descriptive prompt for card generation
prompt = f"Pokemon trading card of {name}, {pokemon_type}-type Pokemon"
if secondary:
prompt += f"/{secondary.capitalize()}"
prompt += f", HP {hp}"
if attack_list:
prompt += f", with attacks: {', '.join(attack_list[:2])}"
description = meta.get("description", "").strip()
if description:
prompt += f". {description}"
prompt += ". High quality illustration, official Pokemon card style."
return prompt