127 lines
5.1 KiB
Python
127 lines
5.1 KiB
Python
import os
|
||
import json
|
||
from typing import List
|
||
from ...utils.multi_objective.evaluator import MultiObjectiveEvaluator
|
||
from ...utils.multi_objective.storage import MultiObjectiveGenerationStorage
|
||
from .individual import MultiObjectiveIndividual
|
||
from ..llm_integration import LLMClient
|
||
from ...config.settings import DEFAULT_EVOLUTION_PARAMS
|
||
from ..evolution_algorithms.multi_objective import MultiObjectiveEvolution
|
||
from ..operators.initialize_operator import InitializeOperator
|
||
|
||
class MultiObjectiveEvolutionEngine:
|
||
"""多目标进化引擎"""
|
||
|
||
def __init__(self, problem_path: str):
|
||
self.problem_path = problem_path
|
||
self.storage = MultiObjectiveGenerationStorage(problem_path)
|
||
self.evaluator = MultiObjectiveEvaluator(problem_path)
|
||
self.llm_client = LLMClient.from_config(problem_path)
|
||
self.initialize_operator = InitializeOperator(self.llm_client)
|
||
|
||
# 加载进化参数
|
||
config = self._load_problem_config()
|
||
self.evolution_params = {
|
||
**DEFAULT_EVOLUTION_PARAMS,
|
||
**config.get("evolution_params", {})
|
||
}
|
||
|
||
# 初始化多目标进化算法
|
||
self.evolution_algorithm = MultiObjectiveEvolution(
|
||
self.evolution_params,
|
||
self.llm_client
|
||
)
|
||
|
||
print(f"多目标进化参数:{self.evolution_params}")
|
||
|
||
def initialize_population(self, size: int) -> List[MultiObjectiveIndividual]:
|
||
"""改进后的种群初始化方法"""
|
||
problem_config = self._load_problem_config()
|
||
|
||
# 首先生成多个不同的算法思路
|
||
ideas = self.initialize_operator.ideas_generator.generate_ideas(
|
||
problem_config["description"],
|
||
size
|
||
)
|
||
|
||
# 基于每个思路生成具体实现
|
||
population = []
|
||
for i, idea in enumerate(ideas):
|
||
code = self.initialize_operator.generate_initial_code(
|
||
problem_config["description"],
|
||
problem_config["function_name"],
|
||
problem_config["input_format"],
|
||
problem_config["output_format"],
|
||
idea
|
||
)
|
||
if code:
|
||
population.append(MultiObjectiveIndividual(
|
||
code,
|
||
generation=0,
|
||
metadata={"idea": idea} # 保存原始思路
|
||
))
|
||
|
||
print(f"成功生成初始种群,包含{len(population)}个不同个体")
|
||
return population
|
||
|
||
def run_evolution(self, generations: int = None, population_size: int = None):
|
||
"""运行多目标进化"""
|
||
generations = generations or self.evolution_params["generations"]
|
||
population_size = population_size or self.evolution_params["population_size"]
|
||
|
||
print(f"开始多目标进化,代数:{generations},种群大小:{population_size}")
|
||
population = self.initialize_population(population_size)
|
||
|
||
# 评估初始种群
|
||
print("\n评估初始种群...")
|
||
for ind in population:
|
||
ind.fitnesses = self.evaluator.evaluate(ind.code)
|
||
print(f"初始个体适应度:{ind.fitnesses}")
|
||
|
||
# 主进化循环
|
||
for gen in range(generations):
|
||
print(f"\n开始第 {gen+1}/{generations} 代进化")
|
||
|
||
# 生成子代
|
||
offspring = []
|
||
while len(offspring) < population_size:
|
||
parents = self.evolution_algorithm.select(population, 2)
|
||
children = self.evolution_algorithm.crossover(parents)
|
||
for child in children:
|
||
mutated_child = self.evolution_algorithm.mutate(child)
|
||
if mutated_child.code:
|
||
mutated_child.fitnesses = self.evaluator.evaluate(mutated_child.code)
|
||
print(f"新个体适应度:{mutated_child.fitnesses}")
|
||
offspring.append(mutated_child)
|
||
|
||
# 生存选择
|
||
population = self.evolution_algorithm.survive(population, offspring, population_size)
|
||
|
||
# 保存当前代信息
|
||
self.storage.save_generation(gen, population)
|
||
|
||
# 打印前沿信息
|
||
front = self._calculate_pareto_front(population)
|
||
print(f"当前代Pareto前沿大小:{len(front)}")
|
||
|
||
print("\n多目标进化完成!")
|
||
return population
|
||
|
||
def _calculate_pareto_front(self, population: List[MultiObjectiveIndividual]) -> List[MultiObjectiveIndividual]:
|
||
"""计算Pareto前沿"""
|
||
front = []
|
||
for ind in population:
|
||
is_dominated = False
|
||
for other in population:
|
||
if other.dominates(ind):
|
||
is_dominated = True
|
||
break
|
||
if not is_dominated:
|
||
front.append(ind)
|
||
return front
|
||
|
||
def _load_problem_config(self) -> dict:
|
||
"""加载问题配置"""
|
||
config_path = os.path.join(self.problem_path, "problem_config.json")
|
||
with open(config_path, "r", encoding="utf-8") as f:
|
||
return json.load(f) |