lead/tsp_data/tsp_solver.py
2025-03-17 16:40:01 +08:00

105 lines
3.4 KiB
Python

import os
import numpy as np
import matplotlib.pyplot as plt
def read_tsp_file(file_path):
cities = []
with open(file_path, 'r') as f:
reading_coords = False
for line in f:
if line.strip() == 'NODE_COORD_SECTION':
reading_coords = True
continue
if line.strip() == 'EOF':
break
if reading_coords:
parts = line.strip().split()
if len(parts) == 3:
index, x, y = int(parts[0]), float(parts[1]), float(parts[2])
cities.append((index, x, y))
return cities
def nearest_neighbor_tsp(cities):
n = len(cities)
unvisited = set(range(1, n + 1)) # 城市索引从1开始
current = 1 # 从第一个城市开始
path = [current]
unvisited.remove(current)
# 计算两个城市之间的欧几里得距离
def distance(city1, city2):
_, x1, y1 = cities[city1 - 1]
_, x2, y2 = cities[city2 - 1]
return np.sqrt((x1 - x2)**2 + (y1 - y2)**2)
# 贪心选择最近的下一个城市
while unvisited:
next_city = min(unvisited, key=lambda x: distance(current, x))
path.append(next_city)
unvisited.remove(next_city)
current = next_city
# 回到起点
path.append(1)
return path
def plot_tsp_solution(cities, path, filename):
# 创建图形
plt.figure(figsize=(12, 8), dpi=300)
# 绘制所有城市点
x_coords = [x for _, x, y in cities]
y_coords = [y for _, x, y in cities]
plt.scatter(x_coords, y_coords, c='blue', s=50)
# 标注城市编号
# for i, (index, x, y) in enumerate(cities):
# plt.annotate(f'{index}', (x, y), xytext=(5, 5), textcoords='offset points')
# 绘制路径
for i in range(len(path) - 1):
city1 = cities[path[i] - 1]
city2 = cities[path[i + 1] - 1]
plt.plot([city1[1], city2[1]], [city1[2], city2[2]], 'r-', alpha=0.6)
# 计算总路径长度
total_distance = 0
for i in range(len(path) - 1):
city1 = cities[path[i] - 1]
city2 = cities[path[i + 1] - 1]
total_distance += np.sqrt((city1[1] - city2[1])**2 + (city1[2] - city2[2])**2)
# 设置标题和标签
plt.title(f'TSP解决方案: {os.path.basename(filename)}\n总路径长度: {total_distance:.2f}')
plt.xlabel('X坐标')
plt.ylabel('Y坐标')
plt.grid(True)
# 保存图片
output_filename = os.path.join('solutions', os.path.splitext(os.path.basename(filename))[0] + '_solution.png')
os.makedirs('solutions', exist_ok=True)
plt.savefig(output_filename, dpi=300, bbox_inches='tight')
plt.close()
def main():
data_dir = 'data'
tsp_files = [f for f in os.listdir(data_dir) if f.endswith('.tsp')]
for tsp_file in tsp_files:
file_path = os.path.join(data_dir, tsp_file)
print(f'\n处理文件: {tsp_file}')
# 读取城市数据
cities = read_tsp_file(file_path)
print(f'城市数量: {len(cities)}')
# 使用最近邻算法求解
path = nearest_neighbor_tsp(cities)
print(f'计算的路径: {path}')
# 绘制并保存结果
plot_tsp_solution(cities, path, file_path)
print(f'结果已保存到 solutions/{os.path.splitext(tsp_file)[0]}_solution.png')
if __name__ == '__main__':
main()