Files
Juicepyter/prompt_to_card_pipeline.py
Louis Labeyrie e1317b5839 Replace debug print() statements with proper logging
Replace 27 unconditional print() calls with Python's logging module.
Debug messages now use logger.debug() and the card-saved message uses
logger.info(). Only legitimate user-facing output (--print-json,
--print-clean-text) remains on stdout.

Fixes #3
2026-03-19 19:45:09 +01:00

342 lines
12 KiB
Python

"""End-to-end prompt -> cleaned text -> inferred JSON -> generated card image.
This script is built to connect the three stages described by the user:
1) call get_clean_text(user_text) from a text-cleaning module file
2) pass cleaned text into infer_json_usage.py with --json-only --template
3) load a checkpoint and generate a card image from inferred metadata
The model-loading part is intentionally pluggable because checkpoint structures vary.
If your .pt checkpoint cannot be used directly as a callable pipeline, provide a
generator module implementing:
def build_pipeline(checkpoint_path: str, device: str): ...
def metadata_to_conditioning(meta: dict) -> str: ... # optional
"""
from __future__ import annotations
import argparse
import importlib
import importlib.util
import json
import logging
import subprocess
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Mapping
logger = logging.getLogger(__name__)
def _load_module_from_file(module_file: str):
module_path = Path(module_file).resolve()
if not module_path.exists():
raise FileNotFoundError(f"Module file not found: {module_path}")
spec = importlib.util.spec_from_file_location(module_path.stem, str(module_path))
if spec is None or spec.loader is None:
raise ImportError(f"Cannot import module from file: {module_path}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
logger.debug("Module loaded from %s", module_path)
return module
def _load_function_from_file(module_file: str, function_name: str) -> Callable[..., Any]:
logger.debug("Loading function '%s' from %s", function_name, module_file)
module = _load_module_from_file(module_file)
logger.debug("Function '%s' loaded", function_name)
if not hasattr(module, function_name):
raise AttributeError(f"{module_file} has no function named '{function_name}'")
func = getattr(module, function_name)
if not callable(func):
raise TypeError(f"{function_name} in {module_file} is not callable")
return func
def _extract_json_from_output(raw: str) -> Mapping[str, Any]:
logger.debug("Extracting JSON from inference output")
stripped = raw.strip()
if not stripped:
raise ValueError("Inference command returned empty output")
try:
parsed = json.loads(stripped)
if not isinstance(parsed, dict):
raise ValueError("Inference output is JSON but not an object")
return parsed
except json.JSONDecodeError:
pass
# Fallback: parse the last JSON object in mixed stdout.
last_open = stripped.rfind("{")
last_close = stripped.rfind("}")
if last_open == -1 or last_close == -1 or last_close <= last_open:
raise ValueError(f"Could not parse JSON from inference output:\n{raw}")
candidate = stripped[last_open : last_close + 1]
parsed = json.loads(candidate)
logger.debug("JSON parsed from fallback extraction")
if not isinstance(parsed, dict):
raise ValueError("Parsed fallback JSON is not an object")
return parsed
def run_infer_json_cli(
infer_script_path: str,
template_path: str,
cleaned_text: str,
python_executable: str | None = None,
) -> Mapping[str, Any]:
infer_script = Path(infer_script_path).resolve()
logger.debug("Running JSON inference CLI: %s", infer_script)
if not infer_script.exists():
raise FileNotFoundError(f"infer_json_usage.py not found: {infer_script}")
template_file = Path(template_path).resolve()
if not template_file.exists():
raise FileNotFoundError(f"Template file not found: {template_file}")
cmd = [
python_executable or sys.executable,
str(infer_script),
"--json-only",
"--template",
str(template_file),
cleaned_text,
]
logger.debug("Launching inference subprocess")
result = subprocess.run(cmd, capture_output=True, text=True, check=False)
if result.returncode != 0:
stderr = result.stderr.strip()
raise RuntimeError(
"JSON inference command failed. "
f"exit={result.returncode}, stderr={stderr or '<empty>'}"
)
logger.debug("Inference subprocess completed successfully")
return _extract_json_from_output(result.stdout)
def default_metadata_to_conditioning(meta: Mapping[str, Any]) -> str:
logger.debug("Building default conditioning from metadata")
name = str(meta.get("name", "Unknown Pokemon"))
types = meta.get("types") or []
if isinstance(types, list):
type_text = ", ".join(str(item) for item in types if item) or str(meta.get("type", "normal"))
else:
type_text = str(meta.get("type", "normal"))
attacks = meta.get("attacks") or []
attack_names = []
if isinstance(attacks, list):
for attack in attacks:
if isinstance(attack, dict):
value = attack.get("name")
if value:
attack_names.append(str(value))
elif attack:
attack_names.append(str(attack))
hp = str(meta.get("hp", "60"))
description = str(meta.get("description", ""))
parts = [
f"Pokemon trading card illustration of {name}",
f"type: {type_text}",
f"hp: {hp}",
]
if attack_names:
parts.append(f"attacks: {', '.join(attack_names[:2])}")
if description:
parts.append(f"description: {description}")
return "; ".join(parts)
@dataclass
class CheckpointCardGenerator:
checkpoint_path: str
device: str = "cpu"
generator_module_path: str = ""
def __post_init__(self) -> None:
self._pipe = self._build_pipe()
self._metadata_to_conditioning = self._build_conditioning_function()
def _build_pipe(self):
if self.generator_module_path:
logger.debug("Loading generator module: %s", self.generator_module_path)
module = _load_module_from_file(self.generator_module_path)
logger.debug("Generator module loaded")
if not hasattr(module, "build_pipeline"):
raise AttributeError(
"Custom generator module must define build_pipeline(checkpoint_path, device)."
)
logger.debug("Building pipeline from generator module")
build_pipeline = getattr(module, "build_pipeline")
if not callable(build_pipeline):
raise TypeError("build_pipeline exists but is not callable")
logger.debug("Pipeline built successfully")
return build_pipeline(self.checkpoint_path, self.device)
# Best-effort direct checkpoint loading for simple callable pipeline dumps.
try:
torch = importlib.import_module("torch")
except ModuleNotFoundError as exc:
raise RuntimeError(
"torch is required to load checkpoint files. Install torch or provide --generator-module."
) from exc
logger.debug("Loading checkpoint from %s", self.checkpoint_path)
checkpoint = torch.load(self.checkpoint_path, map_location=self.device)
logger.debug("Checkpoint loaded")
if callable(checkpoint):
return checkpoint
if isinstance(checkpoint, dict):
for key in ("pipe", "pipeline", "model"):
candidate = checkpoint.get(key)
if callable(candidate):
return candidate
raise RuntimeError(
"Could not construct a callable generation pipeline from checkpoint. "
"Pass --generator-module with a build_pipeline() function for your model layout."
)
def _build_conditioning_function(self) -> Callable[[Mapping[str, Any]], str]:
if self.generator_module_path:
logger.debug("Loading conditioning function from %s", self.generator_module_path)
module = _load_module_from_file(self.generator_module_path)
logger.debug("Conditioning function module loaded")
if hasattr(module, "metadata_to_conditioning"):
func = getattr(module, "metadata_to_conditioning")
if callable(func):
return func
return default_metadata_to_conditioning
def generate_card_from_metadata(
self,
meta: Mapping[str, Any],
num_inference_steps: int = 30,
guidance_scale: float = 7.5,
save_path: str | None = None,
):
conditioning = self._metadata_to_conditioning(meta)
result = self._pipe(
conditioning,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
)
if not hasattr(result, "images") or not result.images:
raise RuntimeError(
"Pipeline call did not return an object with non-empty .images. "
"Ensure your pipeline follows diffusers-style output."
)
image = result.images[0]
if save_path:
output_file = Path(save_path).resolve()
output_file.parent.mkdir(parents=True, exist_ok=True)
image.save(str(output_file))
return image
def _build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Run text cleaning + JSON inference + card generation in one command.",
)
parser.add_argument("text", help="User input text.")
parser.add_argument(
"--text-cleaner-path",
required=True,
help="Path to text-cleaning-pipeline.py that defines get_clean_text(text).",
)
parser.add_argument(
"--infer-script-path",
required=True,
help="Path to infer_json_usage.py.",
)
parser.add_argument(
"--template",
required=True,
help="Path to JSON template file.",
)
parser.add_argument(
"--checkpoint",
required=True,
help="Path to model checkpoint (example: pokemon_card_lora/training_history.pt).",
)
parser.add_argument(
"--generator-module",
default="",
help="Optional module path defining build_pipeline() and metadata_to_conditioning().",
)
parser.add_argument("--device", default="cpu", help="Checkpoint loading device (default: cpu).")
parser.add_argument("--num-inference-steps", type=int, default=30)
parser.add_argument("--guidance-scale", type=float, default=7.5)
parser.add_argument("--save-path", default="generated_card.png")
parser.add_argument(
"--python-executable",
default=sys.executable,
help="Python executable used to run infer_json_usage.py (default: current interpreter).",
)
parser.add_argument(
"--print-json",
action="store_true",
help="Print inferred JSON to stdout.",
)
parser.add_argument(
"--print-clean-text",
action="store_true",
help="Print cleaned text to stdout.",
)
return parser
def main() -> None:
args = _build_parser().parse_args()
logger.debug("Loading text cleaner")
get_clean_text = _load_function_from_file(args.text_cleaner_path, "get_clean_text")
logger.debug("Cleaning input text")
cleaned_text = get_clean_text(args.text)
if not isinstance(cleaned_text, str):
raise TypeError("get_clean_text(...) must return a string")
logger.debug("Running JSON inference")
inferred_json = run_infer_json_cli(
infer_script_path=args.infer_script_path,
template_path=args.template,
cleaned_text=cleaned_text,
python_executable=args.python_executable,
)
logger.debug("Initializing card generator")
generator = CheckpointCardGenerator(
checkpoint_path=args.checkpoint,
device=args.device,
generator_module_path=args.generator_module,
)
logger.debug("Generating card image")
generator.generate_card_from_metadata(
inferred_json,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
save_path=args.save_path,
)
logger.info("Card saved to %s", args.save_path)
if args.print_clean_text:
print(cleaned_text)
if args.print_json:
print(json.dumps(inferred_json, indent=2))
if __name__ == "__main__":
main()