diff --git a/app.py b/app.py index 0b4b828..cce9f23 100644 --- a/app.py +++ b/app.py @@ -37,6 +37,7 @@ def _extract_image_from_stdout(stdout: str) -> Path | None: def run_prompt_pipeline(prompt_text: str) -> tuple[Path | None, str, list[str]]: + save_path = "generated_card.png" cmd = [ sys.executable, "prompt_to_card_pipeline.py", prompt_text, @@ -46,7 +47,7 @@ def run_prompt_pipeline(prompt_text: str) -> tuple[Path | None, str, list[str]]: "--template", "clean-text-to-keywords/json_template_example.json", "--generator-module", "card_generator_adapter.py", "--device", "cuda", - "--save-path", "generated_card.png", + "--save-path", save_path, "--print-json", ] @@ -63,7 +64,9 @@ def run_prompt_pipeline(prompt_text: str) -> tuple[Path | None, str, list[str]]: if result.returncode != 0: return None, full_output.strip() or "Erreur inconnue pendant le pipeline.", cmd - image_path = _extract_image_from_stdout(result.stdout or "") + image_path = APP_DIR / save_path + if not image_path.exists(): + image_path = _extract_image_from_stdout(result.stdout or "") return image_path, full_output.strip(), cmd # ------------------------------------------------------------------ # diff --git a/card_generator_adapter.py b/card_generator_adapter.py index c88271e..3260e18 100644 --- a/card_generator_adapter.py +++ b/card_generator_adapter.py @@ -1,24 +1,33 @@ """Adapter to load the LoRA checkpoint and define conditioning logic. +Matches the training notebook (pokemon_card_training_2.ipynb): +- LoRA saved via PEFT's save_pretrained() → loaded via PeftModel.from_pretrained() +- Conditioning = JSON serialization of metadata (same format used during training) + Customize this file to match your model architecture, then use: --generator-module card_generator_adapter.py """ from __future__ import annotations +import json from typing import Any, Mapping def build_pipeline(checkpoint_path: str, device: str): - """Load LoRA adapter and return a callable pipeline. - - The pipeline must accept: - pipeline(prompt_or_conditioning, num_inference_steps=30, guidance_scale=7.5) - - and return an object with .images attribute. + """Load LoRA adapter via PEFT and return a callable SD pipeline. + + The LoRA was trained with peft.get_peft_model() on the UNet and saved + with unet.save_pretrained(). We reload it the same way, merge weights, + and plug the merged UNet into a StableDiffusionPipeline. """ from pathlib import Path + import torch + from diffusers import AutoencoderKL, StableDiffusionPipeline, UNet2DConditionModel + from peft import PeftModel + from transformers import CLIPTextModel, CLIPTokenizer + checkpoint_input = Path(checkpoint_path).expanduser().resolve() if checkpoint_input.is_dir(): checkpoint_dir = checkpoint_input @@ -26,82 +35,41 @@ def build_pipeline(checkpoint_path: str, device: str): checkpoint_dir = checkpoint_input.parent else: raise FileNotFoundError(f"Checkpoint path not found: {checkpoint_input}") - - # Load base Stable Diffusion model + LoRA adapter (PEFT) - try: - from diffusers import StableDiffusionPipeline - import torch - except ImportError as e: - raise RuntimeError( - f"diffusers and torch required. Install: pip install diffusers torch " - f"(error: {e})" - ) - - # Load base model - model_id = "runwayml/stable-diffusion-v1-5" - pipe = StableDiffusionPipeline.from_pretrained( - model_id, - torch_dtype=torch.float16 if device == "cuda" else torch.float32, - ) - pipe = pipe.to(device) - - # Load LoRA weights from adapter_model.safetensors + adapter_path = checkpoint_dir / "adapter_model.safetensors" - if adapter_path.exists(): - try: - pipe.load_lora_weights(str(checkpoint_dir)) - except Exception as e: - message = str(e) - if "PEFT backend is required" in message: - raise RuntimeError( - "Failed to load LoRA: PEFT backend is missing. " - "Install required packages with: pip install peft transformers accelerate safetensors" - ) from e - raise RuntimeError( - f"Failed to load LoRA from {checkpoint_dir}: {e}\n" - "Ensure adapter_config.json and adapter_model.safetensors are present." - ) from e - else: + if not adapter_path.exists(): raise FileNotFoundError( f"LoRA adapter not found at {adapter_path}. " f"Expected: adapter_model.safetensors in {checkpoint_dir}" ) - + + model_id = "runwayml/stable-diffusion-v1-5" + dtype = torch.float16 if device == "cuda" else torch.float32 + + # Load base components + vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=dtype) + unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", torch_dtype=dtype) + text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=dtype) + tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + + # Load LoRA via PEFT, then merge into base weights + unet = PeftModel.from_pretrained(unet, str(checkpoint_dir)) + unet = unet.merge_and_unload() + + pipe = StableDiffusionPipeline.from_pretrained( + model_id, + vae=vae, + unet=unet, + text_encoder=text_encoder, + tokenizer=tokenizer, + safety_checker=None, + torch_dtype=dtype, + ) + pipe = pipe.to(device) + return pipe def metadata_to_conditioning(meta: Mapping[str, Any]) -> str: - """Convert metadata dict to a Stable Diffusion prompt. - - LoRA is trained on Pokemon cards, so describe it as such. - """ - name = str(meta.get("name", "Unknown Pokemon")) - pokemon_type = str(meta.get("type", "normal")).capitalize() - secondary = meta.get("secondary_type") - - hp = str(meta.get("hp", "60")) - - attacks = meta.get("attacks") or [] - attack_list = [] - if isinstance(attacks, list): - for atk in attacks: - if isinstance(atk, dict): - attack_list.append(str(atk.get("name", "")).lower()) - elif atk: - attack_list.append(str(atk).lower()) - - # Build a descriptive prompt for card generation - prompt = f"Pokemon trading card of {name}, {pokemon_type}-type Pokemon" - if secondary: - prompt += f"/{secondary.capitalize()}" - prompt += f", HP {hp}" - - if attack_list: - prompt += f", with attacks: {', '.join(attack_list[:2])}" - - description = meta.get("description", "").strip() - if description: - prompt += f". {description}" - - prompt += ". High quality illustration, official Pokemon card style." - return prompt + """Serialize metadata dict to JSON, matching the training conditioning format.""" + return json.dumps(meta, sort_keys=True, ensure_ascii=False)