Fix LoRA loading and conditioning to match training notebook
card_generator_adapter.py had two mismatches with the training notebook: 1. LoRA loading: used pipe.load_lora_weights() (diffusers format) but the adapter was saved with PEFT's save_pretrained() — keys didn't match, so no LoRA weights were actually applied. Now uses PeftModel.from_pretrained() + merge_and_unload(). 2. Conditioning: built a natural language prompt, but the LoRA was trained on json.dumps(meta) serialization. Now uses JSON serialization to match.
This commit is contained in:
@@ -1,24 +1,33 @@
|
|||||||
"""Adapter to load the LoRA checkpoint and define conditioning logic.
|
"""Adapter to load the LoRA checkpoint and define conditioning logic.
|
||||||
|
|
||||||
|
Matches the training notebook (pokemon_card_training_2.ipynb):
|
||||||
|
- LoRA saved via PEFT's save_pretrained() → loaded via PeftModel.from_pretrained()
|
||||||
|
- Conditioning = JSON serialization of metadata (same format used during training)
|
||||||
|
|
||||||
Customize this file to match your model architecture, then use:
|
Customize this file to match your model architecture, then use:
|
||||||
--generator-module card_generator_adapter.py
|
--generator-module card_generator_adapter.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
from typing import Any, Mapping
|
from typing import Any, Mapping
|
||||||
|
|
||||||
|
|
||||||
def build_pipeline(checkpoint_path: str, device: str):
|
def build_pipeline(checkpoint_path: str, device: str):
|
||||||
"""Load LoRA adapter and return a callable pipeline.
|
"""Load LoRA adapter via PEFT and return a callable SD pipeline.
|
||||||
|
|
||||||
The pipeline must accept:
|
The LoRA was trained with peft.get_peft_model() on the UNet and saved
|
||||||
pipeline(prompt_or_conditioning, num_inference_steps=30, guidance_scale=7.5)
|
with unet.save_pretrained(). We reload it the same way, merge weights,
|
||||||
|
and plug the merged UNet into a StableDiffusionPipeline.
|
||||||
and return an object with .images attribute.
|
|
||||||
"""
|
"""
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from diffusers import AutoencoderKL, StableDiffusionPipeline, UNet2DConditionModel
|
||||||
|
from peft import PeftModel
|
||||||
|
from transformers import CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
checkpoint_input = Path(checkpoint_path).expanduser().resolve()
|
checkpoint_input = Path(checkpoint_path).expanduser().resolve()
|
||||||
if checkpoint_input.is_dir():
|
if checkpoint_input.is_dir():
|
||||||
checkpoint_dir = checkpoint_input
|
checkpoint_dir = checkpoint_input
|
||||||
@@ -26,82 +35,41 @@ def build_pipeline(checkpoint_path: str, device: str):
|
|||||||
checkpoint_dir = checkpoint_input.parent
|
checkpoint_dir = checkpoint_input.parent
|
||||||
else:
|
else:
|
||||||
raise FileNotFoundError(f"Checkpoint path not found: {checkpoint_input}")
|
raise FileNotFoundError(f"Checkpoint path not found: {checkpoint_input}")
|
||||||
|
|
||||||
# Load base Stable Diffusion model + LoRA adapter (PEFT)
|
|
||||||
try:
|
|
||||||
from diffusers import StableDiffusionPipeline
|
|
||||||
import torch
|
|
||||||
except ImportError as e:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"diffusers and torch required. Install: pip install diffusers torch "
|
|
||||||
f"(error: {e})"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load base model
|
|
||||||
model_id = "runwayml/stable-diffusion-v1-5"
|
|
||||||
pipe = StableDiffusionPipeline.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
|
|
||||||
)
|
|
||||||
pipe = pipe.to(device)
|
|
||||||
|
|
||||||
# Load LoRA weights from adapter_model.safetensors
|
|
||||||
adapter_path = checkpoint_dir / "adapter_model.safetensors"
|
adapter_path = checkpoint_dir / "adapter_model.safetensors"
|
||||||
if adapter_path.exists():
|
if not adapter_path.exists():
|
||||||
try:
|
|
||||||
pipe.load_lora_weights(str(checkpoint_dir))
|
|
||||||
except Exception as e:
|
|
||||||
message = str(e)
|
|
||||||
if "PEFT backend is required" in message:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Failed to load LoRA: PEFT backend is missing. "
|
|
||||||
"Install required packages with: pip install peft transformers accelerate safetensors"
|
|
||||||
) from e
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Failed to load LoRA from {checkpoint_dir}: {e}\n"
|
|
||||||
"Ensure adapter_config.json and adapter_model.safetensors are present."
|
|
||||||
) from e
|
|
||||||
else:
|
|
||||||
raise FileNotFoundError(
|
raise FileNotFoundError(
|
||||||
f"LoRA adapter not found at {adapter_path}. "
|
f"LoRA adapter not found at {adapter_path}. "
|
||||||
f"Expected: adapter_model.safetensors in {checkpoint_dir}"
|
f"Expected: adapter_model.safetensors in {checkpoint_dir}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
model_id = "runwayml/stable-diffusion-v1-5"
|
||||||
|
dtype = torch.float16 if device == "cuda" else torch.float32
|
||||||
|
|
||||||
|
# Load base components
|
||||||
|
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=dtype)
|
||||||
|
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", torch_dtype=dtype)
|
||||||
|
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=dtype)
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||||
|
|
||||||
|
# Load LoRA via PEFT, then merge into base weights
|
||||||
|
unet = PeftModel.from_pretrained(unet, str(checkpoint_dir))
|
||||||
|
unet = unet.merge_and_unload()
|
||||||
|
|
||||||
|
pipe = StableDiffusionPipeline.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
vae=vae,
|
||||||
|
unet=unet,
|
||||||
|
text_encoder=text_encoder,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
safety_checker=None,
|
||||||
|
torch_dtype=dtype,
|
||||||
|
)
|
||||||
|
pipe = pipe.to(device)
|
||||||
|
|
||||||
return pipe
|
return pipe
|
||||||
|
|
||||||
|
|
||||||
def metadata_to_conditioning(meta: Mapping[str, Any]) -> str:
|
def metadata_to_conditioning(meta: Mapping[str, Any]) -> str:
|
||||||
"""Convert metadata dict to a Stable Diffusion prompt.
|
"""Serialize metadata dict to JSON, matching the training conditioning format."""
|
||||||
|
return json.dumps(meta, sort_keys=True, ensure_ascii=False)
|
||||||
LoRA is trained on Pokemon cards, so describe it as such.
|
|
||||||
"""
|
|
||||||
name = str(meta.get("name", "Unknown Pokemon"))
|
|
||||||
pokemon_type = str(meta.get("type", "normal")).capitalize()
|
|
||||||
secondary = meta.get("secondary_type")
|
|
||||||
|
|
||||||
hp = str(meta.get("hp", "60"))
|
|
||||||
|
|
||||||
attacks = meta.get("attacks") or []
|
|
||||||
attack_list = []
|
|
||||||
if isinstance(attacks, list):
|
|
||||||
for atk in attacks:
|
|
||||||
if isinstance(atk, dict):
|
|
||||||
attack_list.append(str(atk.get("name", "")).lower())
|
|
||||||
elif atk:
|
|
||||||
attack_list.append(str(atk).lower())
|
|
||||||
|
|
||||||
# Build a descriptive prompt for card generation
|
|
||||||
prompt = f"Pokemon trading card of {name}, {pokemon_type}-type Pokemon"
|
|
||||||
if secondary:
|
|
||||||
prompt += f"/{secondary.capitalize()}"
|
|
||||||
prompt += f", HP {hp}"
|
|
||||||
|
|
||||||
if attack_list:
|
|
||||||
prompt += f", with attacks: {', '.join(attack_list[:2])}"
|
|
||||||
|
|
||||||
description = meta.get("description", "").strip()
|
|
||||||
if description:
|
|
||||||
prompt += f". {description}"
|
|
||||||
|
|
||||||
prompt += ". High quality illustration, official Pokemon card style."
|
|
||||||
return prompt
|
|
||||||
|
|||||||
Reference in New Issue
Block a user