SIT/services/code_generator_service.py

320 lines
10 KiB
Python
Raw Permalink Normal View History

2026-01-29 09:22:54 +00:00
"""
代码生成服务
提供C/C++代码自动生成功能
"""
from typing import List, Optional
from pathlib import Path
from jinja2 import Environment, FileSystemLoader, Template
from models.message import Message
from models.mapping_graph import GraphEdge
from config import Config, ProtocolType, SerializationType, LanguageStandard
from utils.logger import get_logger
logger = get_logger(__name__)
class CodeGeneratorService:
"""代码生成服务类"""
def __init__(self):
# 初始化Jinja2环境
self.env = Environment(
loader=FileSystemLoader(str(Config.TEMPLATES_DIR)),
trim_blocks=True,
lstrip_blocks=True
)
def generate_converter_code(self,
source_message: Message,
target_message: Message,
mapping_path: List[GraphEdge],
protocol: ProtocolType,
serialization: SerializationType,
language_standard: LanguageStandard = LanguageStandard.CPP17,
output_dir: Optional[Path] = None) -> tuple[bool, str, dict]:
"""
生成消息转换代码
Args:
source_message: 源消息
target_message: 目标消息
mapping_path: 映射路径边列表
protocol: 传输协议
serialization: 序列化格式
language_standard: 语言标准
output_dir: 输出目录
Returns:
(是否成功, 消息, 生成的文件字典)
"""
try:
# 准备模板数据
template_data = {
'source_message': source_message,
'target_message': target_message,
'mapping_path': mapping_path,
'protocol': protocol,
'serialization': serialization,
'language_standard': language_standard,
}
# 生成各个模块的代码
generated_files = {}
# 1. 数据接收模块
receiver_code = self._generate_receiver_code(template_data)
generated_files['receiver.h'] = receiver_code['header']
generated_files['receiver.cpp'] = receiver_code['source']
# 2. 反序列化模块
deserializer_code = self._generate_deserializer_code(template_data)
generated_files['deserializer.h'] = deserializer_code['header']
generated_files['deserializer.cpp'] = deserializer_code['source']
# 3. 消息映射模块
mapper_code = self._generate_mapper_code(template_data)
generated_files['mapper.h'] = mapper_code['header']
generated_files['mapper.cpp'] = mapper_code['source']
# 4. 序列化模块
serializer_code = self._generate_serializer_code(template_data)
generated_files['serializer.h'] = serializer_code['header']
generated_files['serializer.cpp'] = serializer_code['source']
# 5. 数据发送模块
sender_code = self._generate_sender_code(template_data)
generated_files['sender.h'] = sender_code['header']
generated_files['sender.cpp'] = sender_code['source']
# 6. 主转换器
converter_code = self._generate_converter_code(template_data)
generated_files['message_converter.h'] = converter_code['header']
generated_files['message_converter.cpp'] = converter_code['source']
# 7. CMakeLists.txt
cmake_code = self._generate_cmake(template_data)
generated_files['CMakeLists.txt'] = cmake_code
# 如果指定了输出目录,则写入文件
if output_dir:
self._write_files(generated_files, output_dir)
logger.info(f"Generated converter code: {len(generated_files)} files")
return True, "代码生成成功", generated_files
except Exception as e:
logger.error(f"Failed to generate converter code: {e}")
return False, f"代码生成失败: {str(e)}", {}
def _generate_receiver_code(self, data: dict) -> dict:
"""生成接收器代码"""
protocol = data['protocol']
# 根据协议选择模板
if protocol == ProtocolType.TCP:
template_name = 'cpp/tcp_receiver.template'
elif protocol == ProtocolType.DDS:
template_name = 'cpp/dds_receiver.template'
else:
template_name = 'cpp/tcp_receiver.template' # 默认使用TCP
# 生成头文件和源文件
header = self._render_template_or_default(template_name, data, 'header')
source = self._render_template_or_default(template_name, data, 'source')
return {'header': header, 'source': source}
def _generate_deserializer_code(self, data: dict) -> dict:
"""生成反序列化器代码"""
template_name = 'cpp/deserializer.template'
header = self._render_template_or_default(template_name, data, 'header')
source = self._render_template_or_default(template_name, data, 'source')
return {'header': header, 'source': source}
def _generate_mapper_code(self, data: dict) -> dict:
"""生成映射器代码"""
template_name = 'cpp/mapper.template'
header = self._render_template_or_default(template_name, data, 'header')
source = self._render_template_or_default(template_name, data, 'source')
return {'header': header, 'source': source}
def _generate_serializer_code(self, data: dict) -> dict:
"""生成序列化器代码"""
template_name = 'cpp/serializer.template'
header = self._render_template_or_default(template_name, data, 'header')
source = self._render_template_or_default(template_name, data, 'source')
return {'header': header, 'source': source}
def _generate_sender_code(self, data: dict) -> dict:
"""生成发送器代码"""
template_name = 'cpp/sender.template'
header = self._render_template_or_default(template_name, data, 'header')
source = self._render_template_or_default(template_name, data, 'source')
return {'header': header, 'source': source}
def _generate_converter_code(self, data: dict) -> dict:
"""生成转换器主代码"""
# 使用默认模板
header = self._generate_default_converter_header(data)
source = self._generate_default_converter_source(data)
return {'header': header, 'source': source}
def _generate_cmake(self, data: dict) -> str:
"""生成CMakeLists.txt"""
cmake_content = f"""cmake_minimum_required(VERSION 3.10)
project(MessageConverter)
set(CMAKE_CXX_STANDARD {data['language_standard'].value.replace('C++', '')})
# 源文件
set(SOURCES
receiver.cpp
deserializer.cpp
mapper.cpp
serializer.cpp
sender.cpp
message_converter.cpp
)
# 头文件
set(HEADERS
receiver.h
deserializer.h
mapper.h
serializer.h
sender.h
message_converter.h
)
# 创建库
add_library(message_converter STATIC ${{SOURCES}})
# 创建可执行文件(示例)
add_executable(converter_demo main.cpp)
target_link_libraries(converter_demo message_converter)
"""
return cmake_content
def _render_template_or_default(self, template_name: str, data: dict, part: str) -> str:
"""渲染模板或使用默认代码"""
try:
template = self.env.get_template(template_name)
return template.render(**data, part=part)
except Exception as e:
logger.warning(f"Template {template_name} not found, using default: {e}")
return self._generate_default_code(data, part)
def _generate_default_code(self, data: dict, part: str) -> str:
"""生成默认代码"""
if part == 'header':
return "// Default header\n#pragma once\n"
else:
return "// Default source\n"
def _generate_default_converter_header(self, data: dict) -> str:
"""生成默认的转换器头文件"""
source_msg = data['source_message']
target_msg = data['target_message']
header = f"""#pragma once
#include <vector>
#include <cstdint>
#include "receiver.h"
#include "deserializer.h"
#include "mapper.h"
#include "serializer.h"
#include "sender.h"
/**
* 消息转换器: {source_msg.full_name} -> {target_msg.full_name}
* 协议: {data['protocol'].value}
* 序列化: {data['serialization'].value}
*/
class MessageConverter {{
public:
MessageConverter();
~MessageConverter();
// 执行转换
bool convert(const std::vector<uint8_t>& input_data, std::vector<uint8_t>& output_data);
// 完整的接收-转换-发送流程
void run();
private:
DataReceiver receiver_;
Deserializer deserializer_;
MessageMapper mapper_;
Serializer serializer_;
DataSender sender_;
}};
"""
return header
def _generate_default_converter_source(self, data: dict) -> str:
"""生成默认的转换器源文件"""
source = """#include "message_converter.h"
#include <iostream>
MessageConverter::MessageConverter() {
// 初始化各个组件
}
MessageConverter::~MessageConverter() {
// 清理资源
}
bool MessageConverter::convert(const std::vector<uint8_t>& input_data,
std::vector<uint8_t>& output_data) {
try {
// 1. 反序列化输入数据
auto source_message = deserializer_.deserialize(input_data);
// 2. 执行字段映射
auto target_message = mapper_.map(source_message);
// 3. 序列化输出数据
output_data = serializer_.serialize(target_message);
return true;
} catch (const std::exception& e) {
std::cerr << "Conversion failed: " << e.what() << std::endl;
return false;
}
}
void MessageConverter::run() {
// 1. 接收数据
auto input_data = receiver_.receive();
// 2. 转换
std::vector<uint8_t> output_data;
if (convert(input_data, output_data)) {
// 3. 发送数据
sender_.send(output_data);
}
}
"""
return source
def _write_files(self, files: dict, output_dir: Path):
"""将生成的代码写入文件"""
output_dir.mkdir(parents=True, exist_ok=True)
for filename, content in files.items():
file_path = output_dir / filename
with open(file_path, 'w', encoding='utf-8') as f:
f.write(content)
logger.info(f"Written file: {file_path}")