diff --git a/prompt_to_card_pipeline.py b/prompt_to_card_pipeline.py index bf34091..bc80d08 100644 --- a/prompt_to_card_pipeline.py +++ b/prompt_to_card_pipeline.py @@ -19,12 +19,15 @@ import argparse import importlib import importlib.util import json +import logging import subprocess import sys from dataclasses import dataclass from pathlib import Path from typing import Any, Callable, Mapping +logger = logging.getLogger(__name__) + def _load_module_from_file(module_file: str): module_path = Path(module_file).resolve() @@ -37,14 +40,14 @@ def _load_module_from_file(module_file: str): module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) - print("module successfully charged") + logger.debug("Module loaded from %s", module_path) return module def _load_function_from_file(module_file: str, function_name: str) -> Callable[..., Any]: - print("model charging 1") + logger.debug("Loading function '%s' from %s", function_name, module_file) module = _load_module_from_file(module_file) - print("model charged 1") + logger.debug("Function '%s' loaded", function_name) if not hasattr(module, function_name): raise AttributeError(f"{module_file} has no function named '{function_name}'") func = getattr(module, function_name) @@ -54,7 +57,7 @@ def _load_function_from_file(module_file: str, function_name: str) -> Callable[. def _extract_json_from_output(raw: str) -> Mapping[str, Any]: - print("_extract_json_from_output") + logger.debug("Extracting JSON from inference output") stripped = raw.strip() if not stripped: raise ValueError("Inference command returned empty output") @@ -75,7 +78,7 @@ def _extract_json_from_output(raw: str) -> Mapping[str, Any]: candidate = stripped[last_open : last_close + 1] parsed = json.loads(candidate) - print("json parsed with success") + logger.debug("JSON parsed from fallback extraction") if not isinstance(parsed, dict): raise ValueError("Parsed fallback JSON is not an object") return parsed @@ -88,7 +91,7 @@ def run_infer_json_cli( python_executable: str | None = None, ) -> Mapping[str, Any]: infer_script = Path(infer_script_path).resolve() - print("run_infer_json_cli") + logger.debug("Running JSON inference CLI: %s", infer_script) if not infer_script.exists(): raise FileNotFoundError(f"infer_json_usage.py not found: {infer_script}") @@ -104,7 +107,7 @@ def run_infer_json_cli( str(template_file), cleaned_text, ] - print("will start result") + logger.debug("Launching inference subprocess") result = subprocess.run(cmd, capture_output=True, text=True, check=False) if result.returncode != 0: @@ -113,12 +116,12 @@ def run_infer_json_cli( "JSON inference command failed. " f"exit={result.returncode}, stderr={stderr or ''}" ) - print("result is done") + logger.debug("Inference subprocess completed successfully") return _extract_json_from_output(result.stdout) def default_metadata_to_conditioning(meta: Mapping[str, Any]) -> str: - print("default_metadata_to_conditioning") + logger.debug("Building default conditioning from metadata") name = str(meta.get("name", "Unknown Pokemon")) types = meta.get("types") or [] if isinstance(types, list): @@ -164,18 +167,18 @@ class CheckpointCardGenerator: def _build_pipe(self): if self.generator_module_path: - print("getting module") + logger.debug("Loading generator module: %s", self.generator_module_path) module = _load_module_from_file(self.generator_module_path) - print("module got") + logger.debug("Generator module loaded") if not hasattr(module, "build_pipeline"): raise AttributeError( "Custom generator module must define build_pipeline(checkpoint_path, device)." ) - print("building pipeline") + logger.debug("Building pipeline from generator module") build_pipeline = getattr(module, "build_pipeline") if not callable(build_pipeline): raise TypeError("build_pipeline exists but is not callable") - print("pipeline build") + logger.debug("Pipeline built successfully") return build_pipeline(self.checkpoint_path, self.device) # Best-effort direct checkpoint loading for simple callable pipeline dumps. @@ -185,9 +188,9 @@ class CheckpointCardGenerator: raise RuntimeError( "torch is required to load checkpoint files. Install torch or provide --generator-module." ) from exc - print("loading checkpoint") + logger.debug("Loading checkpoint from %s", self.checkpoint_path) checkpoint = torch.load(self.checkpoint_path, map_location=self.device) - print("checkpoint loaded") + logger.debug("Checkpoint loaded") if callable(checkpoint): return checkpoint @@ -205,9 +208,9 @@ class CheckpointCardGenerator: def _build_conditioning_function(self) -> Callable[[Mapping[str, Any]], str]: if self.generator_module_path: - print("model charge 2") + logger.debug("Loading conditioning function from %s", self.generator_module_path) module = _load_module_from_file(self.generator_module_path) - print("model charged 2") + logger.debug("Conditioning function module loaded") if hasattr(module, "metadata_to_conditioning"): func = getattr(module, "metadata_to_conditioning") if callable(func): @@ -296,42 +299,37 @@ def _build_parser() -> argparse.ArgumentParser: def main() -> None: args = _build_parser().parse_args() - print("main get clean text") - + logger.debug("Loading text cleaner") get_clean_text = _load_function_from_file(args.text_cleaner_path, "get_clean_text") - print("main got clean text") + logger.debug("Cleaning input text") cleaned_text = get_clean_text(args.text) - print("main got args.text") if not isinstance(cleaned_text, str): raise TypeError("get_clean_text(...) must return a string") - print("main get inferred") + logger.debug("Running JSON inference") inferred_json = run_infer_json_cli( infer_script_path=args.infer_script_path, template_path=args.template, cleaned_text=cleaned_text, python_executable=args.python_executable, ) - print("main got inferred") - print("main get generator") - - + logger.debug("Initializing card generator") generator = CheckpointCardGenerator( checkpoint_path=args.checkpoint, device=args.device, generator_module_path=args.generator_module, ) - print("main got generator and will generate card") + logger.debug("Generating card image") generator.generate_card_from_metadata( inferred_json, num_inference_steps=args.num_inference_steps, guidance_scale=args.guidance_scale, save_path=args.save_path, ) - print("main card generated") + logger.info("Card saved to %s", args.save_path) if args.print_clean_text: @@ -339,8 +337,6 @@ def main() -> None: if args.print_json: print(json.dumps(inferred_json, indent=2)) - print(f"Card generated and saved to: {args.save_path}") - if __name__ == "__main__": main() \ No newline at end of file