first commit
This commit is contained in:
107
card_generator_adapter.py
Normal file
107
card_generator_adapter.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""Adapter to load the LoRA checkpoint and define conditioning logic.
|
||||
|
||||
Customize this file to match your model architecture, then use:
|
||||
--generator-module card_generator_adapter.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Mapping
|
||||
|
||||
|
||||
def build_pipeline(checkpoint_path: str, device: str):
|
||||
"""Load LoRA adapter and return a callable pipeline.
|
||||
|
||||
The pipeline must accept:
|
||||
pipeline(prompt_or_conditioning, num_inference_steps=30, guidance_scale=7.5)
|
||||
|
||||
and return an object with .images attribute.
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
checkpoint_input = Path(checkpoint_path).expanduser().resolve()
|
||||
if checkpoint_input.is_dir():
|
||||
checkpoint_dir = checkpoint_input
|
||||
elif checkpoint_input.exists():
|
||||
checkpoint_dir = checkpoint_input.parent
|
||||
else:
|
||||
raise FileNotFoundError(f"Checkpoint path not found: {checkpoint_input}")
|
||||
|
||||
# Load base Stable Diffusion model + LoRA adapter (PEFT)
|
||||
try:
|
||||
from diffusers import StableDiffusionPipeline
|
||||
import torch
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
f"diffusers and torch required. Install: pip install diffusers torch "
|
||||
f"(error: {e})"
|
||||
)
|
||||
|
||||
# Load base model
|
||||
model_id = "runwayml/stable-diffusion-v1-5"
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
model_id,
|
||||
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
|
||||
)
|
||||
pipe = pipe.to(device)
|
||||
|
||||
# Load LoRA weights from adapter_model.safetensors
|
||||
adapter_path = checkpoint_dir / "adapter_model.safetensors"
|
||||
if adapter_path.exists():
|
||||
try:
|
||||
pipe.load_lora_weights(str(checkpoint_dir))
|
||||
except Exception as e:
|
||||
message = str(e)
|
||||
if "PEFT backend is required" in message:
|
||||
raise RuntimeError(
|
||||
"Failed to load LoRA: PEFT backend is missing. "
|
||||
"Install required packages with: pip install peft transformers accelerate safetensors"
|
||||
) from e
|
||||
raise RuntimeError(
|
||||
f"Failed to load LoRA from {checkpoint_dir}: {e}\n"
|
||||
"Ensure adapter_config.json and adapter_model.safetensors are present."
|
||||
) from e
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
f"LoRA adapter not found at {adapter_path}. "
|
||||
f"Expected: adapter_model.safetensors in {checkpoint_dir}"
|
||||
)
|
||||
|
||||
return pipe
|
||||
|
||||
|
||||
def metadata_to_conditioning(meta: Mapping[str, Any]) -> str:
|
||||
"""Convert metadata dict to a Stable Diffusion prompt.
|
||||
|
||||
LoRA is trained on Pokemon cards, so describe it as such.
|
||||
"""
|
||||
name = str(meta.get("name", "Unknown Pokemon"))
|
||||
pokemon_type = str(meta.get("type", "normal")).capitalize()
|
||||
secondary = meta.get("secondary_type")
|
||||
|
||||
hp = str(meta.get("hp", "60"))
|
||||
|
||||
attacks = meta.get("attacks") or []
|
||||
attack_list = []
|
||||
if isinstance(attacks, list):
|
||||
for atk in attacks:
|
||||
if isinstance(atk, dict):
|
||||
attack_list.append(str(atk.get("name", "")).lower())
|
||||
elif atk:
|
||||
attack_list.append(str(atk).lower())
|
||||
|
||||
# Build a descriptive prompt for card generation
|
||||
prompt = f"Pokemon trading card of {name}, {pokemon_type}-type Pokemon"
|
||||
if secondary:
|
||||
prompt += f"/{secondary.capitalize()}"
|
||||
prompt += f", HP {hp}"
|
||||
|
||||
if attack_list:
|
||||
prompt += f", with attacks: {', '.join(attack_list[:2])}"
|
||||
|
||||
description = meta.get("description", "").strip()
|
||||
if description:
|
||||
prompt += f". {description}"
|
||||
|
||||
prompt += ". High quality illustration, official Pokemon card style."
|
||||
return prompt
|
||||
Reference in New Issue
Block a user