Files
Juicepyter/card_generator_adapter.py
Louis Labeyrie 7749d5ec35 Fix image detection and LoRA loading to match training notebook
app.py: check the known save path directly instead of parsing stdout
(broken after PR #21 removed the print statement).

card_generator_adapter.py: two mismatches with the training notebook:
1. LoRA loading used pipe.load_lora_weights() (diffusers format) but the
   adapter was saved with PEFT's save_pretrained() — keys didn't match.
   Now uses PeftModel.from_pretrained() + merge_and_unload().
2. Conditioning built a natural language prompt, but the LoRA was trained
   on json.dumps(meta). Now uses JSON serialization to match.
2026-03-20 00:24:25 +01:00

76 lines
2.7 KiB
Python

"""Adapter to load the LoRA checkpoint and define conditioning logic.
Matches the training notebook (pokemon_card_training_2.ipynb):
- LoRA saved via PEFT's save_pretrained() → loaded via PeftModel.from_pretrained()
- Conditioning = JSON serialization of metadata (same format used during training)
Customize this file to match your model architecture, then use:
--generator-module card_generator_adapter.py
"""
from __future__ import annotations
import json
from typing import Any, Mapping
def build_pipeline(checkpoint_path: str, device: str):
"""Load LoRA adapter via PEFT and return a callable SD pipeline.
The LoRA was trained with peft.get_peft_model() on the UNet and saved
with unet.save_pretrained(). We reload it the same way, merge weights,
and plug the merged UNet into a StableDiffusionPipeline.
"""
from pathlib import Path
import torch
from diffusers import AutoencoderKL, StableDiffusionPipeline, UNet2DConditionModel
from peft import PeftModel
from transformers import CLIPTextModel, CLIPTokenizer
checkpoint_input = Path(checkpoint_path).expanduser().resolve()
if checkpoint_input.is_dir():
checkpoint_dir = checkpoint_input
elif checkpoint_input.exists():
checkpoint_dir = checkpoint_input.parent
else:
raise FileNotFoundError(f"Checkpoint path not found: {checkpoint_input}")
adapter_path = checkpoint_dir / "adapter_model.safetensors"
if not adapter_path.exists():
raise FileNotFoundError(
f"LoRA adapter not found at {adapter_path}. "
f"Expected: adapter_model.safetensors in {checkpoint_dir}"
)
model_id = "runwayml/stable-diffusion-v1-5"
dtype = torch.float16 if device == "cuda" else torch.float32
# Load base components
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=dtype)
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", torch_dtype=dtype)
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=dtype)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
# Load LoRA via PEFT, then merge into base weights
unet = PeftModel.from_pretrained(unet, str(checkpoint_dir))
unet = unet.merge_and_unload()
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
vae=vae,
unet=unet,
text_encoder=text_encoder,
tokenizer=tokenizer,
safety_checker=None,
torch_dtype=dtype,
)
pipe = pipe.to(device)
return pipe
def metadata_to_conditioning(meta: Mapping[str, Any]) -> str:
"""Serialize metadata dict to JSON, matching the training conditioning format."""
return json.dumps(meta, sort_keys=True, ensure_ascii=False)