346 lines
12 KiB
Python
346 lines
12 KiB
Python
"""End-to-end prompt -> cleaned text -> inferred JSON -> generated card image.
|
|
|
|
This script is built to connect the three stages described by the user:
|
|
1) call get_clean_text(user_text) from a text-cleaning module file
|
|
2) pass cleaned text into infer_json_usage.py with --json-only --template
|
|
3) load a checkpoint and generate a card image from inferred metadata
|
|
|
|
The model-loading part is intentionally pluggable because checkpoint structures vary.
|
|
If your .pt checkpoint cannot be used directly as a callable pipeline, provide a
|
|
generator module implementing:
|
|
|
|
def build_pipeline(checkpoint_path: str, device: str): ...
|
|
def metadata_to_conditioning(meta: dict) -> str: ... # optional
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import importlib
|
|
import importlib.util
|
|
import json
|
|
import subprocess
|
|
import sys
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Any, Callable, Mapping
|
|
|
|
|
|
def _load_module_from_file(module_file: str):
|
|
module_path = Path(module_file).resolve()
|
|
if not module_path.exists():
|
|
raise FileNotFoundError(f"Module file not found: {module_path}")
|
|
|
|
spec = importlib.util.spec_from_file_location(module_path.stem, str(module_path))
|
|
if spec is None or spec.loader is None:
|
|
raise ImportError(f"Cannot import module from file: {module_path}")
|
|
|
|
module = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(module)
|
|
print("module successfully charged")
|
|
return module
|
|
|
|
|
|
def _load_function_from_file(module_file: str, function_name: str) -> Callable[..., Any]:
|
|
print("model charging 1")
|
|
module = _load_module_from_file(module_file)
|
|
print("model charged 1")
|
|
if not hasattr(module, function_name):
|
|
raise AttributeError(f"{module_file} has no function named '{function_name}'")
|
|
func = getattr(module, function_name)
|
|
if not callable(func):
|
|
raise TypeError(f"{function_name} in {module_file} is not callable")
|
|
return func
|
|
|
|
|
|
def _extract_json_from_output(raw: str) -> Mapping[str, Any]:
|
|
print("_extract_json_from_output")
|
|
stripped = raw.strip()
|
|
if not stripped:
|
|
raise ValueError("Inference command returned empty output")
|
|
|
|
try:
|
|
parsed = json.loads(stripped)
|
|
if not isinstance(parsed, dict):
|
|
raise ValueError("Inference output is JSON but not an object")
|
|
return parsed
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
# Fallback: parse the last JSON object in mixed stdout.
|
|
last_open = stripped.rfind("{")
|
|
last_close = stripped.rfind("}")
|
|
if last_open == -1 or last_close == -1 or last_close <= last_open:
|
|
raise ValueError(f"Could not parse JSON from inference output:\n{raw}")
|
|
|
|
candidate = stripped[last_open : last_close + 1]
|
|
parsed = json.loads(candidate)
|
|
print("json parsed with success")
|
|
if not isinstance(parsed, dict):
|
|
raise ValueError("Parsed fallback JSON is not an object")
|
|
return parsed
|
|
|
|
|
|
def run_infer_json_cli(
|
|
infer_script_path: str,
|
|
template_path: str,
|
|
cleaned_text: str,
|
|
python_executable: str | None = None,
|
|
) -> Mapping[str, Any]:
|
|
infer_script = Path(infer_script_path).resolve()
|
|
print("run_infer_json_cli")
|
|
if not infer_script.exists():
|
|
raise FileNotFoundError(f"infer_json_usage.py not found: {infer_script}")
|
|
|
|
template_file = Path(template_path).resolve()
|
|
if not template_file.exists():
|
|
raise FileNotFoundError(f"Template file not found: {template_file}")
|
|
|
|
cmd = [
|
|
python_executable or sys.executable,
|
|
str(infer_script),
|
|
"--json-only",
|
|
"--template",
|
|
str(template_file),
|
|
cleaned_text,
|
|
]
|
|
print("will start result")
|
|
result = subprocess.run(cmd, capture_output=True, text=True, check=False)
|
|
|
|
if result.returncode != 0:
|
|
stderr = result.stderr.strip()
|
|
raise RuntimeError(
|
|
"JSON inference command failed. "
|
|
f"exit={result.returncode}, stderr={stderr or '<empty>'}"
|
|
)
|
|
print("result is done")
|
|
return _extract_json_from_output(result.stdout)
|
|
|
|
|
|
def default_metadata_to_conditioning(meta: Mapping[str, Any]) -> str:
|
|
print("default_metadata_to_conditioning")
|
|
name = str(meta.get("name", "Unknown Pokemon"))
|
|
types = meta.get("types") or []
|
|
if isinstance(types, list):
|
|
type_text = ", ".join(str(item) for item in types if item) or str(meta.get("type", "normal"))
|
|
else:
|
|
type_text = str(meta.get("type", "normal"))
|
|
|
|
attacks = meta.get("attacks") or []
|
|
attack_names = []
|
|
if isinstance(attacks, list):
|
|
for attack in attacks:
|
|
if isinstance(attack, dict):
|
|
value = attack.get("name")
|
|
if value:
|
|
attack_names.append(str(value))
|
|
elif attack:
|
|
attack_names.append(str(attack))
|
|
|
|
hp = str(meta.get("hp", "60"))
|
|
description = str(meta.get("description", ""))
|
|
|
|
parts = [
|
|
f"Pokemon trading card illustration of {name}",
|
|
f"type: {type_text}",
|
|
f"hp: {hp}",
|
|
]
|
|
if attack_names:
|
|
parts.append(f"attacks: {', '.join(attack_names[:2])}")
|
|
if description:
|
|
parts.append(f"description: {description}")
|
|
return "; ".join(parts)
|
|
|
|
|
|
@dataclass
|
|
class CheckpointCardGenerator:
|
|
checkpoint_path: str
|
|
device: str = "cpu"
|
|
generator_module_path: str = ""
|
|
|
|
def __post_init__(self) -> None:
|
|
self._pipe = self._build_pipe()
|
|
self._metadata_to_conditioning = self._build_conditioning_function()
|
|
|
|
def _build_pipe(self):
|
|
if self.generator_module_path:
|
|
print("getting module")
|
|
module = _load_module_from_file(self.generator_module_path)
|
|
print("module got")
|
|
if not hasattr(module, "build_pipeline"):
|
|
raise AttributeError(
|
|
"Custom generator module must define build_pipeline(checkpoint_path, device)."
|
|
)
|
|
print("building pipeline")
|
|
build_pipeline = getattr(module, "build_pipeline")
|
|
if not callable(build_pipeline):
|
|
raise TypeError("build_pipeline exists but is not callable")
|
|
print("pipeline build")
|
|
return build_pipeline(self.checkpoint_path, self.device)
|
|
|
|
# Best-effort direct checkpoint loading for simple callable pipeline dumps.
|
|
try:
|
|
torch = importlib.import_module("torch")
|
|
except ModuleNotFoundError as exc:
|
|
raise RuntimeError(
|
|
"torch is required to load checkpoint files. Install torch or provide --generator-module."
|
|
) from exc
|
|
print("loading checkpoint")
|
|
checkpoint = torch.load(self.checkpoint_path, map_location=self.device)
|
|
print("checkpoint loaded")
|
|
|
|
if callable(checkpoint):
|
|
return checkpoint
|
|
|
|
if isinstance(checkpoint, dict):
|
|
for key in ("pipe", "pipeline", "model"):
|
|
candidate = checkpoint.get(key)
|
|
if callable(candidate):
|
|
return candidate
|
|
|
|
raise RuntimeError(
|
|
"Could not construct a callable generation pipeline from checkpoint. "
|
|
"Pass --generator-module with a build_pipeline() function for your model layout."
|
|
)
|
|
|
|
def _build_conditioning_function(self) -> Callable[[Mapping[str, Any]], str]:
|
|
if self.generator_module_path:
|
|
print("model charge 2")
|
|
module = _load_module_from_file(self.generator_module_path)
|
|
print("model charged 2")
|
|
if hasattr(module, "metadata_to_conditioning"):
|
|
func = getattr(module, "metadata_to_conditioning")
|
|
if callable(func):
|
|
return func
|
|
return default_metadata_to_conditioning
|
|
|
|
def generate_card_from_metadata(
|
|
self,
|
|
meta: Mapping[str, Any],
|
|
num_inference_steps: int = 30,
|
|
guidance_scale: float = 7.5,
|
|
save_path: str | None = None,
|
|
):
|
|
conditioning = self._metadata_to_conditioning(meta)
|
|
result = self._pipe(
|
|
conditioning,
|
|
num_inference_steps=num_inference_steps,
|
|
guidance_scale=guidance_scale,
|
|
)
|
|
|
|
if not hasattr(result, "images") or not result.images:
|
|
raise RuntimeError(
|
|
"Pipeline call did not return an object with non-empty .images. "
|
|
"Ensure your pipeline follows diffusers-style output."
|
|
)
|
|
|
|
image = result.images[0]
|
|
if save_path:
|
|
output_file = Path(save_path).resolve()
|
|
output_file.parent.mkdir(parents=True, exist_ok=True)
|
|
image.save(str(output_file))
|
|
return image
|
|
|
|
|
|
def _build_parser() -> argparse.ArgumentParser:
|
|
parser = argparse.ArgumentParser(
|
|
description="Run text cleaning + JSON inference + card generation in one command.",
|
|
)
|
|
parser.add_argument("text", help="User input text.")
|
|
parser.add_argument(
|
|
"--text-cleaner-path",
|
|
required=True,
|
|
help="Path to text-cleaning-pipeline.py that defines get_clean_text(text).",
|
|
)
|
|
parser.add_argument(
|
|
"--infer-script-path",
|
|
required=True,
|
|
help="Path to infer_json_usage.py.",
|
|
)
|
|
parser.add_argument(
|
|
"--template",
|
|
required=True,
|
|
help="Path to JSON template file.",
|
|
)
|
|
parser.add_argument(
|
|
"--checkpoint",
|
|
required=True,
|
|
help="Path to model checkpoint (example: pokemon_card_lora/training_history.pt).",
|
|
)
|
|
parser.add_argument(
|
|
"--generator-module",
|
|
default="",
|
|
help="Optional module path defining build_pipeline() and metadata_to_conditioning().",
|
|
)
|
|
parser.add_argument("--device", default="cpu", help="Checkpoint loading device (default: cpu).")
|
|
parser.add_argument("--num-inference-steps", type=int, default=30)
|
|
parser.add_argument("--guidance-scale", type=float, default=7.5)
|
|
parser.add_argument("--save-path", default="generated_card.png")
|
|
parser.add_argument(
|
|
"--python-executable",
|
|
default=sys.executable,
|
|
help="Python executable used to run infer_json_usage.py (default: current interpreter).",
|
|
)
|
|
parser.add_argument(
|
|
"--print-json",
|
|
action="store_true",
|
|
help="Print inferred JSON to stdout.",
|
|
)
|
|
parser.add_argument(
|
|
"--print-clean-text",
|
|
action="store_true",
|
|
help="Print cleaned text to stdout.",
|
|
)
|
|
return parser
|
|
|
|
|
|
def main() -> None:
|
|
args = _build_parser().parse_args()
|
|
print("main get clean text")
|
|
|
|
get_clean_text = _load_function_from_file(args.text_cleaner_path, "get_clean_text")
|
|
print("main got clean text")
|
|
|
|
cleaned_text = get_clean_text(args.text)
|
|
print("main got args.text")
|
|
if not isinstance(cleaned_text, str):
|
|
raise TypeError("get_clean_text(...) must return a string")
|
|
print("main get inferred")
|
|
|
|
inferred_json = run_infer_json_cli(
|
|
infer_script_path=args.infer_script_path,
|
|
template_path=args.template,
|
|
cleaned_text=cleaned_text,
|
|
python_executable=args.python_executable,
|
|
)
|
|
print("main got inferred")
|
|
print("main get generator")
|
|
|
|
|
|
|
|
generator = CheckpointCardGenerator(
|
|
checkpoint_path=args.checkpoint,
|
|
device=args.device,
|
|
generator_module_path=args.generator_module,
|
|
)
|
|
print("main got generator and will generate card")
|
|
|
|
generator.generate_card_from_metadata(
|
|
inferred_json,
|
|
num_inference_steps=args.num_inference_steps,
|
|
guidance_scale=args.guidance_scale,
|
|
save_path=args.save_path,
|
|
)
|
|
print("main card generated")
|
|
|
|
|
|
if args.print_clean_text:
|
|
print(cleaned_text)
|
|
if args.print_json:
|
|
print(json.dumps(inferred_json, indent=2))
|
|
|
|
print(f"Card generated and saved to: {args.save_path}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |