"""Adapter to load the LoRA checkpoint and define conditioning logic. Matches the training notebook (pokemon_card_training_2.ipynb): - LoRA saved via PEFT's save_pretrained() → loaded via PeftModel.from_pretrained() - Conditioning = JSON serialization of metadata (same format used during training) Customize this file to match your model architecture, then use: --generator-module card_generator_adapter.py """ from __future__ import annotations import json from typing import Any, Mapping def build_pipeline(checkpoint_path: str, device: str): """Load LoRA adapter via PEFT and return a callable SD pipeline. The LoRA was trained with peft.get_peft_model() on the UNet and saved with unet.save_pretrained(). We reload it the same way, merge weights, and plug the merged UNet into a StableDiffusionPipeline. """ from pathlib import Path import torch from diffusers import AutoencoderKL, StableDiffusionPipeline, UNet2DConditionModel from peft import PeftModel from transformers import CLIPTextModel, CLIPTokenizer checkpoint_input = Path(checkpoint_path).expanduser().resolve() if checkpoint_input.is_dir(): checkpoint_dir = checkpoint_input elif checkpoint_input.exists(): checkpoint_dir = checkpoint_input.parent else: raise FileNotFoundError(f"Checkpoint path not found: {checkpoint_input}") adapter_path = checkpoint_dir / "adapter_model.safetensors" if not adapter_path.exists(): raise FileNotFoundError( f"LoRA adapter not found at {adapter_path}. " f"Expected: adapter_model.safetensors in {checkpoint_dir}" ) model_id = "runwayml/stable-diffusion-v1-5" dtype = torch.float16 if device == "cuda" else torch.float32 # Load base components vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=dtype) unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", torch_dtype=dtype) text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=dtype) tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") # Load LoRA via PEFT, then merge into base weights unet = PeftModel.from_pretrained(unet, str(checkpoint_dir)) unet = unet.merge_and_unload() pipe = StableDiffusionPipeline.from_pretrained( model_id, vae=vae, unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, safety_checker=None, torch_dtype=dtype, ) pipe = pipe.to(device) return pipe def metadata_to_conditioning(meta: Mapping[str, Any]) -> str: """Serialize metadata dict to JSON, matching the training conditioning format.""" return json.dumps(meta, sort_keys=True, ensure_ascii=False)