Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 17 additions & 14 deletions backend/apps/db/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Date: 2025/7/16

from enum import Enum
from typing import List

from common.utils.utils import equals_ignore_case

Expand All @@ -15,26 +16,28 @@ def __init__(self, type_name):


class DB(Enum):
excel = ('excel', 'Excel/CSV', '"', '"', ConnectType.sqlalchemy, 'PostgreSQL')
redshift = ('redshift', 'AWS Redshift', '"', '"', ConnectType.py_driver, 'AWS_Redshift')
ck = ('ck', 'ClickHouse', '"', '"', ConnectType.sqlalchemy, 'ClickHouse')
dm = ('dm', '达梦', '"', '"', ConnectType.py_driver, 'DM')
doris = ('doris', 'Apache Doris', '`', '`', ConnectType.py_driver, 'Doris')
es = ('es', 'Elasticsearch', '"', '"', ConnectType.py_driver, 'Elasticsearch')
kingbase = ('kingbase', 'Kingbase', '"', '"', ConnectType.py_driver, 'Kingbase')
sqlServer = ('sqlServer', 'Microsoft SQL Server', '[', ']', ConnectType.sqlalchemy, 'Microsoft_SQL_Server')
mysql = ('mysql', 'MySQL', '`', '`', ConnectType.sqlalchemy, 'MySQL')
oracle = ('oracle', 'Oracle', '"', '"', ConnectType.sqlalchemy, 'Oracle')
pg = ('pg', 'PostgreSQL', '"', '"', ConnectType.sqlalchemy, 'PostgreSQL')
starrocks = ('starrocks', 'StarRocks', '`', '`', ConnectType.py_driver, 'StarRocks')

def __init__(self, type, db_name, prefix, suffix, connect_type: ConnectType, template_name: str):
excel = ('excel', 'Excel/CSV', '"', '"', ConnectType.sqlalchemy, 'PostgreSQL', [])
redshift = ('redshift', 'AWS Redshift', '"', '"', ConnectType.py_driver, 'AWS_Redshift', [])
ck = ('ck', 'ClickHouse', '"', '"', ConnectType.sqlalchemy, 'ClickHouse', [])
dm = ('dm', '达梦', '"', '"', ConnectType.py_driver, 'DM', [])
doris = ('doris', 'Apache Doris', '`', '`', ConnectType.py_driver, 'Doris', [])
es = ('es', 'Elasticsearch', '"', '"', ConnectType.py_driver, 'Elasticsearch', [])
kingbase = ('kingbase', 'Kingbase', '"', '"', ConnectType.py_driver, 'Kingbase', [])
sqlServer = ('sqlServer', 'Microsoft SQL Server', '[', ']', ConnectType.sqlalchemy, 'Microsoft_SQL_Server', [])
mysql = ('mysql', 'MySQL', '`', '`', ConnectType.sqlalchemy, 'MySQL', ['local_infile'])
oracle = ('oracle', 'Oracle', '"', '"', ConnectType.sqlalchemy, 'Oracle', [])
pg = ('pg', 'PostgreSQL', '"', '"', ConnectType.sqlalchemy, 'PostgreSQL', [])
starrocks = ('starrocks', 'StarRocks', '`', '`', ConnectType.py_driver, 'StarRocks', [])

def __init__(self, type, db_name, prefix, suffix, connect_type: ConnectType, template_name: str,
illegalParams: List[str]):
self.type = type
self.db_name = db_name
self.prefix = prefix
self.suffix = suffix
self.connect_type = connect_type
self.template_name = template_name
self.illegalParams = illegalParams

@classmethod
def get_db(cls, type, default_if_none=False):
Expand Down
11 changes: 10 additions & 1 deletion backend/apps/db/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import urllib.parse
from datetime import datetime, date, time, timedelta
from decimal import Decimal
from typing import Optional
from typing import Optional, List

import oracledb
import psycopg2
Expand Down Expand Up @@ -57,6 +57,7 @@ def get_uri(ds: CoreDatasource) -> str:
def get_uri_from_config(type: str, conf: DatasourceConf) -> str:
db_url: str
if equals_ignore_case(type, "mysql"):
checkParams(conf.extraJdbc, DB.mysql.illegalParams)
if conf.extraJdbc is not None and conf.extraJdbc != '':
db_url = f"mysql+pymysql://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{conf.database}?{conf.extraJdbc}"
else:
Expand Down Expand Up @@ -682,3 +683,11 @@ def check_sql_read(sql: str, ds: CoreDatasource | AssistantOutDsSchema):

except Exception as e:
raise ValueError(f"Parse SQL Error: {e}")


def checkParams(extraParams: str, illegalParams: List[str]):
kvs = extraParams.split('&')
for kv in kvs:
k, v = kv.split('=')
if k in illegalParams:
raise HTTPException(status_code=500, detail=f'Illegal Parameter: {k}')