rich_utils.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. from pathlib import Path
  2. from typing import Sequence
  3. import rich
  4. import rich.syntax
  5. import rich.tree
  6. from hydra.core.hydra_config import HydraConfig
  7. from lightning.pytorch.utilities import rank_zero_only
  8. from omegaconf import DictConfig, OmegaConf, open_dict
  9. from rich.prompt import Prompt
  10. from fish_speech.utils import logger as log
  11. @rank_zero_only
  12. def print_config_tree(
  13. cfg: DictConfig,
  14. print_order: Sequence[str] = (
  15. "data",
  16. "model",
  17. "callbacks",
  18. "logger",
  19. "trainer",
  20. "paths",
  21. "extras",
  22. ),
  23. resolve: bool = False,
  24. save_to_file: bool = False,
  25. ) -> None:
  26. """Prints content of DictConfig using Rich library and its tree structure.
  27. Args:
  28. cfg (DictConfig): Configuration composed by Hydra.
  29. print_order (Sequence[str], optional): Determines in what order config components are printed.
  30. resolve (bool, optional): Whether to resolve reference fields of DictConfig.
  31. save_to_file (bool, optional): Whether to export config to the hydra output folder.
  32. """ # noqa: E501
  33. style = "dim"
  34. tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
  35. queue = []
  36. # add fields from `print_order` to queue
  37. for field in print_order:
  38. queue.append(field) if field in cfg else log.warning(
  39. f"Field '{field}' not found in config. "
  40. + f"Skipping '{field}' config printing..."
  41. )
  42. # add all the other fields to queue (not specified in `print_order`)
  43. for field in cfg:
  44. if field not in queue:
  45. queue.append(field)
  46. # generate config tree from queue
  47. for field in queue:
  48. branch = tree.add(field, style=style, guide_style=style)
  49. config_group = cfg[field]
  50. if isinstance(config_group, DictConfig):
  51. branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
  52. else:
  53. branch_content = str(config_group)
  54. branch.add(rich.syntax.Syntax(branch_content, "yaml"))
  55. # print config tree
  56. rich.print(tree)
  57. # save config tree to file
  58. if save_to_file:
  59. with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
  60. rich.print(tree, file=file)
  61. @rank_zero_only
  62. def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
  63. """Prompts user to input tags from command line if no tags are provided in config.""" # noqa: E501
  64. if not cfg.get("tags"):
  65. if "id" in HydraConfig().cfg.hydra.job:
  66. raise ValueError("Specify tags before launching a multirun!")
  67. log.warning("No tags provided in config. Prompting user to input tags...")
  68. tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
  69. tags = [t.strip() for t in tags.split(",") if t != ""]
  70. with open_dict(cfg):
  71. cfg.tags = tags
  72. log.info(f"Tags: {cfg.tags}")
  73. if save_to_file:
  74. with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
  75. rich.print(cfg.tags, file=file)