Replace debug print() statements with proper logging #21

Merged
llabeyrie merged 1 commits from fix/replace-print-with-logging into main 2026-03-19 19:06:41 +00:00

View File

@@ -19,12 +19,15 @@ import argparse
import importlib import importlib
import importlib.util import importlib.util
import json import json
import logging
import subprocess import subprocess
import sys import sys
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Mapping from typing import Any, Callable, Mapping
logger = logging.getLogger(__name__)
def _load_module_from_file(module_file: str): def _load_module_from_file(module_file: str):
module_path = Path(module_file).resolve() module_path = Path(module_file).resolve()
@@ -37,14 +40,14 @@ def _load_module_from_file(module_file: str):
module = importlib.util.module_from_spec(spec) module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) spec.loader.exec_module(module)
print("module successfully charged") logger.debug("Module loaded from %s", module_path)
return module return module
def _load_function_from_file(module_file: str, function_name: str) -> Callable[..., Any]: def _load_function_from_file(module_file: str, function_name: str) -> Callable[..., Any]:
print("model charging 1") logger.debug("Loading function '%s' from %s", function_name, module_file)
module = _load_module_from_file(module_file) module = _load_module_from_file(module_file)
print("model charged 1") logger.debug("Function '%s' loaded", function_name)
if not hasattr(module, function_name): if not hasattr(module, function_name):
raise AttributeError(f"{module_file} has no function named '{function_name}'") raise AttributeError(f"{module_file} has no function named '{function_name}'")
func = getattr(module, function_name) func = getattr(module, function_name)
@@ -54,7 +57,7 @@ def _load_function_from_file(module_file: str, function_name: str) -> Callable[.
def _extract_json_from_output(raw: str) -> Mapping[str, Any]: def _extract_json_from_output(raw: str) -> Mapping[str, Any]:
print("_extract_json_from_output") logger.debug("Extracting JSON from inference output")
stripped = raw.strip() stripped = raw.strip()
if not stripped: if not stripped:
raise ValueError("Inference command returned empty output") raise ValueError("Inference command returned empty output")
@@ -75,7 +78,7 @@ def _extract_json_from_output(raw: str) -> Mapping[str, Any]:
candidate = stripped[last_open : last_close + 1] candidate = stripped[last_open : last_close + 1]
parsed = json.loads(candidate) parsed = json.loads(candidate)
print("json parsed with success") logger.debug("JSON parsed from fallback extraction")
if not isinstance(parsed, dict): if not isinstance(parsed, dict):
raise ValueError("Parsed fallback JSON is not an object") raise ValueError("Parsed fallback JSON is not an object")
return parsed return parsed
@@ -88,7 +91,7 @@ def run_infer_json_cli(
python_executable: str | None = None, python_executable: str | None = None,
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
infer_script = Path(infer_script_path).resolve() infer_script = Path(infer_script_path).resolve()
print("run_infer_json_cli") logger.debug("Running JSON inference CLI: %s", infer_script)
if not infer_script.exists(): if not infer_script.exists():
raise FileNotFoundError(f"infer_json_usage.py not found: {infer_script}") raise FileNotFoundError(f"infer_json_usage.py not found: {infer_script}")
@@ -104,7 +107,7 @@ def run_infer_json_cli(
str(template_file), str(template_file),
cleaned_text, cleaned_text,
] ]
print("will start result") logger.debug("Launching inference subprocess")
result = subprocess.run(cmd, capture_output=True, text=True, check=False) result = subprocess.run(cmd, capture_output=True, text=True, check=False)
if result.returncode != 0: if result.returncode != 0:
@@ -113,12 +116,12 @@ def run_infer_json_cli(
"JSON inference command failed. " "JSON inference command failed. "
f"exit={result.returncode}, stderr={stderr or '<empty>'}" f"exit={result.returncode}, stderr={stderr or '<empty>'}"
) )
print("result is done") logger.debug("Inference subprocess completed successfully")
return _extract_json_from_output(result.stdout) return _extract_json_from_output(result.stdout)
def default_metadata_to_conditioning(meta: Mapping[str, Any]) -> str: def default_metadata_to_conditioning(meta: Mapping[str, Any]) -> str:
print("default_metadata_to_conditioning") logger.debug("Building default conditioning from metadata")
name = str(meta.get("name", "Unknown Pokemon")) name = str(meta.get("name", "Unknown Pokemon"))
types = meta.get("types") or [] types = meta.get("types") or []
if isinstance(types, list): if isinstance(types, list):
@@ -164,18 +167,18 @@ class CheckpointCardGenerator:
def _build_pipe(self): def _build_pipe(self):
if self.generator_module_path: if self.generator_module_path:
print("getting module") logger.debug("Loading generator module: %s", self.generator_module_path)
module = _load_module_from_file(self.generator_module_path) module = _load_module_from_file(self.generator_module_path)
print("module got") logger.debug("Generator module loaded")
if not hasattr(module, "build_pipeline"): if not hasattr(module, "build_pipeline"):
raise AttributeError( raise AttributeError(
"Custom generator module must define build_pipeline(checkpoint_path, device)." "Custom generator module must define build_pipeline(checkpoint_path, device)."
) )
print("building pipeline") logger.debug("Building pipeline from generator module")
build_pipeline = getattr(module, "build_pipeline") build_pipeline = getattr(module, "build_pipeline")
if not callable(build_pipeline): if not callable(build_pipeline):
raise TypeError("build_pipeline exists but is not callable") raise TypeError("build_pipeline exists but is not callable")
print("pipeline build") logger.debug("Pipeline built successfully")
return build_pipeline(self.checkpoint_path, self.device) return build_pipeline(self.checkpoint_path, self.device)
# Best-effort direct checkpoint loading for simple callable pipeline dumps. # Best-effort direct checkpoint loading for simple callable pipeline dumps.
@@ -185,9 +188,9 @@ class CheckpointCardGenerator:
raise RuntimeError( raise RuntimeError(
"torch is required to load checkpoint files. Install torch or provide --generator-module." "torch is required to load checkpoint files. Install torch or provide --generator-module."
) from exc ) from exc
print("loading checkpoint") logger.debug("Loading checkpoint from %s", self.checkpoint_path)
checkpoint = torch.load(self.checkpoint_path, map_location=self.device) checkpoint = torch.load(self.checkpoint_path, map_location=self.device)
print("checkpoint loaded") logger.debug("Checkpoint loaded")
if callable(checkpoint): if callable(checkpoint):
return checkpoint return checkpoint
@@ -205,9 +208,9 @@ class CheckpointCardGenerator:
def _build_conditioning_function(self) -> Callable[[Mapping[str, Any]], str]: def _build_conditioning_function(self) -> Callable[[Mapping[str, Any]], str]:
if self.generator_module_path: if self.generator_module_path:
print("model charge 2") logger.debug("Loading conditioning function from %s", self.generator_module_path)
module = _load_module_from_file(self.generator_module_path) module = _load_module_from_file(self.generator_module_path)
print("model charged 2") logger.debug("Conditioning function module loaded")
if hasattr(module, "metadata_to_conditioning"): if hasattr(module, "metadata_to_conditioning"):
func = getattr(module, "metadata_to_conditioning") func = getattr(module, "metadata_to_conditioning")
if callable(func): if callable(func):
@@ -296,42 +299,37 @@ def _build_parser() -> argparse.ArgumentParser:
def main() -> None: def main() -> None:
args = _build_parser().parse_args() args = _build_parser().parse_args()
print("main get clean text") logger.debug("Loading text cleaner")
get_clean_text = _load_function_from_file(args.text_cleaner_path, "get_clean_text") get_clean_text = _load_function_from_file(args.text_cleaner_path, "get_clean_text")
print("main got clean text")
logger.debug("Cleaning input text")
cleaned_text = get_clean_text(args.text) cleaned_text = get_clean_text(args.text)
print("main got args.text")
if not isinstance(cleaned_text, str): if not isinstance(cleaned_text, str):
raise TypeError("get_clean_text(...) must return a string") raise TypeError("get_clean_text(...) must return a string")
print("main get inferred")
logger.debug("Running JSON inference")
inferred_json = run_infer_json_cli( inferred_json = run_infer_json_cli(
infer_script_path=args.infer_script_path, infer_script_path=args.infer_script_path,
template_path=args.template, template_path=args.template,
cleaned_text=cleaned_text, cleaned_text=cleaned_text,
python_executable=args.python_executable, python_executable=args.python_executable,
) )
print("main got inferred")
print("main get generator")
logger.debug("Initializing card generator")
generator = CheckpointCardGenerator( generator = CheckpointCardGenerator(
checkpoint_path=args.checkpoint, checkpoint_path=args.checkpoint,
device=args.device, device=args.device,
generator_module_path=args.generator_module, generator_module_path=args.generator_module,
) )
print("main got generator and will generate card")
logger.debug("Generating card image")
generator.generate_card_from_metadata( generator.generate_card_from_metadata(
inferred_json, inferred_json,
num_inference_steps=args.num_inference_steps, num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale, guidance_scale=args.guidance_scale,
save_path=args.save_path, save_path=args.save_path,
) )
print("main card generated") logger.info("Card saved to %s", args.save_path)
if args.print_clean_text: if args.print_clean_text:
@@ -339,8 +337,6 @@ def main() -> None:
if args.print_json: if args.print_json:
print(json.dumps(inferred_json, indent=2)) print(json.dumps(inferred_json, indent=2))
print(f"Card generated and saved to: {args.save_path}")
if __name__ == "__main__": if __name__ == "__main__":
main() main()