Files
Juicepyter/prompt_to_card_pipeline.py
2026-03-19 18:16:20 +01:00

346 lines
12 KiB
Python

"""End-to-end prompt -> cleaned text -> inferred JSON -> generated card image.
This script is built to connect the three stages described by the user:
1) call get_clean_text(user_text) from a text-cleaning module file
2) pass cleaned text into infer_json_usage.py with --json-only --template
3) load a checkpoint and generate a card image from inferred metadata
The model-loading part is intentionally pluggable because checkpoint structures vary.
If your .pt checkpoint cannot be used directly as a callable pipeline, provide a
generator module implementing:
def build_pipeline(checkpoint_path: str, device: str): ...
def metadata_to_conditioning(meta: dict) -> str: ... # optional
"""
from __future__ import annotations
import argparse
import importlib
import importlib.util
import json
import subprocess
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Mapping
def _load_module_from_file(module_file: str):
module_path = Path(module_file).resolve()
if not module_path.exists():
raise FileNotFoundError(f"Module file not found: {module_path}")
spec = importlib.util.spec_from_file_location(module_path.stem, str(module_path))
if spec is None or spec.loader is None:
raise ImportError(f"Cannot import module from file: {module_path}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
print("module successfully charged")
return module
def _load_function_from_file(module_file: str, function_name: str) -> Callable[..., Any]:
print("model charging 1")
module = _load_module_from_file(module_file)
print("model charged 1")
if not hasattr(module, function_name):
raise AttributeError(f"{module_file} has no function named '{function_name}'")
func = getattr(module, function_name)
if not callable(func):
raise TypeError(f"{function_name} in {module_file} is not callable")
return func
def _extract_json_from_output(raw: str) -> Mapping[str, Any]:
print("_extract_json_from_output")
stripped = raw.strip()
if not stripped:
raise ValueError("Inference command returned empty output")
try:
parsed = json.loads(stripped)
if not isinstance(parsed, dict):
raise ValueError("Inference output is JSON but not an object")
return parsed
except json.JSONDecodeError:
pass
# Fallback: parse the last JSON object in mixed stdout.
last_open = stripped.rfind("{")
last_close = stripped.rfind("}")
if last_open == -1 or last_close == -1 or last_close <= last_open:
raise ValueError(f"Could not parse JSON from inference output:\n{raw}")
candidate = stripped[last_open : last_close + 1]
parsed = json.loads(candidate)
print("json parsed with success")
if not isinstance(parsed, dict):
raise ValueError("Parsed fallback JSON is not an object")
return parsed
def run_infer_json_cli(
infer_script_path: str,
template_path: str,
cleaned_text: str,
python_executable: str | None = None,
) -> Mapping[str, Any]:
infer_script = Path(infer_script_path).resolve()
print("run_infer_json_cli")
if not infer_script.exists():
raise FileNotFoundError(f"infer_json_usage.py not found: {infer_script}")
template_file = Path(template_path).resolve()
if not template_file.exists():
raise FileNotFoundError(f"Template file not found: {template_file}")
cmd = [
python_executable or sys.executable,
str(infer_script),
"--json-only",
"--template",
str(template_file),
cleaned_text,
]
print("will start result")
result = subprocess.run(cmd, capture_output=True, text=True, check=False)
if result.returncode != 0:
stderr = result.stderr.strip()
raise RuntimeError(
"JSON inference command failed. "
f"exit={result.returncode}, stderr={stderr or '<empty>'}"
)
print("result is done")
return _extract_json_from_output(result.stdout)
def default_metadata_to_conditioning(meta: Mapping[str, Any]) -> str:
print("default_metadata_to_conditioning")
name = str(meta.get("name", "Unknown Pokemon"))
types = meta.get("types") or []
if isinstance(types, list):
type_text = ", ".join(str(item) for item in types if item) or str(meta.get("type", "normal"))
else:
type_text = str(meta.get("type", "normal"))
attacks = meta.get("attacks") or []
attack_names = []
if isinstance(attacks, list):
for attack in attacks:
if isinstance(attack, dict):
value = attack.get("name")
if value:
attack_names.append(str(value))
elif attack:
attack_names.append(str(attack))
hp = str(meta.get("hp", "60"))
description = str(meta.get("description", ""))
parts = [
f"Pokemon trading card illustration of {name}",
f"type: {type_text}",
f"hp: {hp}",
]
if attack_names:
parts.append(f"attacks: {', '.join(attack_names[:2])}")
if description:
parts.append(f"description: {description}")
return "; ".join(parts)
@dataclass
class CheckpointCardGenerator:
checkpoint_path: str
device: str = "cpu"
generator_module_path: str = ""
def __post_init__(self) -> None:
self._pipe = self._build_pipe()
self._metadata_to_conditioning = self._build_conditioning_function()
def _build_pipe(self):
if self.generator_module_path:
print("getting module")
module = _load_module_from_file(self.generator_module_path)
print("module got")
if not hasattr(module, "build_pipeline"):
raise AttributeError(
"Custom generator module must define build_pipeline(checkpoint_path, device)."
)
print("building pipeline")
build_pipeline = getattr(module, "build_pipeline")
if not callable(build_pipeline):
raise TypeError("build_pipeline exists but is not callable")
print("pipeline build")
return build_pipeline(self.checkpoint_path, self.device)
# Best-effort direct checkpoint loading for simple callable pipeline dumps.
try:
torch = importlib.import_module("torch")
except ModuleNotFoundError as exc:
raise RuntimeError(
"torch is required to load checkpoint files. Install torch or provide --generator-module."
) from exc
print("loading checkpoint")
checkpoint = torch.load(self.checkpoint_path, map_location=self.device)
print("checkpoint loaded")
if callable(checkpoint):
return checkpoint
if isinstance(checkpoint, dict):
for key in ("pipe", "pipeline", "model"):
candidate = checkpoint.get(key)
if callable(candidate):
return candidate
raise RuntimeError(
"Could not construct a callable generation pipeline from checkpoint. "
"Pass --generator-module with a build_pipeline() function for your model layout."
)
def _build_conditioning_function(self) -> Callable[[Mapping[str, Any]], str]:
if self.generator_module_path:
print("model charge 2")
module = _load_module_from_file(self.generator_module_path)
print("model charged 2")
if hasattr(module, "metadata_to_conditioning"):
func = getattr(module, "metadata_to_conditioning")
if callable(func):
return func
return default_metadata_to_conditioning
def generate_card_from_metadata(
self,
meta: Mapping[str, Any],
num_inference_steps: int = 30,
guidance_scale: float = 7.5,
save_path: str | None = None,
):
conditioning = self._metadata_to_conditioning(meta)
result = self._pipe(
conditioning,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
)
if not hasattr(result, "images") or not result.images:
raise RuntimeError(
"Pipeline call did not return an object with non-empty .images. "
"Ensure your pipeline follows diffusers-style output."
)
image = result.images[0]
if save_path:
output_file = Path(save_path).resolve()
output_file.parent.mkdir(parents=True, exist_ok=True)
image.save(str(output_file))
return image
def _build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Run text cleaning + JSON inference + card generation in one command.",
)
parser.add_argument("text", help="User input text.")
parser.add_argument(
"--text-cleaner-path",
required=True,
help="Path to text-cleaning-pipeline.py that defines get_clean_text(text).",
)
parser.add_argument(
"--infer-script-path",
required=True,
help="Path to infer_json_usage.py.",
)
parser.add_argument(
"--template",
required=True,
help="Path to JSON template file.",
)
parser.add_argument(
"--checkpoint",
required=True,
help="Path to model checkpoint (example: pokemon_card_lora/training_history.pt).",
)
parser.add_argument(
"--generator-module",
default="",
help="Optional module path defining build_pipeline() and metadata_to_conditioning().",
)
parser.add_argument("--device", default="cpu", help="Checkpoint loading device (default: cpu).")
parser.add_argument("--num-inference-steps", type=int, default=30)
parser.add_argument("--guidance-scale", type=float, default=7.5)
parser.add_argument("--save-path", default="generated_card.png")
parser.add_argument(
"--python-executable",
default=sys.executable,
help="Python executable used to run infer_json_usage.py (default: current interpreter).",
)
parser.add_argument(
"--print-json",
action="store_true",
help="Print inferred JSON to stdout.",
)
parser.add_argument(
"--print-clean-text",
action="store_true",
help="Print cleaned text to stdout.",
)
return parser
def main() -> None:
args = _build_parser().parse_args()
print("main get clean text")
get_clean_text = _load_function_from_file(args.text_cleaner_path, "get_clean_text")
print("main got clean text")
cleaned_text = get_clean_text(args.text)
print("main got args.text")
if not isinstance(cleaned_text, str):
raise TypeError("get_clean_text(...) must return a string")
print("main get inferred")
inferred_json = run_infer_json_cli(
infer_script_path=args.infer_script_path,
template_path=args.template,
cleaned_text=cleaned_text,
python_executable=args.python_executable,
)
print("main got inferred")
print("main get generator")
generator = CheckpointCardGenerator(
checkpoint_path=args.checkpoint,
device=args.device,
generator_module_path=args.generator_module,
)
print("main got generator and will generate card")
generator.generate_card_from_metadata(
inferred_json,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
save_path=args.save_path,
)
print("main card generated")
if args.print_clean_text:
print(cleaned_text)
if args.print_json:
print(json.dumps(inferred_json, indent=2))
print(f"Card generated and saved to: {args.save_path}")
if __name__ == "__main__":
main()