test_config.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. from copy import deepcopy
  2. from pathlib import Path
  3. import pytest
  4. import toml
  5. from core.base.utils import deep_update
  6. from core.main.config import R2RConfig
  7. @pytest.fixture
  8. def base_config():
  9. """Load the base r2r.toml config"""
  10. config_path = Path(__file__).parent.parent.parent / "r2r.toml"
  11. with open(config_path) as f:
  12. return toml.load(f)
  13. @pytest.fixture
  14. def config_dir():
  15. """Get the path to the configs directory"""
  16. return Path(__file__).parent.parent.parent / "core" / "configs"
  17. @pytest.fixture
  18. def all_config_files(config_dir):
  19. """Get list of all TOML files in the configs directory"""
  20. return list(config_dir.glob("*.toml"))
  21. @pytest.fixture
  22. def all_configs(all_config_files):
  23. """Load all config files"""
  24. configs = {}
  25. for config_file in all_config_files:
  26. with open(config_file) as f:
  27. configs[config_file.name] = toml.load(f)
  28. return configs
  29. @pytest.fixture
  30. def full_config(all_configs):
  31. """Get the full.toml config"""
  32. return all_configs["full.toml"]
  33. @pytest.fixture
  34. def all_merged_configs(base_config, all_configs):
  35. """Create merged configurations for all config files"""
  36. merged = {}
  37. for config_name, config_data in all_configs.items():
  38. merged[config_name] = deep_update(deepcopy(base_config), config_data)
  39. return merged
  40. @pytest.fixture
  41. def merged_config(base_config, full_config):
  42. """Create the expected merged configuration"""
  43. return deep_update(deepcopy(base_config), full_config)
  44. def test_base_config_loading(base_config):
  45. """Test that the base config loads correctly with all expected values"""
  46. config = R2RConfig(base_config)
  47. # Test critical base values
  48. assert config.database.graph_creation_settings.clustering_mode == "local"
  49. assert (
  50. config.database.graph_creation_settings.generation_config.model
  51. == "openai/gpt-4o-mini"
  52. )
  53. assert config.ingestion.provider == "r2r"
  54. assert config.orchestration.provider == "simple"
  55. def test_full_config_override(base_config, full_config):
  56. """Test that the full config properly overrides base values"""
  57. config = R2RConfig(full_config)
  58. # Test overridden values
  59. assert config.database.graph_creation_settings.clustering_mode == "remote"
  60. assert (
  61. config.database.graph_creation_settings.generation_config.model
  62. == "openai/gpt-4o-mini"
  63. )
  64. assert config.ingestion.provider == "unstructured_local"
  65. assert config.orchestration.provider == "hatchet"
  66. def test_nested_config_preservation(merged_config):
  67. """Test that nested configurations are properly preserved during merging"""
  68. config = R2RConfig(merged_config)
  69. assert (
  70. config.database.graph_creation_settings.generation_config.model
  71. == "openai/gpt-4o-mini"
  72. )
  73. assert (
  74. config.database.graph_creation_settings.generation_config.temperature
  75. == 0.1
  76. )
  77. assert (
  78. config.database.graph_creation_settings.max_knowledge_relationships
  79. == 100
  80. )
  81. def test_new_values_in_override(merged_config):
  82. """Test that new values in the override config are properly added"""
  83. config = R2RConfig(merged_config)
  84. # Test new orchestration values
  85. assert config.orchestration.kg_creation_concurrency_limit == 32
  86. assert config.orchestration.ingestion_concurrency_limit == 16
  87. assert config.orchestration.kg_concurrency_limit == 8
  88. def test_config_type_consistency(merged_config):
  89. """Test that configuration values maintain their expected types"""
  90. config = R2RConfig(merged_config)
  91. # Test type consistency for various fields
  92. assert isinstance(
  93. config.database.graph_creation_settings.max_knowledge_relationships,
  94. int,
  95. )
  96. assert isinstance(
  97. config.database.graph_creation_settings.clustering_mode, str
  98. )
  99. assert isinstance(config.ingestion.chunking_strategy, str)
  100. def get_config_files():
  101. """Helper function to get list of config files"""
  102. config_dir = Path(__file__).parent.parent.parent / "core" / "configs"
  103. return ["r2r.toml"] + [f.name for f in config_dir.glob("*.toml")]
  104. @pytest.mark.parametrize("config_file", get_config_files())
  105. def test_config_required_keys(config_file):
  106. """Test that all required keys are present in all config files"""
  107. if config_file == "r2r.toml":
  108. file_path = Path(__file__).parent.parent.parent / "r2r.toml"
  109. else:
  110. file_path = (
  111. Path(__file__).parent.parent.parent
  112. / "core"
  113. / "configs"
  114. / config_file
  115. )
  116. with open(file_path) as f:
  117. config_data = toml.load(f)
  118. config = R2RConfig(config_data)
  119. # Test required sections
  120. for section in R2RConfig.REQUIRED_KEYS:
  121. assert hasattr(config, section), f"Missing required section: {section}"
  122. # Test required keys in each section
  123. for section, required_keys in R2RConfig.REQUIRED_KEYS.items():
  124. if required_keys: # Skip empty required_keys lists
  125. section_config = getattr(config, section)
  126. for key in required_keys:
  127. if isinstance(section_config, dict):
  128. assert (
  129. key in section_config
  130. ), f"Missing required key {key} in section {section}"
  131. else:
  132. assert hasattr(
  133. section_config, key
  134. ), f"Missing required key {key} in section {section}"
  135. def test_serialization_roundtrip(merged_config):
  136. """Test that configuration can be serialized and deserialized without data loss"""
  137. config = R2RConfig(merged_config)
  138. serialized = config.to_toml()
  139. # Load the serialized config back
  140. roundtrip_config = R2RConfig(toml.loads(serialized))
  141. # Test key values after roundtrip
  142. assert (
  143. roundtrip_config.database.graph_creation_settings.clustering_mode
  144. == config.database.graph_creation_settings.clustering_mode
  145. )
  146. assert (
  147. roundtrip_config.database.graph_creation_settings.generation_config.model
  148. == config.database.graph_creation_settings.generation_config.model
  149. )
  150. assert (
  151. roundtrip_config.orchestration.provider
  152. == config.orchestration.provider
  153. )
  154. def test_all_merged_configs(base_config, all_merged_configs):
  155. """Test that all configs properly merge with base config"""
  156. for config_name, merged_data in all_merged_configs.items():
  157. config = R2RConfig(merged_data)
  158. # Test that the config loads without errors
  159. assert config is not None
  160. # Verify that base values are preserved unless explicitly overridden
  161. if not hasattr(
  162. config.database.graph_creation_settings, "clustering_mode"
  163. ):
  164. assert (
  165. config.database.graph_creation_settings.clustering_mode
  166. == "local"
  167. )
  168. # Verify that generation_config model is preserved unless explicitly overridden
  169. if "generation_config" not in merged_data.get("database", {}).get(
  170. "graph_creation_settings", {}
  171. ):
  172. assert (
  173. config.database.graph_creation_settings.generation_config.model
  174. == "openai/gpt-4o-mini"
  175. )
  176. def test_all_config_overrides(base_config, all_configs):
  177. """Test that all config files can be loaded independently"""
  178. for config_name, config_data in all_configs.items():
  179. config = R2RConfig(config_data)
  180. assert config is not None