143 lines
4.7 KiB
Python
143 lines
4.7 KiB
Python
"""Git 推送工具"""
|
||
import subprocess
|
||
import logging
|
||
from typing import Dict, Any
|
||
|
||
from app.tools.base import BaseTool, ToolResult
|
||
from app.tools.registry import ToolRegistry
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
@ToolRegistry.register
|
||
class GitPushTool(BaseTool):
|
||
"""推送代码到远程 Git 仓库"""
|
||
|
||
@property
|
||
def parameters_schema(self) -> Dict[str, Any]:
|
||
return {
|
||
"type": "object",
|
||
"properties": {
|
||
"repo_path": {
|
||
"type": "string",
|
||
"description": "Git 仓库路径"
|
||
},
|
||
"branch_name": {
|
||
"type": "string",
|
||
"description": "分支名称"
|
||
},
|
||
"set_upstream": {
|
||
"type": "boolean",
|
||
"description": "是否设置上游分支",
|
||
"default": True
|
||
},
|
||
"auth_type": {
|
||
"type": "string",
|
||
"description": "认证类型",
|
||
"default": "password",
|
||
"enum": ["password", "token", "ssh"]
|
||
},
|
||
"username": {
|
||
"type": "string",
|
||
"description": "Git 用户名",
|
||
"default": ""
|
||
},
|
||
"password": {
|
||
"type": "string",
|
||
"description": "Git 密码或 Token",
|
||
"default": ""
|
||
}
|
||
},
|
||
"required": ["repo_path", "branch_name"]
|
||
}
|
||
|
||
def execute(self,
|
||
repo_path: str,
|
||
branch_name: str,
|
||
set_upstream: bool = True,
|
||
auth_type: str = "password",
|
||
username: str = "",
|
||
password: str = "",
|
||
**kwargs) -> ToolResult:
|
||
"""
|
||
推送到远程仓库
|
||
|
||
Args:
|
||
repo_path: 仓库路径
|
||
branch_name: 分支名称
|
||
set_upstream: 是否设置上游分支
|
||
auth_type: 认证类型(password/token/ssh)
|
||
username: 用户名
|
||
password: 密码/Token
|
||
|
||
Returns:
|
||
ToolResult: 包含推送结果的工具返回对象
|
||
"""
|
||
# 参数验证
|
||
if not repo_path or not repo_path.strip():
|
||
error_msg = "仓库路径不能为空"
|
||
logger.error(error_msg)
|
||
raise ValueError(error_msg)
|
||
|
||
if not branch_name or not branch_name.strip():
|
||
error_msg = "分支名称不能为空"
|
||
logger.error(error_msg)
|
||
raise ValueError(error_msg)
|
||
|
||
# 如果提供了认证信息,需要更新远程URL
|
||
if auth_type.lower() in ["password", "token"] and username and password:
|
||
# 获取当前远程URL
|
||
get_url_cmd = ["git", "remote", "get-url", "origin"]
|
||
result = subprocess.run(
|
||
get_url_cmd,
|
||
capture_output=True,
|
||
encoding='utf-8',
|
||
cwd=repo_path
|
||
)
|
||
|
||
if result.returncode == 0:
|
||
current_url = result.stdout.strip()
|
||
# 如果URL中没有认证信息,添加认证信息
|
||
if "://" in current_url and "@" not in current_url:
|
||
protocol, rest = current_url.split("://", 1)
|
||
auth_url = f"{protocol}://{username}:{password}@{rest}"
|
||
# 更新远程URL
|
||
set_url_cmd = ["git", "remote", "set-url", "origin", auth_url]
|
||
subprocess.run(
|
||
set_url_cmd,
|
||
capture_output=True,
|
||
encoding='utf-8',
|
||
cwd=repo_path
|
||
)
|
||
logger.info("已更新远程URL,添加认证信息")
|
||
|
||
# 执行 git push
|
||
push_cmd = ["git", "push"]
|
||
if set_upstream:
|
||
push_cmd.extend(["-u", "origin", branch_name])
|
||
else:
|
||
push_cmd.extend(["origin", branch_name])
|
||
|
||
result = subprocess.run(
|
||
push_cmd,
|
||
capture_output=True,
|
||
encoding='utf-8',
|
||
cwd=repo_path
|
||
)
|
||
|
||
if result.returncode != 0:
|
||
error_msg = f"git push 失败: {result.stderr}"
|
||
logger.error(error_msg)
|
||
raise RuntimeError(error_msg)
|
||
|
||
logger.info(f"已推送到分支: {branch_name}")
|
||
|
||
return ToolResult(
|
||
success=True,
|
||
data={
|
||
"branch_name": branch_name,
|
||
"repo_path": repo_path,
|
||
"set_upstream": set_upstream
|
||
},
|
||
message=f"已推送到分支: {branch_name}"
|
||
)
|