ai-demo/app/tools/git/git_push_tool.py

143 lines
4.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Git 推送工具"""
import subprocess
import logging
from typing import Dict, Any
from app.tools.base import BaseTool, ToolResult
from app.tools.registry import ToolRegistry
logger = logging.getLogger(__name__)
@ToolRegistry.register
class GitPushTool(BaseTool):
"""推送代码到远程 Git 仓库"""
@property
def parameters_schema(self) -> Dict[str, Any]:
return {
"type": "object",
"properties": {
"repo_path": {
"type": "string",
"description": "Git 仓库路径"
},
"branch_name": {
"type": "string",
"description": "分支名称"
},
"set_upstream": {
"type": "boolean",
"description": "是否设置上游分支",
"default": True
},
"auth_type": {
"type": "string",
"description": "认证类型",
"default": "password",
"enum": ["password", "token", "ssh"]
},
"username": {
"type": "string",
"description": "Git 用户名",
"default": ""
},
"password": {
"type": "string",
"description": "Git 密码或 Token",
"default": ""
}
},
"required": ["repo_path", "branch_name"]
}
def execute(self,
repo_path: str,
branch_name: str,
set_upstream: bool = True,
auth_type: str = "password",
username: str = "",
password: str = "",
**kwargs) -> ToolResult:
"""
推送到远程仓库
Args:
repo_path: 仓库路径
branch_name: 分支名称
set_upstream: 是否设置上游分支
auth_type: 认证类型password/token/ssh
username: 用户名
password: 密码/Token
Returns:
ToolResult: 包含推送结果的工具返回对象
"""
# 参数验证
if not repo_path or not repo_path.strip():
error_msg = "仓库路径不能为空"
logger.error(error_msg)
raise ValueError(error_msg)
if not branch_name or not branch_name.strip():
error_msg = "分支名称不能为空"
logger.error(error_msg)
raise ValueError(error_msg)
# 如果提供了认证信息需要更新远程URL
if auth_type.lower() in ["password", "token"] and username and password:
# 获取当前远程URL
get_url_cmd = ["git", "remote", "get-url", "origin"]
result = subprocess.run(
get_url_cmd,
capture_output=True,
encoding='utf-8',
cwd=repo_path
)
if result.returncode == 0:
current_url = result.stdout.strip()
# 如果URL中没有认证信息添加认证信息
if "://" in current_url and "@" not in current_url:
protocol, rest = current_url.split("://", 1)
auth_url = f"{protocol}://{username}:{password}@{rest}"
# 更新远程URL
set_url_cmd = ["git", "remote", "set-url", "origin", auth_url]
subprocess.run(
set_url_cmd,
capture_output=True,
encoding='utf-8',
cwd=repo_path
)
logger.info("已更新远程URL添加认证信息")
# 执行 git push
push_cmd = ["git", "push"]
if set_upstream:
push_cmd.extend(["-u", "origin", branch_name])
else:
push_cmd.extend(["origin", branch_name])
result = subprocess.run(
push_cmd,
capture_output=True,
encoding='utf-8',
cwd=repo_path
)
if result.returncode != 0:
error_msg = f"git push 失败: {result.stderr}"
logger.error(error_msg)
raise RuntimeError(error_msg)
logger.info(f"已推送到分支: {branch_name}")
return ToolResult(
success=True,
data={
"branch_name": branch_name,
"repo_path": repo_path,
"set_upstream": set_upstream
},
message=f"已推送到分支: {branch_name}"
)