first commit
This commit is contained in:
346
prompt_to_card_pipeline.py
Normal file
346
prompt_to_card_pipeline.py
Normal file
@@ -0,0 +1,346 @@
|
||||
"""End-to-end prompt -> cleaned text -> inferred JSON -> generated card image.
|
||||
|
||||
This script is built to connect the three stages described by the user:
|
||||
1) call get_clean_text(user_text) from a text-cleaning module file
|
||||
2) pass cleaned text into infer_json_usage.py with --json-only --template
|
||||
3) load a checkpoint and generate a card image from inferred metadata
|
||||
|
||||
The model-loading part is intentionally pluggable because checkpoint structures vary.
|
||||
If your .pt checkpoint cannot be used directly as a callable pipeline, provide a
|
||||
generator module implementing:
|
||||
|
||||
def build_pipeline(checkpoint_path: str, device: str): ...
|
||||
def metadata_to_conditioning(meta: dict) -> str: ... # optional
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import importlib
|
||||
import importlib.util
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Mapping
|
||||
|
||||
|
||||
def _load_module_from_file(module_file: str):
|
||||
module_path = Path(module_file).resolve()
|
||||
if not module_path.exists():
|
||||
raise FileNotFoundError(f"Module file not found: {module_path}")
|
||||
|
||||
spec = importlib.util.spec_from_file_location(module_path.stem, str(module_path))
|
||||
if spec is None or spec.loader is None:
|
||||
raise ImportError(f"Cannot import module from file: {module_path}")
|
||||
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
print("module successfully charged")
|
||||
return module
|
||||
|
||||
|
||||
def _load_function_from_file(module_file: str, function_name: str) -> Callable[..., Any]:
|
||||
print("model charging 1")
|
||||
module = _load_module_from_file(module_file)
|
||||
print("model charged 1")
|
||||
if not hasattr(module, function_name):
|
||||
raise AttributeError(f"{module_file} has no function named '{function_name}'")
|
||||
func = getattr(module, function_name)
|
||||
if not callable(func):
|
||||
raise TypeError(f"{function_name} in {module_file} is not callable")
|
||||
return func
|
||||
|
||||
|
||||
def _extract_json_from_output(raw: str) -> Mapping[str, Any]:
|
||||
print("_extract_json_from_output")
|
||||
stripped = raw.strip()
|
||||
if not stripped:
|
||||
raise ValueError("Inference command returned empty output")
|
||||
|
||||
try:
|
||||
parsed = json.loads(stripped)
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValueError("Inference output is JSON but not an object")
|
||||
return parsed
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Fallback: parse the last JSON object in mixed stdout.
|
||||
last_open = stripped.rfind("{")
|
||||
last_close = stripped.rfind("}")
|
||||
if last_open == -1 or last_close == -1 or last_close <= last_open:
|
||||
raise ValueError(f"Could not parse JSON from inference output:\n{raw}")
|
||||
|
||||
candidate = stripped[last_open : last_close + 1]
|
||||
parsed = json.loads(candidate)
|
||||
print("json parsed with success")
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValueError("Parsed fallback JSON is not an object")
|
||||
return parsed
|
||||
|
||||
|
||||
def run_infer_json_cli(
|
||||
infer_script_path: str,
|
||||
template_path: str,
|
||||
cleaned_text: str,
|
||||
python_executable: str | None = None,
|
||||
) -> Mapping[str, Any]:
|
||||
infer_script = Path(infer_script_path).resolve()
|
||||
print("run_infer_json_cli")
|
||||
if not infer_script.exists():
|
||||
raise FileNotFoundError(f"infer_json_usage.py not found: {infer_script}")
|
||||
|
||||
template_file = Path(template_path).resolve()
|
||||
if not template_file.exists():
|
||||
raise FileNotFoundError(f"Template file not found: {template_file}")
|
||||
|
||||
cmd = [
|
||||
python_executable or sys.executable,
|
||||
str(infer_script),
|
||||
"--json-only",
|
||||
"--template",
|
||||
str(template_file),
|
||||
cleaned_text,
|
||||
]
|
||||
print("will start result")
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, check=False)
|
||||
|
||||
if result.returncode != 0:
|
||||
stderr = result.stderr.strip()
|
||||
raise RuntimeError(
|
||||
"JSON inference command failed. "
|
||||
f"exit={result.returncode}, stderr={stderr or '<empty>'}"
|
||||
)
|
||||
print("result is done")
|
||||
return _extract_json_from_output(result.stdout)
|
||||
|
||||
|
||||
def default_metadata_to_conditioning(meta: Mapping[str, Any]) -> str:
|
||||
print("default_metadata_to_conditioning")
|
||||
name = str(meta.get("name", "Unknown Pokemon"))
|
||||
types = meta.get("types") or []
|
||||
if isinstance(types, list):
|
||||
type_text = ", ".join(str(item) for item in types if item) or str(meta.get("type", "normal"))
|
||||
else:
|
||||
type_text = str(meta.get("type", "normal"))
|
||||
|
||||
attacks = meta.get("attacks") or []
|
||||
attack_names = []
|
||||
if isinstance(attacks, list):
|
||||
for attack in attacks:
|
||||
if isinstance(attack, dict):
|
||||
value = attack.get("name")
|
||||
if value:
|
||||
attack_names.append(str(value))
|
||||
elif attack:
|
||||
attack_names.append(str(attack))
|
||||
|
||||
hp = str(meta.get("hp", "60"))
|
||||
description = str(meta.get("description", ""))
|
||||
|
||||
parts = [
|
||||
f"Pokemon trading card illustration of {name}",
|
||||
f"type: {type_text}",
|
||||
f"hp: {hp}",
|
||||
]
|
||||
if attack_names:
|
||||
parts.append(f"attacks: {', '.join(attack_names[:2])}")
|
||||
if description:
|
||||
parts.append(f"description: {description}")
|
||||
return "; ".join(parts)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CheckpointCardGenerator:
|
||||
checkpoint_path: str
|
||||
device: str = "cpu"
|
||||
generator_module_path: str = ""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self._pipe = self._build_pipe()
|
||||
self._metadata_to_conditioning = self._build_conditioning_function()
|
||||
|
||||
def _build_pipe(self):
|
||||
if self.generator_module_path:
|
||||
print("getting module")
|
||||
module = _load_module_from_file(self.generator_module_path)
|
||||
print("module got")
|
||||
if not hasattr(module, "build_pipeline"):
|
||||
raise AttributeError(
|
||||
"Custom generator module must define build_pipeline(checkpoint_path, device)."
|
||||
)
|
||||
print("building pipeline")
|
||||
build_pipeline = getattr(module, "build_pipeline")
|
||||
if not callable(build_pipeline):
|
||||
raise TypeError("build_pipeline exists but is not callable")
|
||||
print("pipeline build")
|
||||
return build_pipeline(self.checkpoint_path, self.device)
|
||||
|
||||
# Best-effort direct checkpoint loading for simple callable pipeline dumps.
|
||||
try:
|
||||
torch = importlib.import_module("torch")
|
||||
except ModuleNotFoundError as exc:
|
||||
raise RuntimeError(
|
||||
"torch is required to load checkpoint files. Install torch or provide --generator-module."
|
||||
) from exc
|
||||
print("loading checkpoint")
|
||||
checkpoint = torch.load(self.checkpoint_path, map_location=self.device)
|
||||
print("checkpoint loaded")
|
||||
|
||||
if callable(checkpoint):
|
||||
return checkpoint
|
||||
|
||||
if isinstance(checkpoint, dict):
|
||||
for key in ("pipe", "pipeline", "model"):
|
||||
candidate = checkpoint.get(key)
|
||||
if callable(candidate):
|
||||
return candidate
|
||||
|
||||
raise RuntimeError(
|
||||
"Could not construct a callable generation pipeline from checkpoint. "
|
||||
"Pass --generator-module with a build_pipeline() function for your model layout."
|
||||
)
|
||||
|
||||
def _build_conditioning_function(self) -> Callable[[Mapping[str, Any]], str]:
|
||||
if self.generator_module_path:
|
||||
print("model charge 2")
|
||||
module = _load_module_from_file(self.generator_module_path)
|
||||
print("model charged 2")
|
||||
if hasattr(module, "metadata_to_conditioning"):
|
||||
func = getattr(module, "metadata_to_conditioning")
|
||||
if callable(func):
|
||||
return func
|
||||
return default_metadata_to_conditioning
|
||||
|
||||
def generate_card_from_metadata(
|
||||
self,
|
||||
meta: Mapping[str, Any],
|
||||
num_inference_steps: int = 30,
|
||||
guidance_scale: float = 7.5,
|
||||
save_path: str | None = None,
|
||||
):
|
||||
conditioning = self._metadata_to_conditioning(meta)
|
||||
result = self._pipe(
|
||||
conditioning,
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=guidance_scale,
|
||||
)
|
||||
|
||||
if not hasattr(result, "images") or not result.images:
|
||||
raise RuntimeError(
|
||||
"Pipeline call did not return an object with non-empty .images. "
|
||||
"Ensure your pipeline follows diffusers-style output."
|
||||
)
|
||||
|
||||
image = result.images[0]
|
||||
if save_path:
|
||||
output_file = Path(save_path).resolve()
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
image.save(str(output_file))
|
||||
return image
|
||||
|
||||
|
||||
def _build_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run text cleaning + JSON inference + card generation in one command.",
|
||||
)
|
||||
parser.add_argument("text", help="User input text.")
|
||||
parser.add_argument(
|
||||
"--text-cleaner-path",
|
||||
required=True,
|
||||
help="Path to text-cleaning-pipeline.py that defines get_clean_text(text).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--infer-script-path",
|
||||
required=True,
|
||||
help="Path to infer_json_usage.py.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--template",
|
||||
required=True,
|
||||
help="Path to JSON template file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
required=True,
|
||||
help="Path to model checkpoint (example: pokemon_card_lora/training_history.pt).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--generator-module",
|
||||
default="",
|
||||
help="Optional module path defining build_pipeline() and metadata_to_conditioning().",
|
||||
)
|
||||
parser.add_argument("--device", default="cpu", help="Checkpoint loading device (default: cpu).")
|
||||
parser.add_argument("--num-inference-steps", type=int, default=30)
|
||||
parser.add_argument("--guidance-scale", type=float, default=7.5)
|
||||
parser.add_argument("--save-path", default="generated_card.png")
|
||||
parser.add_argument(
|
||||
"--python-executable",
|
||||
default=sys.executable,
|
||||
help="Python executable used to run infer_json_usage.py (default: current interpreter).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--print-json",
|
||||
action="store_true",
|
||||
help="Print inferred JSON to stdout.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--print-clean-text",
|
||||
action="store_true",
|
||||
help="Print cleaned text to stdout.",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = _build_parser().parse_args()
|
||||
print("main get clean text")
|
||||
|
||||
get_clean_text = _load_function_from_file(args.text_cleaner_path, "get_clean_text")
|
||||
print("main got clean text")
|
||||
|
||||
cleaned_text = get_clean_text(args.text)
|
||||
print("main got args.text")
|
||||
if not isinstance(cleaned_text, str):
|
||||
raise TypeError("get_clean_text(...) must return a string")
|
||||
print("main get inferred")
|
||||
|
||||
inferred_json = run_infer_json_cli(
|
||||
infer_script_path=args.infer_script_path,
|
||||
template_path=args.template,
|
||||
cleaned_text=cleaned_text,
|
||||
python_executable=args.python_executable,
|
||||
)
|
||||
print("main got inferred")
|
||||
print("main get generator")
|
||||
|
||||
|
||||
|
||||
generator = CheckpointCardGenerator(
|
||||
checkpoint_path=args.checkpoint,
|
||||
device=args.device,
|
||||
generator_module_path=args.generator_module,
|
||||
)
|
||||
print("main got generator and will generate card")
|
||||
|
||||
generator.generate_card_from_metadata(
|
||||
inferred_json,
|
||||
num_inference_steps=args.num_inference_steps,
|
||||
guidance_scale=args.guidance_scale,
|
||||
save_path=args.save_path,
|
||||
)
|
||||
print("main card generated")
|
||||
|
||||
|
||||
if args.print_clean_text:
|
||||
print(cleaned_text)
|
||||
if args.print_json:
|
||||
print(json.dumps(inferred_json, indent=2))
|
||||
|
||||
print(f"Card generated and saved to: {args.save_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user