"""Adapter to load the LoRA checkpoint and define conditioning logic. Customize this file to match your model architecture, then use: --generator-module card_generator_adapter.py """ from __future__ import annotations from typing import Any, Mapping def build_pipeline(checkpoint_path: str, device: str): """Load LoRA adapter and return a callable pipeline. The pipeline must accept: pipeline(prompt_or_conditioning, num_inference_steps=30, guidance_scale=7.5) and return an object with .images attribute. """ from pathlib import Path checkpoint_input = Path(checkpoint_path).expanduser().resolve() if checkpoint_input.is_dir(): checkpoint_dir = checkpoint_input elif checkpoint_input.exists(): checkpoint_dir = checkpoint_input.parent else: raise FileNotFoundError(f"Checkpoint path not found: {checkpoint_input}") # Load base Stable Diffusion model + LoRA adapter (PEFT) try: from diffusers import StableDiffusionPipeline import torch except ImportError as e: raise RuntimeError( f"diffusers and torch required. Install: pip install diffusers torch " f"(error: {e})" ) # Load base model model_id = "runwayml/stable-diffusion-v1-5" pipe = StableDiffusionPipeline.from_pretrained( model_id, torch_dtype=torch.float16 if device == "cuda" else torch.float32, ) pipe = pipe.to(device) # Load LoRA weights from adapter_model.safetensors adapter_path = checkpoint_dir / "adapter_model.safetensors" if adapter_path.exists(): try: pipe.load_lora_weights(str(checkpoint_dir)) except Exception as e: message = str(e) if "PEFT backend is required" in message: raise RuntimeError( "Failed to load LoRA: PEFT backend is missing. " "Install required packages with: pip install peft transformers accelerate safetensors" ) from e raise RuntimeError( f"Failed to load LoRA from {checkpoint_dir}: {e}\n" "Ensure adapter_config.json and adapter_model.safetensors are present." ) from e else: raise FileNotFoundError( f"LoRA adapter not found at {adapter_path}. " f"Expected: adapter_model.safetensors in {checkpoint_dir}" ) return pipe def metadata_to_conditioning(meta: Mapping[str, Any]) -> str: """Convert metadata dict to a Stable Diffusion prompt. LoRA is trained on Pokemon cards, so describe it as such. """ name = str(meta.get("name", "Unknown Pokemon")) pokemon_type = str(meta.get("type", "normal")).capitalize() secondary = meta.get("secondary_type") hp = str(meta.get("hp", "60")) attacks = meta.get("attacks") or [] attack_list = [] if isinstance(attacks, list): for atk in attacks: if isinstance(atk, dict): attack_list.append(str(atk.get("name", "")).lower()) elif atk: attack_list.append(str(atk).lower()) # Build a descriptive prompt for card generation prompt = f"Pokemon trading card of {name}, {pokemon_type}-type Pokemon" if secondary: prompt += f"/{secondary.capitalize()}" prompt += f", HP {hp}" if attack_list: prompt += f", with attacks: {', '.join(attack_list[:2])}" description = meta.get("description", "").strip() if description: prompt += f". {description}" prompt += ". High quality illustration, official Pokemon card style." return prompt