rich_utils.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. from pathlib import Path
  2. from typing import Sequence
  3. import rich
  4. import rich.syntax
  5. import rich.tree
  6. from hydra.core.hydra_config import HydraConfig
  7. from lightning.pytorch.utilities import rank_zero_only
  8. from omegaconf import DictConfig, OmegaConf, open_dict
  9. from rich.prompt import Prompt
  10. from fish_speech.utils import logger as log
  11. @rank_zero_only
  12. def print_config_tree(
  13. cfg: DictConfig,
  14. print_order: Sequence[str] = (
  15. "data",
  16. "model",
  17. "callbacks",
  18. "logger",
  19. "trainer",
  20. "paths",
  21. "extras",
  22. ),
  23. resolve: bool = False,
  24. save_to_file: bool = False,
  25. ) -> None:
  26. """Prints content of DictConfig using Rich library and its tree structure.
  27. Args:
  28. cfg (DictConfig): Configuration composed by Hydra.
  29. print_order (Sequence[str], optional): Determines in what order config components are printed.
  30. resolve (bool, optional): Whether to resolve reference fields of DictConfig.
  31. save_to_file (bool, optional): Whether to export config to the hydra output folder.
  32. """ # noqa: E501
  33. style = "dim"
  34. tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
  35. queue = []
  36. # add fields from `print_order` to queue
  37. for field in print_order:
  38. (
  39. queue.append(field)
  40. if field in cfg
  41. else log.warning(
  42. f"Field '{field}' not found in config. "
  43. + f"Skipping '{field}' config printing..."
  44. )
  45. )
  46. # add all the other fields to queue (not specified in `print_order`)
  47. for field in cfg:
  48. if field not in queue:
  49. queue.append(field)
  50. # generate config tree from queue
  51. for field in queue:
  52. branch = tree.add(field, style=style, guide_style=style)
  53. config_group = cfg[field]
  54. if isinstance(config_group, DictConfig):
  55. branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
  56. else:
  57. branch_content = str(config_group)
  58. branch.add(rich.syntax.Syntax(branch_content, "yaml"))
  59. # print config tree
  60. rich.print(tree)
  61. # save config tree to file
  62. if save_to_file:
  63. with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
  64. rich.print(tree, file=file)
  65. @rank_zero_only
  66. def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
  67. """Prompts user to input tags from command line if no tags are provided in config.""" # noqa: E501
  68. if not cfg.get("tags"):
  69. if "id" in HydraConfig().cfg.hydra.job:
  70. raise ValueError("Specify tags before launching a multirun!")
  71. log.warning("No tags provided in config. Prompting user to input tags...")
  72. tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
  73. tags = [t.strip() for t in tags.split(",") if t != ""]
  74. with open_dict(cfg):
  75. cfg.tags = tags
  76. log.info(f"Tags: {cfg.tags}")
  77. if save_to_file:
  78. with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
  79. rich.print(cfg.tags, file=file)