test_config.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. from copy import deepcopy
  2. from pathlib import Path
  3. import pytest
  4. import toml
  5. from core.base.utils import deep_update
  6. from core.main.config import R2RConfig
  7. # Skip all tests in this file until config files are properly set up
  8. pytestmark = pytest.mark.skip("Config tests need to be updated with proper file paths")
  9. ###############################################################################
  10. # Fixtures
  11. ###############################################################################
  12. @pytest.fixture
  13. def base_config():
  14. """Load the base r2r.toml config (new structure)"""
  15. config_path = Path(__file__).parent.parent.parent / "r2r/r2r.toml"
  16. with open(config_path) as f:
  17. return toml.load(f)
  18. @pytest.fixture
  19. def config_dir():
  20. """Get the path to the configs directory."""
  21. return Path(__file__).parent.parent.parent / "core" / "configs"
  22. @pytest.fixture
  23. def all_config_files(config_dir):
  24. """Get list of all TOML files in the configs directory."""
  25. return list(config_dir.glob("*.toml"))
  26. @pytest.fixture
  27. def all_configs(all_config_files):
  28. """Load all config files."""
  29. configs = {}
  30. for config_file in all_config_files:
  31. with open(config_file) as f:
  32. configs[config_file.name] = toml.load(f)
  33. return configs
  34. @pytest.fixture
  35. def full_config(all_configs):
  36. """Return the full override config (full.toml)"""
  37. return all_configs["full.toml"]
  38. @pytest.fixture
  39. def all_merged_configs(base_config, all_configs):
  40. """Merge every override config into the base config."""
  41. merged = {}
  42. for config_name, config_data in all_configs.items():
  43. merged[config_name] = deep_update(deepcopy(base_config), config_data)
  44. return merged
  45. @pytest.fixture
  46. def merged_config(base_config, full_config):
  47. """Merge the full override config into the base config."""
  48. return deep_update(deepcopy(base_config), full_config)
  49. ###############################################################################
  50. # Tests
  51. ###############################################################################
  52. def test_base_config_loading(base_config):
  53. """Test that the base config loads correctly with the new expected values.
  54. """
  55. config = R2RConfig(base_config)
  56. # Verify that the database graph creation settings are present and set
  57. assert (config.database.graph_creation_settings.
  58. graph_entity_description_prompt == "graph_entity_description")
  59. assert (config.database.graph_creation_settings.graph_extraction_prompt ==
  60. "graph_extraction")
  61. assert (config.database.graph_creation_settings.automatic_deduplication
  62. is True)
  63. # Verify other key sections
  64. assert config.ingestion.provider == "r2r"
  65. assert config.orchestration.provider == "simple"
  66. assert config.app.default_max_upload_size == 214748364800
  67. def test_full_config_override(full_config):
  68. """Test that full.toml properly overrides the base values.
  69. For example, assume the full override changes:
  70. - ingestion.provider from "r2r" to "unstructured_local"
  71. - orchestration.provider from "simple" to "hatchet"
  72. - and adds a new nested key in database.graph_creation_settings.
  73. """
  74. config = R2RConfig(full_config)
  75. assert config.ingestion.provider == "unstructured_local"
  76. assert config.orchestration.provider == "hatchet"
  77. # Check that a new nested key has been added
  78. assert (config.database.graph_creation_settings.max_knowledge_relationships
  79. == 100)
  80. def test_nested_config_preservation(merged_config):
  81. """Test that nested configuration values are preserved after merging."""
  82. config = R2RConfig(merged_config)
  83. assert (config.database.graph_creation_settings.max_knowledge_relationships
  84. == 100)
  85. def test_new_values_in_override(merged_config):
  86. """Test that new keys in the override config are added.
  87. In the old tests we asserted values for orchestration concurrency keys. In
  88. the new config structure these keys have been removed (or renamed).
  89. Therefore, we now check for them only if they exist.
  90. """
  91. config = R2RConfig(merged_config)
  92. # If the override adds an ingestion concurrency limit, check it.
  93. if hasattr(config.orchestration, "ingestion_concurrency_limit"):
  94. assert config.orchestration.ingestion_concurrency_limit == 16
  95. # Optionally, if new keys like graph_search_results_creation_concurrency_limit are defined, check them:
  96. if hasattr(config.orchestration,
  97. "graph_search_results_creation_concurrency_limit"):
  98. assert (config.orchestration.
  99. graph_search_results_creation_concurrency_limit == 32)
  100. if hasattr(config.orchestration, "graph_search_results_concurrency_limit"):
  101. assert config.orchestration.graph_search_results_concurrency_limit == 8
  102. def test_config_type_consistency(merged_config):
  103. """Test that configuration values maintain their expected types."""
  104. config = R2RConfig(merged_config)
  105. assert isinstance(
  106. config.database.graph_creation_settings.
  107. graph_entity_description_prompt,
  108. str,
  109. )
  110. assert isinstance(
  111. config.database.graph_creation_settings.automatic_deduplication, bool)
  112. assert isinstance(config.ingestion.chunking_strategy, str)
  113. if hasattr(config.database.graph_creation_settings,
  114. "max_knowledge_relationships"):
  115. assert isinstance(
  116. config.database.graph_creation_settings.
  117. max_knowledge_relationships,
  118. int,
  119. )
  120. def get_config_files():
  121. """Helper function to return the list of configuration file names."""
  122. config_dir = Path(__file__).parent.parent.parent / "core" / "configs"
  123. return ["r2r.toml"] + [f.name for f in config_dir.glob("*.toml")]
  124. @pytest.mark.parametrize("config_file", get_config_files())
  125. def test_config_required_keys(config_file):
  126. """Test that all required sections and keys (per R2RConfig.REQUIRED_KEYS)
  127. exist.
  128. In the new structure the 'agent' section no longer includes the key
  129. 'generation_config', so we filter that out.
  130. """
  131. if config_file == "r2r.toml":
  132. file_path = Path(__file__).parent.parent.parent / "r2r/r2r.toml"
  133. else:
  134. file_path = (Path(__file__).parent.parent.parent / "core" / "configs" /
  135. config_file)
  136. with open(file_path) as f:
  137. config_data = toml.load(f)
  138. config = R2RConfig(config_data)
  139. # Check for required sections
  140. for section in R2RConfig.REQUIRED_KEYS:
  141. assert hasattr(config, section), f"Missing required section: {section}"
  142. # Check for required keys in each section.
  143. # For the agent section, remove 'generation_config' since it no longer exists.
  144. for section, required_keys in R2RConfig.REQUIRED_KEYS.items():
  145. keys_to_check = required_keys
  146. if section == "agent":
  147. keys_to_check = [
  148. key for key in required_keys if key != "generation_config"
  149. ]
  150. if keys_to_check:
  151. section_config = getattr(config, section)
  152. for key in keys_to_check:
  153. if isinstance(section_config, dict):
  154. assert key in section_config, (
  155. f"Missing required key {key} in section {section}")
  156. else:
  157. assert hasattr(section_config, key), (
  158. f"Missing required key {key} in section {section}")
  159. def test_serialization_roundtrip(merged_config):
  160. """Test that serializing and then deserializing the config does not lose
  161. data."""
  162. config = R2RConfig(merged_config)
  163. serialized = config.to_toml()
  164. # Load the serialized config back
  165. roundtrip_config = R2RConfig(toml.loads(serialized))
  166. # Compare a couple of key values after roundtrip.
  167. assert (roundtrip_config.database.graph_creation_settings.
  168. graph_entity_description_prompt == config.database.
  169. graph_creation_settings.graph_entity_description_prompt)
  170. assert (roundtrip_config.orchestration.provider ==
  171. config.orchestration.provider)
  172. def test_all_merged_configs(base_config, all_merged_configs):
  173. """Test that every override file properly merges with the base config."""
  174. for config_name, merged_data in all_merged_configs.items():
  175. config = R2RConfig(merged_data)
  176. assert config is not None
  177. # Example: if the override does not change app.default_max_upload_size,
  178. # it should remain as in the base config.
  179. if "default_max_upload_size" not in merged_data.get("app", {}):
  180. assert config.app.default_max_upload_size == 214748364800
  181. def test_all_config_overrides(all_configs):
  182. """Test that all configuration files can be loaded independently."""
  183. for config_name, config_data in all_configs.items():
  184. config = R2RConfig(config_data)
  185. assert config is not None