app.py: check the known save path directly instead of parsing stdout (broken after PR #21 removed the print statement). card_generator_adapter.py: two mismatches with the training notebook: 1. LoRA loading used pipe.load_lora_weights() (diffusers format) but the adapter was saved with PEFT's save_pretrained() — keys didn't match. Now uses PeftModel.from_pretrained() + merge_and_unload(). 2. Conditioning built a natural language prompt, but the LoRA was trained on json.dumps(meta). Now uses JSON serialization to match.
76 lines
2.7 KiB
Python
76 lines
2.7 KiB
Python
"""Adapter to load the LoRA checkpoint and define conditioning logic.
|
|
|
|
Matches the training notebook (pokemon_card_training_2.ipynb):
|
|
- LoRA saved via PEFT's save_pretrained() → loaded via PeftModel.from_pretrained()
|
|
- Conditioning = JSON serialization of metadata (same format used during training)
|
|
|
|
Customize this file to match your model architecture, then use:
|
|
--generator-module card_generator_adapter.py
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from typing import Any, Mapping
|
|
|
|
|
|
def build_pipeline(checkpoint_path: str, device: str):
|
|
"""Load LoRA adapter via PEFT and return a callable SD pipeline.
|
|
|
|
The LoRA was trained with peft.get_peft_model() on the UNet and saved
|
|
with unet.save_pretrained(). We reload it the same way, merge weights,
|
|
and plug the merged UNet into a StableDiffusionPipeline.
|
|
"""
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
from diffusers import AutoencoderKL, StableDiffusionPipeline, UNet2DConditionModel
|
|
from peft import PeftModel
|
|
from transformers import CLIPTextModel, CLIPTokenizer
|
|
|
|
checkpoint_input = Path(checkpoint_path).expanduser().resolve()
|
|
if checkpoint_input.is_dir():
|
|
checkpoint_dir = checkpoint_input
|
|
elif checkpoint_input.exists():
|
|
checkpoint_dir = checkpoint_input.parent
|
|
else:
|
|
raise FileNotFoundError(f"Checkpoint path not found: {checkpoint_input}")
|
|
|
|
adapter_path = checkpoint_dir / "adapter_model.safetensors"
|
|
if not adapter_path.exists():
|
|
raise FileNotFoundError(
|
|
f"LoRA adapter not found at {adapter_path}. "
|
|
f"Expected: adapter_model.safetensors in {checkpoint_dir}"
|
|
)
|
|
|
|
model_id = "runwayml/stable-diffusion-v1-5"
|
|
dtype = torch.float16 if device == "cuda" else torch.float32
|
|
|
|
# Load base components
|
|
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=dtype)
|
|
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", torch_dtype=dtype)
|
|
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=dtype)
|
|
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
|
|
|
# Load LoRA via PEFT, then merge into base weights
|
|
unet = PeftModel.from_pretrained(unet, str(checkpoint_dir))
|
|
unet = unet.merge_and_unload()
|
|
|
|
pipe = StableDiffusionPipeline.from_pretrained(
|
|
model_id,
|
|
vae=vae,
|
|
unet=unet,
|
|
text_encoder=text_encoder,
|
|
tokenizer=tokenizer,
|
|
safety_checker=None,
|
|
torch_dtype=dtype,
|
|
)
|
|
pipe = pipe.to(device)
|
|
|
|
return pipe
|
|
|
|
|
|
def metadata_to_conditioning(meta: Mapping[str, Any]) -> str:
|
|
"""Serialize metadata dict to JSON, matching the training conditioning format."""
|
|
return json.dumps(meta, sort_keys=True, ensure_ascii=False)
|