"""End-to-end prompt -> cleaned text -> inferred JSON -> generated card image. This script is built to connect the three stages described by the user: 1) call get_clean_text(user_text) from a text-cleaning module file 2) pass cleaned text into infer_json_usage.py with --json-only --template 3) load a checkpoint and generate a card image from inferred metadata The model-loading part is intentionally pluggable because checkpoint structures vary. If your .pt checkpoint cannot be used directly as a callable pipeline, provide a generator module implementing: def build_pipeline(checkpoint_path: str, device: str): ... def metadata_to_conditioning(meta: dict) -> str: ... # optional """ from __future__ import annotations import argparse import importlib import importlib.util import json import subprocess import sys from dataclasses import dataclass from pathlib import Path from typing import Any, Callable, Mapping def _load_module_from_file(module_file: str): module_path = Path(module_file).resolve() if not module_path.exists(): raise FileNotFoundError(f"Module file not found: {module_path}") spec = importlib.util.spec_from_file_location(module_path.stem, str(module_path)) if spec is None or spec.loader is None: raise ImportError(f"Cannot import module from file: {module_path}") module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) print("module successfully charged") return module def _load_function_from_file(module_file: str, function_name: str) -> Callable[..., Any]: print("model charging 1") module = _load_module_from_file(module_file) print("model charged 1") if not hasattr(module, function_name): raise AttributeError(f"{module_file} has no function named '{function_name}'") func = getattr(module, function_name) if not callable(func): raise TypeError(f"{function_name} in {module_file} is not callable") return func def _extract_json_from_output(raw: str) -> Mapping[str, Any]: print("_extract_json_from_output") stripped = raw.strip() if not stripped: raise ValueError("Inference command returned empty output") try: parsed = json.loads(stripped) if not isinstance(parsed, dict): raise ValueError("Inference output is JSON but not an object") return parsed except json.JSONDecodeError: pass # Fallback: parse the last JSON object in mixed stdout. last_open = stripped.rfind("{") last_close = stripped.rfind("}") if last_open == -1 or last_close == -1 or last_close <= last_open: raise ValueError(f"Could not parse JSON from inference output:\n{raw}") candidate = stripped[last_open : last_close + 1] parsed = json.loads(candidate) print("json parsed with success") if not isinstance(parsed, dict): raise ValueError("Parsed fallback JSON is not an object") return parsed def run_infer_json_cli( infer_script_path: str, template_path: str, cleaned_text: str, python_executable: str | None = None, ) -> Mapping[str, Any]: infer_script = Path(infer_script_path).resolve() print("run_infer_json_cli") if not infer_script.exists(): raise FileNotFoundError(f"infer_json_usage.py not found: {infer_script}") template_file = Path(template_path).resolve() if not template_file.exists(): raise FileNotFoundError(f"Template file not found: {template_file}") cmd = [ python_executable or sys.executable, str(infer_script), "--json-only", "--template", str(template_file), cleaned_text, ] print("will start result") result = subprocess.run(cmd, capture_output=True, text=True, check=False) if result.returncode != 0: stderr = result.stderr.strip() raise RuntimeError( "JSON inference command failed. " f"exit={result.returncode}, stderr={stderr or ''}" ) print("result is done") return _extract_json_from_output(result.stdout) def default_metadata_to_conditioning(meta: Mapping[str, Any]) -> str: print("default_metadata_to_conditioning") name = str(meta.get("name", "Unknown Pokemon")) types = meta.get("types") or [] if isinstance(types, list): type_text = ", ".join(str(item) for item in types if item) or str(meta.get("type", "normal")) else: type_text = str(meta.get("type", "normal")) attacks = meta.get("attacks") or [] attack_names = [] if isinstance(attacks, list): for attack in attacks: if isinstance(attack, dict): value = attack.get("name") if value: attack_names.append(str(value)) elif attack: attack_names.append(str(attack)) hp = str(meta.get("hp", "60")) description = str(meta.get("description", "")) parts = [ f"Pokemon trading card illustration of {name}", f"type: {type_text}", f"hp: {hp}", ] if attack_names: parts.append(f"attacks: {', '.join(attack_names[:2])}") if description: parts.append(f"description: {description}") return "; ".join(parts) @dataclass class CheckpointCardGenerator: checkpoint_path: str device: str = "cpu" generator_module_path: str = "" def __post_init__(self) -> None: self._pipe = self._build_pipe() self._metadata_to_conditioning = self._build_conditioning_function() def _build_pipe(self): if self.generator_module_path: print("getting module") module = _load_module_from_file(self.generator_module_path) print("module got") if not hasattr(module, "build_pipeline"): raise AttributeError( "Custom generator module must define build_pipeline(checkpoint_path, device)." ) print("building pipeline") build_pipeline = getattr(module, "build_pipeline") if not callable(build_pipeline): raise TypeError("build_pipeline exists but is not callable") print("pipeline build") return build_pipeline(self.checkpoint_path, self.device) # Best-effort direct checkpoint loading for simple callable pipeline dumps. try: torch = importlib.import_module("torch") except ModuleNotFoundError as exc: raise RuntimeError( "torch is required to load checkpoint files. Install torch or provide --generator-module." ) from exc print("loading checkpoint") checkpoint = torch.load(self.checkpoint_path, map_location=self.device) print("checkpoint loaded") if callable(checkpoint): return checkpoint if isinstance(checkpoint, dict): for key in ("pipe", "pipeline", "model"): candidate = checkpoint.get(key) if callable(candidate): return candidate raise RuntimeError( "Could not construct a callable generation pipeline from checkpoint. " "Pass --generator-module with a build_pipeline() function for your model layout." ) def _build_conditioning_function(self) -> Callable[[Mapping[str, Any]], str]: if self.generator_module_path: print("model charge 2") module = _load_module_from_file(self.generator_module_path) print("model charged 2") if hasattr(module, "metadata_to_conditioning"): func = getattr(module, "metadata_to_conditioning") if callable(func): return func return default_metadata_to_conditioning def generate_card_from_metadata( self, meta: Mapping[str, Any], num_inference_steps: int = 30, guidance_scale: float = 7.5, save_path: str | None = None, ): conditioning = self._metadata_to_conditioning(meta) result = self._pipe( conditioning, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, ) if not hasattr(result, "images") or not result.images: raise RuntimeError( "Pipeline call did not return an object with non-empty .images. " "Ensure your pipeline follows diffusers-style output." ) image = result.images[0] if save_path: output_file = Path(save_path).resolve() output_file.parent.mkdir(parents=True, exist_ok=True) image.save(str(output_file)) return image def _build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="Run text cleaning + JSON inference + card generation in one command.", ) parser.add_argument("text", help="User input text.") parser.add_argument( "--text-cleaner-path", required=True, help="Path to text-cleaning-pipeline.py that defines get_clean_text(text).", ) parser.add_argument( "--infer-script-path", required=True, help="Path to infer_json_usage.py.", ) parser.add_argument( "--template", required=True, help="Path to JSON template file.", ) parser.add_argument( "--checkpoint", required=True, help="Path to model checkpoint (example: pokemon_card_lora/training_history.pt).", ) parser.add_argument( "--generator-module", default="", help="Optional module path defining build_pipeline() and metadata_to_conditioning().", ) parser.add_argument("--device", default="cpu", help="Checkpoint loading device (default: cpu).") parser.add_argument("--num-inference-steps", type=int, default=30) parser.add_argument("--guidance-scale", type=float, default=7.5) parser.add_argument("--save-path", default="generated_card.png") parser.add_argument( "--python-executable", default=sys.executable, help="Python executable used to run infer_json_usage.py (default: current interpreter).", ) parser.add_argument( "--print-json", action="store_true", help="Print inferred JSON to stdout.", ) parser.add_argument( "--print-clean-text", action="store_true", help="Print cleaned text to stdout.", ) return parser def main() -> None: args = _build_parser().parse_args() print("main get clean text") get_clean_text = _load_function_from_file(args.text_cleaner_path, "get_clean_text") print("main got clean text") cleaned_text = get_clean_text(args.text) print("main got args.text") if not isinstance(cleaned_text, str): raise TypeError("get_clean_text(...) must return a string") print("main get inferred") inferred_json = run_infer_json_cli( infer_script_path=args.infer_script_path, template_path=args.template, cleaned_text=cleaned_text, python_executable=args.python_executable, ) print("main got inferred") print("main get generator") generator = CheckpointCardGenerator( checkpoint_path=args.checkpoint, device=args.device, generator_module_path=args.generator_module, ) print("main got generator and will generate card") generator.generate_card_from_metadata( inferred_json, num_inference_steps=args.num_inference_steps, guidance_scale=args.guidance_scale, save_path=args.save_path, ) print("main card generated") if args.print_clean_text: print(cleaned_text) if args.print_json: print(json.dumps(inferred_json, indent=2)) print(f"Card generated and saved to: {args.save_path}") if __name__ == "__main__": main()