diff --git a/backend/apps/db/constant.py b/backend/apps/db/constant.py index 3b6a66c1..1509fcf4 100644 --- a/backend/apps/db/constant.py +++ b/backend/apps/db/constant.py @@ -2,6 +2,7 @@ # Date: 2025/7/16 from enum import Enum +from typing import List from common.utils.utils import equals_ignore_case @@ -15,26 +16,28 @@ def __init__(self, type_name): class DB(Enum): - excel = ('excel', 'Excel/CSV', '"', '"', ConnectType.sqlalchemy, 'PostgreSQL') - redshift = ('redshift', 'AWS Redshift', '"', '"', ConnectType.py_driver, 'AWS_Redshift') - ck = ('ck', 'ClickHouse', '"', '"', ConnectType.sqlalchemy, 'ClickHouse') - dm = ('dm', '达梦', '"', '"', ConnectType.py_driver, 'DM') - doris = ('doris', 'Apache Doris', '`', '`', ConnectType.py_driver, 'Doris') - es = ('es', 'Elasticsearch', '"', '"', ConnectType.py_driver, 'Elasticsearch') - kingbase = ('kingbase', 'Kingbase', '"', '"', ConnectType.py_driver, 'Kingbase') - sqlServer = ('sqlServer', 'Microsoft SQL Server', '[', ']', ConnectType.sqlalchemy, 'Microsoft_SQL_Server') - mysql = ('mysql', 'MySQL', '`', '`', ConnectType.sqlalchemy, 'MySQL') - oracle = ('oracle', 'Oracle', '"', '"', ConnectType.sqlalchemy, 'Oracle') - pg = ('pg', 'PostgreSQL', '"', '"', ConnectType.sqlalchemy, 'PostgreSQL') - starrocks = ('starrocks', 'StarRocks', '`', '`', ConnectType.py_driver, 'StarRocks') - - def __init__(self, type, db_name, prefix, suffix, connect_type: ConnectType, template_name: str): + excel = ('excel', 'Excel/CSV', '"', '"', ConnectType.sqlalchemy, 'PostgreSQL', []) + redshift = ('redshift', 'AWS Redshift', '"', '"', ConnectType.py_driver, 'AWS_Redshift', []) + ck = ('ck', 'ClickHouse', '"', '"', ConnectType.sqlalchemy, 'ClickHouse', []) + dm = ('dm', '达梦', '"', '"', ConnectType.py_driver, 'DM', []) + doris = ('doris', 'Apache Doris', '`', '`', ConnectType.py_driver, 'Doris', []) + es = ('es', 'Elasticsearch', '"', '"', ConnectType.py_driver, 'Elasticsearch', []) + kingbase = ('kingbase', 'Kingbase', '"', '"', ConnectType.py_driver, 'Kingbase', []) + sqlServer = ('sqlServer', 'Microsoft SQL Server', '[', ']', ConnectType.sqlalchemy, 'Microsoft_SQL_Server', []) + mysql = ('mysql', 'MySQL', '`', '`', ConnectType.sqlalchemy, 'MySQL', ['local_infile']) + oracle = ('oracle', 'Oracle', '"', '"', ConnectType.sqlalchemy, 'Oracle', []) + pg = ('pg', 'PostgreSQL', '"', '"', ConnectType.sqlalchemy, 'PostgreSQL', []) + starrocks = ('starrocks', 'StarRocks', '`', '`', ConnectType.py_driver, 'StarRocks', []) + + def __init__(self, type, db_name, prefix, suffix, connect_type: ConnectType, template_name: str, + illegalParams: List[str]): self.type = type self.db_name = db_name self.prefix = prefix self.suffix = suffix self.connect_type = connect_type self.template_name = template_name + self.illegalParams = illegalParams @classmethod def get_db(cls, type, default_if_none=False): diff --git a/backend/apps/db/db.py b/backend/apps/db/db.py index 28f43a68..5add5f81 100644 --- a/backend/apps/db/db.py +++ b/backend/apps/db/db.py @@ -5,7 +5,7 @@ import urllib.parse from datetime import datetime, date, time, timedelta from decimal import Decimal -from typing import Optional +from typing import Optional, List import oracledb import psycopg2 @@ -57,6 +57,7 @@ def get_uri(ds: CoreDatasource) -> str: def get_uri_from_config(type: str, conf: DatasourceConf) -> str: db_url: str if equals_ignore_case(type, "mysql"): + checkParams(conf.extraJdbc, DB.mysql.illegalParams) if conf.extraJdbc is not None and conf.extraJdbc != '': db_url = f"mysql+pymysql://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{conf.database}?{conf.extraJdbc}" else: @@ -682,3 +683,11 @@ def check_sql_read(sql: str, ds: CoreDatasource | AssistantOutDsSchema): except Exception as e: raise ValueError(f"Parse SQL Error: {e}") + + +def checkParams(extraParams: str, illegalParams: List[str]): + kvs = extraParams.split('&') + for kv in kvs: + k, v = kv.split('=') + if k in illegalParams: + raise HTTPException(status_code=500, detail=f'Illegal Parameter: {k}')