Añadiendo todos los archivos del proyecto (incluidos secretos y venv)
This commit is contained in:
27
venv/lib/python3.12/site-packages/mysql/__init__.py
Normal file
27
venv/lib/python3.12/site-packages/mysql/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
# Copyright (c) 2014, 2025, Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
Binary file not shown.
27
venv/lib/python3.12/site-packages/mysql/ai/__init__.py
Normal file
27
venv/lib/python3.12/site-packages/mysql/ai/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
# Copyright (c) 2025 Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
Binary file not shown.
43
venv/lib/python3.12/site-packages/mysql/ai/genai/__init__.py
Normal file
43
venv/lib/python3.12/site-packages/mysql/ai/genai/__init__.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# Copyright (c) 2025 Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""GenAI package for MySQL Connector/Python.
|
||||
|
||||
Performs optional dependency checks and exposes public classes:
|
||||
- MyEmbeddings
|
||||
- MyLLM
|
||||
- MyVectorStore
|
||||
"""
|
||||
from mysql.ai.utils import check_dependencies as _check_dependencies
|
||||
|
||||
_check_dependencies(["GENAI"])
|
||||
del _check_dependencies
|
||||
|
||||
from .embedding import MyEmbeddings
|
||||
from .generation import MyLLM
|
||||
from .vector_store import MyVectorStore
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
197
venv/lib/python3.12/site-packages/mysql/ai/genai/embedding.py
Normal file
197
venv/lib/python3.12/site-packages/mysql/ai/genai/embedding.py
Normal file
@@ -0,0 +1,197 @@
|
||||
# Copyright (c) 2025 Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""Embeddings integration utilities for MySQL Connector/Python.
|
||||
|
||||
Provides MyEmbeddings class to generate embeddings via MySQL HeatWave
|
||||
using ML_EMBED_TABLE and ML_EMBED_ROW.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from mysql.ai.utils import (
|
||||
atomic_transaction,
|
||||
execute_sql,
|
||||
format_value_sql,
|
||||
source_schema,
|
||||
sql_table_from_df,
|
||||
sql_table_to_df,
|
||||
temporary_sql_tables,
|
||||
)
|
||||
from mysql.connector.abstracts import MySQLConnectionAbstract
|
||||
|
||||
|
||||
class MyEmbeddings(Embeddings):
|
||||
"""
|
||||
Embedding generator class that uses a MySQL database to compute embeddings for input text.
|
||||
|
||||
This class batches input text into temporary SQL tables, invokes MySQL's ML_EMBED_TABLE
|
||||
to generate embeddings, and retrieves the results as lists of floats.
|
||||
|
||||
Attributes:
|
||||
_db_connection (MySQLConnectionAbstract): MySQL connection used for all database operations.
|
||||
schema_name (str): Name of the database schema to use.
|
||||
options_placeholder (str): SQL-ready placeholder string for ML_EMBED_TABLE options.
|
||||
options_params (dict): Dictionary of concrete option values to be passed as SQL parameters.
|
||||
"""
|
||||
|
||||
_db_connection: MySQLConnectionAbstract = PrivateAttr()
|
||||
|
||||
def __init__(
|
||||
self, db_connection: MySQLConnectionAbstract, options: Optional[Dict] = None
|
||||
):
|
||||
"""
|
||||
Initialize MyEmbeddings with a database connection and optional embedding parameters.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwgenai-ml-embed-row.html
|
||||
A full list of supported options can be found under "options"
|
||||
|
||||
NOTE: The supported "options" are the intersection of the options provided in
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwgenai-ml-embed-row.html
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwgenai-ml-embed-table.html
|
||||
|
||||
Args:
|
||||
db_connection: Active MySQL connector database connection.
|
||||
options: Optional dictionary of options for embedding operations.
|
||||
|
||||
Raises:
|
||||
ValueError: If the schema name is not valid
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
super().__init__()
|
||||
self._db_connection = db_connection
|
||||
self.schema_name = source_schema(db_connection)
|
||||
options = options or {}
|
||||
self.options_placeholder, self.options_params = format_value_sql(options)
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""
|
||||
Generate embeddings for a list of input texts using the MySQL ML embedding procedure.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwgenai-ml-embed-table.html
|
||||
|
||||
Args:
|
||||
texts: List of input strings to embed.
|
||||
|
||||
Returns:
|
||||
List of lists of floats, with each inner list containing the embedding for a text.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported.
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
ValueError:
|
||||
If one or more text entries were unable to be embedded.
|
||||
|
||||
Implementation notes:
|
||||
- Creates a temporary table to pass input text to the MySQL embedding service.
|
||||
- Adds a primary key to ensure results preserve input order.
|
||||
- Calls ML_EMBED_TABLE and fetches the resulting embeddings.
|
||||
- Deletes the temporary table after use to avoid polluting the database.
|
||||
- Embedding vectors are extracted from the "embeddings" column of the result table.
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
df = pd.DataFrame({"id": range(len(texts)), "text": texts})
|
||||
|
||||
with (
|
||||
atomic_transaction(self._db_connection) as cursor,
|
||||
temporary_sql_tables(self._db_connection) as temporary_tables,
|
||||
):
|
||||
qualified_table_name, table_name = sql_table_from_df(
|
||||
cursor, self.schema_name, df
|
||||
)
|
||||
temporary_tables.append((self.schema_name, table_name))
|
||||
|
||||
# ML_EMBED_TABLE expects input/output columns and options as parameters
|
||||
embed_query = (
|
||||
"CALL sys.ML_EMBED_TABLE("
|
||||
f"'{qualified_table_name}.text', "
|
||||
f"'{qualified_table_name}.embeddings', "
|
||||
f"{self.options_placeholder}"
|
||||
")"
|
||||
)
|
||||
execute_sql(cursor, embed_query, params=self.options_params)
|
||||
|
||||
# Read back all columns, including "embeddings"
|
||||
df_embeddings = sql_table_to_df(cursor, self.schema_name, table_name)
|
||||
|
||||
if df_embeddings["embeddings"].isnull().any() or any(
|
||||
e is None for e in df_embeddings["embeddings"]
|
||||
):
|
||||
raise ValueError(
|
||||
"Failure to generate embeddings for one or more text entry."
|
||||
)
|
||||
|
||||
# Convert fetched embeddings to lists of floats
|
||||
embeddings = df_embeddings["embeddings"].tolist()
|
||||
embeddings = [list(e) for e in embeddings]
|
||||
|
||||
return embeddings
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""
|
||||
Generate an embedding for a single text string.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwgenai-ml-embed-row.html
|
||||
|
||||
Args:
|
||||
text: The input string to embed.
|
||||
|
||||
Returns:
|
||||
List of floats representing the embedding vector.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported.
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Example:
|
||||
>>> MyEmbeddings(db_conn).embed_query("Hello world")
|
||||
[0.1, 0.2, ...]
|
||||
"""
|
||||
with atomic_transaction(self._db_connection) as cursor:
|
||||
execute_sql(
|
||||
cursor,
|
||||
f'SELECT sys.ML_EMBED_ROW("%s", {self.options_placeholder})',
|
||||
params=(text, *self.options_params),
|
||||
)
|
||||
return list(cursor.fetchone()[0])
|
||||
162
venv/lib/python3.12/site-packages/mysql/ai/genai/generation.py
Normal file
162
venv/lib/python3.12/site-packages/mysql/ai/genai/generation.py
Normal file
@@ -0,0 +1,162 @@
|
||||
# Copyright (c) 2025 Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""GenAI LLM integration utilities for MySQL Connector/Python.
|
||||
|
||||
Provides MyLLM wrapper that issues ML_GENERATE calls via SQL.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
from typing import Any, List, Optional
|
||||
|
||||
try:
|
||||
from langchain_core.language_models.llms import LLM
|
||||
except ImportError:
|
||||
from langchain.llms.base import LLM
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from mysql.ai.utils import atomic_transaction, execute_sql, format_value_sql
|
||||
from mysql.connector.abstracts import MySQLConnectionAbstract
|
||||
|
||||
|
||||
class MyLLM(LLM):
|
||||
"""
|
||||
Custom Large Language Model (LLM) interface for MySQL HeatWave.
|
||||
|
||||
This class wraps the generation functionality provided by HeatWave LLMs,
|
||||
exposing an interface compatible with common LLM APIs for text generation.
|
||||
It provides full support for generative queries and limited support for
|
||||
agentic queries.
|
||||
|
||||
Attributes:
|
||||
_db_connection (MySQLConnectionAbstract):
|
||||
Underlying MySQL connector database connection.
|
||||
"""
|
||||
|
||||
_db_connection: MySQLConnectionAbstract = PrivateAttr()
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Pydantic config for the model.
|
||||
|
||||
By default, LangChain (through Pydantic BaseModel) does not allow
|
||||
setting or storing undeclared attributes such as _db_connection.
|
||||
Setting extra = "allow" makes it possible to store extra attributes
|
||||
on the class instance, which is required for MyLLM.
|
||||
"""
|
||||
|
||||
extra = "allow"
|
||||
|
||||
def __init__(self, db_connection: MySQLConnectionAbstract):
|
||||
"""
|
||||
Initialize the MyLLM instance with an active MySQL database connection.
|
||||
|
||||
Args:
|
||||
db_connection: A MySQL connection object used to run LLM queries.
|
||||
|
||||
Notes:
|
||||
The db_connection is stored as a private attribute via object.__setattr__,
|
||||
which is compatible with Pydantic models.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self._db_connection = db_connection
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a text completion from the LLM for a given input prompt.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwgenai-ml-generate.html
|
||||
A full list of supported options (specified by kwargs) can be found under "options"
|
||||
|
||||
Args:
|
||||
prompt: The input prompt string for the language model.
|
||||
stop: Optional list of stop strings to support agentic and chain-of-thought
|
||||
reasoning workflows.
|
||||
**kwargs: Additional keyword arguments providing generation options to
|
||||
the LLM (these are serialized to JSON and passed to the HeatWave syscall).
|
||||
|
||||
Returns:
|
||||
The generated model output as a string.
|
||||
(The actual completion does NOT include the input prompt.)
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported.
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Implementation Notes:
|
||||
- Serializes kwargs into a SQL-compatible JSON string.
|
||||
- Calls the LLM stored procedure using a database cursor context.
|
||||
- Uses `sys.ML_GENERATE` on the server to produce the model output.
|
||||
- Expects the server response to be a JSON object with a 'text' key.
|
||||
"""
|
||||
options = kwargs.copy()
|
||||
if stop is not None:
|
||||
options["stop_sequences"] = stop
|
||||
|
||||
options_placeholder, options_params = format_value_sql(options)
|
||||
with atomic_transaction(self._db_connection) as cursor:
|
||||
# The prompt is passed as a parameterized argument (avoids SQL injection).
|
||||
generate_query = f"""SELECT sys.ML_GENERATE("%s", {options_placeholder});"""
|
||||
execute_sql(cursor, generate_query, params=(prompt, *options_params))
|
||||
# Expect a JSON-encoded result from MySQL; parse to extract the output.
|
||||
llm_response = json.loads(cursor.fetchone()[0])["text"]
|
||||
|
||||
return llm_response
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> dict:
|
||||
"""
|
||||
Return a dictionary of params that uniquely identify this LLM instance.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary of identifier parameters (should include
|
||||
model_name for tracing/caching).
|
||||
"""
|
||||
return {
|
||||
"model_name": "mysql_heatwave_llm",
|
||||
}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""
|
||||
Get the type name of this LLM implementation.
|
||||
|
||||
Returns:
|
||||
A string identifying the LLM provider (used for logging or metrics).
|
||||
"""
|
||||
return "mysql_heatwave_llm"
|
||||
520
venv/lib/python3.12/site-packages/mysql/ai/genai/vector_store.py
Normal file
520
venv/lib/python3.12/site-packages/mysql/ai/genai/vector_store.py
Normal file
@@ -0,0 +1,520 @@
|
||||
# Copyright (c) 2025 Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""MySQL-backed vector store for embeddings and semantic document retrieval.
|
||||
|
||||
Provides a VectorStore implementation persisting documents, metadata, and
|
||||
embeddings in MySQL, plus similarity search utilities.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
from typing import Any, Iterable, List, Optional, Sequence, Union
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from mysql.ai.genai.embedding import MyEmbeddings
|
||||
from mysql.ai.utils import (
|
||||
VAR_NAME_SPACE,
|
||||
atomic_transaction,
|
||||
delete_sql_table,
|
||||
execute_sql,
|
||||
extend_sql_table,
|
||||
format_value_sql,
|
||||
get_random_name,
|
||||
is_table_empty,
|
||||
source_schema,
|
||||
table_exists,
|
||||
)
|
||||
from mysql.connector.abstracts import MySQLConnectionAbstract
|
||||
|
||||
BASIC_EMBEDDING_QUERY = "Hello world!"
|
||||
EMBEDDING_SOURCE = "external_source"
|
||||
|
||||
VAR_EMBEDDING = f"{VAR_NAME_SPACE}.embedding"
|
||||
VAR_CONTEXT = f"{VAR_NAME_SPACE}.context"
|
||||
VAR_CONTEXT_MAP = f"{VAR_NAME_SPACE}.context_map"
|
||||
VAR_RETRIEVAL_INFO = f"{VAR_NAME_SPACE}.retrieval_info"
|
||||
VAR_OPTIONS = f"{VAR_NAME_SPACE}.options"
|
||||
|
||||
ID_SPACE = "internal_ai_id_"
|
||||
|
||||
|
||||
class MyVectorStore(VectorStore):
|
||||
"""
|
||||
MySQL-backed vector store for handling embeddings and semantic document retrieval.
|
||||
|
||||
Supports adding, deleting, and searching high-dimensional vector representations
|
||||
of documents using efficient storage and HeatWave ML similarity search procedures.
|
||||
|
||||
Supports use as a context manager: when used in a `with` statement, all backing
|
||||
tables/data are deleted automatically when the block exits (even on exception).
|
||||
|
||||
Attributes:
|
||||
db_connection (MySQLConnectionAbstract): Active MySQL database connection.
|
||||
embedder (Embeddings): Embeddings generator for computing vector representations.
|
||||
schema_name (str): SQL schema for table storage.
|
||||
table_name (Optional[str]): Name of the active table backing the store
|
||||
(or None until created).
|
||||
embedding_dimension (int): Size of embedding vectors stored.
|
||||
next_id (int): Internal counter for unique document ID generation.
|
||||
"""
|
||||
|
||||
_db_connection: MySQLConnectionAbstract = PrivateAttr()
|
||||
_embedder: Embeddings = PrivateAttr()
|
||||
_schema_name: str = PrivateAttr()
|
||||
_table_name: Optional[str] = PrivateAttr()
|
||||
_embedding_dimension: int = PrivateAttr()
|
||||
_next_id: int = PrivateAttr()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_connection: MySQLConnectionAbstract,
|
||||
embedder: Optional[Embeddings] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize a MyVectorStore with a database connection and embedding generator.
|
||||
|
||||
Args:
|
||||
db_connection: MySQL database connection for all vector operations.
|
||||
embedder: Embeddings generator used for creating and querying embeddings.
|
||||
|
||||
Raises:
|
||||
ValueError: If the schema name is not valid
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
super().__init__()
|
||||
self._next_id = 0
|
||||
|
||||
self._schema_name = source_schema(db_connection)
|
||||
self._embedder = embedder or MyEmbeddings(db_connection)
|
||||
self._db_connection = db_connection
|
||||
self._table_name: Optional[str] = None
|
||||
|
||||
# Embedding dimension determined using an example call.
|
||||
# Assumes embeddings have fixed length.
|
||||
self._embedding_dimension = len(
|
||||
self._embedder.embed_query(BASIC_EMBEDDING_QUERY)
|
||||
)
|
||||
|
||||
def _get_ids(self, num_ids: int) -> list[str]:
|
||||
"""
|
||||
Generate a batch of unique internal document IDs for vector storage.
|
||||
|
||||
Args:
|
||||
num_ids: Number of IDs to create.
|
||||
|
||||
Returns:
|
||||
List of sequentially numbered internal string IDs.
|
||||
"""
|
||||
ids = [
|
||||
f"internal_ai_id_{i}" for i in range(self._next_id, self._next_id + num_ids)
|
||||
]
|
||||
self._next_id += num_ids
|
||||
return ids
|
||||
|
||||
def _make_vector_store(self) -> None:
|
||||
"""
|
||||
Create a backing SQL table for storing vectors if not already created.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Notes:
|
||||
The table name is randomized to avoid collisions.
|
||||
Schema includes content, metadata, and embedding vector.
|
||||
"""
|
||||
if self._table_name is None:
|
||||
with atomic_transaction(self._db_connection) as cursor:
|
||||
table_name = get_random_name(
|
||||
lambda table_name: not table_exists(
|
||||
cursor, self._schema_name, table_name
|
||||
)
|
||||
)
|
||||
|
||||
create_table_stmt = f"""
|
||||
CREATE TABLE {self._schema_name}.{table_name} (
|
||||
`id` VARCHAR(128) NOT NULL,
|
||||
`content` TEXT,
|
||||
`metadata` JSON DEFAULT NULL,
|
||||
`embed` vector(%s),
|
||||
PRIMARY KEY (`id`)
|
||||
) ENGINE=InnoDB;
|
||||
"""
|
||||
execute_sql(
|
||||
cursor, create_table_stmt, params=(self._embedding_dimension,)
|
||||
)
|
||||
|
||||
self._table_name = table_name
|
||||
|
||||
def delete(self, ids: Optional[Sequence[str]] = None, **_: Any) -> None:
|
||||
"""
|
||||
Delete documents by ID. Optionally deletes the vector table if empty after deletions.
|
||||
|
||||
Args:
|
||||
ids: Optional sequence of document IDs to delete. If None, no action is taken.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Notes:
|
||||
If the backing table is empty after deletions, the table is dropped and
|
||||
table_name is set to None.
|
||||
"""
|
||||
with atomic_transaction(self._db_connection) as cursor:
|
||||
if ids:
|
||||
for _id in ids:
|
||||
execute_sql(
|
||||
cursor,
|
||||
f"DELETE FROM {self._schema_name}.{self._table_name} WHERE id = %s",
|
||||
params=(_id,),
|
||||
)
|
||||
|
||||
if is_table_empty(cursor, self._schema_name, self._table_name):
|
||||
self.delete_all()
|
||||
|
||||
def delete_all(self) -> None:
|
||||
"""
|
||||
Delete and drop the entire vector store table.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if self._table_name is not None:
|
||||
with atomic_transaction(self._db_connection) as cursor:
|
||||
delete_sql_table(cursor, self._schema_name, self._table_name)
|
||||
self._table_name = None
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[list[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
**_: dict,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Add a batch of text strings and corresponding metadata to the vector store.
|
||||
|
||||
Args:
|
||||
texts: List of strings to embed and store.
|
||||
metadatas: Optional list of metadata dicts (one per text).
|
||||
ids: Optional custom document IDs.
|
||||
|
||||
Returns:
|
||||
List of document IDs corresponding to the added texts.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Notes:
|
||||
If metadatas is None, an empty dict is assigned to each document.
|
||||
"""
|
||||
texts = list(texts)
|
||||
|
||||
documents = [
|
||||
Document(page_content=text, metadata=meta)
|
||||
for text, meta in zip(texts, metadatas or [{}] * len(texts))
|
||||
]
|
||||
return self.add_documents(documents, ids=ids)
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls,
|
||||
texts: Iterable[str],
|
||||
embedder: Embeddings,
|
||||
metadatas: Optional[list[dict]] = None,
|
||||
db_connection: MySQLConnectionAbstract = None,
|
||||
) -> VectorStore:
|
||||
"""
|
||||
Construct and populate a MyVectorStore instance from raw texts and metadata.
|
||||
|
||||
Args:
|
||||
texts: List of strings to vectorize and store.
|
||||
embedder: Embeddings generator to use.
|
||||
metadatas: Optional list of metadata dicts per text.
|
||||
db_connection: Active MySQL connection.
|
||||
|
||||
Returns:
|
||||
Instance of MyVectorStore containing the added texts.
|
||||
|
||||
Raises:
|
||||
ValueError: If db_connection is not provided.
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
if db_connection is None:
|
||||
raise ValueError(
|
||||
"db_connection must be specified to create a MyVectorStore object"
|
||||
)
|
||||
|
||||
texts = list(texts)
|
||||
|
||||
instance = cls(db_connection=db_connection, embedder=embedder)
|
||||
instance.add_texts(texts, metadatas=metadatas)
|
||||
|
||||
return instance
|
||||
|
||||
def add_documents(
|
||||
self, documents: list[Document], ids: list[str] = None
|
||||
) -> list[str]:
|
||||
"""
|
||||
Embed and store Document objects as high-dimensional vectors with metadata.
|
||||
|
||||
Args:
|
||||
documents: List of Document objects (each with 'page_content' and 'metadata').
|
||||
ids: Optional list of explicit document IDs. Must match the length of documents.
|
||||
|
||||
Returns:
|
||||
List of document IDs stored.
|
||||
|
||||
Raises:
|
||||
ValueError: If provided IDs do not match the number of documents.
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Notes:
|
||||
Automatically creates the backing table if it does not exist.
|
||||
"""
|
||||
if ids and len(ids) != len(documents):
|
||||
msg = (
|
||||
"ids must be the same length as documents. "
|
||||
f"Got {len(ids)} ids and {len(documents)} documents."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
if len(documents) > 0:
|
||||
self._make_vector_store()
|
||||
else:
|
||||
return []
|
||||
|
||||
if ids is None:
|
||||
ids = self._get_ids(len(documents))
|
||||
|
||||
content = [doc.page_content for doc in documents]
|
||||
vectors = self._embedder.embed_documents(content)
|
||||
|
||||
df = pd.DataFrame()
|
||||
df["id"] = ids
|
||||
df["content"] = content
|
||||
df["embed"] = vectors
|
||||
df["metadata"] = [doc.metadata for doc in documents]
|
||||
|
||||
with atomic_transaction(self._db_connection) as cursor:
|
||||
extend_sql_table(
|
||||
cursor,
|
||||
self._schema_name,
|
||||
self._table_name,
|
||||
df,
|
||||
col_name_to_placeholder_string={"embed": "string_to_vector(%s)"},
|
||||
)
|
||||
|
||||
return ids
|
||||
|
||||
def similarity_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 3,
|
||||
**kwargs: Any,
|
||||
) -> list[Document]:
|
||||
"""
|
||||
Search for and return the most similar documents in the store to the given query.
|
||||
|
||||
Args:
|
||||
query: String query to embed and use for similarity search.
|
||||
k: Number of top documents to return.
|
||||
kwargs: options to pass to ML_SIMILARITY_SEARCH. Currently supports
|
||||
distance_metric, max_distance, percentage_distance, and segment_overlap
|
||||
|
||||
Returns:
|
||||
List of Document objects, ordered from most to least similar.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided kwargs are invalid or unsupported.
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Implementation Notes:
|
||||
- Calls ML similarity search within MySQL using stored procedures.
|
||||
- Retrieves IDs, content, and metadata for search matches.
|
||||
- Parsing and retrieval for context results are handled via intermediate JSONs.
|
||||
"""
|
||||
if self._table_name is None:
|
||||
return []
|
||||
|
||||
embedding = self._embedder.embed_query(query)
|
||||
|
||||
with atomic_transaction(self._db_connection) as cursor:
|
||||
# Set the embedding variable for the similarity search SP
|
||||
execute_sql(
|
||||
cursor,
|
||||
f"SET @{VAR_EMBEDDING} = string_to_vector(%s)",
|
||||
params=[str(embedding)],
|
||||
)
|
||||
|
||||
distance_metric = kwargs.get("distance_metric", "COSINE")
|
||||
retrieval_options = {
|
||||
"max_distance": kwargs.get("max_distance", 0.6),
|
||||
"percentage_distance": kwargs.get("percentage_distance", 20.0),
|
||||
"segment_overlap": kwargs.get("segment_overlap", 0),
|
||||
}
|
||||
|
||||
retrieval_options_placeholder, retrieval_options_params = format_value_sql(
|
||||
retrieval_options
|
||||
)
|
||||
similarity_search_query = f"""
|
||||
CALL sys.ML_SIMILARITY_SEARCH(
|
||||
@{VAR_EMBEDDING},
|
||||
JSON_ARRAY(
|
||||
'{self._schema_name}.{self._table_name}'
|
||||
),
|
||||
JSON_OBJECT(
|
||||
"segment", "content",
|
||||
"segment_embedding", "embed",
|
||||
"document_name", "id"
|
||||
),
|
||||
{k},
|
||||
%s,
|
||||
NULL,
|
||||
NULL,
|
||||
{retrieval_options_placeholder},
|
||||
@{VAR_CONTEXT},
|
||||
@{VAR_CONTEXT_MAP},
|
||||
@{VAR_RETRIEVAL_INFO}
|
||||
)
|
||||
"""
|
||||
|
||||
execute_sql(
|
||||
cursor,
|
||||
similarity_search_query,
|
||||
params=[distance_metric, *retrieval_options_params],
|
||||
)
|
||||
execute_sql(cursor, f"SELECT @{VAR_CONTEXT_MAP}")
|
||||
|
||||
results = []
|
||||
|
||||
context_maps = json.loads(cursor.fetchone()[0])
|
||||
for context in context_maps:
|
||||
execute_sql(
|
||||
cursor,
|
||||
(
|
||||
"SELECT id, content, metadata "
|
||||
f"FROM {self._schema_name}.{self._table_name} "
|
||||
"WHERE id = %s"
|
||||
),
|
||||
params=(context["document_name"],),
|
||||
)
|
||||
doc_id, content, metadata = cursor.fetchone()
|
||||
|
||||
doc_args = {
|
||||
"id": doc_id,
|
||||
"page_content": content,
|
||||
}
|
||||
if metadata is not None:
|
||||
doc_args["metadata"] = json.loads(metadata)
|
||||
|
||||
doc = Document(**doc_args)
|
||||
results.append(doc)
|
||||
|
||||
return results
|
||||
|
||||
def __enter__(self) -> "VectorStore":
|
||||
"""
|
||||
Enter the runtime context related to this vector store instance.
|
||||
|
||||
Returns:
|
||||
The current MyVectorStore object, allowing use within a `with` statement block.
|
||||
|
||||
Usage Notes:
|
||||
- Intended for use in a `with` statement to ensure automatic
|
||||
cleanup of resources.
|
||||
- No special initialization occurs during context entry, but enables
|
||||
proper context-managed lifecycle.
|
||||
|
||||
Example:
|
||||
with MyVectorStore(db_connection, embedder) as vectorstore:
|
||||
vectorstore.add_texts([...])
|
||||
# Vector store is active within this block.
|
||||
# All storage and resources are now cleaned up.
|
||||
"""
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Union[type, None],
|
||||
exc_val: Union[BaseException, None],
|
||||
exc_tb: Union[object, None],
|
||||
) -> None:
|
||||
"""
|
||||
Exit the runtime context for the vector store, ensuring all storage
|
||||
resources are cleaned up.
|
||||
|
||||
Args:
|
||||
exc_type: The exception type, if any exception occurred in the context block.
|
||||
exc_val: The exception value, if any exception occurred in the context block.
|
||||
exc_tb: The traceback object, if any exception occurred in the context block.
|
||||
|
||||
Returns:
|
||||
None: Indicates that exceptions are never suppressed; they will propagate as normal.
|
||||
|
||||
Implementation Notes:
|
||||
- Automatically deletes all vector store data and backing tables via `delete_all()`
|
||||
upon exiting the context.
|
||||
- This cleanup occurs whether the block exits normally or due to an exception.
|
||||
- Does not suppress exceptions; errors in the context block will continue to propagate.
|
||||
- Use when the vector store lifecycle is intended to be temporary or scoped.
|
||||
|
||||
Example:
|
||||
with MyVectorStore(db_connection, embedder) as vectorstore:
|
||||
vectorstore.add_texts([...])
|
||||
# Vector store is active within this block.
|
||||
# All storage and resources are now cleaned up.
|
||||
"""
|
||||
self.delete_all()
|
||||
# No return, so exceptions are never suppressed
|
||||
48
venv/lib/python3.12/site-packages/mysql/ai/ml/__init__.py
Normal file
48
venv/lib/python3.12/site-packages/mysql/ai/ml/__init__.py
Normal file
@@ -0,0 +1,48 @@
|
||||
# Copyright (c) 2025 Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""ML package for MySQL Connector/Python.
|
||||
|
||||
Performs optional dependency checks and exposes ML utilities:
|
||||
- ML_TASK, MyModel
|
||||
- MyClassifier, MyRegressor, MyGenericTransformer
|
||||
- MyAnomalyDetector
|
||||
"""
|
||||
from mysql.ai.utils import check_dependencies as _check_dependencies
|
||||
|
||||
_check_dependencies(["ML"])
|
||||
del _check_dependencies
|
||||
|
||||
# Sklearn models
|
||||
from .classifier import MyClassifier
|
||||
|
||||
# Minimal interface
|
||||
from .model import ML_TASK, MyModel
|
||||
from .outlier import MyAnomalyDetector
|
||||
from .regressor import MyRegressor
|
||||
from .transformer import MyGenericTransformer
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
142
venv/lib/python3.12/site-packages/mysql/ai/ml/base.py
Normal file
142
venv/lib/python3.12/site-packages/mysql/ai/ml/base.py
Normal file
@@ -0,0 +1,142 @@
|
||||
# Copyright (c) 2025 Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""Base classes for MySQL HeatWave ML estimators for Connector/Python.
|
||||
|
||||
Implements a scikit-learn-compatible base estimator wrapping server-side ML.
|
||||
"""
|
||||
from typing import Optional, Union
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from sklearn.base import BaseEstimator
|
||||
|
||||
from mysql.connector.abstracts import MySQLConnectionAbstract
|
||||
from mysql.ai.ml.model import ML_TASK, MyModel
|
||||
from mysql.ai.utils import copy_dict
|
||||
|
||||
|
||||
class MyBaseMLModel(BaseEstimator):
|
||||
"""
|
||||
Base class for MySQL HeatWave machine learning estimators.
|
||||
|
||||
Implements the scikit-learn API and core model management logic,
|
||||
including fit, explain, serialization, and dynamic option handling.
|
||||
For use as a base class by classifiers, regressors, transformers, and outlier models.
|
||||
|
||||
Args:
|
||||
db_connection (MySQLConnectionAbstract): An active MySQL connector database connection.
|
||||
task (str): ML task type, e.g. "classification" or "regression".
|
||||
model_name (str, optional): Custom name for the deployed model.
|
||||
fit_extra_options (dict, optional): Extra options for fitting.
|
||||
|
||||
Attributes:
|
||||
_model: Underlying database helper for fit/predict/explain.
|
||||
fit_extra_options: User-provided options for fitting.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_connection: MySQLConnectionAbstract,
|
||||
task: Union[str, ML_TASK],
|
||||
model_name: Optional[str] = None,
|
||||
fit_extra_options: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Initialize a MyBaseMLModel with connection, task, and option parameters.
|
||||
|
||||
Args:
|
||||
db_connection: Active MySQL connector database connection.
|
||||
task: String label of ML task (e.g. "classification").
|
||||
model_name: Optional custom model name.
|
||||
fit_extra_options: Optional extra fit options.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
self._model = MyModel(db_connection, task=task, model_name=model_name)
|
||||
self.fit_extra_options = copy_dict(fit_extra_options)
|
||||
|
||||
def fit(
|
||||
self,
|
||||
X: pd.DataFrame, # pylint: disable=invalid-name
|
||||
y: Optional[pd.DataFrame] = None,
|
||||
) -> "MyBaseMLModel":
|
||||
"""
|
||||
Fit the underlying ML model using pandas DataFrames.
|
||||
Delegates to MyMLModelPandasHelper.fit.
|
||||
|
||||
Args:
|
||||
X: Features DataFrame.
|
||||
y: (Optional) Target labels DataFrame or Series.
|
||||
|
||||
Returns:
|
||||
self
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported.
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Notes:
|
||||
Additional temp SQL resources may be created and cleaned up during the operation.
|
||||
"""
|
||||
self._model.fit(X, y, self.fit_extra_options)
|
||||
return self
|
||||
|
||||
def _delete_model(self) -> bool:
|
||||
"""
|
||||
Deletes the model from the model catalog if present
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Returns:
|
||||
Whether the model was deleted
|
||||
"""
|
||||
return self._model._delete_model()
|
||||
|
||||
def get_model_info(self) -> Optional[dict]:
|
||||
"""
|
||||
Checks if the model name is available. Model info will only be present in the
|
||||
catalog if the model has previously been fitted.
|
||||
|
||||
Returns:
|
||||
True if the model name is not part of the model catalog
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
return self._model.get_model_info()
|
||||
194
venv/lib/python3.12/site-packages/mysql/ai/ml/classifier.py
Normal file
194
venv/lib/python3.12/site-packages/mysql/ai/ml/classifier.py
Normal file
@@ -0,0 +1,194 @@
|
||||
# Copyright (c) 2025 Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""Classifier utilities for MySQL Connector/Python.
|
||||
|
||||
Provides a scikit-learn compatible classifier backed by HeatWave ML.
|
||||
"""
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.base import ClassifierMixin
|
||||
|
||||
from mysql.ai.ml.base import MyBaseMLModel
|
||||
from mysql.ai.ml.model import ML_TASK
|
||||
from mysql.ai.utils import copy_dict
|
||||
|
||||
from mysql.connector.abstracts import MySQLConnectionAbstract
|
||||
|
||||
|
||||
class MyClassifier(MyBaseMLModel, ClassifierMixin):
|
||||
"""
|
||||
MySQL HeatWave scikit-learn compatible classifier estimator.
|
||||
|
||||
Provides prediction and probability output from a model deployed in MySQL,
|
||||
and manages fit, explain, and prediction options as per HeatWave ML interface.
|
||||
|
||||
Attributes:
|
||||
predict_extra_options (dict): Dictionary of optional parameters passed through
|
||||
to the MySQL backend for prediction and probability inference.
|
||||
_model (MyModel): Underlying interface for database model operations.
|
||||
fit_extra_options (dict): See MyBaseMLModel.
|
||||
|
||||
Args:
|
||||
db_connection (MySQLConnectionAbstract): Active MySQL connector DB connection.
|
||||
model_name (str, optional): Custom name for the model.
|
||||
fit_extra_options (dict, optional): Extra options for fitting.
|
||||
explain_extra_options (dict, optional): Extra options for explanations.
|
||||
predict_extra_options (dict, optional): Extra options for predict/predict_proba.
|
||||
|
||||
Methods:
|
||||
predict(X): Predict class labels.
|
||||
predict_proba(X): Predict class probabilities.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_connection: MySQLConnectionAbstract,
|
||||
model_name: Optional[str] = None,
|
||||
fit_extra_options: Optional[dict] = None,
|
||||
explain_extra_options: Optional[dict] = None,
|
||||
predict_extra_options: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Initialize a MyClassifier.
|
||||
|
||||
Args:
|
||||
db_connection: Active MySQL connector database connection.
|
||||
model_name: Optional, custom model name.
|
||||
fit_extra_options: Optional fit options.
|
||||
explain_extra_options: Optional explain options.
|
||||
predict_extra_options: Optional predict/predict_proba options.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
MyBaseMLModel.__init__(
|
||||
self,
|
||||
db_connection,
|
||||
ML_TASK.CLASSIFICATION,
|
||||
model_name=model_name,
|
||||
fit_extra_options=fit_extra_options,
|
||||
)
|
||||
self.predict_extra_options = copy_dict(predict_extra_options)
|
||||
self.explain_extra_options = copy_dict(explain_extra_options)
|
||||
|
||||
def predict(
|
||||
self, X: Union[pd.DataFrame, np.ndarray]
|
||||
) -> np.ndarray: # pylint: disable=invalid-name
|
||||
"""
|
||||
Predict class labels for the input features using the MySQL model.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-predict-table.html
|
||||
A full list of supported options can be found under "ML_PREDICT_TABLE Options"
|
||||
|
||||
Args:
|
||||
X: Input samples as a numpy array or pandas DataFrame.
|
||||
|
||||
Returns:
|
||||
ndarray: Array of predicted class labels, shape (n_samples,).
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
result = self._model.predict(X, options=self.predict_extra_options)
|
||||
return result["Prediction"].to_numpy()
|
||||
|
||||
def predict_proba(
|
||||
self, X: Union[pd.DataFrame, np.ndarray]
|
||||
) -> np.ndarray: # pylint: disable=invalid-name
|
||||
"""
|
||||
Predict class probabilities for the input features using the MySQL model.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-predict-table.html
|
||||
A full list of supported options can be found under "ML_PREDICT_TABLE Options"
|
||||
|
||||
Args:
|
||||
X: Input samples as a numpy array or pandas DataFrame.
|
||||
|
||||
Returns:
|
||||
ndarray: Array of shape (n_samples, n_classes) with class probabilities.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
result = self._model.predict(X, options=self.predict_extra_options)
|
||||
|
||||
classes = sorted(result["ml_results"].iloc[0]["probabilities"].keys())
|
||||
|
||||
return np.stack(
|
||||
result["ml_results"].map(
|
||||
lambda ml_result: [
|
||||
ml_result["probabilities"][class_name] for class_name in classes
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
def explain_predictions(
|
||||
self, X: Union[pd.DataFrame, np.ndarray]
|
||||
) -> pd.DataFrame: # pylint: disable=invalid-name
|
||||
"""
|
||||
Explain model predictions using provided data.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-explain-table.html
|
||||
A full list of supported options can be found under "ML_EXPLAIN_TABLE Options"
|
||||
|
||||
Args:
|
||||
X: DataFrame for which predictions should be explained.
|
||||
|
||||
Returns:
|
||||
DataFrame containing explanation details (feature attributions, etc.)
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Notes:
|
||||
Temporary input/output tables are cleaned up after explanation.
|
||||
"""
|
||||
self._model.explain_predictions(X, options=self.explain_extra_options)
|
||||
780
venv/lib/python3.12/site-packages/mysql/ai/ml/model.py
Normal file
780
venv/lib/python3.12/site-packages/mysql/ai/ml/model.py
Normal file
@@ -0,0 +1,780 @@
|
||||
# Copyright (c) 2025 Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
"""HeatWave ML model utilities for MySQL Connector/Python.
|
||||
|
||||
Provides classes to manage training, prediction, scoring, and explanations
|
||||
via MySQL HeatWave stored procedures.
|
||||
"""
|
||||
import copy
|
||||
import json
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from mysql.ai.utils import (
|
||||
VAR_NAME_SPACE,
|
||||
atomic_transaction,
|
||||
convert_to_df,
|
||||
execute_sql,
|
||||
format_value_sql,
|
||||
get_random_name,
|
||||
source_schema,
|
||||
sql_response_to_df,
|
||||
sql_table_from_df,
|
||||
sql_table_to_df,
|
||||
table_exists,
|
||||
temporary_sql_tables,
|
||||
validate_name,
|
||||
)
|
||||
from mysql.connector.abstracts import MySQLConnectionAbstract
|
||||
|
||||
|
||||
class ML_TASK(Enum): # pylint: disable=invalid-name
|
||||
"""Enumeration of supported ML tasks for HeatWave."""
|
||||
|
||||
CLASSIFICATION = "classification"
|
||||
REGRESSION = "regression"
|
||||
FORECASTING = "forecasting"
|
||||
ANOMALY_DETECTION = "anomaly_detection"
|
||||
LOG_ANOMALY_DETECTION = "log_anomaly_detection"
|
||||
RECOMMENDATION = "recommendation"
|
||||
TOPIC_MODELING = "topic_modeling"
|
||||
|
||||
@staticmethod
|
||||
def get_task_string(task: Union[str, "ML_TASK"]) -> str:
|
||||
"""
|
||||
Return the string representation of a machine learning task.
|
||||
|
||||
Args:
|
||||
task (Union[str, ML_TASK]): The task to convert.
|
||||
Accepts either a task enum member (ML_TASK) or a string.
|
||||
|
||||
Returns:
|
||||
str: The string value of the ML task.
|
||||
"""
|
||||
|
||||
if isinstance(task, str):
|
||||
return task
|
||||
|
||||
return task.value
|
||||
|
||||
|
||||
class _MyModelCommon:
|
||||
"""
|
||||
Common utilities and workflow for MySQL HeatWave ML models.
|
||||
|
||||
This class handles model lifecycle steps such as loading, fitting, scoring,
|
||||
making predictions, and explaining models or predictions. Not intended for
|
||||
direct instantiation, but as a superclass for heatwave model wrappers.
|
||||
|
||||
Attributes:
|
||||
db_connection: MySQL connector database connection.
|
||||
task: ML task, e.g., "classification" or "regression".
|
||||
model_name: Identifier of model in MySQL.
|
||||
schema_name: Database schema used for operations and temp tables.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_connection: MySQLConnectionAbstract,
|
||||
task: Union[str, ML_TASK] = ML_TASK.CLASSIFICATION,
|
||||
model_name: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Instantiate _MyMLModelCommon.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-train.html
|
||||
A full list of supported tasks can be found under "Common ML_TRAIN Options"
|
||||
|
||||
Args:
|
||||
db_connection: MySQL database connection.
|
||||
task: ML task type (default: "classification").
|
||||
model_name: Name to register the model within MySQL (default: None).
|
||||
|
||||
Raises:
|
||||
ValueError: If the schema name is not valid
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self.db_connection = db_connection
|
||||
self.task = ML_TASK.get_task_string(task)
|
||||
self.schema_name = source_schema(db_connection)
|
||||
|
||||
with atomic_transaction(self.db_connection) as cursor:
|
||||
execute_sql(cursor, "CALL sys.ML_CREATE_OR_UPGRADE_CATALOG();")
|
||||
|
||||
if model_name is None:
|
||||
model_name = get_random_name(self._is_model_name_available)
|
||||
|
||||
self.model_var = f"{VAR_NAME_SPACE}.{model_name}"
|
||||
self.model_var_score = f"{self.model_var}.score"
|
||||
|
||||
self.model_name = model_name
|
||||
validate_name(model_name)
|
||||
|
||||
with atomic_transaction(self.db_connection) as cursor:
|
||||
execute_sql(cursor, f"SET @{self.model_var} = %s;", params=(model_name,))
|
||||
|
||||
def _delete_model(self) -> bool:
|
||||
"""
|
||||
Deletes the model from the model catalog if present
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Returns:
|
||||
Whether the model was deleted
|
||||
"""
|
||||
current_user = self._get_user()
|
||||
|
||||
qualified_model_catalog = f"ML_SCHEMA_{current_user}.MODEL_CATALOG"
|
||||
delete_model = (
|
||||
f"DELETE FROM {qualified_model_catalog} "
|
||||
f"WHERE model_handle = @{self.model_var}"
|
||||
)
|
||||
|
||||
with atomic_transaction(self.db_connection) as cursor:
|
||||
execute_sql(cursor, delete_model)
|
||||
return cursor.rowcount > 0
|
||||
|
||||
def _get_model_info(self, model_name: str) -> Optional[dict]:
|
||||
"""
|
||||
Retrieves the model info from the model_catalog
|
||||
|
||||
Args:
|
||||
model_var: The model alias to retrieve
|
||||
|
||||
Returns:
|
||||
The model info from the model_catalog (None if the model is not present in the catalog)
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
|
||||
def process_col(elem: Any) -> Any:
|
||||
if isinstance(elem, str):
|
||||
try:
|
||||
elem = json.loads(elem)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return elem
|
||||
|
||||
current_user = self._get_user()
|
||||
|
||||
qualified_model_catalog = f"ML_SCHEMA_{current_user}.MODEL_CATALOG"
|
||||
model_exists = (
|
||||
f"SELECT * FROM {qualified_model_catalog} WHERE model_handle = %s"
|
||||
)
|
||||
|
||||
with atomic_transaction(self.db_connection) as cursor:
|
||||
execute_sql(cursor, model_exists, params=(model_name,))
|
||||
model_info_df = sql_response_to_df(cursor)
|
||||
|
||||
if model_info_df.empty:
|
||||
result = None
|
||||
else:
|
||||
unprocessed_result = model_info_df.to_json(orient="records")
|
||||
unprocessed_result_json = json.loads(unprocessed_result)[0]
|
||||
result = {
|
||||
key: process_col(elem)
|
||||
for key, elem in unprocessed_result_json.items()
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def get_model_info(self) -> Optional[dict]:
|
||||
"""
|
||||
Checks if the model name is available.
|
||||
Model info is present in the catalog only if the model was previously fitted.
|
||||
|
||||
Returns:
|
||||
True if the model name is not part of the model catalog
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
return self._get_model_info(self.model_name)
|
||||
|
||||
def _is_model_name_available(self, model_name: str) -> bool:
|
||||
"""
|
||||
Checks if the model name is available
|
||||
|
||||
Returns:
|
||||
True if the model name is not part of the model catalog
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
return self._get_model_info(model_name) is None
|
||||
|
||||
def _load_model(self) -> None:
|
||||
"""
|
||||
Loads the model specified by `self.model_name` into MySQL.
|
||||
After loading, the model is ready to handle ML operations.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-model-load.html
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If the model is not initialized, i.e., fit or import has not been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
with atomic_transaction(self.db_connection) as cursor:
|
||||
load_model_query = f"CALL sys.ML_MODEL_LOAD(@{self.model_var}, NULL);"
|
||||
execute_sql(cursor, load_model_query)
|
||||
|
||||
def _get_user(self) -> str:
|
||||
"""
|
||||
Fetch the current database user (without host).
|
||||
|
||||
Returns:
|
||||
The username string associated with the connection.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
ValueError: If the user name includes unsupported characters
|
||||
"""
|
||||
with atomic_transaction(self.db_connection) as cursor:
|
||||
cursor.execute("SELECT CURRENT_USER()")
|
||||
current_user = cursor.fetchone()[0].split("@")[0]
|
||||
|
||||
return validate_name(current_user)
|
||||
|
||||
def explain_model(self) -> dict:
|
||||
"""
|
||||
Get model explanations, such as detailed feature importances.
|
||||
|
||||
Returns:
|
||||
dict: Feature importances and model explainability data.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-model-explanations.html
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If the model is not initialized, i.e., fit or import has not been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
ValueError:
|
||||
If the model does not exist in the model catalog.
|
||||
Should only occur if model was not fitted or was deleted.
|
||||
"""
|
||||
self._load_model()
|
||||
with atomic_transaction(self.db_connection) as cursor:
|
||||
current_user = self._get_user()
|
||||
|
||||
qualified_model_catalog = f"ML_SCHEMA_{current_user}.MODEL_CATALOG"
|
||||
explain_query = (
|
||||
f"SELECT model_explanation FROM {qualified_model_catalog} "
|
||||
f"WHERE model_handle = @{self.model_var}"
|
||||
)
|
||||
|
||||
execute_sql(cursor, explain_query)
|
||||
df = sql_response_to_df(cursor)
|
||||
|
||||
return df.iloc[0, 0]
|
||||
|
||||
def _fit(
|
||||
self,
|
||||
table_name: str,
|
||||
target_column_name: Optional[str],
|
||||
options: Optional[dict],
|
||||
) -> None:
|
||||
"""
|
||||
Fit an ML model using a referenced SQL table and target column.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-train.html
|
||||
A full list of supported options can be found under "Common ML_TRAIN Options"
|
||||
|
||||
Args:
|
||||
table_name: Name of the training data table.
|
||||
target_column_name: Name of the target/label column.
|
||||
options: Additional fit/config options (may override defaults).
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported.
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
ValueError: If the table or target_column name is not valid
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
validate_name(table_name)
|
||||
if target_column_name is not None:
|
||||
validate_name(target_column_name)
|
||||
target_col_string = f"'{target_column_name}'"
|
||||
else:
|
||||
target_col_string = "NULL"
|
||||
|
||||
if options is None:
|
||||
options = {}
|
||||
options = copy.deepcopy(options)
|
||||
options["task"] = self.task
|
||||
|
||||
self._delete_model()
|
||||
|
||||
with atomic_transaction(self.db_connection) as cursor:
|
||||
placeholders, parameters = format_value_sql(options)
|
||||
execute_sql(
|
||||
cursor,
|
||||
(
|
||||
"CALL sys.ML_TRAIN("
|
||||
f"'{self.schema_name}.{table_name}', "
|
||||
f"{target_col_string}, "
|
||||
f"{placeholders}, "
|
||||
f"@{self.model_var}"
|
||||
")"
|
||||
),
|
||||
params=parameters,
|
||||
)
|
||||
|
||||
def _predict(
|
||||
self, table_name: str, output_table_name: str, options: Optional[dict]
|
||||
) -> None:
|
||||
"""
|
||||
Predict on a given data table and write results to an output table.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-predict-table.html
|
||||
A full list of supported options can be found under "ML_PREDICT_TABLE Options"
|
||||
|
||||
Args:
|
||||
table_name: Name of the SQL table with input data.
|
||||
output_table_name: Name for the SQL output table to contain predictions.
|
||||
options: Optional prediction options.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
ValueError: If the table or output_table name is not valid
|
||||
"""
|
||||
validate_name(table_name)
|
||||
validate_name(output_table_name)
|
||||
|
||||
self._load_model()
|
||||
with atomic_transaction(self.db_connection) as cursor:
|
||||
placeholders, parameters = format_value_sql(options)
|
||||
execute_sql(
|
||||
cursor,
|
||||
(
|
||||
"CALL sys.ML_PREDICT_TABLE("
|
||||
f"'{self.schema_name}.{table_name}', "
|
||||
f"@{self.model_var}, "
|
||||
f"'{self.schema_name}.{output_table_name}', "
|
||||
f"{placeholders}"
|
||||
")"
|
||||
),
|
||||
params=parameters,
|
||||
)
|
||||
|
||||
def _score(
|
||||
self,
|
||||
table_name: str,
|
||||
target_column_name: str,
|
||||
metric: str,
|
||||
options: Optional[dict],
|
||||
) -> float:
|
||||
"""
|
||||
Evaluate model performance with a scoring metric.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-score.html
|
||||
A full list of supported options can be found under
|
||||
"Options for Recommendation Models" and
|
||||
"Options for Anomaly Detection Models"
|
||||
|
||||
Args:
|
||||
table_name: Table with features and ground truth.
|
||||
target_column_name: Column of true target labels.
|
||||
metric: String name of the metric to compute.
|
||||
options: Optional dictionary of further scoring options.
|
||||
|
||||
Returns:
|
||||
float: Computed score from the ML system.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
ValueError: If the table or target_column name or metric is not valid
|
||||
"""
|
||||
validate_name(table_name)
|
||||
validate_name(target_column_name)
|
||||
validate_name(metric)
|
||||
|
||||
self._load_model()
|
||||
with atomic_transaction(self.db_connection) as cursor:
|
||||
placeholders, parameters = format_value_sql(options)
|
||||
execute_sql(
|
||||
cursor,
|
||||
(
|
||||
"CALL sys.ML_SCORE("
|
||||
f"'{self.schema_name}.{table_name}', "
|
||||
f"'{target_column_name}', "
|
||||
f"@{self.model_var}, "
|
||||
"%s, "
|
||||
f"@{self.model_var_score}, "
|
||||
f"{placeholders}"
|
||||
")"
|
||||
),
|
||||
params=[metric, *parameters],
|
||||
)
|
||||
execute_sql(cursor, f"SELECT @{self.model_var_score}")
|
||||
df = sql_response_to_df(cursor)
|
||||
|
||||
return df.iloc[0, 0]
|
||||
|
||||
def _explain_predictions(
|
||||
self, table_name: str, output_table_name: str, options: Optional[dict]
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Produce explanations for model predictions on provided data.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-explain-table.html
|
||||
A full list of supported options can be found under "ML_EXPLAIN_TABLE Options"
|
||||
|
||||
Args:
|
||||
table_name: Name of the SQL table with input data.
|
||||
output_table_name: Name for the SQL table to store explanations.
|
||||
options: Optional dictionary (default:
|
||||
{"prediction_explainer": "permutation_importance"}).
|
||||
|
||||
Returns:
|
||||
DataFrame: Prediction explanations from the output SQL table.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
ValueError: If the table or output_table name is not valid
|
||||
"""
|
||||
validate_name(table_name)
|
||||
validate_name(output_table_name)
|
||||
|
||||
if options is None:
|
||||
options = {"prediction_explainer": "permutation_importance"}
|
||||
|
||||
self._load_model()
|
||||
|
||||
with atomic_transaction(self.db_connection) as cursor:
|
||||
placeholders, parameters = format_value_sql(options)
|
||||
execute_sql(
|
||||
cursor,
|
||||
(
|
||||
"CALL sys.ML_EXPLAIN_TABLE("
|
||||
f"'{self.schema_name}.{table_name}', "
|
||||
f"@{self.model_var}, "
|
||||
f"'{self.schema_name}.{output_table_name}', "
|
||||
f"{placeholders}"
|
||||
")"
|
||||
),
|
||||
params=parameters,
|
||||
)
|
||||
execute_sql(cursor, f"SELECT * FROM {self.schema_name}.{output_table_name}")
|
||||
df = sql_response_to_df(cursor)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
class MyModel(_MyModelCommon):
|
||||
"""
|
||||
Convenience class for managing the ML workflow using pandas DataFrames.
|
||||
|
||||
Methods convert in-memory DataFrames into temp SQL tables before delegating to the
|
||||
_MyMLModelCommon routines, and automatically clean up temp resources.
|
||||
"""
|
||||
|
||||
def fit(
|
||||
self,
|
||||
X: Union[pd.DataFrame, np.ndarray], # pylint: disable=invalid-name
|
||||
y: Optional[Union[pd.DataFrame, np.ndarray]],
|
||||
options: Optional[dict] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Fit a model using DataFrame inputs.
|
||||
|
||||
If an 'id' column is defined in either dataframe, it will be used as the primary key.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-train.html
|
||||
A full list of supported options can be found under "Common ML_TRAIN Options"
|
||||
|
||||
Args:
|
||||
X: Features DataFrame.
|
||||
y: (Optional) Target labels DataFrame or Series. If None, only X is used.
|
||||
options: Additional options to pass to training.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported.
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Notes:
|
||||
Combines X and y as necessary. Creates a temporary table in the schema for training,
|
||||
and deletes it afterward.
|
||||
"""
|
||||
X, y = convert_to_df(X), convert_to_df(y)
|
||||
|
||||
with (
|
||||
atomic_transaction(self.db_connection) as cursor,
|
||||
temporary_sql_tables(self.db_connection) as temporary_tables,
|
||||
):
|
||||
if y is not None:
|
||||
if isinstance(y, pd.DataFrame):
|
||||
# keep column name if it exists
|
||||
target_column_name = y.columns[0]
|
||||
else:
|
||||
target_column_name = get_random_name(
|
||||
lambda name: name not in X.columns
|
||||
)
|
||||
|
||||
if target_column_name in X.columns:
|
||||
raise ValueError(
|
||||
f"Target column y with name {target_column_name} already present "
|
||||
"in feature dataframe X"
|
||||
)
|
||||
|
||||
df_combined = X.copy()
|
||||
df_combined[target_column_name] = y
|
||||
final_df = df_combined
|
||||
else:
|
||||
target_column_name = None
|
||||
final_df = X
|
||||
|
||||
_, table_name = sql_table_from_df(cursor, self.schema_name, final_df)
|
||||
temporary_tables.append((self.schema_name, table_name))
|
||||
|
||||
self._fit(table_name, target_column_name, options)
|
||||
|
||||
def predict(
|
||||
self,
|
||||
X: Union[pd.DataFrame, np.ndarray], # pylint: disable=invalid-name
|
||||
options: Optional[dict] = None,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Generate model predictions using DataFrame input.
|
||||
|
||||
If an 'id' column is defined in either dataframe, it will be used as the primary key.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-predict-table.html
|
||||
A full list of supported options can be found under "ML_PREDICT_TABLE Options"
|
||||
|
||||
Args:
|
||||
X: DataFrame containing prediction features (no labels).
|
||||
options: Additional prediction settings.
|
||||
|
||||
Returns:
|
||||
DataFrame with prediction results as returned by HeatWave.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Notes:
|
||||
Temporary SQL tables are created and deleted for input/output.
|
||||
"""
|
||||
X = convert_to_df(X)
|
||||
|
||||
with (
|
||||
atomic_transaction(self.db_connection) as cursor,
|
||||
temporary_sql_tables(self.db_connection) as temporary_tables,
|
||||
):
|
||||
_, table_name = sql_table_from_df(cursor, self.schema_name, X)
|
||||
temporary_tables.append((self.schema_name, table_name))
|
||||
|
||||
output_table_name = get_random_name(
|
||||
lambda table_name: not table_exists(
|
||||
cursor, self.schema_name, table_name
|
||||
)
|
||||
)
|
||||
temporary_tables.append((self.schema_name, output_table_name))
|
||||
|
||||
self._predict(table_name, output_table_name, options)
|
||||
predictions = sql_table_to_df(cursor, self.schema_name, output_table_name)
|
||||
|
||||
# ml_results is text but known to always follow JSON format
|
||||
predictions["ml_results"] = predictions["ml_results"].map(json.loads)
|
||||
|
||||
return predictions
|
||||
|
||||
def score(
|
||||
self,
|
||||
X: Union[pd.DataFrame, np.ndarray], # pylint: disable=invalid-name
|
||||
y: Union[pd.DataFrame, np.ndarray],
|
||||
metric: str,
|
||||
options: Optional[dict] = None,
|
||||
) -> float:
|
||||
"""
|
||||
Score the model using X/y data and a selected metric.
|
||||
|
||||
If an 'id' column is defined in either dataframe, it will be used as the primary key.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-score.html
|
||||
A full list of supported options can be found under
|
||||
"Options for Recommendation Models" and
|
||||
"Options for Anomaly Detection Models"
|
||||
|
||||
Args:
|
||||
X: DataFrame of features.
|
||||
y: DataFrame or Series of labels.
|
||||
metric: Metric name (e.g., "balanced_accuracy").
|
||||
options: Optional ml scoring options.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Returns:
|
||||
float: Computed score.
|
||||
"""
|
||||
X, y = convert_to_df(X), convert_to_df(y)
|
||||
|
||||
with (
|
||||
atomic_transaction(self.db_connection) as cursor,
|
||||
temporary_sql_tables(self.db_connection) as temporary_tables,
|
||||
):
|
||||
target_column_name = get_random_name(lambda name: name not in X.columns)
|
||||
df_combined = X.copy()
|
||||
df_combined[target_column_name] = y
|
||||
final_df = df_combined
|
||||
|
||||
_, table_name = sql_table_from_df(cursor, self.schema_name, final_df)
|
||||
temporary_tables.append((self.schema_name, table_name))
|
||||
|
||||
score = self._score(table_name, target_column_name, metric, options)
|
||||
|
||||
return score
|
||||
|
||||
def explain_predictions(
|
||||
self,
|
||||
X: Union[pd.DataFrame, np.ndarray], # pylint: disable=invalid-name
|
||||
options: Dict = None,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Explain model predictions using provided data.
|
||||
|
||||
If an 'id' column is defined in either dataframe, it will be used as the primary key.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-explain-table.html
|
||||
A full list of supported options can be found under
|
||||
"ML_EXPLAIN_TABLE Options"
|
||||
|
||||
Args:
|
||||
X: DataFrame for which predictions should be explained.
|
||||
options: Optional dictionary of explainability options.
|
||||
|
||||
Returns:
|
||||
DataFrame containing explanation details (feature attributions, etc.)
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported, or if the model is not initialized,
|
||||
i.e., fit or import has not been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Notes:
|
||||
Temporary input/output tables are cleaned up after explanation.
|
||||
"""
|
||||
X = convert_to_df(X)
|
||||
|
||||
with (
|
||||
atomic_transaction(self.db_connection) as cursor,
|
||||
temporary_sql_tables(self.db_connection) as temporary_tables,
|
||||
):
|
||||
|
||||
_, table_name = sql_table_from_df(cursor, self.schema_name, X)
|
||||
temporary_tables.append((self.schema_name, table_name))
|
||||
|
||||
output_table_name = get_random_name(
|
||||
lambda table_name: not table_exists(
|
||||
cursor, self.schema_name, table_name
|
||||
)
|
||||
)
|
||||
temporary_tables.append((self.schema_name, output_table_name))
|
||||
|
||||
explanations = self._explain_predictions(
|
||||
table_name, output_table_name, options
|
||||
)
|
||||
|
||||
return explanations
|
||||
221
venv/lib/python3.12/site-packages/mysql/ai/ml/outlier.py
Normal file
221
venv/lib/python3.12/site-packages/mysql/ai/ml/outlier.py
Normal file
@@ -0,0 +1,221 @@
|
||||
# Copyright (c) 2025 Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""Outlier/anomaly detection utilities for MySQL Connector/Python.
|
||||
|
||||
Provides a scikit-learn compatible wrapper using HeatWave to score anomalies.
|
||||
"""
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.base import OutlierMixin
|
||||
|
||||
from mysql.ai.ml.base import MyBaseMLModel
|
||||
from mysql.ai.ml.model import ML_TASK
|
||||
from mysql.ai.utils import copy_dict
|
||||
|
||||
from mysql.connector.abstracts import MySQLConnectionAbstract
|
||||
|
||||
EPS = 1e-5
|
||||
|
||||
|
||||
def _get_logits(prob: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
|
||||
"""
|
||||
Compute logit (logodds) for a probability, clipping to avoid numerical overflow.
|
||||
|
||||
Args:
|
||||
prob: Scalar or array of probability values in (0,1).
|
||||
|
||||
Returns:
|
||||
logit-transformed probabilities.
|
||||
"""
|
||||
result = np.clip(prob, EPS, 1 - EPS)
|
||||
return np.log(result / (1 - result))
|
||||
|
||||
|
||||
class MyAnomalyDetector(MyBaseMLModel, OutlierMixin):
|
||||
"""
|
||||
MySQL HeatWave scikit-learn compatible anomaly/outlier detector.
|
||||
|
||||
Flags samples as outliers when the probability of being an anomaly
|
||||
exceeds a user-tunable threshold.
|
||||
Includes helpers to obtain decision scores and anomaly probabilities
|
||||
for ranking.
|
||||
|
||||
Args:
|
||||
db_connection (MySQLConnectionAbstract): Active MySQL DB connection.
|
||||
model_name (str, optional): Custom model name in the database.
|
||||
fit_extra_options (dict, optional): Extra options for fitting.
|
||||
score_extra_options (dict, optional): Extra options for scoring/prediction.
|
||||
|
||||
Attributes:
|
||||
boundary: Decision threshold boundary in logit space. Derived from
|
||||
trained model's catalog info
|
||||
|
||||
Methods:
|
||||
predict(X): Predict outlier/inlier labels.
|
||||
score_samples(X): Compute anomaly (normal class) logit scores.
|
||||
decision_function(X): Compute signed score above/below threshold for ranking.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_connection: MySQLConnectionAbstract,
|
||||
model_name: Optional[str] = None,
|
||||
fit_extra_options: Optional[dict] = None,
|
||||
score_extra_options: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Initialize an anomaly detector instance with threshold and extra options.
|
||||
|
||||
Args:
|
||||
db_connection: Active MySQL DB connection.
|
||||
model_name: Optional model name in DB.
|
||||
fit_extra_options: Optional extra fit options.
|
||||
score_extra_options: Optional extra scoring options.
|
||||
|
||||
Raises:
|
||||
ValueError: If outlier_threshold is not in (0,1).
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
MyBaseMLModel.__init__(
|
||||
self,
|
||||
db_connection,
|
||||
ML_TASK.ANOMALY_DETECTION,
|
||||
model_name=model_name,
|
||||
fit_extra_options=fit_extra_options,
|
||||
)
|
||||
self.score_extra_options = copy_dict(score_extra_options)
|
||||
self.boundary: Optional[float] = None
|
||||
|
||||
def predict(
|
||||
self,
|
||||
X: Union[pd.DataFrame, np.ndarray], # pylint: disable=invalid-name
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Predict outlier/inlier binary labels for input samples.
|
||||
|
||||
Args:
|
||||
X: Samples to predict on.
|
||||
|
||||
Returns:
|
||||
ndarray: Values are -1 for outliers, +1 for inliers, as per scikit-learn convention.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
return np.where(self.decision_function(X) < 0.0, -1, 1)
|
||||
|
||||
def decision_function(
|
||||
self,
|
||||
X: Union[pd.DataFrame, np.ndarray], # pylint: disable=invalid-name
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute signed distance to the outlier threshold.
|
||||
|
||||
Args:
|
||||
X: Samples to predict on.
|
||||
|
||||
Returns:
|
||||
ndarray: Score > 0 means inlier, < 0 means outlier; |value| gives margin.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
ValueError:
|
||||
If the provided model info does not provide threshold
|
||||
"""
|
||||
sample_scores = self.score_samples(X)
|
||||
|
||||
if self.boundary is None:
|
||||
model_info = self.get_model_info()
|
||||
if model_info is None:
|
||||
raise ValueError("Model does not exist in catalog.")
|
||||
|
||||
threshold = model_info["model_metadata"]["training_params"].get(
|
||||
"anomaly_detection_threshold", None
|
||||
)
|
||||
if threshold is None:
|
||||
raise ValueError(
|
||||
"Trained model is outdated and does not support threshold. "
|
||||
"Try retraining or using an existing, trained model with MyModel."
|
||||
)
|
||||
|
||||
# scikit-learn uses large positive values as inlier
|
||||
# and negative as outlier, so we need to flip our threshold
|
||||
self.boundary = _get_logits(1.0 - threshold)
|
||||
|
||||
return sample_scores - self.boundary
|
||||
|
||||
def score_samples(
|
||||
self,
|
||||
X: Union[pd.DataFrame, np.ndarray], # pylint: disable=invalid-name
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute normal probability logit score for each sample.
|
||||
Used for ranking, thresholding.
|
||||
|
||||
Args:
|
||||
X: Samples to score.
|
||||
|
||||
Returns:
|
||||
ndarray: Logit scores based on "normal" class probability.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
result = self._model.predict(X, options=self.score_extra_options)
|
||||
|
||||
return _get_logits(
|
||||
result["ml_results"]
|
||||
.apply(lambda x: x["probabilities"]["normal"])
|
||||
.to_numpy()
|
||||
)
|
||||
154
venv/lib/python3.12/site-packages/mysql/ai/ml/regressor.py
Normal file
154
venv/lib/python3.12/site-packages/mysql/ai/ml/regressor.py
Normal file
@@ -0,0 +1,154 @@
|
||||
# Copyright (c) 2025 Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""Regressor utilities for MySQL Connector/Python.
|
||||
|
||||
Provides a scikit-learn compatible regressor backed by HeatWave ML.
|
||||
"""
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.base import RegressorMixin
|
||||
|
||||
from mysql.ai.ml.base import MyBaseMLModel
|
||||
from mysql.ai.ml.model import ML_TASK
|
||||
from mysql.ai.utils import copy_dict
|
||||
|
||||
from mysql.connector.abstracts import MySQLConnectionAbstract
|
||||
|
||||
|
||||
class MyRegressor(MyBaseMLModel, RegressorMixin):
|
||||
"""
|
||||
MySQL HeatWave scikit-learn compatible regressor estimator.
|
||||
|
||||
Provides prediction output from a regression model deployed in MySQL,
|
||||
and manages fit, explain, and prediction options as per HeatWave ML interface.
|
||||
|
||||
Attributes:
|
||||
predict_extra_options (dict): Optional parameter dict passed to the backend for prediction.
|
||||
_model (MyModel): Underlying interface for database model operations.
|
||||
fit_extra_options (dict): See MyBaseMLModel.
|
||||
explain_extra_options (dict): See MyBaseMLModel.
|
||||
|
||||
Args:
|
||||
db_connection (MySQLConnectionAbstract): Active MySQL connector DB connection.
|
||||
model_name (str, optional): Custom name for the model.
|
||||
fit_extra_options (dict, optional): Extra options for fitting.
|
||||
explain_extra_options (dict, optional): Extra options for explanations.
|
||||
predict_extra_options (dict, optional): Extra options for predictions.
|
||||
|
||||
Methods:
|
||||
predict(X): Predict regression target.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_connection: MySQLConnectionAbstract,
|
||||
model_name: Optional[str] = None,
|
||||
fit_extra_options: Optional[dict] = None,
|
||||
explain_extra_options: Optional[dict] = None,
|
||||
predict_extra_options: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Initialize a MyRegressor.
|
||||
|
||||
Args:
|
||||
db_connection: Active MySQL connector database connection.
|
||||
model_name: Optional, custom model name.
|
||||
fit_extra_options: Optional fit options.
|
||||
explain_extra_options: Optional explain options.
|
||||
predict_extra_options: Optional prediction options.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
MyBaseMLModel.__init__(
|
||||
self,
|
||||
db_connection,
|
||||
ML_TASK.REGRESSION,
|
||||
model_name=model_name,
|
||||
fit_extra_options=fit_extra_options,
|
||||
)
|
||||
|
||||
self.predict_extra_options = copy_dict(predict_extra_options)
|
||||
self.explain_extra_options = copy_dict(explain_extra_options)
|
||||
|
||||
def predict(
|
||||
self, X: Union[pd.DataFrame, np.ndarray]
|
||||
) -> np.ndarray: # pylint: disable=invalid-name
|
||||
"""
|
||||
Predict a continuous target for the input features using the MySQL model.
|
||||
|
||||
Args:
|
||||
X: Input samples as a numpy array or pandas DataFrame.
|
||||
|
||||
Returns:
|
||||
ndarray: Array of predicted target values, shape (n_samples,).
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
result = self._model.predict(X, options=self.predict_extra_options)
|
||||
return result["Prediction"].to_numpy()
|
||||
|
||||
def explain_predictions(
|
||||
self, X: Union[pd.DataFrame, np.ndarray]
|
||||
) -> pd.DataFrame: # pylint: disable=invalid-name
|
||||
"""
|
||||
Explain model predictions using provided data.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-explain-table.html
|
||||
A full list of supported options can be found under "ML_EXPLAIN_TABLE Options"
|
||||
|
||||
Args:
|
||||
X: DataFrame for which predictions should be explained.
|
||||
|
||||
Returns:
|
||||
DataFrame containing explanation details (feature attributions, etc.)
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Notes:
|
||||
Temporary input/output tables are cleaned up after explanation.
|
||||
"""
|
||||
self._model.explain_predictions(X, options=self.explain_extra_options)
|
||||
164
venv/lib/python3.12/site-packages/mysql/ai/ml/transformer.py
Normal file
164
venv/lib/python3.12/site-packages/mysql/ai/ml/transformer.py
Normal file
@@ -0,0 +1,164 @@
|
||||
# Copyright (c) 2025 Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
"""Generic transformer utilities for MySQL Connector/Python.
|
||||
|
||||
Provides a scikit-learn compatible Transformer using HeatWave for fit/transform
|
||||
and scoring operations.
|
||||
"""
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.base import TransformerMixin
|
||||
|
||||
from mysql.ai.ml.base import MyBaseMLModel
|
||||
from mysql.ai.ml.model import ML_TASK
|
||||
from mysql.ai.utils import copy_dict
|
||||
from mysql.connector.abstracts import MySQLConnectionAbstract
|
||||
|
||||
|
||||
class MyGenericTransformer(MyBaseMLModel, TransformerMixin):
|
||||
"""
|
||||
MySQL HeatWave scikit-learn compatible generic transformer.
|
||||
|
||||
Can be used as the transformation step in an sklearn pipeline. Implements fit, transform,
|
||||
explain, and scoring capability, passing options for server-side transform logic.
|
||||
|
||||
Args:
|
||||
db_connection (MySQLConnectionAbstract): Active MySQL connector database connection.
|
||||
task (str): ML task type for transformer (default: "classification").
|
||||
score_metric (str): Scoring metric to request from backend (default: "balanced_accuracy").
|
||||
model_name (str, optional): Custom name for the deployed model.
|
||||
fit_extra_options (dict, optional): Extra fit options.
|
||||
transform_extra_options (dict, optional): Extra options for transformations.
|
||||
score_extra_options (dict, optional): Extra options for scoring.
|
||||
|
||||
Attributes:
|
||||
score_metric (str): Name of the backend metric to use for scoring
|
||||
(e.g. "balanced_accuracy").
|
||||
score_extra_options (dict): Dictionary of optional scoring parameters;
|
||||
passed to backend score.
|
||||
transform_extra_options (dict): Dictionary of inference (/predict)
|
||||
parameters for the backend.
|
||||
fit_extra_options (dict): See MyBaseMLModel.
|
||||
_model (MyModel): Underlying interface for database model operations.
|
||||
|
||||
Methods:
|
||||
fit(X, y): Fit the underlying model using the provided features/targets.
|
||||
transform(X): Transform features using the backend model.
|
||||
score(X, y): Score data using backend metric and options.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_connection: MySQLConnectionAbstract,
|
||||
task: Union[str, ML_TASK] = ML_TASK.CLASSIFICATION,
|
||||
score_metric: str = "balanced_accuracy",
|
||||
model_name: Optional[str] = None,
|
||||
fit_extra_options: Optional[dict] = None,
|
||||
transform_extra_options: Optional[dict] = None,
|
||||
score_extra_options: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Initialize transformer with required and optional arguments.
|
||||
|
||||
Args:
|
||||
db_connection: Active MySQL backend database connection.
|
||||
task: ML task type for transformer.
|
||||
score_metric: Requested backend scoring metric.
|
||||
model_name: Optional model name for storage.
|
||||
fit_extra_options: Optional extra options for fitting.
|
||||
transform_extra_options: Optional extra options for transformation/inference.
|
||||
score_extra_options: Optional extra scoring options.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
MyBaseMLModel.__init__(
|
||||
self,
|
||||
db_connection,
|
||||
task,
|
||||
model_name=model_name,
|
||||
fit_extra_options=fit_extra_options,
|
||||
)
|
||||
|
||||
self.score_metric = score_metric
|
||||
self.score_extra_options = copy_dict(score_extra_options)
|
||||
|
||||
self.transform_extra_options = copy_dict(transform_extra_options)
|
||||
|
||||
def transform(
|
||||
self, X: pd.DataFrame
|
||||
) -> pd.DataFrame: # pylint: disable=invalid-name
|
||||
"""
|
||||
Transform input data to model predictions using the underlying helper.
|
||||
|
||||
Args:
|
||||
X: DataFrame of features to predict/transform.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: Results of transformation as returned by backend.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
return self._model.predict(X, options=self.transform_extra_options)
|
||||
|
||||
def score(
|
||||
self,
|
||||
X: Union[pd.DataFrame, np.ndarray], # pylint: disable=invalid-name
|
||||
y: Union[pd.DataFrame, np.ndarray],
|
||||
) -> float:
|
||||
"""
|
||||
Score the transformed data using the backend scoring interface.
|
||||
|
||||
Args:
|
||||
X: Transformed features.
|
||||
y: Target labels or data for scoring.
|
||||
|
||||
Returns:
|
||||
float: Score based on backend metric.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
return self._model.score(
|
||||
X, y, self.score_metric, options=self.score_extra_options
|
||||
)
|
||||
44
venv/lib/python3.12/site-packages/mysql/ai/utils/__init__.py
Normal file
44
venv/lib/python3.12/site-packages/mysql/ai/utils/__init__.py
Normal file
@@ -0,0 +1,44 @@
|
||||
# Copyright (c) 2025 Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""Utilities for AI-related helpers in MySQL Connector/Python.
|
||||
|
||||
This package exposes:
|
||||
- check_dependencies(): runtime dependency guard for optional AI features
|
||||
- atomic_transaction(): context manager ensuring atomic DB transactions
|
||||
- utils: general-purpose helpers used by AI integrations
|
||||
|
||||
Importing this package validates base dependencies required for AI utilities.
|
||||
"""
|
||||
|
||||
from .dependencies import check_dependencies
|
||||
|
||||
check_dependencies(["BASE"])
|
||||
|
||||
from .atomic_cursor import atomic_transaction
|
||||
from .utils import *
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,94 @@
|
||||
# Copyright (c) 2025 Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""Atomic transaction context manager utilities for MySQL Connector/Python.
|
||||
|
||||
Provides context manager atomic_transaction() that ensures commit on success
|
||||
and rollback on error without obscuring the original exception.
|
||||
"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Iterator
|
||||
|
||||
from mysql.connector.abstracts import MySQLConnectionAbstract
|
||||
from mysql.connector.cursor import MySQLCursorAbstract
|
||||
|
||||
|
||||
@contextmanager
|
||||
def atomic_transaction(
|
||||
conn: MySQLConnectionAbstract,
|
||||
) -> Iterator[MySQLCursorAbstract]:
|
||||
"""
|
||||
Context manager that wraps a MySQL database cursor and ensures transaction
|
||||
rollback in case of exception.
|
||||
|
||||
NOTE: DDL statements such as CREATE TABLE cause implicit commits. These cannot
|
||||
be managed by a cursor object. Changes made at or before a DDL statement will
|
||||
be committed and not rolled back. Callers are responsible for any cleanup of
|
||||
this type.
|
||||
|
||||
This class acts as a robust, PEP 343-compliant context manager for handling
|
||||
database cursor operations on a MySQL connection. It ensures that all operations
|
||||
executed within the context block are part of the same transaction, and
|
||||
automatically calls `connection.rollback()` if an exception occurs, helping
|
||||
to maintain database integrity. On normal completion (no exception), it simply
|
||||
closes the cursor after use. Exceptions are always propagated to the caller.
|
||||
|
||||
Args:
|
||||
conn: A MySQLConnectionAbstract instance.
|
||||
"""
|
||||
old_autocommit = conn.autocommit
|
||||
cursor = conn.cursor()
|
||||
|
||||
exception_raised = False
|
||||
try:
|
||||
if old_autocommit:
|
||||
conn.autocommit = False
|
||||
|
||||
yield cursor # provide cursor to block
|
||||
|
||||
conn.commit()
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
exception_raised = True
|
||||
try:
|
||||
conn.rollback()
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
# Don't obscure original exception
|
||||
pass
|
||||
|
||||
# Raise original exception
|
||||
raise
|
||||
finally:
|
||||
conn.autocommit = old_autocommit
|
||||
|
||||
try:
|
||||
cursor.close()
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
# don't obscure original exception if exists
|
||||
if not exception_raised:
|
||||
raise
|
||||
@@ -0,0 +1,88 @@
|
||||
# Copyright (c) 2025 Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""Dependency checking utilities for AI features in MySQL Connector/Python.
|
||||
|
||||
Provides check_dependencies() to assert required optional packages are present
|
||||
with acceptable minimum versions at runtime.
|
||||
"""
|
||||
|
||||
import importlib.metadata
|
||||
|
||||
from typing import List
|
||||
|
||||
|
||||
def check_dependencies(tasks: List[str]) -> None:
|
||||
"""
|
||||
Check required runtime dependencies and minimum versions; raise an error
|
||||
if any are missing or version-incompatible.
|
||||
|
||||
This verifies the presence and minimum version of essential Python packages.
|
||||
Missing or insufficient versions cause an ImportError listing the packages
|
||||
and a suggested install command.
|
||||
|
||||
Args:
|
||||
tasks (List[str]): Task types to check requirements for.
|
||||
|
||||
Raises:
|
||||
ImportError: If any required dependencies are missing or below the
|
||||
minimum version.
|
||||
"""
|
||||
task_set = set(tasks)
|
||||
task_set.add("BASE")
|
||||
|
||||
# Requirements: (import_name, min_version)
|
||||
task_to_requirement = {
|
||||
"BASE": [("pandas", "1.5.0")],
|
||||
"GENAI": [
|
||||
("langchain", "0.1.11"),
|
||||
("langchain_core", "0.1.11"),
|
||||
("pydantic", "1.10.0"),
|
||||
],
|
||||
"ML": [("scikit-learn", "1.3.0")],
|
||||
}
|
||||
requirements = []
|
||||
for task in task_set:
|
||||
requirements.extend(task_to_requirement[task])
|
||||
requirements_set = set(requirements)
|
||||
|
||||
problems = []
|
||||
for name, min_version in requirements_set:
|
||||
try:
|
||||
installed_version = importlib.metadata.version(name)
|
||||
# Version comparison uses simple string comparison to avoid extra
|
||||
# dependencies. This is valid for the dependencies defined above;
|
||||
# reconsider if adding packages with version schemes that do not
|
||||
# compare correctly as strings.
|
||||
error = installed_version < min_version
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
error = True
|
||||
if error:
|
||||
problems.append(f"{name} v{min_version} (or later)")
|
||||
if problems:
|
||||
raise ImportError("Please install " + ", ".join(problems) + ".")
|
||||
573
venv/lib/python3.12/site-packages/mysql/ai/utils/utils.py
Normal file
573
venv/lib/python3.12/site-packages/mysql/ai/utils/utils.py
Normal file
@@ -0,0 +1,573 @@
|
||||
# Copyright (c) 2025 Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
"""General utilities for AI features in MySQL Connector/Python.
|
||||
|
||||
Includes helpers for:
|
||||
- defensive dict copying
|
||||
- temporary table lifecycle management
|
||||
- SQL execution and result conversions
|
||||
- DataFrame to/from SQL table utilities
|
||||
- schema/table/column name validation
|
||||
- array-like to DataFrame conversion
|
||||
"""
|
||||
|
||||
import copy
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
import string
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from mysql.ai.utils.atomic_cursor import atomic_transaction
|
||||
|
||||
from mysql.connector.abstracts import MySQLConnectionAbstract
|
||||
from mysql.connector.cursor import MySQLCursorAbstract
|
||||
from mysql.connector.types import ParamsSequenceOrDictType
|
||||
|
||||
VAR_NAME_SPACE = "mysql_ai"
|
||||
RANDOM_TABLE_NAME_LENGTH = 32
|
||||
|
||||
PD_TO_SQL_DTYPE_MAPPING = {
|
||||
"int64": "BIGINT",
|
||||
"float64": "DOUBLE",
|
||||
"object": "LONGTEXT",
|
||||
"bool": "BOOLEAN",
|
||||
"datetime64[ns]": "DATETIME",
|
||||
}
|
||||
|
||||
DEFAULT_SCHEMA = "mysql_ai"
|
||||
|
||||
# Misc Utilities
|
||||
|
||||
|
||||
def copy_dict(options: Optional[dict]) -> dict:
|
||||
"""
|
||||
Make a defensive copy of a dictionary, or return an empty dict if None.
|
||||
|
||||
Args:
|
||||
options: param dict or None
|
||||
|
||||
Returns:
|
||||
dict
|
||||
"""
|
||||
if options is None:
|
||||
return {}
|
||||
|
||||
return copy.deepcopy(options)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def temporary_sql_tables(
|
||||
db_connection: MySQLConnectionAbstract,
|
||||
) -> Iterator[list[tuple[str, str]]]:
|
||||
"""
|
||||
Context manager to track and automatically clean up temporary SQL tables.
|
||||
|
||||
Args:
|
||||
db_connection: Database connection object used to create and delete tables.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Yields:
|
||||
temporary_tables: List of (schema_name, table_name) tuples created during the
|
||||
context. All tables in this list are deleted on context exit.
|
||||
"""
|
||||
temporary_tables: List[Tuple[str, str]] = []
|
||||
try:
|
||||
yield temporary_tables
|
||||
finally:
|
||||
with atomic_transaction(db_connection) as cursor:
|
||||
for schema_name, table_name in temporary_tables:
|
||||
delete_sql_table(cursor, schema_name, table_name)
|
||||
|
||||
|
||||
def execute_sql(
|
||||
cursor: MySQLCursorAbstract, query: str, params: ParamsSequenceOrDictType = None
|
||||
) -> None:
|
||||
"""
|
||||
Execute an SQL query with optional parameters using the given cursor.
|
||||
|
||||
Args:
|
||||
cursor: MySQLCursorAbstract object to execute the query.
|
||||
query: SQL query string to execute.
|
||||
params: Optional sequence or dict providing parameters for the query.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If the provided SQL query/params are invalid
|
||||
If the query is valid but the sql raises as an exception
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
cursor.execute(query, params or ())
|
||||
|
||||
|
||||
def _get_name() -> str:
|
||||
"""
|
||||
Generate a random uppercase string of fixed length for table names.
|
||||
|
||||
Returns:
|
||||
Random string of length RANDOM_TABLE_NAME_LENGTH.
|
||||
"""
|
||||
char_set = string.ascii_uppercase
|
||||
return "".join(random.choices(char_set, k=RANDOM_TABLE_NAME_LENGTH))
|
||||
|
||||
|
||||
def get_random_name(condition: Callable[[str], bool], max_calls: int = 100) -> str:
|
||||
"""
|
||||
Generate a random string name that satisfies a given condition.
|
||||
|
||||
Args:
|
||||
condition: Callable that takes a generated name and returns True if it is valid.
|
||||
max_calls: Maximum number of attempts before giving up (default 100).
|
||||
|
||||
Returns:
|
||||
A random string that fulfills the provided condition.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the maximum number of attempts is reached without success.
|
||||
"""
|
||||
for _ in range(max_calls):
|
||||
if condition(name := _get_name()):
|
||||
return name
|
||||
# condition never met
|
||||
raise RuntimeError("Reached max tries without successfully finding a unique name")
|
||||
|
||||
|
||||
# Format conversions
|
||||
|
||||
|
||||
def format_value_sql(value: Any) -> Tuple[str, List[Any]]:
|
||||
"""
|
||||
Convert a Python value into its SQL-compatible string representation and parameters.
|
||||
|
||||
Args:
|
||||
value: The value to format.
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- A string for substitution into a SQL query.
|
||||
- A list of parameters to be bound into the query.
|
||||
"""
|
||||
if isinstance(value, (dict, list)):
|
||||
if len(value) == 0:
|
||||
return "%s", [None]
|
||||
return "CAST(%s as JSON)", [json.dumps(value)]
|
||||
return "%s", [value]
|
||||
|
||||
|
||||
def sql_response_to_df(cursor: MySQLCursorAbstract) -> pd.DataFrame:
|
||||
"""
|
||||
Convert the results of a cursor's last executed query to a pandas DataFrame.
|
||||
|
||||
Args:
|
||||
cursor: MySQLCursorAbstract with a completed query.
|
||||
|
||||
Returns:
|
||||
DataFrame with data from the cursor.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
If a compatible SELECT query wasn't the last statement ran
|
||||
"""
|
||||
|
||||
def _json_processor(elem: Optional[str]) -> Optional[dict]:
|
||||
return json.loads(elem) if elem is not None else None
|
||||
|
||||
def _default_processor(elem: Any) -> Any:
|
||||
return elem
|
||||
|
||||
idx_to_processor = {}
|
||||
for idx, col in enumerate(cursor.description):
|
||||
if col[1] == 245:
|
||||
# 245 is the MySQL type code for JSON
|
||||
idx_to_processor[idx] = _json_processor
|
||||
else:
|
||||
idx_to_processor[idx] = _default_processor
|
||||
|
||||
rows = cursor.fetchall()
|
||||
|
||||
# Process results
|
||||
processed_rows = []
|
||||
for row in rows:
|
||||
processed_row = list(row)
|
||||
|
||||
for idx, elem in enumerate(row):
|
||||
processed_row[idx] = idx_to_processor[idx](elem)
|
||||
|
||||
processed_rows.append(processed_row)
|
||||
|
||||
return pd.DataFrame(processed_rows, columns=cursor.column_names)
|
||||
|
||||
|
||||
def sql_table_to_df(
|
||||
cursor: MySQLCursorAbstract, schema_name: str, table_name: str
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Load the entire contents of a SQL table into a pandas DataFrame.
|
||||
|
||||
Args:
|
||||
cursor: MySQLCursorAbstract to execute the query.
|
||||
schema_name: Name of the schema containing the table.
|
||||
table_name: Name of the table to fetch.
|
||||
|
||||
Returns:
|
||||
DataFrame containing all rows from the specified table.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If the table does not exist
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
ValueError: If the schema or table name is not valid
|
||||
"""
|
||||
validate_name(schema_name)
|
||||
validate_name(table_name)
|
||||
|
||||
execute_sql(cursor, f"SELECT * FROM {schema_name}.{table_name}")
|
||||
return sql_response_to_df(cursor)
|
||||
|
||||
|
||||
# Table operations
|
||||
|
||||
|
||||
def table_exists(
|
||||
cursor: MySQLCursorAbstract, schema_name: str, table_name: str
|
||||
) -> bool:
|
||||
"""
|
||||
Check whether a table exists in a specific schema.
|
||||
|
||||
Args:
|
||||
cursor: MySQLCursorAbstract object to execute the query.
|
||||
schema_name: Name of the database schema.
|
||||
table_name: Name of the table.
|
||||
|
||||
Returns:
|
||||
True if the table exists, False otherwise.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
ValueError: If the schema or table name is not valid
|
||||
"""
|
||||
validate_name(schema_name)
|
||||
validate_name(table_name)
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT 1
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = %s AND table_name = %s
|
||||
LIMIT 1
|
||||
""",
|
||||
(schema_name, table_name),
|
||||
)
|
||||
return cursor.fetchone() is not None
|
||||
|
||||
|
||||
def delete_sql_table(
|
||||
cursor: MySQLCursorAbstract, schema_name: str, table_name: str
|
||||
) -> None:
|
||||
"""
|
||||
Drop a table from the SQL database if it exists.
|
||||
|
||||
Args:
|
||||
cursor: MySQLCursorAbstract to execute the drop command.
|
||||
schema_name: Name of the schema.
|
||||
table_name: Name of the table to delete.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
ValueError: If the schema or table name is not valid
|
||||
"""
|
||||
validate_name(schema_name)
|
||||
validate_name(table_name)
|
||||
|
||||
execute_sql(cursor, f"DROP TABLE IF EXISTS {schema_name}.{table_name}")
|
||||
|
||||
|
||||
def extend_sql_table(
|
||||
cursor: MySQLCursorAbstract,
|
||||
schema_name: str,
|
||||
table_name: str,
|
||||
df: pd.DataFrame,
|
||||
col_name_to_placeholder_string: Dict[str, str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Insert all rows from a pandas DataFrame into an existing SQL table.
|
||||
|
||||
Args:
|
||||
cursor: MySQLCursorAbstract for execution.
|
||||
schema_name: Name of the database schema.
|
||||
table_name: Table to insert new rows into.
|
||||
df: DataFrame containing the rows to insert.
|
||||
col_name_to_placeholder_string:
|
||||
Optional mapping of column names to custom SQL value/placeholder
|
||||
strings.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If the rows could not be inserted into the table, e.g., a type or shape issue
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
ValueError: If the schema or table name is not valid
|
||||
"""
|
||||
if col_name_to_placeholder_string is None:
|
||||
col_name_to_placeholder_string = {}
|
||||
|
||||
validate_name(schema_name)
|
||||
validate_name(table_name)
|
||||
for col in df.columns:
|
||||
validate_name(str(col))
|
||||
|
||||
qualified_table_name = f"{schema_name}.{table_name}"
|
||||
|
||||
# Iterate over all rows in the DataFrame to build insert statements row by row
|
||||
for row in df.values:
|
||||
placeholders, params = [], []
|
||||
for elem, col in zip(row, df.columns):
|
||||
elem = elem.item() if hasattr(elem, "item") else elem
|
||||
|
||||
if col in col_name_to_placeholder_string:
|
||||
elem_placeholder, elem_params = col_name_to_placeholder_string[col], [
|
||||
str(elem)
|
||||
]
|
||||
else:
|
||||
elem_placeholder, elem_params = format_value_sql(elem)
|
||||
|
||||
placeholders.append(elem_placeholder)
|
||||
params.extend(elem_params)
|
||||
|
||||
cols_sql = ", ".join([str(col) for col in df.columns])
|
||||
placeholders_sql = ", ".join(placeholders)
|
||||
insert_sql = (
|
||||
f"INSERT INTO {qualified_table_name} "
|
||||
f"({cols_sql}) VALUES ({placeholders_sql})"
|
||||
)
|
||||
execute_sql(cursor, insert_sql, params=params)
|
||||
|
||||
|
||||
def sql_table_from_df(
|
||||
cursor: MySQLCursorAbstract, schema_name: str, df: pd.DataFrame
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Create a new SQL table with a random name, and populate it with data from a DataFrame.
|
||||
|
||||
If an 'id' column is defined in the dataframe, it will be used as the primary key.
|
||||
|
||||
Args:
|
||||
cursor: MySQLCursorAbstract for executing SQL.
|
||||
schema_name: Schema in which to create the table.
|
||||
df: DataFrame containing the data to be inserted.
|
||||
|
||||
Returns:
|
||||
Tuple (qualified_table_name, table_name): The schema-qualified and
|
||||
unqualified table names.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If a random available table name could not be found.
|
||||
ValueError: If any schema, table, or a column name is invalid.
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
table_name = get_random_name(
|
||||
lambda table_name: not table_exists(cursor, schema_name, table_name)
|
||||
)
|
||||
qualified_table_name = f"{schema_name}.{table_name}"
|
||||
|
||||
validate_name(schema_name)
|
||||
validate_name(table_name)
|
||||
for col in df.columns:
|
||||
validate_name(str(col))
|
||||
|
||||
columns_sql = []
|
||||
for col, dtype in df.dtypes.items():
|
||||
# Map pandas dtype to SQL type, fallback is VARCHAR
|
||||
sql_type = PD_TO_SQL_DTYPE_MAPPING.get(str(dtype), "LONGTEXT")
|
||||
validate_name(str(col))
|
||||
columns_sql.append(f"{col} {sql_type}")
|
||||
|
||||
columns_str = ", ".join(columns_sql)
|
||||
|
||||
has_id_col = any(col.lower() == "id" for col in df.columns)
|
||||
if has_id_col:
|
||||
columns_str += ", PRIMARY KEY (id)"
|
||||
|
||||
# Create table with generated columns
|
||||
create_table_sql = f"CREATE TABLE {qualified_table_name} ({columns_str})"
|
||||
execute_sql(cursor, create_table_sql)
|
||||
|
||||
try:
|
||||
# Insert provided data into new table
|
||||
extend_sql_table(cursor, schema_name, table_name, df)
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
# Delete table before we lose access to it
|
||||
delete_sql_table(cursor, schema_name, table_name)
|
||||
raise
|
||||
return qualified_table_name, table_name
|
||||
|
||||
|
||||
def validate_name(name: str) -> str:
|
||||
"""
|
||||
Validate that the string is a legal SQL identifier (letters, digits, underscores).
|
||||
|
||||
Args:
|
||||
name: Name (schema, table, or column) to validate.
|
||||
|
||||
Returns:
|
||||
The validated name.
|
||||
|
||||
Raises:
|
||||
ValueError: If the name does not meet format requirements.
|
||||
"""
|
||||
# Accepts only letters, digits, and underscores; change as needed
|
||||
if not (isinstance(name, str) and re.match(r"^[A-Za-z0-9_]+$", name)):
|
||||
raise ValueError(f"Unsupported name format {name}")
|
||||
|
||||
return name
|
||||
|
||||
|
||||
def source_schema(db_connection: MySQLConnectionAbstract) -> str:
|
||||
"""
|
||||
Retrieve the name of the currently selected schema, or set and ensure the default schema.
|
||||
|
||||
Args:
|
||||
db_connection: MySQL connector database connection object.
|
||||
|
||||
Returns:
|
||||
Name of the schema (database in use).
|
||||
|
||||
Raises:
|
||||
ValueError: If the schema name is not valid
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
schema = db_connection.database
|
||||
if schema is None:
|
||||
schema = DEFAULT_SCHEMA
|
||||
|
||||
with atomic_transaction(db_connection) as cursor:
|
||||
create_database_stmt = f"CREATE DATABASE IF NOT EXISTS {schema}"
|
||||
execute_sql(cursor, create_database_stmt)
|
||||
|
||||
validate_name(schema)
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def is_table_empty(
|
||||
cursor: MySQLCursorAbstract, schema_name: str, table_name: str
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if a given SQL table is empty.
|
||||
|
||||
Args:
|
||||
cursor: MySQLCursorAbstract with access to the database.
|
||||
schema_name: Name of the schema containing the table.
|
||||
table_name: Name of the table to check.
|
||||
|
||||
Returns:
|
||||
True if the table has no rows, False otherwise.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If the table does not exist
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
ValueError: If the schema or table name is not valid
|
||||
"""
|
||||
validate_name(schema_name)
|
||||
validate_name(table_name)
|
||||
|
||||
cursor.execute(f"SELECT 1 FROM {schema_name}.{table_name} LIMIT 1")
|
||||
return cursor.fetchone() is None
|
||||
|
||||
|
||||
def convert_to_df(
|
||||
arr: Optional[Union[pd.DataFrame, pd.Series, np.ndarray]],
|
||||
col_prefix: str = "feature",
|
||||
) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
Convert input data to a pandas DataFrame if necessary.
|
||||
|
||||
Args:
|
||||
arr: Input data as a pandas DataFrame, NumPy ndarray, pandas Series, or None.
|
||||
|
||||
Returns:
|
||||
If the input is None, returns None.
|
||||
Otherwise, returns a DataFrame backed by the same underlying data whenever
|
||||
possible (except in cases where pandas or NumPy must copy, such as for
|
||||
certain views or non-contiguous arrays).
|
||||
|
||||
Notes:
|
||||
- If an ndarray is passed, column names will be integer indices (0, 1, ...).
|
||||
- If a DataFrame is passed, column names and indices are preserved.
|
||||
- The returned DataFrame is a shallow copy and shares data with the original
|
||||
input when possible; however, copies may still occur for certain input
|
||||
types or memory layouts.
|
||||
"""
|
||||
if arr is None:
|
||||
return None
|
||||
|
||||
if isinstance(arr, pd.DataFrame):
|
||||
return pd.DataFrame(arr)
|
||||
if isinstance(arr, pd.Series):
|
||||
return arr.to_frame()
|
||||
|
||||
if arr.ndim == 1:
|
||||
arr = arr.reshape(-1, 1)
|
||||
col_names = [f"{col_prefix}_{idx}" for idx in range(arr.shape[1])]
|
||||
|
||||
return pd.DataFrame(arr, columns=col_names, copy=False)
|
||||
127
venv/lib/python3.12/site-packages/mysql/connector/__init__.py
Normal file
127
venv/lib/python3.12/site-packages/mysql/connector/__init__.py
Normal file
@@ -0,0 +1,127 @@
|
||||
# Copyright (c) 2009, 2024, Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""MySQL Connector/Python - MySQL driver written in Python."""
|
||||
|
||||
try:
|
||||
from .connection_cext import CMySQLConnection
|
||||
except ImportError:
|
||||
HAVE_CEXT = False
|
||||
else:
|
||||
HAVE_CEXT = True
|
||||
|
||||
|
||||
from . import version
|
||||
from .connection import MySQLConnection
|
||||
from .constants import CharacterSet, ClientFlag, FieldFlag, FieldType, RefreshOption
|
||||
from .dbapi import (
|
||||
BINARY,
|
||||
DATETIME,
|
||||
NUMBER,
|
||||
ROWID,
|
||||
STRING,
|
||||
Binary,
|
||||
Date,
|
||||
DateFromTicks,
|
||||
Time,
|
||||
TimeFromTicks,
|
||||
Timestamp,
|
||||
TimestampFromTicks,
|
||||
apilevel,
|
||||
paramstyle,
|
||||
threadsafety,
|
||||
)
|
||||
from .errors import ( # pylint: disable=redefined-builtin
|
||||
DatabaseError,
|
||||
DataError,
|
||||
Error,
|
||||
IntegrityError,
|
||||
InterfaceError,
|
||||
InternalError,
|
||||
NotSupportedError,
|
||||
OperationalError,
|
||||
PoolError,
|
||||
ProgrammingError,
|
||||
Warning,
|
||||
custom_error_exception,
|
||||
)
|
||||
from .pooling import connect
|
||||
|
||||
Connect = connect
|
||||
|
||||
__version_info__ = version.VERSION
|
||||
"""This attribute indicates the Connector/Python version as an array
|
||||
of version components."""
|
||||
|
||||
__version__ = version.VERSION_TEXT
|
||||
"""This attribute indicates the Connector/Python version as a string."""
|
||||
|
||||
__all__ = [
|
||||
"MySQLConnection",
|
||||
"Connect",
|
||||
"custom_error_exception",
|
||||
# Some useful constants
|
||||
"FieldType",
|
||||
"FieldFlag",
|
||||
"ClientFlag",
|
||||
"CharacterSet",
|
||||
"RefreshOption",
|
||||
"HAVE_CEXT",
|
||||
# Error handling
|
||||
"Error",
|
||||
"Warning",
|
||||
"InterfaceError",
|
||||
"DatabaseError",
|
||||
"NotSupportedError",
|
||||
"DataError",
|
||||
"IntegrityError",
|
||||
"PoolError",
|
||||
"ProgrammingError",
|
||||
"OperationalError",
|
||||
"InternalError",
|
||||
# DBAPI PEP 249 required exports
|
||||
"connect",
|
||||
"apilevel",
|
||||
"threadsafety",
|
||||
"paramstyle",
|
||||
"Date",
|
||||
"Time",
|
||||
"Timestamp",
|
||||
"Binary",
|
||||
"DateFromTicks",
|
||||
"DateFromTicks",
|
||||
"TimestampFromTicks",
|
||||
"TimeFromTicks",
|
||||
"STRING",
|
||||
"BINARY",
|
||||
"NUMBER",
|
||||
"DATETIME",
|
||||
"ROWID",
|
||||
# C Extension
|
||||
"CMySQLConnection",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
111
venv/lib/python3.12/site-packages/mysql/connector/_decorating.py
Normal file
111
venv/lib/python3.12/site-packages/mysql/connector/_decorating.py
Normal file
@@ -0,0 +1,111 @@
|
||||
# Copyright (c) 2009, 2025, Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""Decorators Hub."""
|
||||
|
||||
import functools
|
||||
import warnings
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
from .constants import RefreshOption
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .abstracts import MySQLConnectionAbstract
|
||||
|
||||
|
||||
def cmd_refresh_verify_options() -> Callable:
|
||||
"""Decorator verifying which options are relevant and which aren't based on
|
||||
the server version the client is connecting to."""
|
||||
|
||||
def decorator(cmd_refresh: Callable) -> Callable:
|
||||
@functools.wraps(cmd_refresh)
|
||||
def wrapper(
|
||||
cnx: "MySQLConnectionAbstract", *args: Any, **kwargs: Any
|
||||
) -> Callable:
|
||||
options: int = args[0]
|
||||
if (options & RefreshOption.GRANT) and cnx.server_version >= (
|
||||
9,
|
||||
2,
|
||||
0,
|
||||
):
|
||||
warnings.warn(
|
||||
"As of MySQL Server 9.2.0, refreshing grant tables is not needed "
|
||||
"if you use statements GRANT, REVOKE, CREATE, DROP, or ALTER. "
|
||||
"You should expect this option to be unsupported in a future "
|
||||
"version of MySQL Connector/Python when MySQL Server removes it.",
|
||||
category=DeprecationWarning,
|
||||
stacklevel=1,
|
||||
)
|
||||
|
||||
return cmd_refresh(cnx, options, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def handle_read_write_timeout() -> Callable:
|
||||
"""
|
||||
Decorator to close the current connection if a read or a write timeout
|
||||
is raised by the method passed via the func parameter.
|
||||
"""
|
||||
|
||||
def decorator(cnx_method: Callable) -> Callable:
|
||||
@functools.wraps(cnx_method)
|
||||
def handle_cnx_method(
|
||||
cnx: "MySQLConnectionAbstract", *args: Any, **kwargs: Any
|
||||
) -> Any:
|
||||
try:
|
||||
return cnx_method(cnx, *args, **kwargs)
|
||||
except Exception as err:
|
||||
if isinstance(err, TimeoutError):
|
||||
cnx.close()
|
||||
raise err
|
||||
|
||||
return handle_cnx_method
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def deprecated(reason: str) -> Callable:
|
||||
"""Use it to decorate deprecated methods."""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Callable:
|
||||
warnings.warn(
|
||||
f"Call to deprecated function {func.__name__}. Reason: {reason}",
|
||||
category=DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
406
venv/lib/python3.12/site-packages/mysql/connector/_scripting.py
Normal file
406
venv/lib/python3.12/site-packages/mysql/connector/_scripting.py
Normal file
@@ -0,0 +1,406 @@
|
||||
# Copyright (c) 2024, 2025, Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""Classes and methods utilized to work with MySQL Scripts."""
|
||||
|
||||
import re
|
||||
import unicodedata
|
||||
|
||||
from collections import deque
|
||||
from typing import Deque, Generator, Optional
|
||||
|
||||
from .errors import InterfaceError
|
||||
from .types import MySQLScriptPartition
|
||||
|
||||
DEFAULT_DELIMITER = b";"
|
||||
"""The default delimiter of MySQL Client and the only one
|
||||
recognized by the MySQL server protocol."""
|
||||
|
||||
DELIMITER_RESERVED_SYMBOLS = {
|
||||
"$": rb"\$",
|
||||
"^": rb"\^",
|
||||
"?": rb"\?",
|
||||
"(": rb"\(",
|
||||
")": rb"\)",
|
||||
"[": rb"\[",
|
||||
"]": rb"\]",
|
||||
"{": rb"\{",
|
||||
"}": rb"\}",
|
||||
".": rb"\.",
|
||||
"|": rb"\|",
|
||||
"+": rb"\+",
|
||||
"-": rb"\-",
|
||||
"*": rb"\*",
|
||||
}
|
||||
"""Symbols with a special meaning in regular expression contexts."""
|
||||
|
||||
DELIMITER_PATTERN: re.Pattern = re.compile(
|
||||
rb"""(delimiter\s+)(?=(?:[^"'`]*(?:"[^"]*"|'[^']*'|`[^`]*`))*[^"'`]*$)""",
|
||||
flags=re.IGNORECASE | re.MULTILINE,
|
||||
)
|
||||
"""Regular expression pattern recognizing the delimiter command."""
|
||||
|
||||
|
||||
class MySQLScriptSplitter:
|
||||
"""Breaks a MySQL script into single statements.
|
||||
|
||||
It strips custom delimiters and comments along the way, except for comments
|
||||
representing a MySQL extension or optimizer hint.
|
||||
"""
|
||||
|
||||
_regex_sql_split_stmts = b"""(?=(?:[^"'`]*(?:"[^"]*"|'[^']*'|`[^`]*`))*[^"'`]*$)"""
|
||||
|
||||
def __init__(self, sql_script: bytes) -> None:
|
||||
"""Constructor."""
|
||||
self._code = sql_script
|
||||
self._single_stmts: Optional[list[bytes]] = None
|
||||
self._mappable_stmts: Optional[list[bytes]] = None
|
||||
self._re_sql_split_stmts: dict[bytes, re.Pattern] = {}
|
||||
|
||||
def _split_statement(self, code: bytes, delimiter: bytes) -> list[bytes]:
|
||||
"""Split code context by delimiter."""
|
||||
snippets = []
|
||||
|
||||
if delimiter not in self._re_sql_split_stmts:
|
||||
if b"\\" in delimiter:
|
||||
raise InterfaceError(
|
||||
"The backslash (\\) character is not a valid delimiter."
|
||||
)
|
||||
delimiter_pattern = [
|
||||
DELIMITER_RESERVED_SYMBOLS.get(char, char.encode())
|
||||
for char in delimiter.decode()
|
||||
]
|
||||
self._re_sql_split_stmts[delimiter] = re.compile(
|
||||
b"".join(delimiter_pattern) + self._regex_sql_split_stmts
|
||||
)
|
||||
|
||||
for snippet in self._re_sql_split_stmts[delimiter].split(code):
|
||||
snippet_strip = snippet.strip()
|
||||
if snippet_strip:
|
||||
snippets.append(snippet_strip)
|
||||
|
||||
return snippets
|
||||
|
||||
@staticmethod
|
||||
def is_white_space_char(char: int) -> bool:
|
||||
"""Validates whether `char` is a white-space character or not."""
|
||||
return unicodedata.category(chr(char))[0] in {"Z"}
|
||||
|
||||
@staticmethod
|
||||
def is_control_char(char: int) -> bool:
|
||||
"""Validates whether `char` is a control character or not."""
|
||||
return unicodedata.category(chr(char))[0] in {"C"}
|
||||
|
||||
@staticmethod
|
||||
def split_by_control_char_or_white_space(string: bytes) -> list[bytes]:
|
||||
"""Split `string` by any control character or whitespace."""
|
||||
return re.split(rb"[\s\x00-\x1f\x7f-\x9f]", string)
|
||||
|
||||
@staticmethod
|
||||
def has_delimiter(code: bytes) -> bool:
|
||||
"""Validates whether `code` has the delimiter command pattern or not."""
|
||||
return re.search(DELIMITER_PATTERN, code) is not None
|
||||
|
||||
@staticmethod
|
||||
def remove_comments(code: bytes) -> bytes:
|
||||
"""Remove MySQL comments which include `--`-style, `#`-style
|
||||
and `C`-style comments.
|
||||
|
||||
A `--`-style comment spans from `--` to the end of the line.
|
||||
It requires the second dash to be
|
||||
followed by at least one whitespace or control character
|
||||
(such as a space, tab, newline, and so on).
|
||||
|
||||
A `#`-style comment spans from `#` to the end of the line.
|
||||
|
||||
A C-style comment spans from a `/*` sequence to the following `*/`
|
||||
sequence, as in the C programming language. This syntax enables a
|
||||
comment to extend over multiple lines because the beginning and
|
||||
closing sequences need not be on the same line.
|
||||
|
||||
**NOTE: Only C-style comments representing MySQL extensions or
|
||||
optimizer hints are preserved**. E.g.,
|
||||
|
||||
```
|
||||
/*! MySQL-specific code */
|
||||
|
||||
/*+ MySQL-specific code */
|
||||
```
|
||||
|
||||
*For Reference Manual- MySQL Comments*, see
|
||||
https://dev.mysql.com/doc/refman/en/comments.html.
|
||||
"""
|
||||
|
||||
def is_dash_style(b_str: bytes, b_char: int) -> bool:
|
||||
return b_str == b"--" and (
|
||||
MySQLScriptSplitter.is_control_char(b_char)
|
||||
or MySQLScriptSplitter.is_white_space_char(b_char)
|
||||
)
|
||||
|
||||
def is_hash_style(b_str: bytes) -> bool:
|
||||
return b_str == b"#"
|
||||
|
||||
def is_c_style(b_str: bytes, b_char: int) -> bool:
|
||||
return b_str == b"/*" and b_char not in {ord("!"), ord("+")}
|
||||
|
||||
buf = bytearray(b"")
|
||||
i, literal_ctx = 0, None
|
||||
line_break, single_quote, double_quote = ord("\n"), ord("'"), ord('"')
|
||||
while i < len(code):
|
||||
if literal_ctx is None:
|
||||
style = None
|
||||
if is_dash_style(buf[-2:], code[i]):
|
||||
style = "--"
|
||||
elif is_hash_style(buf[-1:]):
|
||||
style = "#"
|
||||
elif is_c_style(buf[-2:], code[i]):
|
||||
style = "/*"
|
||||
if style is not None:
|
||||
if style in ("--", "#"):
|
||||
while i < len(code) and code[i] != line_break:
|
||||
i += 1
|
||||
else:
|
||||
while i + 1 < len(code) and code[i : i + 2] != b"*/":
|
||||
i += 1
|
||||
i += 2
|
||||
|
||||
for _ in range(len(style)):
|
||||
buf.pop()
|
||||
|
||||
while buf and (
|
||||
MySQLScriptSplitter.is_control_char(buf[-1])
|
||||
or MySQLScriptSplitter.is_white_space_char(buf[-1])
|
||||
):
|
||||
buf.pop()
|
||||
|
||||
continue
|
||||
|
||||
if literal_ctx is None and code[i] in [single_quote, double_quote]:
|
||||
literal_ctx = code[i]
|
||||
elif literal_ctx is not None and code[i] in {literal_ctx, line_break}:
|
||||
literal_ctx = None
|
||||
|
||||
buf.append(code[i])
|
||||
i += 1
|
||||
|
||||
return bytes(buf)
|
||||
|
||||
def split_script(self, remove_comments: bool = True) -> list[bytes]:
|
||||
"""Splits the given script text into a sequence of individual statements.
|
||||
|
||||
The word DELIMITER and any of its lower and upper case combinations
|
||||
such as delimiter, DeLiMiter, etc., are considered reserved words by
|
||||
the connector. Users must quote these when included in multi statements
|
||||
for other purposes different from declaring an actual statement delimiter;
|
||||
e.g., as names for tables, columns, variables, in comments, etc.
|
||||
|
||||
```
|
||||
CREATE TABLE `delimiter` (begin INT, end INT); -- I am a `DELimiTer` comment
|
||||
```
|
||||
|
||||
If they are not quoted, the statement-mapping will not produce the expected
|
||||
experience.
|
||||
|
||||
See https://dev.mysql.com/doc/refman/8.0/en/keywords.html to know more
|
||||
about quoting a reserved word.
|
||||
|
||||
*Note that comments are always ignored as they are not considered to be
|
||||
part of statements, with one exeception; **C-style comments representing
|
||||
MySQL extensions or optimizer hints are preserved***.
|
||||
"""
|
||||
# If it was already computed, then skip computation and use the cache
|
||||
if self._single_stmts is not None:
|
||||
return self._single_stmts
|
||||
|
||||
# initialize variables
|
||||
self._single_stmts = []
|
||||
delimiter = DEFAULT_DELIMITER
|
||||
buf: list[bytes] = []
|
||||
prev = b""
|
||||
|
||||
# remove comments
|
||||
if remove_comments:
|
||||
code = MySQLScriptSplitter.remove_comments(code=self._code)
|
||||
else:
|
||||
code = self._code
|
||||
|
||||
# let's split the script by `delimiter pattern` - the pattern is also
|
||||
# included in the returned list.
|
||||
for curr in re.split(pattern=DELIMITER_PATTERN, string=code):
|
||||
# Checking if the previous substring is a "switch of context
|
||||
# (delimiter)" point.
|
||||
if re.search(DELIMITER_PATTERN, prev):
|
||||
# The next delimiter must be the sequence of chars until
|
||||
# reaching a control char or whitespace
|
||||
next_delimiter = self.split_by_control_char_or_white_space(curr)[0]
|
||||
|
||||
# We shall remove the delimiter command from the code
|
||||
buf.pop()
|
||||
|
||||
# At this point buf includes all the code where `delimiter` applies.
|
||||
self._single_stmts.extend(
|
||||
self._split_statement(code=b" ".join(buf), delimiter=delimiter)
|
||||
)
|
||||
|
||||
# From the current substring, let's take everything but the
|
||||
# "next delimiter" portion. Also, let's update the delimiter
|
||||
delimiter, buf = next_delimiter, [curr[len(next_delimiter) :]]
|
||||
else:
|
||||
# Let's accumulate
|
||||
buf.append(curr)
|
||||
|
||||
# track the previous substring
|
||||
prev = curr
|
||||
|
||||
# Ensure there are no loose ends
|
||||
if buf:
|
||||
self._single_stmts.extend(
|
||||
self._split_statement(code=b" ".join(buf), delimiter=delimiter)
|
||||
)
|
||||
|
||||
return self._single_stmts
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self._code.decode("utf-8")
|
||||
|
||||
|
||||
def split_multi_statement(
|
||||
sql_code: bytes,
|
||||
map_results: bool = False,
|
||||
) -> Generator[MySQLScriptPartition, None, None]:
|
||||
"""Breaks a MySQL script into sub-scripts.
|
||||
|
||||
If the given script uses `DELIMITER` statements (which are not recognized
|
||||
by MySQL Server), the connector will parse such statements to remove them
|
||||
from the script and substitute delimiters as needed. This pre-processing
|
||||
may cause a performance hit when using long scripts. Note that when enabling
|
||||
`map_results`, the script is expected to use `DELIMITER` statements in order
|
||||
to split the script into multiple query strings.
|
||||
|
||||
Args:
|
||||
sql_code: MySQL script.
|
||||
map_results: If True, each sub-script is `statement-result` mappable.
|
||||
|
||||
Returns:
|
||||
A generator of typed dictionaries with keys `single_stmts` and `mappable_stmts`.
|
||||
|
||||
If mapping disabled and no delimiters detected, it returns a 1-item generator,
|
||||
the field `single_stmts` is an empty list and the `mappable_stmt` field
|
||||
corresponds to the unmodified script, that may be mappable.
|
||||
|
||||
If mapping disabled and delimiters detected, it returns a 1-item generator,
|
||||
the field `single_stmts` is a list including all the single statements
|
||||
found in the script and the `mappable_stmt` field corresponds to the processed
|
||||
script (delimiters are stripped) that may be mappable.
|
||||
|
||||
If maping enabled, the script is broken into mappable partitions. It returns
|
||||
an N-item generator (as many items as computed partitions), the field
|
||||
`single_stmts` is a list including all the single statements of the partition
|
||||
and the `mappable_stmt` field corresponds to the sub-script (partition) that
|
||||
is guaranteed to be mappable.
|
||||
|
||||
Raises:
|
||||
`InterfaceError` if an invalid delimiter string is found.
|
||||
"""
|
||||
if not MySQLScriptSplitter.has_delimiter(sql_code) and not map_results:
|
||||
# For those users executing single statements or scripts with no delimiters,
|
||||
# they can get a performance boost by bypassing the multi statement splitter.
|
||||
|
||||
# Simply wrap the multi statement up (so it can be processed correctly
|
||||
# downstream) and return it as it is.
|
||||
yield MySQLScriptPartition(single_stmts=deque([]), mappable_stmt=sql_code)
|
||||
return
|
||||
|
||||
tok = MySQLScriptSplitter(sql_script=sql_code)
|
||||
|
||||
# The splitter splits the sql code into many single statements
|
||||
# while also getting rid of the delimiters (if any).
|
||||
stmts = tok.split_script()
|
||||
|
||||
# if there are not statements to execute
|
||||
if not stmts:
|
||||
# Simply wrap the multi statement up (so it can be processed correctly
|
||||
# downstream).
|
||||
yield MySQLScriptPartition(single_stmts=deque([b""]), mappable_stmt=b"")
|
||||
return
|
||||
|
||||
if not map_results:
|
||||
# group single statements into a unique and possibly no mappable
|
||||
# multi statement.
|
||||
yield MySQLScriptPartition(
|
||||
single_stmts=deque(stmts), mappable_stmt=b";\n".join(stmts)
|
||||
)
|
||||
return
|
||||
|
||||
# group single statements into one or more mappable multi statements.
|
||||
i = 0
|
||||
partition_ids = (j for j, stmt in enumerate(stmts) if stmt[:5].upper() == b"CALL ")
|
||||
for j in partition_ids:
|
||||
if j > i:
|
||||
yield (
|
||||
MySQLScriptPartition(
|
||||
mappable_stmt=b";\n".join(stmts[i:j]),
|
||||
single_stmts=deque(stmts[i:j]),
|
||||
)
|
||||
)
|
||||
yield MySQLScriptPartition(
|
||||
mappable_stmt=stmts[j], single_stmts=deque([stmts[j]])
|
||||
)
|
||||
i = j + 1
|
||||
|
||||
if i < len(stmts):
|
||||
yield (
|
||||
MySQLScriptPartition(
|
||||
mappable_stmt=b";\n".join(stmts[i : len(stmts)]),
|
||||
single_stmts=deque(stmts[i : len(stmts)]),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def get_local_infile_filenames(script: bytes) -> Deque[str]:
|
||||
"""Scans the MySQL script looking for `filenames` (one for each
|
||||
`LOCAL INFILE` statement found).
|
||||
|
||||
Arguments:
|
||||
script: a MySQL script that may include one or more `LOCAL INFILE` statements.
|
||||
|
||||
Returns:
|
||||
filenames: a list of filenames (one for each `LOCAL INFILE` statement found).
|
||||
An empty list is returned if no matches are found.
|
||||
"""
|
||||
matches = re.findall(
|
||||
pattern=rb"""LOCAL\s+INFILE\s+(["'])((?:\\\1|(?:(?!\1)).)*)(\1)""",
|
||||
string=MySQLScriptSplitter.remove_comments(script),
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
if not matches or len(matches[0]) != 3:
|
||||
return deque([])
|
||||
|
||||
# If there is a match, we get ("'", "filename", "'") , that's to say,
|
||||
# the 1st and 3rd entries are the quote symbols, and the 2nd the actual filename.
|
||||
return deque([match[1].decode("utf-8") for match in matches])
|
||||
3185
venv/lib/python3.12/site-packages/mysql/connector/abstracts.py
Normal file
3185
venv/lib/python3.12/site-packages/mysql/connector/abstracts.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,40 @@
|
||||
# Copyright (c) 2023, 2025, Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""MySQL Connector/Python - MySQL driver written in Python."""
|
||||
|
||||
from .connection import MySQLConnection, MySQLConnectionAbstract
|
||||
from .pooling import MySQLConnectionPool, PooledMySQLConnection, connect
|
||||
|
||||
__all__ = [
|
||||
"MySQLConnection",
|
||||
"connect",
|
||||
"MySQLConnectionAbstract",
|
||||
"MySQLConnectionPool",
|
||||
"PooledMySQLConnection",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,112 @@
|
||||
# Copyright (c) 2009, 2025, Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""Decorators Hub."""
|
||||
|
||||
import functools
|
||||
import warnings
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
from ..constants import RefreshOption
|
||||
from ..errors import ReadTimeoutError, WriteTimeoutError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .abstracts import MySQLConnectionAbstract
|
||||
|
||||
|
||||
def cmd_refresh_verify_options() -> Callable:
|
||||
"""Decorator verifying which options are relevant and which aren't based on
|
||||
the server version the client is connecting to."""
|
||||
|
||||
def decorator(cmd_refresh: Callable) -> Callable:
|
||||
@functools.wraps(cmd_refresh)
|
||||
async def wrapper(
|
||||
cnx: "MySQLConnectionAbstract", *args: Any, **kwargs: Any
|
||||
) -> Callable:
|
||||
options: int = args[0]
|
||||
if (options & RefreshOption.GRANT) and cnx.server_version >= (
|
||||
9,
|
||||
2,
|
||||
0,
|
||||
):
|
||||
warnings.warn(
|
||||
"As of MySQL Server 9.2.0, refreshing grant tables is not needed "
|
||||
"if you use statements GRANT, REVOKE, CREATE, DROP, or ALTER. "
|
||||
"You should expect this option to be unsupported in a future "
|
||||
"version of MySQL Connector/Python when MySQL Server removes it.",
|
||||
category=DeprecationWarning,
|
||||
stacklevel=1,
|
||||
)
|
||||
|
||||
return await cmd_refresh(cnx, options, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def deprecated(reason: str) -> Callable:
|
||||
"""Use it to decorate deprecated methods."""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: Any, **kwargs: Any) -> Callable:
|
||||
warnings.warn(
|
||||
f"Call to deprecated function {func.__name__}. Reason: {reason}",
|
||||
category=DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def handle_read_write_timeout() -> Callable:
|
||||
"""
|
||||
Decorator to close the current connection if a read or a write timeout
|
||||
is raised by the method passed via the func parameter.
|
||||
"""
|
||||
|
||||
def decorator(cnx_method: Callable) -> Callable:
|
||||
@functools.wraps(cnx_method)
|
||||
async def handle_cnx_method(
|
||||
cnx: "MySQLConnectionAbstract", *args: Any, **kwargs: Any
|
||||
) -> Any:
|
||||
try:
|
||||
return await cnx_method(cnx, *args, **kwargs)
|
||||
except Exception as err:
|
||||
if isinstance(err, (ReadTimeoutError, WriteTimeoutError)):
|
||||
await cnx.close()
|
||||
raise err
|
||||
|
||||
return handle_cnx_method
|
||||
|
||||
return decorator
|
||||
2703
venv/lib/python3.12/site-packages/mysql/connector/aio/abstracts.py
Normal file
2703
venv/lib/python3.12/site-packages/mysql/connector/aio/abstracts.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,335 @@
|
||||
# Copyright (c) 2023, 2024, Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""Implementing support for MySQL Authentication Plugins."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
__all__ = ["MySQLAuthenticator"]
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
from ..errors import InterfaceError, NotSupportedError, get_exception
|
||||
from ..protocol import (
|
||||
AUTH_SWITCH_STATUS,
|
||||
DEFAULT_CHARSET_ID,
|
||||
DEFAULT_MAX_ALLOWED_PACKET,
|
||||
ERR_STATUS,
|
||||
EXCHANGE_FURTHER_STATUS,
|
||||
MFA_STATUS,
|
||||
OK_STATUS,
|
||||
)
|
||||
from ..types import HandShakeType
|
||||
from .logger import logger
|
||||
from .plugins import MySQLAuthPlugin, get_auth_plugin
|
||||
from .protocol import MySQLProtocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .network import MySQLSocket
|
||||
|
||||
|
||||
class MySQLAuthenticator:
|
||||
"""Implements the authentication phase."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Constructor."""
|
||||
self._username: str = ""
|
||||
self._passwords: Dict[int, str] = {}
|
||||
self._plugin_config: Dict[str, Any] = {}
|
||||
self._ssl_enabled: bool = False
|
||||
self._auth_strategy: Optional[MySQLAuthPlugin] = None
|
||||
self._auth_plugin_class: Optional[str] = None
|
||||
|
||||
@property
|
||||
def ssl_enabled(self) -> bool:
|
||||
"""Signals whether or not SSL is enabled."""
|
||||
return self._ssl_enabled
|
||||
|
||||
@property
|
||||
def plugin_config(self) -> Dict[str, Any]:
|
||||
"""Custom arguments that are being provided to the authentication plugin.
|
||||
|
||||
The parameters defined here will override the ones defined in the
|
||||
auth plugin itself.
|
||||
|
||||
The plugin config is a read-only property - the plugin configuration
|
||||
provided when invoking `authenticate()` is recorded and can be queried
|
||||
by accessing this property.
|
||||
|
||||
Returns:
|
||||
dict: The latest plugin configuration provided when invoking
|
||||
`authenticate()`.
|
||||
"""
|
||||
return self._plugin_config
|
||||
|
||||
def update_plugin_config(self, config: Dict[str, Any]) -> None:
|
||||
"""Update the 'plugin_config' instance variable"""
|
||||
self._plugin_config.update(config)
|
||||
|
||||
def _switch_auth_strategy(
|
||||
self,
|
||||
new_strategy_name: str,
|
||||
strategy_class: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
password_factor: int = 1,
|
||||
) -> None:
|
||||
"""Switch the authorization plugin.
|
||||
|
||||
Args:
|
||||
new_strategy_name: New authorization plugin name to switch to.
|
||||
strategy_class: New authorization plugin class to switch to
|
||||
(has higher precedence than the authorization plugin name).
|
||||
username: Username to be used - if not defined, the username
|
||||
provided when `authentication()` was invoked is used.
|
||||
password_factor: Up to three levels of authentication (MFA) are allowed,
|
||||
hence you can choose the password corresponding to the 1st,
|
||||
2nd, or 3rd factor - 1st is the default.
|
||||
"""
|
||||
if username is None:
|
||||
username = self._username
|
||||
|
||||
if strategy_class is None:
|
||||
strategy_class = self._auth_plugin_class
|
||||
|
||||
logger.debug("Switching to strategy %s", new_strategy_name)
|
||||
self._auth_strategy = get_auth_plugin(
|
||||
plugin_name=new_strategy_name, auth_plugin_class=strategy_class
|
||||
)(
|
||||
username,
|
||||
self._passwords.get(password_factor, ""),
|
||||
ssl_enabled=self.ssl_enabled,
|
||||
)
|
||||
|
||||
async def _mfa_n_factor(
|
||||
self,
|
||||
sock: MySQLSocket,
|
||||
pkt: bytes,
|
||||
) -> Optional[bytes]:
|
||||
"""Handle MFA (Multi-Factor Authentication) response.
|
||||
|
||||
Up to three levels of authentication (MFA) are allowed.
|
||||
|
||||
Args:
|
||||
sock: Pointer to the socket connection.
|
||||
pkt: MFA response.
|
||||
|
||||
Returns:
|
||||
ok_packet: If last server's response is an OK packet.
|
||||
None: If last server's response isn't an OK packet and no ERROR was raised.
|
||||
|
||||
Raises:
|
||||
InterfaceError: If got an invalid N factor.
|
||||
errors.ErrorTypes: If got an ERROR response.
|
||||
"""
|
||||
n_factor = 2
|
||||
while pkt[4] == MFA_STATUS:
|
||||
if n_factor not in self._passwords:
|
||||
raise InterfaceError(
|
||||
"Failed Multi Factor Authentication (invalid N factor)"
|
||||
)
|
||||
|
||||
new_strategy_name, auth_data = MySQLProtocol.parse_auth_next_factor(pkt)
|
||||
self._switch_auth_strategy(new_strategy_name, password_factor=n_factor)
|
||||
logger.debug("MFA %i factor %s", n_factor, self._auth_strategy.name)
|
||||
|
||||
pkt = await self._auth_strategy.auth_switch_response(
|
||||
sock, auth_data, **self._plugin_config
|
||||
)
|
||||
|
||||
if pkt[4] == EXCHANGE_FURTHER_STATUS:
|
||||
auth_data = MySQLProtocol.parse_auth_more_data(pkt)
|
||||
pkt = await self._auth_strategy.auth_more_response(
|
||||
sock, auth_data, **self._plugin_config
|
||||
)
|
||||
|
||||
if pkt[4] == OK_STATUS:
|
||||
logger.debug("MFA completed succesfully")
|
||||
return pkt
|
||||
|
||||
if pkt[4] == ERR_STATUS:
|
||||
raise get_exception(pkt)
|
||||
|
||||
n_factor += 1
|
||||
|
||||
logger.warning("MFA terminated with a no ok packet")
|
||||
return None
|
||||
|
||||
async def _handle_server_response(
|
||||
self,
|
||||
sock: MySQLSocket,
|
||||
pkt: bytes,
|
||||
) -> Optional[bytes]:
|
||||
"""Handle server's response.
|
||||
|
||||
Args:
|
||||
sock: Pointer to the socket connection.
|
||||
pkt: Server's response after completing the `HandShakeResponse`.
|
||||
|
||||
Returns:
|
||||
ok_packet: If last server's response is an OK packet.
|
||||
None: If last server's response isn't an OK packet and no ERROR was raised.
|
||||
|
||||
Raises:
|
||||
errors.ErrorTypes: If got an ERROR response.
|
||||
NotSupportedError: If got Authentication with old (insecure) passwords.
|
||||
"""
|
||||
if pkt[4] == AUTH_SWITCH_STATUS and len(pkt) == 5:
|
||||
raise NotSupportedError(
|
||||
"Authentication with old (insecure) passwords "
|
||||
"is not supported. For more information, lookup "
|
||||
"Password Hashing in the latest MySQL manual"
|
||||
)
|
||||
|
||||
if pkt[4] == AUTH_SWITCH_STATUS:
|
||||
logger.debug("Server's response is an auth switch request")
|
||||
new_strategy_name, auth_data = MySQLProtocol.parse_auth_switch_request(pkt)
|
||||
self._switch_auth_strategy(new_strategy_name)
|
||||
pkt = await self._auth_strategy.auth_switch_response(
|
||||
sock, auth_data, **self._plugin_config
|
||||
)
|
||||
|
||||
if pkt[4] == EXCHANGE_FURTHER_STATUS:
|
||||
logger.debug("Exchanging further packets")
|
||||
auth_data = MySQLProtocol.parse_auth_more_data(pkt)
|
||||
pkt = await self._auth_strategy.auth_more_response(
|
||||
sock, auth_data, **self._plugin_config
|
||||
)
|
||||
|
||||
if pkt[4] == OK_STATUS:
|
||||
logger.debug("%s completed succesfully", self._auth_strategy.name)
|
||||
return pkt
|
||||
|
||||
if pkt[4] == MFA_STATUS:
|
||||
logger.debug("Starting multi-factor authentication")
|
||||
logger.debug("MFA 1 factor %s", self._auth_strategy.name)
|
||||
return await self._mfa_n_factor(sock, pkt)
|
||||
|
||||
if pkt[4] == ERR_STATUS:
|
||||
raise get_exception(pkt)
|
||||
|
||||
return None
|
||||
|
||||
async def authenticate(
|
||||
self,
|
||||
sock: MySQLSocket,
|
||||
handshake: HandShakeType,
|
||||
username: str = "",
|
||||
password1: str = "",
|
||||
password2: str = "",
|
||||
password3: str = "",
|
||||
database: Optional[str] = None,
|
||||
charset: int = DEFAULT_CHARSET_ID,
|
||||
client_flags: int = 0,
|
||||
ssl_enabled: bool = False,
|
||||
max_allowed_packet: int = DEFAULT_MAX_ALLOWED_PACKET,
|
||||
auth_plugin: Optional[str] = None,
|
||||
auth_plugin_class: Optional[str] = None,
|
||||
conn_attrs: Optional[Dict[str, str]] = None,
|
||||
is_change_user_request: bool = False,
|
||||
read_timeout: Optional[int] = None,
|
||||
write_timeout: Optional[int] = None,
|
||||
) -> bytes:
|
||||
"""Perform the authentication phase.
|
||||
|
||||
During re-authentication you must set `is_change_user_request` to True.
|
||||
|
||||
Args:
|
||||
sock: Pointer to the socket connection.
|
||||
handshake: Initial handshake.
|
||||
username: Account's username.
|
||||
password1: Account's password factor 1.
|
||||
password2: Account's password factor 2.
|
||||
password3: Account's password factor 3.
|
||||
database: Initial database name for the connection.
|
||||
charset: Client charset (see [1]), only the lower 8-bits.
|
||||
client_flags: Integer representing client capabilities flags.
|
||||
ssl_enabled: Boolean indicating whether SSL is enabled,
|
||||
max_allowed_packet: Maximum packet size.
|
||||
auth_plugin: Authorization plugin name.
|
||||
auth_plugin_class: Authorization plugin class (has higher precedence
|
||||
than the authorization plugin name).
|
||||
conn_attrs: Connection attributes.
|
||||
is_change_user_request: Whether is a `change user request` operation or not.
|
||||
read_timeout: Timeout in seconds upto which the connector should wait for
|
||||
the server to reply back before raising an ReadTimeoutError.
|
||||
write_timeout: Timeout in seconds upto which the connector should spend to
|
||||
send data to the server before raising an WriteTimeoutError.
|
||||
|
||||
Returns:
|
||||
ok_packet: OK packet.
|
||||
|
||||
Raises:
|
||||
InterfaceError: If OK packet is NULL.
|
||||
ReadTimeoutError: If the time taken for the server to reply back exceeds
|
||||
'read_timeout' (if set).
|
||||
WriteTimeoutError: If the time taken to send data packets to the server
|
||||
exceeds 'write_timeout' (if set).
|
||||
|
||||
References:
|
||||
[1]: https://dev.mysql.com/doc/dev/mysql-server/latest/\
|
||||
page_protocol_basic_character_set.html#a_protocol_character_set
|
||||
"""
|
||||
# update credentials, plugin config and plugin class
|
||||
self._username = username
|
||||
self._passwords = {1: password1, 2: password2, 3: password3}
|
||||
self._ssl_enabled = ssl_enabled
|
||||
self._auth_plugin_class = auth_plugin_class
|
||||
|
||||
# client's handshake response
|
||||
response_payload, self._auth_strategy = MySQLProtocol.make_auth(
|
||||
handshake=handshake,
|
||||
username=username,
|
||||
password=password1,
|
||||
database=database,
|
||||
charset=charset,
|
||||
client_flags=client_flags,
|
||||
max_allowed_packet=max_allowed_packet,
|
||||
auth_plugin=auth_plugin,
|
||||
auth_plugin_class=auth_plugin_class,
|
||||
conn_attrs=conn_attrs,
|
||||
is_change_user_request=is_change_user_request,
|
||||
ssl_enabled=self.ssl_enabled,
|
||||
plugin_config=self.plugin_config,
|
||||
)
|
||||
|
||||
# client sends transaction response
|
||||
send_args = (
|
||||
(0, 0, write_timeout)
|
||||
if is_change_user_request
|
||||
else (None, None, write_timeout)
|
||||
)
|
||||
await sock.write(response_payload, *send_args)
|
||||
|
||||
# server replies back
|
||||
pkt = bytes(await sock.read(read_timeout))
|
||||
|
||||
ok_pkt = await self._handle_server_response(sock, pkt)
|
||||
if ok_pkt is None:
|
||||
raise InterfaceError("Got a NULL ok_pkt") from None
|
||||
|
||||
return ok_pkt
|
||||
@@ -0,0 +1,686 @@
|
||||
# Copyright (c) 2023, 2024, Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""This module contains the MySQL Server Character Sets."""
|
||||
|
||||
__all__ = ["Charset", "charsets"]
|
||||
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import DefaultDict, Dict, Optional, Sequence, Tuple
|
||||
|
||||
from ..errors import ProgrammingError
|
||||
|
||||
|
||||
@dataclass
|
||||
class Charset:
|
||||
"""Dataclass representing a character set."""
|
||||
|
||||
charset_id: int
|
||||
name: str
|
||||
collation: str
|
||||
is_default: bool
|
||||
|
||||
|
||||
class Charsets:
|
||||
"""MySQL supported character sets and collations class.
|
||||
|
||||
This class holds the list of character sets with their collations supported by
|
||||
MySQL, making available methods to get character sets by name, collation, or ID.
|
||||
It uses a sparse matrix or tree-like representation using a dict in a dict to hold
|
||||
the character set name and collations combinations.
|
||||
The list is hardcoded, so we avoid a database query when getting the name of the
|
||||
used character set or collation.
|
||||
|
||||
The call of ``charsets.set_mysql_major_version()`` should be done before using any
|
||||
of the retrieval methods.
|
||||
|
||||
Usage:
|
||||
>>> from mysql.connector.aio.charsets import charsets
|
||||
>>> charsets.set_mysql_major_version(8)
|
||||
>>> charsets.get_by_name("utf-8")
|
||||
Charset(charset_id=255,
|
||||
name='utf8mb4',
|
||||
collation='utf8mb4_0900_ai_ci',
|
||||
is_default=True)
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._charset_id_store: Dict[int, Charset] = {}
|
||||
self._collation_store: Dict[str, Charset] = {}
|
||||
self._name_store: DefaultDict[str, Dict[str, Charset]] = defaultdict(dict)
|
||||
self._mysql_major_version: Optional[int] = None
|
||||
|
||||
def set_mysql_major_version(self, version: int) -> None:
|
||||
"""Set the MySQL major version.
|
||||
|
||||
Sets what tuple should be used based on the MySQL major version to store the
|
||||
list of character sets and collations.
|
||||
|
||||
Args:
|
||||
version: The MySQL major version (i.e. 8 or 5)
|
||||
"""
|
||||
self._mysql_major_version = version
|
||||
self._charset_id_store.clear()
|
||||
self._collation_store.clear()
|
||||
self._name_store.clear()
|
||||
|
||||
charsets_tuple: Sequence[Tuple[int, str, str, bool]] = None
|
||||
if version >= 8:
|
||||
charsets_tuple = MYSQL_8_CHARSETS
|
||||
elif version == 5:
|
||||
charsets_tuple = MYSQL_5_CHARSETS
|
||||
else:
|
||||
raise ProgrammingError("Invalid MySQL major version")
|
||||
|
||||
for charset_id, name, collation, is_default in charsets_tuple:
|
||||
charset = Charset(charset_id, name, collation, is_default)
|
||||
self._charset_id_store[charset_id] = charset
|
||||
self._collation_store[collation] = charset
|
||||
self._name_store[name][collation] = charset
|
||||
|
||||
def get_by_id(self, charset_id: int) -> Charset:
|
||||
"""Get character set by ID.
|
||||
|
||||
Args:
|
||||
charset_id: The charset ID.
|
||||
|
||||
Returns:
|
||||
Charset: The Charset dataclass instance.
|
||||
"""
|
||||
try:
|
||||
return self._charset_id_store[charset_id]
|
||||
except KeyError as err:
|
||||
raise ProgrammingError(f"Character set ID {charset_id} unknown") from err
|
||||
|
||||
def get_by_collation(self, collation: str) -> Charset:
|
||||
"""Get character set by collation.
|
||||
|
||||
Args:
|
||||
collation: The collation name.
|
||||
|
||||
Returns:
|
||||
Charset: The Charset dataclass instance.
|
||||
"""
|
||||
try:
|
||||
return self._collation_store[collation]
|
||||
except KeyError as err:
|
||||
raise ProgrammingError(f"Collation {collation} unknown") from err
|
||||
|
||||
def get_by_name(self, name: str) -> Charset:
|
||||
"""Get character set by name.
|
||||
|
||||
Args:
|
||||
name: The charset name.
|
||||
|
||||
Returns:
|
||||
Charset: The Charset dataclass instance.
|
||||
"""
|
||||
try:
|
||||
if name in ("utf8", "utf-8") and self._mysql_major_version == 8:
|
||||
name = "utf8mb4"
|
||||
for charset in self._name_store[name].values():
|
||||
if charset.is_default:
|
||||
return charset
|
||||
except KeyError as err:
|
||||
raise ProgrammingError(f"Character set name {name} unknown") from err
|
||||
raise ProgrammingError(f"No default was found for character set '{name}'")
|
||||
|
||||
def get_by_name_and_collation(self, name: str, collation: str) -> Charset:
|
||||
"""Get character set by name and collation.
|
||||
|
||||
Args:
|
||||
name: The charset name.
|
||||
collation: The collation name.
|
||||
|
||||
Returns:
|
||||
Charset: The Charset dataclass instance.
|
||||
"""
|
||||
try:
|
||||
return self._name_store[name][collation]
|
||||
except KeyError as err:
|
||||
raise ProgrammingError(
|
||||
f"Character set name '{name}' with collation '{collation}' not found"
|
||||
) from err
|
||||
|
||||
|
||||
MYSQL_8_CHARSETS = (
|
||||
(1, "big5", "big5_chinese_ci", True),
|
||||
(2, "latin2", "latin2_czech_cs", False),
|
||||
(3, "dec8", "dec8_swedish_ci", True),
|
||||
(4, "cp850", "cp850_general_ci", True),
|
||||
(5, "latin1", "latin1_german1_ci", False),
|
||||
(6, "hp8", "hp8_english_ci", True),
|
||||
(7, "koi8r", "koi8r_general_ci", True),
|
||||
(8, "latin1", "latin1_swedish_ci", True),
|
||||
(9, "latin2", "latin2_general_ci", True),
|
||||
(10, "swe7", "swe7_swedish_ci", True),
|
||||
(11, "ascii", "ascii_general_ci", True),
|
||||
(12, "ujis", "ujis_japanese_ci", True),
|
||||
(13, "sjis", "sjis_japanese_ci", True),
|
||||
(14, "cp1251", "cp1251_bulgarian_ci", False),
|
||||
(15, "latin1", "latin1_danish_ci", False),
|
||||
(16, "hebrew", "hebrew_general_ci", True),
|
||||
(18, "tis620", "tis620_thai_ci", True),
|
||||
(19, "euckr", "euckr_korean_ci", True),
|
||||
(20, "latin7", "latin7_estonian_cs", False),
|
||||
(21, "latin2", "latin2_hungarian_ci", False),
|
||||
(22, "koi8u", "koi8u_general_ci", True),
|
||||
(23, "cp1251", "cp1251_ukrainian_ci", False),
|
||||
(24, "gb2312", "gb2312_chinese_ci", True),
|
||||
(25, "greek", "greek_general_ci", True),
|
||||
(26, "cp1250", "cp1250_general_ci", True),
|
||||
(27, "latin2", "latin2_croatian_ci", False),
|
||||
(28, "gbk", "gbk_chinese_ci", True),
|
||||
(29, "cp1257", "cp1257_lithuanian_ci", False),
|
||||
(30, "latin5", "latin5_turkish_ci", True),
|
||||
(31, "latin1", "latin1_german2_ci", False),
|
||||
(32, "armscii8", "armscii8_general_ci", True),
|
||||
(33, "utf8mb3", "utf8mb3_general_ci", True),
|
||||
(34, "cp1250", "cp1250_czech_cs", False),
|
||||
(35, "ucs2", "ucs2_general_ci", True),
|
||||
(36, "cp866", "cp866_general_ci", True),
|
||||
(37, "keybcs2", "keybcs2_general_ci", True),
|
||||
(38, "macce", "macce_general_ci", True),
|
||||
(39, "macroman", "macroman_general_ci", True),
|
||||
(40, "cp852", "cp852_general_ci", True),
|
||||
(41, "latin7", "latin7_general_ci", True),
|
||||
(42, "latin7", "latin7_general_cs", False),
|
||||
(43, "macce", "macce_bin", False),
|
||||
(44, "cp1250", "cp1250_croatian_ci", False),
|
||||
(45, "utf8mb4", "utf8mb4_general_ci", False),
|
||||
(46, "utf8mb4", "utf8mb4_bin", False),
|
||||
(47, "latin1", "latin1_bin", False),
|
||||
(48, "latin1", "latin1_general_ci", False),
|
||||
(49, "latin1", "latin1_general_cs", False),
|
||||
(50, "cp1251", "cp1251_bin", False),
|
||||
(51, "cp1251", "cp1251_general_ci", True),
|
||||
(52, "cp1251", "cp1251_general_cs", False),
|
||||
(53, "macroman", "macroman_bin", False),
|
||||
(54, "utf16", "utf16_general_ci", True),
|
||||
(55, "utf16", "utf16_bin", False),
|
||||
(56, "utf16le", "utf16le_general_ci", True),
|
||||
(57, "cp1256", "cp1256_general_ci", True),
|
||||
(58, "cp1257", "cp1257_bin", False),
|
||||
(59, "cp1257", "cp1257_general_ci", True),
|
||||
(60, "utf32", "utf32_general_ci", True),
|
||||
(61, "utf32", "utf32_bin", False),
|
||||
(62, "utf16le", "utf16le_bin", False),
|
||||
(63, "binary", "binary", True),
|
||||
(64, "armscii8", "armscii8_bin", False),
|
||||
(65, "ascii", "ascii_bin", False),
|
||||
(66, "cp1250", "cp1250_bin", False),
|
||||
(67, "cp1256", "cp1256_bin", False),
|
||||
(68, "cp866", "cp866_bin", False),
|
||||
(69, "dec8", "dec8_bin", False),
|
||||
(70, "greek", "greek_bin", False),
|
||||
(71, "hebrew", "hebrew_bin", False),
|
||||
(72, "hp8", "hp8_bin", False),
|
||||
(73, "keybcs2", "keybcs2_bin", False),
|
||||
(74, "koi8r", "koi8r_bin", False),
|
||||
(75, "koi8u", "koi8u_bin", False),
|
||||
(76, "utf8mb3", "utf8mb3_tolower_ci", False),
|
||||
(77, "latin2", "latin2_bin", False),
|
||||
(78, "latin5", "latin5_bin", False),
|
||||
(79, "latin7", "latin7_bin", False),
|
||||
(80, "cp850", "cp850_bin", False),
|
||||
(81, "cp852", "cp852_bin", False),
|
||||
(82, "swe7", "swe7_bin", False),
|
||||
(83, "utf8mb3", "utf8mb3_bin", False),
|
||||
(84, "big5", "big5_bin", False),
|
||||
(85, "euckr", "euckr_bin", False),
|
||||
(86, "gb2312", "gb2312_bin", False),
|
||||
(87, "gbk", "gbk_bin", False),
|
||||
(88, "sjis", "sjis_bin", False),
|
||||
(89, "tis620", "tis620_bin", False),
|
||||
(90, "ucs2", "ucs2_bin", False),
|
||||
(91, "ujis", "ujis_bin", False),
|
||||
(92, "geostd8", "geostd8_general_ci", True),
|
||||
(93, "geostd8", "geostd8_bin", False),
|
||||
(94, "latin1", "latin1_spanish_ci", False),
|
||||
(95, "cp932", "cp932_japanese_ci", True),
|
||||
(96, "cp932", "cp932_bin", False),
|
||||
(97, "eucjpms", "eucjpms_japanese_ci", True),
|
||||
(98, "eucjpms", "eucjpms_bin", False),
|
||||
(99, "cp1250", "cp1250_polish_ci", False),
|
||||
(101, "utf16", "utf16_unicode_ci", False),
|
||||
(102, "utf16", "utf16_icelandic_ci", False),
|
||||
(103, "utf16", "utf16_latvian_ci", False),
|
||||
(104, "utf16", "utf16_romanian_ci", False),
|
||||
(105, "utf16", "utf16_slovenian_ci", False),
|
||||
(106, "utf16", "utf16_polish_ci", False),
|
||||
(107, "utf16", "utf16_estonian_ci", False),
|
||||
(108, "utf16", "utf16_spanish_ci", False),
|
||||
(109, "utf16", "utf16_swedish_ci", False),
|
||||
(110, "utf16", "utf16_turkish_ci", False),
|
||||
(111, "utf16", "utf16_czech_ci", False),
|
||||
(112, "utf16", "utf16_danish_ci", False),
|
||||
(113, "utf16", "utf16_lithuanian_ci", False),
|
||||
(114, "utf16", "utf16_slovak_ci", False),
|
||||
(115, "utf16", "utf16_spanish2_ci", False),
|
||||
(116, "utf16", "utf16_roman_ci", False),
|
||||
(117, "utf16", "utf16_persian_ci", False),
|
||||
(118, "utf16", "utf16_esperanto_ci", False),
|
||||
(119, "utf16", "utf16_hungarian_ci", False),
|
||||
(120, "utf16", "utf16_sinhala_ci", False),
|
||||
(121, "utf16", "utf16_german2_ci", False),
|
||||
(122, "utf16", "utf16_croatian_ci", False),
|
||||
(123, "utf16", "utf16_unicode_520_ci", False),
|
||||
(124, "utf16", "utf16_vietnamese_ci", False),
|
||||
(128, "ucs2", "ucs2_unicode_ci", False),
|
||||
(129, "ucs2", "ucs2_icelandic_ci", False),
|
||||
(130, "ucs2", "ucs2_latvian_ci", False),
|
||||
(131, "ucs2", "ucs2_romanian_ci", False),
|
||||
(132, "ucs2", "ucs2_slovenian_ci", False),
|
||||
(133, "ucs2", "ucs2_polish_ci", False),
|
||||
(134, "ucs2", "ucs2_estonian_ci", False),
|
||||
(135, "ucs2", "ucs2_spanish_ci", False),
|
||||
(136, "ucs2", "ucs2_swedish_ci", False),
|
||||
(137, "ucs2", "ucs2_turkish_ci", False),
|
||||
(138, "ucs2", "ucs2_czech_ci", False),
|
||||
(139, "ucs2", "ucs2_danish_ci", False),
|
||||
(140, "ucs2", "ucs2_lithuanian_ci", False),
|
||||
(141, "ucs2", "ucs2_slovak_ci", False),
|
||||
(142, "ucs2", "ucs2_spanish2_ci", False),
|
||||
(143, "ucs2", "ucs2_roman_ci", False),
|
||||
(144, "ucs2", "ucs2_persian_ci", False),
|
||||
(145, "ucs2", "ucs2_esperanto_ci", False),
|
||||
(146, "ucs2", "ucs2_hungarian_ci", False),
|
||||
(147, "ucs2", "ucs2_sinhala_ci", False),
|
||||
(148, "ucs2", "ucs2_german2_ci", False),
|
||||
(149, "ucs2", "ucs2_croatian_ci", False),
|
||||
(150, "ucs2", "ucs2_unicode_520_ci", False),
|
||||
(151, "ucs2", "ucs2_vietnamese_ci", False),
|
||||
(159, "ucs2", "ucs2_general_mysql500_ci", False),
|
||||
(160, "utf32", "utf32_unicode_ci", False),
|
||||
(161, "utf32", "utf32_icelandic_ci", False),
|
||||
(162, "utf32", "utf32_latvian_ci", False),
|
||||
(163, "utf32", "utf32_romanian_ci", False),
|
||||
(164, "utf32", "utf32_slovenian_ci", False),
|
||||
(165, "utf32", "utf32_polish_ci", False),
|
||||
(166, "utf32", "utf32_estonian_ci", False),
|
||||
(167, "utf32", "utf32_spanish_ci", False),
|
||||
(168, "utf32", "utf32_swedish_ci", False),
|
||||
(169, "utf32", "utf32_turkish_ci", False),
|
||||
(170, "utf32", "utf32_czech_ci", False),
|
||||
(171, "utf32", "utf32_danish_ci", False),
|
||||
(172, "utf32", "utf32_lithuanian_ci", False),
|
||||
(173, "utf32", "utf32_slovak_ci", False),
|
||||
(174, "utf32", "utf32_spanish2_ci", False),
|
||||
(175, "utf32", "utf32_roman_ci", False),
|
||||
(176, "utf32", "utf32_persian_ci", False),
|
||||
(177, "utf32", "utf32_esperanto_ci", False),
|
||||
(178, "utf32", "utf32_hungarian_ci", False),
|
||||
(179, "utf32", "utf32_sinhala_ci", False),
|
||||
(180, "utf32", "utf32_german2_ci", False),
|
||||
(181, "utf32", "utf32_croatian_ci", False),
|
||||
(182, "utf32", "utf32_unicode_520_ci", False),
|
||||
(183, "utf32", "utf32_vietnamese_ci", False),
|
||||
(192, "utf8mb3", "utf8mb3_unicode_ci", False),
|
||||
(193, "utf8mb3", "utf8mb3_icelandic_ci", False),
|
||||
(194, "utf8mb3", "utf8mb3_latvian_ci", False),
|
||||
(195, "utf8mb3", "utf8mb3_romanian_ci", False),
|
||||
(196, "utf8mb3", "utf8mb3_slovenian_ci", False),
|
||||
(197, "utf8mb3", "utf8mb3_polish_ci", False),
|
||||
(198, "utf8mb3", "utf8mb3_estonian_ci", False),
|
||||
(199, "utf8mb3", "utf8mb3_spanish_ci", False),
|
||||
(200, "utf8mb3", "utf8mb3_swedish_ci", False),
|
||||
(201, "utf8mb3", "utf8mb3_turkish_ci", False),
|
||||
(202, "utf8mb3", "utf8mb3_czech_ci", False),
|
||||
(203, "utf8mb3", "utf8mb3_danish_ci", False),
|
||||
(204, "utf8mb3", "utf8mb3_lithuanian_ci", False),
|
||||
(205, "utf8mb3", "utf8mb3_slovak_ci", False),
|
||||
(206, "utf8mb3", "utf8mb3_spanish2_ci", False),
|
||||
(207, "utf8mb3", "utf8mb3_roman_ci", False),
|
||||
(208, "utf8mb3", "utf8mb3_persian_ci", False),
|
||||
(209, "utf8mb3", "utf8mb3_esperanto_ci", False),
|
||||
(210, "utf8mb3", "utf8mb3_hungarian_ci", False),
|
||||
(211, "utf8mb3", "utf8mb3_sinhala_ci", False),
|
||||
(212, "utf8mb3", "utf8mb3_german2_ci", False),
|
||||
(213, "utf8mb3", "utf8mb3_croatian_ci", False),
|
||||
(214, "utf8mb3", "utf8mb3_unicode_520_ci", False),
|
||||
(215, "utf8mb3", "utf8mb3_vietnamese_ci", False),
|
||||
(223, "utf8mb3", "utf8mb3_general_mysql500_ci", False),
|
||||
(224, "utf8mb4", "utf8mb4_unicode_ci", False),
|
||||
(225, "utf8mb4", "utf8mb4_icelandic_ci", False),
|
||||
(226, "utf8mb4", "utf8mb4_latvian_ci", False),
|
||||
(227, "utf8mb4", "utf8mb4_romanian_ci", False),
|
||||
(228, "utf8mb4", "utf8mb4_slovenian_ci", False),
|
||||
(229, "utf8mb4", "utf8mb4_polish_ci", False),
|
||||
(230, "utf8mb4", "utf8mb4_estonian_ci", False),
|
||||
(231, "utf8mb4", "utf8mb4_spanish_ci", False),
|
||||
(232, "utf8mb4", "utf8mb4_swedish_ci", False),
|
||||
(233, "utf8mb4", "utf8mb4_turkish_ci", False),
|
||||
(234, "utf8mb4", "utf8mb4_czech_ci", False),
|
||||
(235, "utf8mb4", "utf8mb4_danish_ci", False),
|
||||
(236, "utf8mb4", "utf8mb4_lithuanian_ci", False),
|
||||
(237, "utf8mb4", "utf8mb4_slovak_ci", False),
|
||||
(238, "utf8mb4", "utf8mb4_spanish2_ci", False),
|
||||
(239, "utf8mb4", "utf8mb4_roman_ci", False),
|
||||
(240, "utf8mb4", "utf8mb4_persian_ci", False),
|
||||
(241, "utf8mb4", "utf8mb4_esperanto_ci", False),
|
||||
(242, "utf8mb4", "utf8mb4_hungarian_ci", False),
|
||||
(243, "utf8mb4", "utf8mb4_sinhala_ci", False),
|
||||
(244, "utf8mb4", "utf8mb4_german2_ci", False),
|
||||
(245, "utf8mb4", "utf8mb4_croatian_ci", False),
|
||||
(246, "utf8mb4", "utf8mb4_unicode_520_ci", False),
|
||||
(247, "utf8mb4", "utf8mb4_vietnamese_ci", False),
|
||||
(248, "gb18030", "gb18030_chinese_ci", True),
|
||||
(249, "gb18030", "gb18030_bin", False),
|
||||
(250, "gb18030", "gb18030_unicode_520_ci", False),
|
||||
(255, "utf8mb4", "utf8mb4_0900_ai_ci", True),
|
||||
(256, "utf8mb4", "utf8mb4_de_pb_0900_ai_ci", False),
|
||||
(257, "utf8mb4", "utf8mb4_is_0900_ai_ci", False),
|
||||
(258, "utf8mb4", "utf8mb4_lv_0900_ai_ci", False),
|
||||
(259, "utf8mb4", "utf8mb4_ro_0900_ai_ci", False),
|
||||
(260, "utf8mb4", "utf8mb4_sl_0900_ai_ci", False),
|
||||
(261, "utf8mb4", "utf8mb4_pl_0900_ai_ci", False),
|
||||
(262, "utf8mb4", "utf8mb4_et_0900_ai_ci", False),
|
||||
(263, "utf8mb4", "utf8mb4_es_0900_ai_ci", False),
|
||||
(264, "utf8mb4", "utf8mb4_sv_0900_ai_ci", False),
|
||||
(265, "utf8mb4", "utf8mb4_tr_0900_ai_ci", False),
|
||||
(266, "utf8mb4", "utf8mb4_cs_0900_ai_ci", False),
|
||||
(267, "utf8mb4", "utf8mb4_da_0900_ai_ci", False),
|
||||
(268, "utf8mb4", "utf8mb4_lt_0900_ai_ci", False),
|
||||
(269, "utf8mb4", "utf8mb4_sk_0900_ai_ci", False),
|
||||
(270, "utf8mb4", "utf8mb4_es_trad_0900_ai_ci", False),
|
||||
(271, "utf8mb4", "utf8mb4_la_0900_ai_ci", False),
|
||||
(273, "utf8mb4", "utf8mb4_eo_0900_ai_ci", False),
|
||||
(274, "utf8mb4", "utf8mb4_hu_0900_ai_ci", False),
|
||||
(275, "utf8mb4", "utf8mb4_hr_0900_ai_ci", False),
|
||||
(277, "utf8mb4", "utf8mb4_vi_0900_ai_ci", False),
|
||||
(278, "utf8mb4", "utf8mb4_0900_as_cs", False),
|
||||
(279, "utf8mb4", "utf8mb4_de_pb_0900_as_cs", False),
|
||||
(280, "utf8mb4", "utf8mb4_is_0900_as_cs", False),
|
||||
(281, "utf8mb4", "utf8mb4_lv_0900_as_cs", False),
|
||||
(282, "utf8mb4", "utf8mb4_ro_0900_as_cs", False),
|
||||
(283, "utf8mb4", "utf8mb4_sl_0900_as_cs", False),
|
||||
(284, "utf8mb4", "utf8mb4_pl_0900_as_cs", False),
|
||||
(285, "utf8mb4", "utf8mb4_et_0900_as_cs", False),
|
||||
(286, "utf8mb4", "utf8mb4_es_0900_as_cs", False),
|
||||
(287, "utf8mb4", "utf8mb4_sv_0900_as_cs", False),
|
||||
(288, "utf8mb4", "utf8mb4_tr_0900_as_cs", False),
|
||||
(289, "utf8mb4", "utf8mb4_cs_0900_as_cs", False),
|
||||
(290, "utf8mb4", "utf8mb4_da_0900_as_cs", False),
|
||||
(291, "utf8mb4", "utf8mb4_lt_0900_as_cs", False),
|
||||
(292, "utf8mb4", "utf8mb4_sk_0900_as_cs", False),
|
||||
(293, "utf8mb4", "utf8mb4_es_trad_0900_as_cs", False),
|
||||
(294, "utf8mb4", "utf8mb4_la_0900_as_cs", False),
|
||||
(296, "utf8mb4", "utf8mb4_eo_0900_as_cs", False),
|
||||
(297, "utf8mb4", "utf8mb4_hu_0900_as_cs", False),
|
||||
(298, "utf8mb4", "utf8mb4_hr_0900_as_cs", False),
|
||||
(300, "utf8mb4", "utf8mb4_vi_0900_as_cs", False),
|
||||
(303, "utf8mb4", "utf8mb4_ja_0900_as_cs", False),
|
||||
(304, "utf8mb4", "utf8mb4_ja_0900_as_cs_ks", False),
|
||||
(305, "utf8mb4", "utf8mb4_0900_as_ci", False),
|
||||
(306, "utf8mb4", "utf8mb4_ru_0900_ai_ci", False),
|
||||
(307, "utf8mb4", "utf8mb4_ru_0900_as_cs", False),
|
||||
(308, "utf8mb4", "utf8mb4_zh_0900_as_cs", False),
|
||||
(309, "utf8mb4", "utf8mb4_0900_bin", False),
|
||||
(310, "utf8mb4", "utf8mb4_nb_0900_ai_ci", False),
|
||||
(311, "utf8mb4", "utf8mb4_nb_0900_as_cs", False),
|
||||
(312, "utf8mb4", "utf8mb4_nn_0900_ai_ci", False),
|
||||
(313, "utf8mb4", "utf8mb4_nn_0900_as_cs", False),
|
||||
(314, "utf8mb4", "utf8mb4_sr_latn_0900_ai_ci", False),
|
||||
(315, "utf8mb4", "utf8mb4_sr_latn_0900_as_cs", False),
|
||||
(316, "utf8mb4", "utf8mb4_bs_0900_ai_ci", False),
|
||||
(317, "utf8mb4", "utf8mb4_bs_0900_as_cs", False),
|
||||
(318, "utf8mb4", "utf8mb4_bg_0900_ai_ci", False),
|
||||
(319, "utf8mb4", "utf8mb4_bg_0900_as_cs", False),
|
||||
(320, "utf8mb4", "utf8mb4_gl_0900_ai_ci", False),
|
||||
(321, "utf8mb4", "utf8mb4_gl_0900_as_cs", False),
|
||||
(322, "utf8mb4", "utf8mb4_mn_cyrl_0900_ai_ci", False),
|
||||
(323, "utf8mb4", "utf8mb4_mn_cyrl_0900_as_cs", False),
|
||||
)
|
||||
|
||||
MYSQL_5_CHARSETS = (
|
||||
(1, "big5", "big5_chinese_ci", True),
|
||||
(2, "latin2", "latin2_czech_cs", False),
|
||||
(3, "dec8", "dec8_swedish_ci", True),
|
||||
(4, "cp850", "cp850_general_ci", True),
|
||||
(5, "latin1", "latin1_german1_ci", False),
|
||||
(6, "hp8", "hp8_english_ci", True),
|
||||
(7, "koi8r", "koi8r_general_ci", True),
|
||||
(8, "latin1", "latin1_swedish_ci", True),
|
||||
(9, "latin2", "latin2_general_ci", True),
|
||||
(10, "swe7", "swe7_swedish_ci", True),
|
||||
(11, "ascii", "ascii_general_ci", True),
|
||||
(12, "ujis", "ujis_japanese_ci", True),
|
||||
(13, "sjis", "sjis_japanese_ci", True),
|
||||
(14, "cp1251", "cp1251_bulgarian_ci", False),
|
||||
(15, "latin1", "latin1_danish_ci", False),
|
||||
(16, "hebrew", "hebrew_general_ci", True),
|
||||
(18, "tis620", "tis620_thai_ci", True),
|
||||
(19, "euckr", "euckr_korean_ci", True),
|
||||
(20, "latin7", "latin7_estonian_cs", False),
|
||||
(21, "latin2", "latin2_hungarian_ci", False),
|
||||
(22, "koi8u", "koi8u_general_ci", True),
|
||||
(23, "cp1251", "cp1251_ukrainian_ci", False),
|
||||
(24, "gb2312", "gb2312_chinese_ci", True),
|
||||
(25, "greek", "greek_general_ci", True),
|
||||
(26, "cp1250", "cp1250_general_ci", True),
|
||||
(27, "latin2", "latin2_croatian_ci", False),
|
||||
(28, "gbk", "gbk_chinese_ci", True),
|
||||
(29, "cp1257", "cp1257_lithuanian_ci", False),
|
||||
(30, "latin5", "latin5_turkish_ci", True),
|
||||
(31, "latin1", "latin1_german2_ci", False),
|
||||
(32, "armscii8", "armscii8_general_ci", True),
|
||||
(33, "utf8", "utf8_general_ci", True),
|
||||
(34, "cp1250", "cp1250_czech_cs", False),
|
||||
(35, "ucs2", "ucs2_general_ci", True),
|
||||
(36, "cp866", "cp866_general_ci", True),
|
||||
(37, "keybcs2", "keybcs2_general_ci", True),
|
||||
(38, "macce", "macce_general_ci", True),
|
||||
(39, "macroman", "macroman_general_ci", True),
|
||||
(40, "cp852", "cp852_general_ci", True),
|
||||
(41, "latin7", "latin7_general_ci", True),
|
||||
(42, "latin7", "latin7_general_cs", False),
|
||||
(43, "macce", "macce_bin", False),
|
||||
(44, "cp1250", "cp1250_croatian_ci", False),
|
||||
(45, "utf8mb4", "utf8mb4_general_ci", True),
|
||||
(46, "utf8mb4", "utf8mb4_bin", False),
|
||||
(47, "latin1", "latin1_bin", False),
|
||||
(48, "latin1", "latin1_general_ci", False),
|
||||
(49, "latin1", "latin1_general_cs", False),
|
||||
(50, "cp1251", "cp1251_bin", False),
|
||||
(51, "cp1251", "cp1251_general_ci", True),
|
||||
(52, "cp1251", "cp1251_general_cs", False),
|
||||
(53, "macroman", "macroman_bin", False),
|
||||
(54, "utf16", "utf16_general_ci", True),
|
||||
(55, "utf16", "utf16_bin", False),
|
||||
(56, "utf16le", "utf16le_general_ci", True),
|
||||
(57, "cp1256", "cp1256_general_ci", True),
|
||||
(58, "cp1257", "cp1257_bin", False),
|
||||
(59, "cp1257", "cp1257_general_ci", True),
|
||||
(60, "utf32", "utf32_general_ci", True),
|
||||
(61, "utf32", "utf32_bin", False),
|
||||
(62, "utf16le", "utf16le_bin", False),
|
||||
(63, "binary", "binary", True),
|
||||
(64, "armscii8", "armscii8_bin", False),
|
||||
(65, "ascii", "ascii_bin", False),
|
||||
(66, "cp1250", "cp1250_bin", False),
|
||||
(67, "cp1256", "cp1256_bin", False),
|
||||
(68, "cp866", "cp866_bin", False),
|
||||
(69, "dec8", "dec8_bin", False),
|
||||
(70, "greek", "greek_bin", False),
|
||||
(71, "hebrew", "hebrew_bin", False),
|
||||
(72, "hp8", "hp8_bin", False),
|
||||
(73, "keybcs2", "keybcs2_bin", False),
|
||||
(74, "koi8r", "koi8r_bin", False),
|
||||
(75, "koi8u", "koi8u_bin", False),
|
||||
(77, "latin2", "latin2_bin", False),
|
||||
(78, "latin5", "latin5_bin", False),
|
||||
(79, "latin7", "latin7_bin", False),
|
||||
(80, "cp850", "cp850_bin", False),
|
||||
(81, "cp852", "cp852_bin", False),
|
||||
(82, "swe7", "swe7_bin", False),
|
||||
(83, "utf8", "utf8_bin", False),
|
||||
(84, "big5", "big5_bin", False),
|
||||
(85, "euckr", "euckr_bin", False),
|
||||
(86, "gb2312", "gb2312_bin", False),
|
||||
(87, "gbk", "gbk_bin", False),
|
||||
(88, "sjis", "sjis_bin", False),
|
||||
(89, "tis620", "tis620_bin", False),
|
||||
(90, "ucs2", "ucs2_bin", False),
|
||||
(91, "ujis", "ujis_bin", False),
|
||||
(92, "geostd8", "geostd8_general_ci", True),
|
||||
(93, "geostd8", "geostd8_bin", False),
|
||||
(94, "latin1", "latin1_spanish_ci", False),
|
||||
(95, "cp932", "cp932_japanese_ci", True),
|
||||
(96, "cp932", "cp932_bin", False),
|
||||
(97, "eucjpms", "eucjpms_japanese_ci", True),
|
||||
(98, "eucjpms", "eucjpms_bin", False),
|
||||
(99, "cp1250", "cp1250_polish_ci", False),
|
||||
(101, "utf16", "utf16_unicode_ci", False),
|
||||
(102, "utf16", "utf16_icelandic_ci", False),
|
||||
(103, "utf16", "utf16_latvian_ci", False),
|
||||
(104, "utf16", "utf16_romanian_ci", False),
|
||||
(105, "utf16", "utf16_slovenian_ci", False),
|
||||
(106, "utf16", "utf16_polish_ci", False),
|
||||
(107, "utf16", "utf16_estonian_ci", False),
|
||||
(108, "utf16", "utf16_spanish_ci", False),
|
||||
(109, "utf16", "utf16_swedish_ci", False),
|
||||
(110, "utf16", "utf16_turkish_ci", False),
|
||||
(111, "utf16", "utf16_czech_ci", False),
|
||||
(112, "utf16", "utf16_danish_ci", False),
|
||||
(113, "utf16", "utf16_lithuanian_ci", False),
|
||||
(114, "utf16", "utf16_slovak_ci", False),
|
||||
(115, "utf16", "utf16_spanish2_ci", False),
|
||||
(116, "utf16", "utf16_roman_ci", False),
|
||||
(117, "utf16", "utf16_persian_ci", False),
|
||||
(118, "utf16", "utf16_esperanto_ci", False),
|
||||
(119, "utf16", "utf16_hungarian_ci", False),
|
||||
(120, "utf16", "utf16_sinhala_ci", False),
|
||||
(121, "utf16", "utf16_german2_ci", False),
|
||||
(122, "utf16", "utf16_croatian_ci", False),
|
||||
(123, "utf16", "utf16_unicode_520_ci", False),
|
||||
(124, "utf16", "utf16_vietnamese_ci", False),
|
||||
(128, "ucs2", "ucs2_unicode_ci", False),
|
||||
(129, "ucs2", "ucs2_icelandic_ci", False),
|
||||
(130, "ucs2", "ucs2_latvian_ci", False),
|
||||
(131, "ucs2", "ucs2_romanian_ci", False),
|
||||
(132, "ucs2", "ucs2_slovenian_ci", False),
|
||||
(133, "ucs2", "ucs2_polish_ci", False),
|
||||
(134, "ucs2", "ucs2_estonian_ci", False),
|
||||
(135, "ucs2", "ucs2_spanish_ci", False),
|
||||
(136, "ucs2", "ucs2_swedish_ci", False),
|
||||
(137, "ucs2", "ucs2_turkish_ci", False),
|
||||
(138, "ucs2", "ucs2_czech_ci", False),
|
||||
(139, "ucs2", "ucs2_danish_ci", False),
|
||||
(140, "ucs2", "ucs2_lithuanian_ci", False),
|
||||
(141, "ucs2", "ucs2_slovak_ci", False),
|
||||
(142, "ucs2", "ucs2_spanish2_ci", False),
|
||||
(143, "ucs2", "ucs2_roman_ci", False),
|
||||
(144, "ucs2", "ucs2_persian_ci", False),
|
||||
(145, "ucs2", "ucs2_esperanto_ci", False),
|
||||
(146, "ucs2", "ucs2_hungarian_ci", False),
|
||||
(147, "ucs2", "ucs2_sinhala_ci", False),
|
||||
(148, "ucs2", "ucs2_german2_ci", False),
|
||||
(149, "ucs2", "ucs2_croatian_ci", False),
|
||||
(150, "ucs2", "ucs2_unicode_520_ci", False),
|
||||
(151, "ucs2", "ucs2_vietnamese_ci", False),
|
||||
(159, "ucs2", "ucs2_general_mysql500_ci", False),
|
||||
(160, "utf32", "utf32_unicode_ci", False),
|
||||
(161, "utf32", "utf32_icelandic_ci", False),
|
||||
(162, "utf32", "utf32_latvian_ci", False),
|
||||
(163, "utf32", "utf32_romanian_ci", False),
|
||||
(164, "utf32", "utf32_slovenian_ci", False),
|
||||
(165, "utf32", "utf32_polish_ci", False),
|
||||
(166, "utf32", "utf32_estonian_ci", False),
|
||||
(167, "utf32", "utf32_spanish_ci", False),
|
||||
(168, "utf32", "utf32_swedish_ci", False),
|
||||
(169, "utf32", "utf32_turkish_ci", False),
|
||||
(170, "utf32", "utf32_czech_ci", False),
|
||||
(171, "utf32", "utf32_danish_ci", False),
|
||||
(172, "utf32", "utf32_lithuanian_ci", False),
|
||||
(173, "utf32", "utf32_slovak_ci", False),
|
||||
(174, "utf32", "utf32_spanish2_ci", False),
|
||||
(175, "utf32", "utf32_roman_ci", False),
|
||||
(176, "utf32", "utf32_persian_ci", False),
|
||||
(177, "utf32", "utf32_esperanto_ci", False),
|
||||
(178, "utf32", "utf32_hungarian_ci", False),
|
||||
(179, "utf32", "utf32_sinhala_ci", False),
|
||||
(180, "utf32", "utf32_german2_ci", False),
|
||||
(181, "utf32", "utf32_croatian_ci", False),
|
||||
(182, "utf32", "utf32_unicode_520_ci", False),
|
||||
(183, "utf32", "utf32_vietnamese_ci", False),
|
||||
(192, "utf8", "utf8_unicode_ci", False),
|
||||
(193, "utf8", "utf8_icelandic_ci", False),
|
||||
(194, "utf8", "utf8_latvian_ci", False),
|
||||
(195, "utf8", "utf8_romanian_ci", False),
|
||||
(196, "utf8", "utf8_slovenian_ci", False),
|
||||
(197, "utf8", "utf8_polish_ci", False),
|
||||
(198, "utf8", "utf8_estonian_ci", False),
|
||||
(199, "utf8", "utf8_spanish_ci", False),
|
||||
(200, "utf8", "utf8_swedish_ci", False),
|
||||
(201, "utf8", "utf8_turkish_ci", False),
|
||||
(202, "utf8", "utf8_czech_ci", False),
|
||||
(203, "utf8", "utf8_danish_ci", False),
|
||||
(204, "utf8", "utf8_lithuanian_ci", False),
|
||||
(205, "utf8", "utf8_slovak_ci", False),
|
||||
(206, "utf8", "utf8_spanish2_ci", False),
|
||||
(207, "utf8", "utf8_roman_ci", False),
|
||||
(208, "utf8", "utf8_persian_ci", False),
|
||||
(209, "utf8", "utf8_esperanto_ci", False),
|
||||
(210, "utf8", "utf8_hungarian_ci", False),
|
||||
(211, "utf8", "utf8_sinhala_ci", False),
|
||||
(212, "utf8", "utf8_german2_ci", False),
|
||||
(213, "utf8", "utf8_croatian_ci", False),
|
||||
(214, "utf8", "utf8_unicode_520_ci", False),
|
||||
(215, "utf8", "utf8_vietnamese_ci", False),
|
||||
(223, "utf8", "utf8_general_mysql500_ci", False),
|
||||
(224, "utf8mb4", "utf8mb4_unicode_ci", False),
|
||||
(225, "utf8mb4", "utf8mb4_icelandic_ci", False),
|
||||
(226, "utf8mb4", "utf8mb4_latvian_ci", False),
|
||||
(227, "utf8mb4", "utf8mb4_romanian_ci", False),
|
||||
(228, "utf8mb4", "utf8mb4_slovenian_ci", False),
|
||||
(229, "utf8mb4", "utf8mb4_polish_ci", False),
|
||||
(230, "utf8mb4", "utf8mb4_estonian_ci", False),
|
||||
(231, "utf8mb4", "utf8mb4_spanish_ci", False),
|
||||
(232, "utf8mb4", "utf8mb4_swedish_ci", False),
|
||||
(233, "utf8mb4", "utf8mb4_turkish_ci", False),
|
||||
(234, "utf8mb4", "utf8mb4_czech_ci", False),
|
||||
(235, "utf8mb4", "utf8mb4_danish_ci", False),
|
||||
(236, "utf8mb4", "utf8mb4_lithuanian_ci", False),
|
||||
(237, "utf8mb4", "utf8mb4_slovak_ci", False),
|
||||
(238, "utf8mb4", "utf8mb4_spanish2_ci", False),
|
||||
(239, "utf8mb4", "utf8mb4_roman_ci", False),
|
||||
(240, "utf8mb4", "utf8mb4_persian_ci", False),
|
||||
(241, "utf8mb4", "utf8mb4_esperanto_ci", False),
|
||||
(242, "utf8mb4", "utf8mb4_hungarian_ci", False),
|
||||
(243, "utf8mb4", "utf8mb4_sinhala_ci", False),
|
||||
(244, "utf8mb4", "utf8mb4_german2_ci", False),
|
||||
(245, "utf8mb4", "utf8mb4_croatian_ci", False),
|
||||
(246, "utf8mb4", "utf8mb4_unicode_520_ci", False),
|
||||
(247, "utf8mb4", "utf8mb4_vietnamese_ci", False),
|
||||
(248, "gb18030", "gb18030_chinese_ci", True),
|
||||
(249, "gb18030", "gb18030_bin", False),
|
||||
(250, "gb18030", "gb18030_unicode_520_ci", False),
|
||||
)
|
||||
|
||||
charsets = Charsets()
|
||||
1580
venv/lib/python3.12/site-packages/mysql/connector/aio/connection.py
Normal file
1580
venv/lib/python3.12/site-packages/mysql/connector/aio/connection.py
Normal file
File diff suppressed because it is too large
Load Diff
1392
venv/lib/python3.12/site-packages/mysql/connector/aio/cursor.py
Normal file
1392
venv/lib/python3.12/site-packages/mysql/connector/aio/cursor.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,33 @@
|
||||
# Copyright (c) 2023, 2024, Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""Setup of the `mysql.connector.aio` logger."""
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("mysql.connector.aio")
|
||||
765
venv/lib/python3.12/site-packages/mysql/connector/aio/network.py
Normal file
765
venv/lib/python3.12/site-packages/mysql/connector/aio/network.py
Normal file
@@ -0,0 +1,765 @@
|
||||
# Copyright (c) 2023, 2024, Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
# pylint: disable=dangerous-default-value
|
||||
|
||||
"""Module implementing low-level socket communication with MySQL servers."""
|
||||
|
||||
|
||||
__all__ = ["MySQLTcpSocket", "MySQLUnixSocket"]
|
||||
|
||||
import asyncio
|
||||
import struct
|
||||
import zlib
|
||||
|
||||
try:
|
||||
import ssl
|
||||
|
||||
TLS_VERSIONS = {
|
||||
"TLSv1": ssl.PROTOCOL_TLSv1,
|
||||
"TLSv1.1": ssl.PROTOCOL_TLSv1_1,
|
||||
"TLSv1.2": ssl.PROTOCOL_TLSv1_2,
|
||||
"TLSv1.3": ssl.PROTOCOL_TLS,
|
||||
}
|
||||
except ImportError:
|
||||
ssl = None
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from typing import Any, Deque, List, Optional, Tuple
|
||||
|
||||
from ..errors import (
|
||||
InterfaceError,
|
||||
NotSupportedError,
|
||||
OperationalError,
|
||||
ProgrammingError,
|
||||
ReadTimeoutError,
|
||||
WriteTimeoutError,
|
||||
)
|
||||
from ..network import (
|
||||
COMPRESSED_PACKET_HEADER_LENGTH,
|
||||
MAX_PAYLOAD_LENGTH,
|
||||
MIN_COMPRESS_LENGTH,
|
||||
PACKET_HEADER_LENGTH,
|
||||
)
|
||||
from .utils import StreamWriter, open_connection
|
||||
|
||||
|
||||
def _strioerror(err: IOError) -> str:
|
||||
"""Reformat the IOError error message.
|
||||
|
||||
This function reformats the IOError error message.
|
||||
"""
|
||||
return str(err) if not err.errno else f"{err.errno} {err.strerror}"
|
||||
|
||||
|
||||
class NetworkBroker(ABC):
|
||||
"""Broker class interface.
|
||||
|
||||
The network object is a broker used as a delegate by a socket object. Whenever the
|
||||
socket wants to deliver or get packets to or from the MySQL server it needs to rely
|
||||
on its network broker (netbroker).
|
||||
|
||||
The netbroker sends `payloads` and receives `packets`.
|
||||
|
||||
A packet is a bytes sequence, it has a header and body (referred to as payload).
|
||||
The first `PACKET_HEADER_LENGTH` or `COMPRESSED_PACKET_HEADER_LENGTH`
|
||||
(as appropriate) bytes correspond to the `header`, the remaining ones represent the
|
||||
`payload`.
|
||||
|
||||
The maximum payload length allowed to be sent per packet to the server is
|
||||
`MAX_PAYLOAD_LENGTH`. When `send` is called with a payload whose length is greater
|
||||
than `MAX_PAYLOAD_LENGTH` the netbroker breaks it down into packets, so the caller
|
||||
of `send` can provide payloads of arbitrary length.
|
||||
|
||||
Finally, data received by the netbroker comes directly from the server, expect to
|
||||
get a packet for each call to `recv`. The received packet contains a header and
|
||||
payload, the latter respecting `MAX_PAYLOAD_LENGTH`.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def write(
|
||||
self,
|
||||
writer: StreamWriter,
|
||||
address: str,
|
||||
payload: bytes,
|
||||
packet_number: Optional[int] = None,
|
||||
compressed_packet_number: Optional[int] = None,
|
||||
write_timeout: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Send `payload` to the MySQL server.
|
||||
|
||||
If provided a payload whose length is greater than `MAX_PAYLOAD_LENGTH`, it is
|
||||
broken down into packets.
|
||||
|
||||
Args:
|
||||
sock: Object holding the socket connection.
|
||||
address: Socket's location.
|
||||
payload: Packet's body to send.
|
||||
packet_number: Sequence id (packet ID) to attach to the header when sending
|
||||
plain packets.
|
||||
compressed_packet_number: Same as `packet_number` but used when sending
|
||||
compressed packets.
|
||||
write_timeout: Timeout in seconds before which sending a packet to the server
|
||||
should finish else WriteTimeoutError is raised.
|
||||
|
||||
|
||||
Raises:
|
||||
:class:`OperationalError`: If something goes wrong while sending packets to
|
||||
the MySQL server.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def read(
|
||||
self,
|
||||
reader: asyncio.StreamReader,
|
||||
address: str,
|
||||
read_timeout: Optional[int] = None,
|
||||
) -> bytearray:
|
||||
"""Get the next available packet from the MySQL server.
|
||||
|
||||
Args:
|
||||
sock: Object holding the socket connection.
|
||||
address: Socket's location.
|
||||
read_timeout: Timeout in seconds before which reading a packet from the server
|
||||
should finish.
|
||||
|
||||
Returns:
|
||||
packet: A packet from the MySQL server.
|
||||
|
||||
Raises:
|
||||
:class:`OperationalError`: If something goes wrong while receiving packets
|
||||
from the MySQL server.
|
||||
:class:`ReadTimeoutError`: If the time to receive a packet from the server takes
|
||||
longer than `read_timeout`.
|
||||
:class:`InterfaceError`: If something goes wrong while receiving packets
|
||||
from the MySQL server.
|
||||
"""
|
||||
|
||||
|
||||
class NetworkBrokerPlain(NetworkBroker):
|
||||
"""Broker class for MySQL socket communication."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._pktnr: int = -1 # packet number
|
||||
|
||||
@staticmethod
|
||||
def get_header(pkt: bytes) -> Tuple[int, int]:
|
||||
"""Recover the header information from a packet."""
|
||||
if len(pkt) < PACKET_HEADER_LENGTH:
|
||||
raise ValueError("Can't recover header info from an incomplete packet")
|
||||
|
||||
pll, seqid = (
|
||||
struct.unpack("<I", pkt[0:3] + b"\x00")[0],
|
||||
pkt[3],
|
||||
)
|
||||
# payload length, sequence id
|
||||
return pll, seqid
|
||||
|
||||
def _set_next_pktnr(self, next_id: Optional[int] = None) -> None:
|
||||
"""Set the given packet id, if any, else increment packet id."""
|
||||
if next_id is None:
|
||||
self._pktnr += 1
|
||||
else:
|
||||
self._pktnr = next_id
|
||||
self._pktnr %= 256
|
||||
|
||||
async def _write_pkt(
|
||||
self,
|
||||
writer: StreamWriter,
|
||||
address: str,
|
||||
pkt: bytes,
|
||||
) -> None:
|
||||
"""Write packet to the comm channel."""
|
||||
try:
|
||||
writer.write(pkt)
|
||||
await writer.drain()
|
||||
except IOError as err:
|
||||
raise OperationalError(
|
||||
errno=2055, values=(address, _strioerror(err))
|
||||
) from err
|
||||
except AttributeError as err:
|
||||
raise OperationalError(errno=2006) from err
|
||||
|
||||
async def _read_chunk(
|
||||
self,
|
||||
reader: asyncio.StreamReader,
|
||||
size: int = 0,
|
||||
read_timeout: Optional[int] = None,
|
||||
) -> bytearray:
|
||||
"""Read `size` bytes from the comm channel."""
|
||||
try:
|
||||
pkt = bytearray(b"")
|
||||
while len(pkt) < size:
|
||||
chunk = await asyncio.wait_for(
|
||||
reader.read(size - len(pkt)), read_timeout
|
||||
)
|
||||
if not chunk:
|
||||
raise InterfaceError(errno=2013)
|
||||
pkt += chunk
|
||||
return pkt
|
||||
except asyncio.TimeoutError as err:
|
||||
raise ReadTimeoutError(errno=3024) from err
|
||||
except asyncio.CancelledError as err:
|
||||
raise err
|
||||
|
||||
async def write(
|
||||
self,
|
||||
writer: StreamWriter,
|
||||
address: str,
|
||||
payload: bytes,
|
||||
packet_number: Optional[int] = None,
|
||||
compressed_packet_number: Optional[int] = None,
|
||||
write_timeout: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Send payload to the MySQL server.
|
||||
|
||||
If provided a payload whose length is greater than `MAX_PAYLOAD_LENGTH`, it is
|
||||
broken down into packets.
|
||||
"""
|
||||
self._set_next_pktnr(packet_number)
|
||||
# If the payload is larger than or equal to MAX_PAYLOAD_LENGTH the length is
|
||||
# set to 2^24 - 1 (ff ff ff) and additional packets are sent with the rest of
|
||||
# the payload until the payload of a packet is less than MAX_PAYLOAD_LENGTH.
|
||||
offset = 0
|
||||
try:
|
||||
for _ in range(len(payload) // MAX_PAYLOAD_LENGTH):
|
||||
# payload_len, sequence_id, payload
|
||||
await asyncio.wait_for(
|
||||
self._write_pkt(
|
||||
writer,
|
||||
address,
|
||||
b"\xff" * 3
|
||||
+ struct.pack("<B", self._pktnr)
|
||||
+ payload[offset : offset + MAX_PAYLOAD_LENGTH],
|
||||
),
|
||||
write_timeout,
|
||||
)
|
||||
self._set_next_pktnr()
|
||||
offset += MAX_PAYLOAD_LENGTH
|
||||
await asyncio.wait_for(
|
||||
self._write_pkt(
|
||||
writer,
|
||||
address,
|
||||
struct.pack("<I", len(payload) - offset)[0:3]
|
||||
+ struct.pack("<B", self._pktnr)
|
||||
+ payload[offset:],
|
||||
),
|
||||
write_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError as err:
|
||||
raise WriteTimeoutError(errno=3024) from err
|
||||
except asyncio.CancelledError as err:
|
||||
raise err
|
||||
|
||||
async def read(
|
||||
self,
|
||||
reader: asyncio.StreamReader,
|
||||
address: str,
|
||||
read_timeout: Optional[int] = None,
|
||||
) -> bytearray:
|
||||
"""Receive `one` packet from the MySQL server."""
|
||||
try:
|
||||
# Read the header of the MySQL packet.
|
||||
header = await self._read_chunk(reader, PACKET_HEADER_LENGTH, read_timeout)
|
||||
|
||||
# Pull the payload length and sequence id.
|
||||
payload_len, self._pktnr = self.get_header(header)
|
||||
|
||||
# Read the payload, and return packet.
|
||||
return header + await self._read_chunk(reader, payload_len, read_timeout)
|
||||
except IOError as err:
|
||||
raise OperationalError(
|
||||
errno=2055, values=(address, _strioerror(err))
|
||||
) from err
|
||||
|
||||
|
||||
class NetworkBrokerCompressed(NetworkBrokerPlain):
|
||||
"""Broker class for MySQL socket communication."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._compressed_pktnr = -1
|
||||
self._queue_read: Deque[bytearray] = deque()
|
||||
|
||||
@staticmethod
|
||||
def _prepare_packets(payload: bytes, pktnr: int) -> List[bytes]:
|
||||
"""Prepare a payload for sending to the MySQL server."""
|
||||
offset = 0
|
||||
pkts = []
|
||||
|
||||
# If the payload is larger than or equal to MAX_PAYLOAD_LENGTH the length is
|
||||
# set to 2^24 - 1 (ff ff ff) and additional packets are sent with the rest of
|
||||
# the payload until the payload of a packet is less than MAX_PAYLOAD_LENGTH.
|
||||
for _ in range(len(payload) // MAX_PAYLOAD_LENGTH):
|
||||
# payload length + sequence id + payload
|
||||
pkts.append(
|
||||
b"\xff" * 3
|
||||
+ struct.pack("<B", pktnr)
|
||||
+ payload[offset : offset + MAX_PAYLOAD_LENGTH]
|
||||
)
|
||||
pktnr = (pktnr + 1) % 256
|
||||
offset += MAX_PAYLOAD_LENGTH
|
||||
pkts.append(
|
||||
struct.pack("<I", len(payload) - offset)[0:3]
|
||||
+ struct.pack("<B", pktnr)
|
||||
+ payload[offset:]
|
||||
)
|
||||
return pkts
|
||||
|
||||
@staticmethod
|
||||
def get_header(pkt: bytes) -> Tuple[int, int, int]: # type: ignore[override]
|
||||
"""Recover the header information from a packet."""
|
||||
if len(pkt) < COMPRESSED_PACKET_HEADER_LENGTH:
|
||||
raise ValueError("Can't recover header info from an incomplete packet")
|
||||
|
||||
compressed_pll, seqid, uncompressed_pll = (
|
||||
struct.unpack("<I", pkt[0:3] + b"\x00")[0],
|
||||
pkt[3],
|
||||
struct.unpack("<I", pkt[4:7] + b"\x00")[0],
|
||||
)
|
||||
# compressed payload length, sequence id, uncompressed payload length
|
||||
return compressed_pll, seqid, uncompressed_pll
|
||||
|
||||
def _set_next_compressed_pktnr(self, next_id: Optional[int] = None) -> None:
|
||||
"""Set the given packet id, if any, else increment packet id."""
|
||||
if next_id is None:
|
||||
self._compressed_pktnr += 1
|
||||
else:
|
||||
self._compressed_pktnr = next_id
|
||||
self._compressed_pktnr %= 256
|
||||
|
||||
async def _write_pkt(
|
||||
self,
|
||||
writer: StreamWriter,
|
||||
address: str,
|
||||
pkt: bytes,
|
||||
) -> None:
|
||||
"""Compress packet and write it to the comm channel."""
|
||||
compressed_pkt = zlib.compress(pkt)
|
||||
pkt = (
|
||||
struct.pack("<I", len(compressed_pkt))[0:3]
|
||||
+ struct.pack("<B", self._compressed_pktnr)
|
||||
+ struct.pack("<I", len(pkt))[0:3]
|
||||
+ compressed_pkt
|
||||
)
|
||||
return await super()._write_pkt(writer, address, pkt)
|
||||
|
||||
async def write(
|
||||
self,
|
||||
writer: StreamWriter,
|
||||
address: str,
|
||||
payload: bytes,
|
||||
packet_number: Optional[int] = None,
|
||||
compressed_packet_number: Optional[int] = None,
|
||||
write_timeout: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Send `payload` as compressed packets to the MySQL server.
|
||||
|
||||
If provided a payload whose length is greater than `MAX_PAYLOAD_LENGTH`, it is
|
||||
broken down into packets.
|
||||
"""
|
||||
# Get next packet numbers.
|
||||
self._set_next_pktnr(packet_number)
|
||||
self._set_next_compressed_pktnr(compressed_packet_number)
|
||||
try:
|
||||
payload_prep = bytearray(b"").join(
|
||||
self._prepare_packets(payload, self._pktnr)
|
||||
)
|
||||
if len(payload) >= MAX_PAYLOAD_LENGTH - PACKET_HEADER_LENGTH:
|
||||
# Sending a MySQL payload of the size greater or equal to 2^24 - 5 via
|
||||
# compression leads to at least one extra compressed packet WHY? let's say
|
||||
# len(payload) is MAX_PAYLOAD_LENGTH - 3; when preparing the payload, a
|
||||
# header of size PACKET_HEADER_LENGTH is pre-appended to the payload.
|
||||
# This means that len(payload_prep) is
|
||||
# MAX_PAYLOAD_LENGTH - 3 + PACKET_HEADER_LENGTH = MAX_PAYLOAD_LENGTH + 1
|
||||
# surpassing the maximum allowed payload size per packet.
|
||||
offset = 0
|
||||
|
||||
# Send several MySQL packets.
|
||||
for _ in range(len(payload_prep) // MAX_PAYLOAD_LENGTH):
|
||||
await asyncio.wait_for(
|
||||
self._write_pkt(
|
||||
writer,
|
||||
address,
|
||||
payload_prep[offset : offset + MAX_PAYLOAD_LENGTH],
|
||||
),
|
||||
write_timeout,
|
||||
)
|
||||
self._set_next_compressed_pktnr()
|
||||
offset += MAX_PAYLOAD_LENGTH
|
||||
await asyncio.wait_for(
|
||||
self._write_pkt(writer, address, payload_prep[offset:]),
|
||||
write_timeout,
|
||||
)
|
||||
else:
|
||||
# Send one MySQL packet.
|
||||
# For small packets it may be too costly to compress the packet.
|
||||
# Usually payloads less than 50 bytes (MIN_COMPRESS_LENGTH) aren't
|
||||
# compressed (see MySQL source code Documentation).
|
||||
if len(payload) > MIN_COMPRESS_LENGTH:
|
||||
# Perform compression.
|
||||
await asyncio.wait_for(
|
||||
self._write_pkt(writer, address, payload_prep), write_timeout
|
||||
)
|
||||
else:
|
||||
# Skip compression.
|
||||
await asyncio.wait_for(
|
||||
super()._write_pkt(
|
||||
writer,
|
||||
address,
|
||||
struct.pack("<I", len(payload_prep))[0:3]
|
||||
+ struct.pack("<B", self._compressed_pktnr)
|
||||
+ struct.pack("<I", 0)[0:3]
|
||||
+ payload_prep,
|
||||
),
|
||||
write_timeout,
|
||||
)
|
||||
except (asyncio.CancelledError, asyncio.TimeoutError) as err:
|
||||
raise WriteTimeoutError(errno=3024) from err
|
||||
|
||||
async def _read_compressed_pkt(
|
||||
self,
|
||||
reader: asyncio.StreamReader,
|
||||
compressed_pll: int,
|
||||
read_timeout: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Handle reading of a compressed packet."""
|
||||
# compressed_pll stands for compressed payload length.
|
||||
pkt = bytearray(
|
||||
zlib.decompress(
|
||||
await super()._read_chunk(reader, compressed_pll, read_timeout)
|
||||
)
|
||||
)
|
||||
offset = 0
|
||||
while offset < len(pkt):
|
||||
# pll stands for payload length
|
||||
pll = struct.unpack(
|
||||
"<I", pkt[offset : offset + PACKET_HEADER_LENGTH - 1] + b"\x00"
|
||||
)[0]
|
||||
if PACKET_HEADER_LENGTH + pll > len(pkt) - offset:
|
||||
# More bytes need to be consumed.
|
||||
# Read the header of the next MySQL packet.
|
||||
header = await super()._read_chunk(
|
||||
reader, COMPRESSED_PACKET_HEADER_LENGTH, read_timeout
|
||||
)
|
||||
|
||||
# compressed payload length, sequence id, uncompressed payload length.
|
||||
(
|
||||
compressed_pll,
|
||||
self._compressed_pktnr,
|
||||
uncompressed_pll,
|
||||
) = self.get_header(header)
|
||||
compressed_pkt = await super()._read_chunk(
|
||||
reader, compressed_pll, read_timeout
|
||||
)
|
||||
|
||||
# Recalling that if uncompressed payload length == 0, the packet comes
|
||||
# in uncompressed, so no decompression is needed.
|
||||
pkt += (
|
||||
compressed_pkt
|
||||
if uncompressed_pll == 0
|
||||
else zlib.decompress(compressed_pkt)
|
||||
)
|
||||
|
||||
self._queue_read.append(pkt[offset : offset + PACKET_HEADER_LENGTH + pll])
|
||||
offset += PACKET_HEADER_LENGTH + pll
|
||||
|
||||
async def read(
|
||||
self,
|
||||
reader: asyncio.StreamReader,
|
||||
address: str,
|
||||
read_timeout: Optional[int] = None,
|
||||
) -> bytearray:
|
||||
"""Receive `one` or `several` packets from the MySQL server, enqueue them, and
|
||||
return the packet at the head.
|
||||
"""
|
||||
|
||||
if not self._queue_read:
|
||||
try:
|
||||
# Read the header of the next MySQL packet.
|
||||
header = await super()._read_chunk(
|
||||
reader, COMPRESSED_PACKET_HEADER_LENGTH, read_timeout
|
||||
)
|
||||
|
||||
# compressed payload length, sequence id, uncompressed payload length
|
||||
(
|
||||
compressed_pll,
|
||||
self._compressed_pktnr,
|
||||
uncompressed_pll,
|
||||
) = self.get_header(header)
|
||||
|
||||
if uncompressed_pll == 0:
|
||||
# Packet is not compressed, so just store it.
|
||||
self._queue_read.append(
|
||||
await super()._read_chunk(reader, compressed_pll, read_timeout)
|
||||
)
|
||||
else:
|
||||
# Packet comes in compressed, further action is needed.
|
||||
await self._read_compressed_pkt(
|
||||
reader, compressed_pll, read_timeout
|
||||
)
|
||||
except IOError as err:
|
||||
raise OperationalError(
|
||||
errno=2055, values=(address, _strioerror(err))
|
||||
) from err
|
||||
|
||||
if not self._queue_read:
|
||||
return None
|
||||
|
||||
pkt = self._queue_read.popleft()
|
||||
self._pktnr = pkt[3]
|
||||
|
||||
return pkt
|
||||
|
||||
|
||||
class MySQLSocket(ABC):
|
||||
"""MySQL socket communication interface.
|
||||
|
||||
Examples:
|
||||
Subclasses: network.MySQLTCPSocket and network.MySQLUnixSocket.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Network layer where transactions are made with plain (uncompressed) packets
|
||||
is enabled by default.
|
||||
"""
|
||||
self._reader: Optional[asyncio.StreamReader] = None
|
||||
self._writer: Optional[StreamWriter] = None
|
||||
self._connection_timeout: Optional[int] = None
|
||||
self._address: Optional[str] = None
|
||||
self._netbroker: NetworkBroker = NetworkBrokerPlain()
|
||||
self._is_connected: bool = False
|
||||
|
||||
@property
|
||||
def address(self) -> str:
|
||||
"""Socket location."""
|
||||
return self._address
|
||||
|
||||
@abstractmethod
|
||||
async def open_connection(self, **kwargs: Any) -> None:
|
||||
"""Open the socket."""
|
||||
|
||||
async def close_connection(self) -> None:
|
||||
"""Close the connection."""
|
||||
if self._writer:
|
||||
try:
|
||||
self._writer.close()
|
||||
# Without transport.abort(), an error is raised when using SSL
|
||||
if self._writer.transport is not None:
|
||||
self._writer.transport.abort()
|
||||
await self._writer.wait_closed()
|
||||
except Exception as _: # pylint: disable=broad-exception-caught)
|
||||
# we can ignore issues like ConnectionRefused or ConnectionAborted
|
||||
# as these instances might popup if the connection was closed due to timeout issues
|
||||
pass
|
||||
self._is_connected = False
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if the socket is connected.
|
||||
|
||||
Return:
|
||||
bool: Returns `True` if the socket is connected to MySQL server.
|
||||
"""
|
||||
return self._is_connected
|
||||
|
||||
def set_connection_timeout(self, timeout: int) -> None:
|
||||
"""Set the connection timeout."""
|
||||
self._connection_timeout = timeout
|
||||
|
||||
def switch_to_compressed_mode(self) -> None:
|
||||
"""Enable network layer where transactions are made with compressed packets."""
|
||||
self._netbroker = NetworkBrokerCompressed()
|
||||
|
||||
async def switch_to_ssl(self, ssl_context: ssl.SSLContext) -> None:
|
||||
"""Upgrade an existing stream-based connection to TLS.
|
||||
|
||||
The `start_tls()` method from `asyncio.streams.StreamWriter` is only available
|
||||
in Python 3.11. This method is used as a workaround.
|
||||
|
||||
The MySQL TLS negotiation happens in the middle of the TCP connection.
|
||||
Therefore, passing a socket to open connection will cause it to negotiate
|
||||
TLS on an existing connection.
|
||||
|
||||
Args:
|
||||
ssl_context: The SSL Context to be used.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the transport does not expose the socket instance.
|
||||
"""
|
||||
# Ensure that self._writer is already created
|
||||
assert self._writer is not None
|
||||
|
||||
socket = self._writer.transport.get_extra_info("socket")
|
||||
if socket.family == 1: # socket.AF_UNIX
|
||||
raise ProgrammingError("SSL is not supported when using Unix sockets")
|
||||
|
||||
await self._writer.start_tls(ssl_context)
|
||||
|
||||
async def write(
|
||||
self,
|
||||
payload: bytes,
|
||||
packet_number: Optional[int] = None,
|
||||
compressed_packet_number: Optional[int] = None,
|
||||
write_timeout: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Send packets to the MySQL server."""
|
||||
await self._netbroker.write(
|
||||
self._writer,
|
||||
self.address,
|
||||
payload,
|
||||
packet_number=packet_number,
|
||||
compressed_packet_number=compressed_packet_number,
|
||||
write_timeout=write_timeout,
|
||||
)
|
||||
|
||||
async def read(self, read_timeout: Optional[int] = None) -> bytearray:
|
||||
"""Read packets from the MySQL server."""
|
||||
return await self._netbroker.read(self._reader, self.address, read_timeout)
|
||||
|
||||
def build_ssl_context(
|
||||
self,
|
||||
ssl_ca: Optional[str] = None,
|
||||
ssl_cert: Optional[str] = None,
|
||||
ssl_key: Optional[str] = None,
|
||||
ssl_verify_cert: Optional[bool] = False,
|
||||
ssl_verify_identity: Optional[bool] = False,
|
||||
tls_versions: Optional[List[str]] = [],
|
||||
tls_cipher_suites: Optional[List[str]] = [],
|
||||
) -> ssl.SSLContext:
|
||||
"""Build a SSLContext."""
|
||||
tls_version: Optional[str] = None
|
||||
|
||||
if not self._reader:
|
||||
raise InterfaceError(errno=2048)
|
||||
|
||||
if ssl is None:
|
||||
raise RuntimeError("Python installation has no SSL support")
|
||||
|
||||
try:
|
||||
if tls_versions:
|
||||
tls_versions.sort(reverse=True)
|
||||
tls_version = tls_versions[0]
|
||||
ssl_protocol = TLS_VERSIONS[tls_version]
|
||||
context = ssl.SSLContext(ssl_protocol)
|
||||
|
||||
if tls_version == "TLSv1.3":
|
||||
if "TLSv1.2" not in tls_versions:
|
||||
context.options |= ssl.OP_NO_TLSv1_2
|
||||
if "TLSv1.1" not in tls_versions:
|
||||
context.options |= ssl.OP_NO_TLSv1_1
|
||||
if "TLSv1" not in tls_versions:
|
||||
context.options |= ssl.OP_NO_TLSv1
|
||||
else:
|
||||
context = ssl.create_default_context()
|
||||
|
||||
context.check_hostname = ssl_verify_identity
|
||||
|
||||
if ssl_verify_cert:
|
||||
context.verify_mode = ssl.CERT_REQUIRED
|
||||
elif ssl_verify_identity:
|
||||
context.verify_mode = ssl.CERT_OPTIONAL
|
||||
else:
|
||||
context.verify_mode = ssl.CERT_NONE
|
||||
|
||||
context.load_default_certs()
|
||||
|
||||
if ssl_ca:
|
||||
try:
|
||||
context.load_verify_locations(ssl_ca)
|
||||
except (IOError, ssl.SSLError) as err:
|
||||
raise InterfaceError(f"Invalid CA Certificate: {err}") from err
|
||||
if ssl_cert:
|
||||
try:
|
||||
context.load_cert_chain(ssl_cert, ssl_key)
|
||||
except (IOError, ssl.SSLError) as err:
|
||||
raise InterfaceError(f"Invalid Certificate/Key: {err}") from err
|
||||
|
||||
# TLSv1.3 ciphers cannot be disabled with `SSLContext.set_ciphers(...)`,
|
||||
# see https://docs.python.org/3/library/ssl.html#ssl.SSLContext.set_ciphers.
|
||||
if tls_cipher_suites and tls_version == "TLSv1.2":
|
||||
context.set_ciphers(":".join(tls_cipher_suites))
|
||||
|
||||
return context
|
||||
except NameError as err:
|
||||
raise NotSupportedError("Python installation has no SSL support") from err
|
||||
except (
|
||||
IOError,
|
||||
NotImplementedError,
|
||||
ssl.CertificateError,
|
||||
ssl.SSLError,
|
||||
) as err:
|
||||
raise InterfaceError(str(err)) from err
|
||||
|
||||
|
||||
class MySQLTcpSocket(MySQLSocket):
|
||||
"""MySQL socket class using TCP/IP.
|
||||
|
||||
Args:
|
||||
host: MySQL host name.
|
||||
port: MySQL port.
|
||||
force_ipv6: Force IPv6 usage.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, host: str = "127.0.0.1", port: int = 3306, force_ipv6: bool = False
|
||||
):
|
||||
super().__init__()
|
||||
self._host: str = host
|
||||
self._port: int = port
|
||||
self._force_ipv6: bool = force_ipv6
|
||||
self._address: str = f"{host}:{port}"
|
||||
|
||||
async def open_connection(self, **kwargs: Any) -> None:
|
||||
"""Open TCP/IP connection."""
|
||||
self._reader, self._writer = await open_connection(
|
||||
host=self._host, port=self._port, **kwargs
|
||||
)
|
||||
self._is_connected = True
|
||||
|
||||
|
||||
class MySQLUnixSocket(MySQLSocket):
|
||||
"""MySQL socket class using UNIX sockets.
|
||||
|
||||
Args:
|
||||
unix_socket: UNIX socket file path.
|
||||
"""
|
||||
|
||||
def __init__(self, unix_socket: str = "/tmp/mysql.sock"):
|
||||
super().__init__()
|
||||
self._address: str = unix_socket
|
||||
|
||||
async def open_connection(self, **kwargs: Any) -> None:
|
||||
"""Open UNIX socket connection."""
|
||||
(
|
||||
self._reader,
|
||||
self._writer,
|
||||
) = await asyncio.open_unix_connection( # type: ignore[assignment]
|
||||
path=self._address, **kwargs
|
||||
)
|
||||
self._is_connected = True
|
||||
@@ -0,0 +1,162 @@
|
||||
# Copyright (c) 2023, 2024, Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""Base Authentication Plugin class."""
|
||||
|
||||
__all__ = ["MySQLAuthPlugin", "get_auth_plugin"]
|
||||
|
||||
import importlib
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Any, Optional, Type
|
||||
|
||||
from mysql.connector.errors import NotSupportedError, ProgrammingError
|
||||
from mysql.connector.logger import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..network import MySQLSocket
|
||||
|
||||
DEFAULT_PLUGINS_PKG = "mysql.connector.aio.plugins"
|
||||
|
||||
|
||||
class MySQLAuthPlugin(ABC):
|
||||
"""Authorization plugin interface."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
username: str,
|
||||
password: str,
|
||||
ssl_enabled: bool = False,
|
||||
) -> None:
|
||||
"""Constructor."""
|
||||
self._username: str = "" if username is None else username
|
||||
self._password: str = "" if password is None else password
|
||||
self._ssl_enabled: bool = ssl_enabled
|
||||
|
||||
@property
|
||||
def ssl_enabled(self) -> bool:
|
||||
"""Signals whether or not SSL is enabled."""
|
||||
return self._ssl_enabled
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def requires_ssl(self) -> bool:
|
||||
"""Signals whether or not SSL is required."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Plugin official name."""
|
||||
|
||||
@abstractmethod
|
||||
def auth_response(self, auth_data: bytes, **kwargs: Any) -> Optional[bytes]:
|
||||
"""Make the client's authorization response.
|
||||
|
||||
Args:
|
||||
auth_data: Authorization data.
|
||||
kwargs: Custom configuration to be passed to the auth plugin
|
||||
when invoked. The parameters defined here will override the ones
|
||||
defined in the auth plugin itself.
|
||||
|
||||
Returns:
|
||||
packet: Client's authorization response.
|
||||
"""
|
||||
|
||||
async def auth_more_response(
|
||||
self, sock: "MySQLSocket", auth_data: bytes, **kwargs: Any
|
||||
) -> bytes:
|
||||
"""Handles server's `auth more data` response.
|
||||
|
||||
Args:
|
||||
sock: Pointer to the socket connection.
|
||||
auth_data: Authentication method data (from a packet representing
|
||||
an `auth more data` response).
|
||||
kwargs: Custom configuration to be passed to the auth plugin
|
||||
when invoked. The parameters defined here will override the ones
|
||||
defined in the auth plugin itself.
|
||||
|
||||
Returns:
|
||||
packet: Last server's response after back-and-forth communication.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def auth_switch_response(
|
||||
self, sock: "MySQLSocket", auth_data: bytes, **kwargs: Any
|
||||
) -> bytes:
|
||||
"""Handles server's `auth switch request` response.
|
||||
|
||||
Args:
|
||||
sock: Pointer to the socket connection.
|
||||
auth_data: Plugin provided data (extracted from a packet
|
||||
representing an `auth switch request` response).
|
||||
kwargs: Custom configuration to be passed to the auth plugin
|
||||
when invoked. The parameters defined here will override the ones
|
||||
defined in the auth plugin itself.
|
||||
|
||||
Returns:
|
||||
packet: Last server's response after back-and-forth communication.
|
||||
"""
|
||||
|
||||
|
||||
@lru_cache(maxsize=10, typed=False)
|
||||
def get_auth_plugin(
|
||||
plugin_name: str,
|
||||
auth_plugin_class: Optional[str] = None,
|
||||
) -> Type[MySQLAuthPlugin]:
|
||||
"""Return authentication class based on plugin name
|
||||
|
||||
This function returns the class for the authentication plugin plugin_name.
|
||||
The returned class is a subclass of BaseAuthPlugin.
|
||||
|
||||
Args:
|
||||
plugin_name (str): Authentication plugin name.
|
||||
auth_plugin_class (str): Authentication plugin class name.
|
||||
|
||||
Raises:
|
||||
NotSupportedError: When plugin_name is not supported.
|
||||
|
||||
Returns:
|
||||
Subclass of `MySQLAuthPlugin`.
|
||||
"""
|
||||
package = DEFAULT_PLUGINS_PKG
|
||||
if plugin_name:
|
||||
try:
|
||||
logger.info("package: %s", package)
|
||||
logger.info("plugin_name: %s", plugin_name)
|
||||
plugin_module = importlib.import_module(f".{plugin_name}", package)
|
||||
if not auth_plugin_class or not hasattr(plugin_module, auth_plugin_class):
|
||||
auth_plugin_class = plugin_module.AUTHENTICATION_PLUGIN_CLASS
|
||||
logger.info("AUTHENTICATION_PLUGIN_CLASS: %s", auth_plugin_class)
|
||||
return getattr(plugin_module, auth_plugin_class)
|
||||
except ModuleNotFoundError as err:
|
||||
logger.warning("Requested Module was not found: %s", err)
|
||||
except ValueError as err:
|
||||
raise ProgrammingError(f"Invalid module name: {err}") from err
|
||||
raise NotSupportedError(f"Authentication plugin '{plugin_name}' is not supported")
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,577 @@
|
||||
# Copyright (c) 2023, 2024, Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
# mypy: disable-error-code="str-bytes-safe,misc"
|
||||
|
||||
"""Kerberos Authentication Plugin."""
|
||||
|
||||
import getpass
|
||||
import os
|
||||
import struct
|
||||
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Optional, Tuple
|
||||
|
||||
from mysql.connector.errors import InterfaceError, ProgrammingError
|
||||
from mysql.connector.logger import logger
|
||||
|
||||
from ..authentication import ERR_STATUS
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..network import MySQLSocket
|
||||
|
||||
try:
|
||||
import gssapi
|
||||
except ImportError:
|
||||
gssapi = None
|
||||
if os.name != "nt":
|
||||
raise ProgrammingError(
|
||||
"Module gssapi is required for GSSAPI authentication "
|
||||
"mechanism but was not found. Unable to authenticate "
|
||||
"with the server"
|
||||
) from None
|
||||
|
||||
try:
|
||||
import sspi
|
||||
import sspicon
|
||||
except ImportError:
|
||||
sspi = None
|
||||
sspicon = None
|
||||
|
||||
from . import MySQLAuthPlugin
|
||||
|
||||
AUTHENTICATION_PLUGIN_CLASS = (
|
||||
"MySQLSSPIKerberosAuthPlugin" if os.name == "nt" else "MySQLKerberosAuthPlugin"
|
||||
)
|
||||
|
||||
|
||||
class MySQLBaseKerberosAuthPlugin(MySQLAuthPlugin):
|
||||
"""Base class for the MySQL Kerberos authentication plugin."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Plugin official name."""
|
||||
return "authentication_kerberos_client"
|
||||
|
||||
@property
|
||||
def requires_ssl(self) -> bool:
|
||||
"""Signals whether or not SSL is required."""
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def auth_continue(
|
||||
self, tgt_auth_challenge: Optional[bytes]
|
||||
) -> Tuple[Optional[bytes], bool]:
|
||||
"""Continue with the Kerberos TGT service request.
|
||||
|
||||
With the TGT authentication service given response generate a TGT
|
||||
service request. This method must be invoked sequentially (in a loop)
|
||||
until the security context is completed and an empty response needs to
|
||||
be send to acknowledge the server.
|
||||
|
||||
Args:
|
||||
tgt_auth_challenge: the challenge for the negotiation.
|
||||
|
||||
Returns:
|
||||
tuple (bytearray TGS service request,
|
||||
bool True if context is completed otherwise False).
|
||||
"""
|
||||
|
||||
async def auth_switch_response(
|
||||
self, sock: "MySQLSocket", auth_data: bytes, **kwargs: Any
|
||||
) -> bytes:
|
||||
"""Handles server's `auth switch request` response.
|
||||
|
||||
Args:
|
||||
sock: Pointer to the socket connection.
|
||||
auth_data: Plugin provided data (extracted from a packet
|
||||
representing an `auth switch request` response).
|
||||
kwargs: Custom configuration to be passed to the auth plugin
|
||||
when invoked. The parameters defined here will override the ones
|
||||
defined in the auth plugin itself.
|
||||
|
||||
Returns:
|
||||
packet: Last server's response after back-and-forth
|
||||
communication.
|
||||
"""
|
||||
logger.debug("# auth_data: %s", auth_data)
|
||||
response = self.auth_response(auth_data, ignore_auth_data=False, **kwargs)
|
||||
if response is None:
|
||||
raise InterfaceError("Got a NULL auth response")
|
||||
|
||||
logger.debug("# request: %s size: %s", response, len(response))
|
||||
await sock.write(response)
|
||||
|
||||
packet = await sock.read()
|
||||
logger.debug("# server response packet: %s", packet)
|
||||
|
||||
if packet != ERR_STATUS:
|
||||
rcode_size = 5 # Reader size for the response status code
|
||||
logger.debug("# Continue with GSSAPI authentication")
|
||||
logger.debug("# Response header: %s", packet[: rcode_size + 1])
|
||||
logger.debug("# Response size: %s", len(packet))
|
||||
logger.debug("# Negotiate a service request")
|
||||
complete = False
|
||||
tries = 0
|
||||
|
||||
while not complete and tries < 5:
|
||||
logger.debug("%s Attempt %s %s", "-" * 20, tries + 1, "-" * 20)
|
||||
logger.debug("<< Server response: %s", packet)
|
||||
logger.debug("# Response code: %s", packet[: rcode_size + 1])
|
||||
token, complete = self.auth_continue(packet[rcode_size:])
|
||||
if token:
|
||||
await sock.write(token)
|
||||
if complete:
|
||||
break
|
||||
packet = await sock.read()
|
||||
|
||||
logger.debug(">> Response to server: %s", token)
|
||||
tries += 1
|
||||
|
||||
if not complete:
|
||||
raise InterfaceError(
|
||||
f"Unable to fulfill server request after {tries} "
|
||||
f"attempts. Last server response: {packet}"
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Last response from server: %s length: %d",
|
||||
packet,
|
||||
len(packet),
|
||||
)
|
||||
|
||||
# Receive OK packet from server.
|
||||
packet = await sock.read()
|
||||
logger.debug("<< Ok packet from server: %s", packet)
|
||||
|
||||
return bytes(packet)
|
||||
|
||||
|
||||
# pylint: disable=c-extension-no-member,no-member
|
||||
class MySQLKerberosAuthPlugin(MySQLBaseKerberosAuthPlugin):
|
||||
"""Implement the MySQL Kerberos authentication plugin."""
|
||||
|
||||
context: Optional[gssapi.SecurityContext] = None
|
||||
|
||||
@staticmethod
|
||||
def get_user_from_credentials() -> str:
|
||||
"""Get user from credentials without realm."""
|
||||
try:
|
||||
creds = gssapi.Credentials(usage="initiate")
|
||||
user = str(creds.name)
|
||||
if user.find("@") != -1:
|
||||
user, _ = user.split("@", 1)
|
||||
return user
|
||||
except gssapi.raw.misc.GSSError:
|
||||
return getpass.getuser()
|
||||
|
||||
@staticmethod
|
||||
def get_store() -> dict:
|
||||
"""Get a credentials store dictionary.
|
||||
|
||||
Returns:
|
||||
dict: Credentials store dictionary with the krb5 ccache name.
|
||||
|
||||
Raises:
|
||||
InterfaceError: If 'KRB5CCNAME' environment variable is empty.
|
||||
"""
|
||||
krb5ccname = os.environ.get(
|
||||
"KRB5CCNAME",
|
||||
(
|
||||
f"/tmp/krb5cc_{os.getuid()}"
|
||||
if os.name == "posix"
|
||||
else Path("%TEMP%").joinpath("krb5cc")
|
||||
),
|
||||
)
|
||||
if not krb5ccname:
|
||||
raise InterfaceError(
|
||||
"The 'KRB5CCNAME' environment variable is set to empty"
|
||||
)
|
||||
logger.debug("Using krb5 ccache name: FILE:%s", krb5ccname)
|
||||
store = {b"ccache": f"FILE:{krb5ccname}".encode("utf-8")}
|
||||
return store
|
||||
|
||||
def _acquire_cred_with_password(self, upn: str) -> gssapi.raw.creds.Creds:
|
||||
"""Acquire and store credentials through provided password.
|
||||
|
||||
Args:
|
||||
upn (str): User Principal Name.
|
||||
|
||||
Returns:
|
||||
gssapi.raw.creds.Creds: GSSAPI credentials.
|
||||
"""
|
||||
logger.debug("Attempt to acquire credentials through provided password")
|
||||
user = gssapi.Name(upn, gssapi.NameType.user)
|
||||
password = self._password.encode("utf-8")
|
||||
|
||||
try:
|
||||
acquire_cred_result = gssapi.raw.acquire_cred_with_password(
|
||||
user, password, usage="initiate"
|
||||
)
|
||||
creds = acquire_cred_result.creds
|
||||
gssapi.raw.store_cred_into(
|
||||
self.get_store(),
|
||||
creds=creds,
|
||||
mech=gssapi.MechType.kerberos,
|
||||
overwrite=True,
|
||||
set_default=True,
|
||||
)
|
||||
except gssapi.raw.misc.GSSError as err:
|
||||
raise ProgrammingError(
|
||||
f"Unable to acquire credentials with the given password: {err}"
|
||||
) from err
|
||||
return creds
|
||||
|
||||
@staticmethod
|
||||
def _parse_auth_data(packet: bytes) -> Tuple[str, str]:
|
||||
"""Parse authentication data.
|
||||
|
||||
Get the SPN and REALM from the authentication data packet.
|
||||
|
||||
Format:
|
||||
SPN string length two bytes <B1> <B2> +
|
||||
SPN string +
|
||||
UPN realm string length two bytes <B1> <B2> +
|
||||
UPN realm string
|
||||
|
||||
Returns:
|
||||
tuple: With 'spn' and 'realm'.
|
||||
"""
|
||||
spn_len = struct.unpack("<H", packet[:2])[0]
|
||||
packet = packet[2:]
|
||||
|
||||
spn = struct.unpack(f"<{spn_len}s", packet[:spn_len])[0]
|
||||
packet = packet[spn_len:]
|
||||
|
||||
realm_len = struct.unpack("<H", packet[:2])[0]
|
||||
realm = struct.unpack(f"<{realm_len}s", packet[2:])[0]
|
||||
|
||||
return spn.decode(), realm.decode()
|
||||
|
||||
def auth_response(
|
||||
self, auth_data: Optional[bytes] = None, **kwargs: Any
|
||||
) -> Optional[bytes]:
|
||||
"""Prepare the first message to the server."""
|
||||
spn = None
|
||||
realm = None
|
||||
|
||||
if auth_data and not kwargs.get("ignore_auth_data", True):
|
||||
try:
|
||||
spn, realm = self._parse_auth_data(auth_data)
|
||||
except struct.error as err:
|
||||
raise InterruptedError(f"Invalid authentication data: {err}") from err
|
||||
|
||||
if spn is None:
|
||||
return self._password.encode() + b"\x00"
|
||||
|
||||
upn = f"{self._username}@{realm}" if self._username else None
|
||||
|
||||
logger.debug("Service Principal: %s", spn)
|
||||
logger.debug("Realm: %s", realm)
|
||||
|
||||
try:
|
||||
# Attempt to retrieve credentials from cache file
|
||||
creds: Any = gssapi.Credentials(usage="initiate")
|
||||
creds_upn = str(creds.name)
|
||||
|
||||
logger.debug("Cached credentials found")
|
||||
logger.debug("Cached credentials UPN: %s", creds_upn)
|
||||
|
||||
# Remove the realm from user
|
||||
if creds_upn.find("@") != -1:
|
||||
creds_user, creds_realm = creds_upn.split("@", 1)
|
||||
else:
|
||||
creds_user = creds_upn
|
||||
creds_realm = None
|
||||
|
||||
upn = f"{self._username}@{realm}" if self._username else creds_upn
|
||||
|
||||
# The user from cached credentials matches with the given user?
|
||||
if self._username and self._username != creds_user:
|
||||
logger.debug(
|
||||
"The user from cached credentials doesn't match with the "
|
||||
"given user"
|
||||
)
|
||||
if self._password is not None:
|
||||
creds = self._acquire_cred_with_password(upn)
|
||||
if creds_realm and creds_realm != realm and self._password is not None:
|
||||
creds = self._acquire_cred_with_password(upn)
|
||||
except gssapi.raw.exceptions.ExpiredCredentialsError as err:
|
||||
if upn and self._password is not None:
|
||||
creds = self._acquire_cred_with_password(upn)
|
||||
else:
|
||||
raise InterfaceError(f"Credentials has expired: {err}") from err
|
||||
except gssapi.raw.misc.GSSError as err:
|
||||
if upn and self._password is not None:
|
||||
creds = self._acquire_cred_with_password(upn)
|
||||
else:
|
||||
raise InterfaceError(
|
||||
f"Unable to retrieve cached credentials error: {err}"
|
||||
) from err
|
||||
|
||||
flags = (
|
||||
gssapi.RequirementFlag.mutual_authentication,
|
||||
gssapi.RequirementFlag.extended_error,
|
||||
gssapi.RequirementFlag.delegate_to_peer,
|
||||
)
|
||||
name = gssapi.Name(spn, name_type=gssapi.NameType.kerberos_principal)
|
||||
cname = name.canonicalize(gssapi.MechType.kerberos)
|
||||
self.context = gssapi.SecurityContext(
|
||||
name=cname, creds=creds, flags=sum(flags), usage="initiate"
|
||||
)
|
||||
|
||||
try:
|
||||
initial_client_token: Optional[bytes] = self.context.step()
|
||||
except gssapi.raw.misc.GSSError as err:
|
||||
raise InterfaceError(f"Unable to initiate security context: {err}") from err
|
||||
|
||||
logger.debug("Initial client token: %s", initial_client_token)
|
||||
return initial_client_token
|
||||
|
||||
def auth_continue(
|
||||
self, tgt_auth_challenge: Optional[bytes]
|
||||
) -> Tuple[Optional[bytes], bool]:
|
||||
"""Continue with the Kerberos TGT service request.
|
||||
|
||||
With the TGT authentication service given response generate a TGT
|
||||
service request. This method must be invoked sequentially (in a loop)
|
||||
until the security context is completed and an empty response needs to
|
||||
be send to acknowledge the server.
|
||||
|
||||
Args:
|
||||
tgt_auth_challenge: the challenge for the negotiation.
|
||||
|
||||
Returns:
|
||||
tuple (bytearray TGS service request,
|
||||
bool True if context is completed otherwise False).
|
||||
"""
|
||||
logger.debug("tgt_auth challenge: %s", tgt_auth_challenge)
|
||||
|
||||
resp: Optional[bytes] = self.context.step(tgt_auth_challenge)
|
||||
|
||||
logger.debug("Context step response: %s", resp)
|
||||
logger.debug("Context completed?: %s", self.context.complete)
|
||||
|
||||
return resp, self.context.complete
|
||||
|
||||
def auth_accept_close_handshake(self, message: bytes) -> bytes:
|
||||
"""Accept handshake and generate closing handshake message for server.
|
||||
|
||||
This method verifies the server authenticity from the given message
|
||||
and included signature and generates the closing handshake for the
|
||||
server.
|
||||
|
||||
When this method is invoked the security context is already established
|
||||
and the client and server can send GSSAPI formated secure messages.
|
||||
|
||||
To finish the authentication handshake the server sends a message
|
||||
with the security layer availability and the maximum buffer size.
|
||||
|
||||
Since the connector only uses the GSSAPI authentication mechanism to
|
||||
authenticate the user with the server, the server will verify clients
|
||||
message signature and terminate the GSSAPI authentication and send two
|
||||
messages; an authentication acceptance b'\x01\x00\x00\x08\x01' and a
|
||||
OK packet (that must be received after sent the returned message from
|
||||
this method).
|
||||
|
||||
Args:
|
||||
message: a wrapped gssapi message from the server.
|
||||
|
||||
Returns:
|
||||
bytearray (closing handshake message to be send to the server).
|
||||
"""
|
||||
if not self.context.complete:
|
||||
raise ProgrammingError("Security context is not completed")
|
||||
logger.debug("Server message: %s", message)
|
||||
logger.debug("GSSAPI flags in use: %s", self.context.actual_flags)
|
||||
try:
|
||||
unwraped = self.context.unwrap(message)
|
||||
logger.debug("Unwraped: %s", unwraped)
|
||||
except gssapi.raw.exceptions.BadMICError as err:
|
||||
logger.debug("Unable to unwrap server message: %s", err)
|
||||
raise InterfaceError(f"Unable to unwrap server message: {err}") from err
|
||||
|
||||
logger.debug("Unwrapped server message: %s", unwraped)
|
||||
# The message contents for the clients closing message:
|
||||
# - security level 1 byte, must be always 1.
|
||||
# - conciliated buffer size 3 bytes, without importance as no
|
||||
# further GSSAPI messages will be sends.
|
||||
response = bytearray(b"\x01\x00\x00\00")
|
||||
# Closing handshake must not be encrypted.
|
||||
logger.debug("Message response: %s", response)
|
||||
wraped = self.context.wrap(response, encrypt=False)
|
||||
logger.debug(
|
||||
"Wrapped message response: %s, length: %d",
|
||||
wraped[0],
|
||||
len(wraped[0]),
|
||||
)
|
||||
|
||||
return wraped.message
|
||||
|
||||
|
||||
class MySQLSSPIKerberosAuthPlugin(MySQLBaseKerberosAuthPlugin):
|
||||
"""Implement the MySQL Kerberos authentication plugin with Windows SSPI"""
|
||||
|
||||
context: Any = None
|
||||
clientauth: Any = None
|
||||
|
||||
@staticmethod
|
||||
def _parse_auth_data(packet: bytes) -> Tuple[str, str]:
|
||||
"""Parse authentication data.
|
||||
|
||||
Get the SPN and REALM from the authentication data packet.
|
||||
|
||||
Format:
|
||||
SPN string length two bytes <B1> <B2> +
|
||||
SPN string +
|
||||
UPN realm string length two bytes <B1> <B2> +
|
||||
UPN realm string
|
||||
|
||||
Returns:
|
||||
tuple: With 'spn' and 'realm'.
|
||||
"""
|
||||
spn_len = struct.unpack("<H", packet[:2])[0]
|
||||
packet = packet[2:]
|
||||
|
||||
spn = struct.unpack(f"<{spn_len}s", packet[:spn_len])[0]
|
||||
packet = packet[spn_len:]
|
||||
|
||||
realm_len = struct.unpack("<H", packet[:2])[0]
|
||||
realm = struct.unpack(f"<{realm_len}s", packet[2:])[0]
|
||||
|
||||
return spn.decode(), realm.decode()
|
||||
|
||||
def auth_response(
|
||||
self, auth_data: Optional[bytes] = None, **kwargs: Any
|
||||
) -> Optional[bytes]:
|
||||
"""Prepare the first message to the server.
|
||||
|
||||
Args:
|
||||
kwargs:
|
||||
ignore_auth_data (bool): if True, the provided auth data is ignored.
|
||||
"""
|
||||
logger.debug("auth_response for sspi")
|
||||
spn = None
|
||||
realm = None
|
||||
|
||||
if auth_data and not kwargs.get("ignore_auth_data", True):
|
||||
try:
|
||||
spn, realm = self._parse_auth_data(auth_data)
|
||||
except struct.error as err:
|
||||
raise InterruptedError(f"Invalid authentication data: {err}") from err
|
||||
|
||||
logger.debug("Service Principal: %s", spn)
|
||||
logger.debug("Realm: %s", realm)
|
||||
|
||||
if sspicon is None or sspi is None:
|
||||
raise ProgrammingError(
|
||||
'Package "pywin32" (Python for Win32 (pywin32) extensions)'
|
||||
" is not installed."
|
||||
)
|
||||
|
||||
flags = (sspicon.ISC_REQ_MUTUAL_AUTH, sspicon.ISC_REQ_DELEGATE)
|
||||
|
||||
if self._username and self._password:
|
||||
_auth_info = (self._username, realm, self._password)
|
||||
else:
|
||||
_auth_info = None
|
||||
|
||||
targetspn = spn
|
||||
logger.debug("targetspn: %s", targetspn)
|
||||
logger.debug("_auth_info is None: %s", _auth_info is None)
|
||||
|
||||
# The Security Support Provider Interface (SSPI) is an interface
|
||||
# that allows us to choose from a set of SSPs available in the
|
||||
# system; the idea of SSPI is to keep interface consistent no
|
||||
# matter what back end (a.k.a., SSP) we choose.
|
||||
|
||||
# When using SSPI we should not use Kerberos directly as SSP,
|
||||
# as remarked in [2], but we can use it indirectly via another
|
||||
# SSP named Negotiate that acts as an application layer between
|
||||
# SSPI and the other SSPs [1].
|
||||
|
||||
# Negotiate can select between Kerberos and NTLM on the fly;
|
||||
# it chooses Kerberos unless it cannot be used by one of the
|
||||
# systems involved in the authentication or the calling
|
||||
# application did not provide sufficient information to use
|
||||
# Kerberos.
|
||||
|
||||
# prefix: https://docs.microsoft.com/en-us/windows/win32/secauthn
|
||||
# [1] prefix/microsoft-negotiate?source=recommendations
|
||||
# [2] prefix/microsoft-kerberos?source=recommendations
|
||||
self.clientauth = sspi.ClientAuth(
|
||||
"Negotiate",
|
||||
targetspn=targetspn,
|
||||
auth_info=_auth_info,
|
||||
scflags=sum(flags),
|
||||
datarep=sspicon.SECURITY_NETWORK_DREP,
|
||||
)
|
||||
|
||||
try:
|
||||
data = None
|
||||
err, out_buf = self.clientauth.authorize(data)
|
||||
logger.debug("Context step err: %s", err)
|
||||
logger.debug("Context step out_buf: %s", out_buf)
|
||||
logger.debug("Context completed?: %s", self.clientauth.authenticated)
|
||||
initial_client_token = out_buf[0].Buffer
|
||||
logger.debug("pkg_info: %s", self.clientauth.pkg_info)
|
||||
except Exception as err:
|
||||
raise InterfaceError(f"Unable to initiate security context: {err}") from err
|
||||
|
||||
logger.debug("Initial client token: %s", initial_client_token)
|
||||
return initial_client_token
|
||||
|
||||
def auth_continue(
|
||||
self, tgt_auth_challenge: Optional[bytes]
|
||||
) -> Tuple[Optional[bytes], bool]:
|
||||
"""Continue with the Kerberos TGT service request.
|
||||
|
||||
With the TGT authentication service given response generate a TGT
|
||||
service request. This method must be invoked sequentially (in a loop)
|
||||
until the security context is completed and an empty response needs to
|
||||
be send to acknowledge the server.
|
||||
|
||||
Args:
|
||||
tgt_auth_challenge: the challenge for the negotiation.
|
||||
|
||||
Returns:
|
||||
tuple (bytearray TGS service request,
|
||||
bool True if context is completed otherwise False).
|
||||
"""
|
||||
logger.debug("tgt_auth challenge: %s", tgt_auth_challenge)
|
||||
|
||||
err, out_buf = self.clientauth.authorize(tgt_auth_challenge)
|
||||
|
||||
logger.debug("Context step err: %s", err)
|
||||
logger.debug("Context step out_buf: %s", out_buf)
|
||||
resp = out_buf[0].Buffer
|
||||
logger.debug("Context step resp: %s", resp)
|
||||
logger.debug("Context completed?: %s", self.clientauth.authenticated)
|
||||
|
||||
return resp, self.clientauth.authenticated
|
||||
@@ -0,0 +1,595 @@
|
||||
# Copyright (c) 2023, 2024, Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""LDAP SASL Authentication Plugin."""
|
||||
|
||||
import hmac
|
||||
|
||||
from base64 import b64decode, b64encode
|
||||
from hashlib import sha1, sha256
|
||||
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
|
||||
from uuid import uuid4
|
||||
|
||||
from mysql.connector.authentication import ERR_STATUS
|
||||
from mysql.connector.errors import InterfaceError, ProgrammingError
|
||||
from mysql.connector.logger import logger
|
||||
from mysql.connector.types import StrOrBytes
|
||||
from mysql.connector.utils import (
|
||||
normalize_unicode_string as norm_ustr,
|
||||
validate_normalized_unicode_string as valid_norm,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..network import MySQLSocket
|
||||
|
||||
try:
|
||||
import gssapi
|
||||
except ImportError:
|
||||
raise ProgrammingError(
|
||||
"Module gssapi is required for GSSAPI authentication "
|
||||
"mechanism but was not found. Unable to authenticate "
|
||||
"with the server"
|
||||
) from None
|
||||
|
||||
from . import MySQLAuthPlugin
|
||||
|
||||
AUTHENTICATION_PLUGIN_CLASS = "MySQLLdapSaslPasswordAuthPlugin"
|
||||
|
||||
|
||||
# pylint: disable=c-extension-no-member,no-member
|
||||
class MySQLLdapSaslPasswordAuthPlugin(MySQLAuthPlugin):
|
||||
"""Class implementing the MySQL ldap sasl authentication plugin.
|
||||
|
||||
The MySQL's ldap sasl authentication plugin support two authentication
|
||||
methods SCRAM-SHA-1 and GSSAPI (using Kerberos). This implementation only
|
||||
support SCRAM-SHA-1 and SCRAM-SHA-256.
|
||||
|
||||
SCRAM-SHA-1 amd SCRAM-SHA-256
|
||||
This method requires 2 messages from client and 2 responses from
|
||||
server.
|
||||
|
||||
The first message from client will be generated by prepare_password(),
|
||||
after receive the response from the server, it is required that this
|
||||
response is passed back to auth_continue() which will return the
|
||||
second message from the client. After send this second message to the
|
||||
server, the second server respond needs to be passed to auth_finalize()
|
||||
to finish the authentication process.
|
||||
"""
|
||||
|
||||
sasl_mechanisms: List[str] = ["SCRAM-SHA-1", "SCRAM-SHA-256", "GSSAPI"]
|
||||
def_digest_mode: Callable = sha1
|
||||
client_nonce: Optional[str] = None
|
||||
client_salt: Any = None
|
||||
server_salt: Optional[str] = None
|
||||
krb_service_principal: Optional[str] = None
|
||||
iterations: int = 0
|
||||
server_auth_var: Optional[str] = None
|
||||
target_name: Optional[gssapi.Name] = None
|
||||
ctx: gssapi.SecurityContext = None
|
||||
servers_first: Optional[str] = None
|
||||
server_nonce: Optional[str] = None
|
||||
|
||||
@staticmethod
|
||||
def _xor(bytes1: bytes, bytes2: bytes) -> bytes:
|
||||
return bytes([b1 ^ b2 for b1, b2 in zip(bytes1, bytes2)])
|
||||
|
||||
def _hmac(self, password: bytes, salt: bytes) -> bytes:
|
||||
digest_maker = hmac.new(password, salt, self.def_digest_mode)
|
||||
return digest_maker.digest()
|
||||
|
||||
def _hi(self, password: str, salt: bytes, count: int) -> bytes:
|
||||
"""Prepares Hi
|
||||
Hi(password, salt, iterations) where Hi(p,s,i) is defined as
|
||||
PBKDF2 (HMAC, p, s, i, output length of H).
|
||||
"""
|
||||
pw = password.encode()
|
||||
hi = self._hmac(pw, salt + b"\x00\x00\x00\x01")
|
||||
aux = hi
|
||||
for _ in range(count - 1):
|
||||
aux = self._hmac(pw, aux)
|
||||
hi = self._xor(hi, aux)
|
||||
return hi
|
||||
|
||||
@staticmethod
|
||||
def _normalize(string: str) -> str:
|
||||
norm_str = norm_ustr(string)
|
||||
broken_rule = valid_norm(norm_str)
|
||||
if broken_rule is not None:
|
||||
raise InterfaceError(f"broken_rule: {broken_rule}")
|
||||
return norm_str
|
||||
|
||||
def _first_message(self) -> bytes:
|
||||
"""This method generates the first message to the server to start the
|
||||
|
||||
The client-first message consists of a gs2-header,
|
||||
the desired username, and a randomly generated client nonce cnonce.
|
||||
|
||||
The first message from the server has the form:
|
||||
b'n,a=<user_name>,n=<user_name>,r=<client_nonce>
|
||||
|
||||
Returns client's first message
|
||||
"""
|
||||
cfm_fprnat = "n,a={user_name},n={user_name},r={client_nonce}"
|
||||
self.client_nonce = str(uuid4()).replace("-", "")
|
||||
cfm: StrOrBytes = cfm_fprnat.format(
|
||||
user_name=self._normalize(self._username),
|
||||
client_nonce=self.client_nonce,
|
||||
)
|
||||
|
||||
if isinstance(cfm, str):
|
||||
cfm = cfm.encode("utf8")
|
||||
return cfm
|
||||
|
||||
def _first_message_krb(self) -> Optional[bytes]:
|
||||
"""Get a TGT Authentication request and initiates security context.
|
||||
|
||||
This method will contact the Kerberos KDC in order of obtain a TGT.
|
||||
"""
|
||||
user_name = gssapi.raw.names.import_name(
|
||||
self._username.encode("utf8"), name_type=gssapi.NameType.user
|
||||
)
|
||||
|
||||
# Use defaults store = {'ccache': 'FILE:/tmp/krb5cc_1000'}#,
|
||||
# 'keytab':'/etc/some.keytab' }
|
||||
# Attempt to retrieve credential from default cache file.
|
||||
try:
|
||||
cred: Any = gssapi.Credentials()
|
||||
logger.debug(
|
||||
"# Stored credentials found, if password was given it will be ignored."
|
||||
)
|
||||
try:
|
||||
# validate credentials has not expired.
|
||||
cred.lifetime
|
||||
except gssapi.raw.exceptions.ExpiredCredentialsError as err:
|
||||
logger.warning(" Credentials has expired: %s", err)
|
||||
cred.acquire(user_name)
|
||||
raise InterfaceError(f"Credentials has expired: {err}") from err
|
||||
except gssapi.raw.misc.GSSError as err:
|
||||
if not self._password:
|
||||
raise InterfaceError(
|
||||
f"Unable to retrieve stored credentials error: {err}"
|
||||
) from err
|
||||
try:
|
||||
logger.debug("# Attempt to retrieve credentials with given password")
|
||||
acquire_cred_result = gssapi.raw.acquire_cred_with_password(
|
||||
user_name,
|
||||
self._password.encode("utf8"),
|
||||
usage="initiate",
|
||||
)
|
||||
cred = acquire_cred_result[0]
|
||||
except gssapi.raw.misc.GSSError as err2:
|
||||
raise ProgrammingError(
|
||||
f"Unable to retrieve credentials with the given password: {err2}"
|
||||
) from err
|
||||
|
||||
flags_l = (
|
||||
gssapi.RequirementFlag.mutual_authentication,
|
||||
gssapi.RequirementFlag.extended_error,
|
||||
gssapi.RequirementFlag.delegate_to_peer,
|
||||
)
|
||||
|
||||
if self.krb_service_principal:
|
||||
service_principal = self.krb_service_principal
|
||||
else:
|
||||
service_principal = "ldap/ldapauth"
|
||||
logger.debug("# service principal: %s", service_principal)
|
||||
servk = gssapi.Name(
|
||||
service_principal, name_type=gssapi.NameType.kerberos_principal
|
||||
)
|
||||
self.target_name = servk
|
||||
self.ctx = gssapi.SecurityContext(
|
||||
name=servk, creds=cred, flags=sum(flags_l), usage="initiate"
|
||||
)
|
||||
|
||||
try:
|
||||
# step() returns bytes | None, see documentation,
|
||||
# so this method could return a NULL payload.
|
||||
# ref: https://pythongssapi.github.io/<suffix>
|
||||
# suffix: python-gssapi/latest/gssapi.html#gssapi.sec_contexts.SecurityContext
|
||||
initial_client_token = self.ctx.step()
|
||||
except gssapi.raw.misc.GSSError as err:
|
||||
raise InterfaceError(f"Unable to initiate security context: {err}") from err
|
||||
|
||||
logger.debug("# initial client token: %s", initial_client_token)
|
||||
return initial_client_token
|
||||
|
||||
def auth_continue_krb(
|
||||
self, tgt_auth_challenge: Optional[bytes]
|
||||
) -> Tuple[Optional[bytes], bool]:
|
||||
"""Continue with the Kerberos TGT service request.
|
||||
|
||||
With the TGT authentication service given response generate a TGT
|
||||
service request. This method must be invoked sequentially (in a loop)
|
||||
until the security context is completed and an empty response needs to
|
||||
be send to acknowledge the server.
|
||||
|
||||
Args:
|
||||
tgt_auth_challenge the challenge for the negotiation.
|
||||
|
||||
Returns: tuple (bytearray TGS service request,
|
||||
bool True if context is completed otherwise False).
|
||||
"""
|
||||
logger.debug("tgt_auth challenge: %s", tgt_auth_challenge)
|
||||
|
||||
resp = self.ctx.step(tgt_auth_challenge)
|
||||
logger.debug("# context step response: %s", resp)
|
||||
logger.debug("# context completed?: %s", self.ctx.complete)
|
||||
|
||||
return resp, self.ctx.complete
|
||||
|
||||
def auth_accept_close_handshake(self, message: bytes) -> bytes:
|
||||
"""Accept handshake and generate closing handshake message for server.
|
||||
|
||||
This method verifies the server authenticity from the given message
|
||||
and included signature and generates the closing handshake for the
|
||||
server.
|
||||
|
||||
When this method is invoked the security context is already established
|
||||
and the client and server can send GSSAPI formated secure messages.
|
||||
|
||||
To finish the authentication handshake the server sends a message
|
||||
with the security layer availability and the maximum buffer size.
|
||||
|
||||
Since the connector only uses the GSSAPI authentication mechanism to
|
||||
authenticate the user with the server, the server will verify clients
|
||||
message signature and terminate the GSSAPI authentication and send two
|
||||
messages; an authentication acceptance b'\x01\x00\x00\x08\x01' and a
|
||||
OK packet (that must be received after sent the returned message from
|
||||
this method).
|
||||
|
||||
Args:
|
||||
message a wrapped hssapi message from the server.
|
||||
|
||||
Returns: bytearray closing handshake message to be send to the server.
|
||||
"""
|
||||
if not self.ctx.complete:
|
||||
raise ProgrammingError("Security context is not completed.")
|
||||
logger.debug("# servers message: %s", message)
|
||||
logger.debug("# GSSAPI flags in use: %s", self.ctx.actual_flags)
|
||||
try:
|
||||
unwraped = self.ctx.unwrap(message)
|
||||
logger.debug("# unwraped: %s", unwraped)
|
||||
except gssapi.raw.exceptions.BadMICError as err:
|
||||
raise InterfaceError(f"Unable to unwrap server message: {err}") from err
|
||||
|
||||
logger.debug("# unwrapped server message: %s", unwraped)
|
||||
# The message contents for the clients closing message:
|
||||
# - security level 1 byte, must be always 1.
|
||||
# - conciliated buffer size 3 bytes, without importance as no
|
||||
# further GSSAPI messages will be sends.
|
||||
response = bytearray(b"\x01\x00\x00\00")
|
||||
# Closing handshake must not be encrypted.
|
||||
logger.debug("# message response: %s", response)
|
||||
wraped = self.ctx.wrap(response, encrypt=False)
|
||||
logger.debug(
|
||||
"# wrapped message response: %s, length: %d",
|
||||
wraped[0],
|
||||
len(wraped[0]),
|
||||
)
|
||||
|
||||
return wraped.message
|
||||
|
||||
def auth_response(
|
||||
self,
|
||||
auth_data: bytes,
|
||||
**kwargs: Any,
|
||||
) -> Optional[bytes]:
|
||||
"""This method will prepare the fist message to the server.
|
||||
|
||||
Returns bytes to send to the server as the first message.
|
||||
"""
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
self._auth_data = auth_data
|
||||
|
||||
auth_mechanism = self._auth_data.decode()
|
||||
logger.debug("read_method_name_from_server: %s", auth_mechanism)
|
||||
if auth_mechanism not in self.sasl_mechanisms:
|
||||
auth_mechanisms = '", "'.join(self.sasl_mechanisms[:-1])
|
||||
raise InterfaceError(
|
||||
f'The sasl authentication method "{auth_mechanism}" requested '
|
||||
f'from the server is not supported. Only "{auth_mechanisms}" '
|
||||
f'and "{self.sasl_mechanisms[-1]}" are supported'
|
||||
)
|
||||
|
||||
if b"GSSAPI" in self._auth_data:
|
||||
return self._first_message_krb()
|
||||
|
||||
if self._auth_data == b"SCRAM-SHA-256":
|
||||
self.def_digest_mode = sha256
|
||||
|
||||
return self._first_message()
|
||||
|
||||
def _second_message(self) -> bytes:
|
||||
"""This method generates the second message to the server
|
||||
|
||||
Second message consist on the concatenation of the client and the
|
||||
server nonce, and cproof.
|
||||
|
||||
c=<n,a=<user_name>>,r=<server_nonce>,p=<client_proof>
|
||||
where:
|
||||
<client_proof>: xor(<client_key>, <client_signature>)
|
||||
|
||||
<client_key>: hmac(salted_password, b"Client Key")
|
||||
<client_signature>: hmac(<stored_key>, <auth_msg>)
|
||||
<stored_key>: h(<client_key>)
|
||||
<auth_msg>: <client_first_no_header>,<servers_first>,
|
||||
c=<client_header>,r=<server_nonce>
|
||||
<client_first_no_header>: n=<username>r=<client_nonce>
|
||||
"""
|
||||
if not self._auth_data:
|
||||
raise InterfaceError("Missing authentication data (seed)")
|
||||
|
||||
passw = self._normalize(self._password)
|
||||
salted_password = self._hi(passw, b64decode(self.server_salt), self.iterations)
|
||||
logger.debug("salted_password: %s", b64encode(salted_password).decode())
|
||||
|
||||
client_key = self._hmac(salted_password, b"Client Key")
|
||||
logger.debug("client_key: %s", b64encode(client_key).decode())
|
||||
|
||||
stored_key = self.def_digest_mode(client_key).digest()
|
||||
logger.debug("stored_key: %s", b64encode(stored_key).decode())
|
||||
|
||||
server_key = self._hmac(salted_password, b"Server Key")
|
||||
logger.debug("server_key: %s", b64encode(server_key).decode())
|
||||
|
||||
client_first_no_header = ",".join(
|
||||
[
|
||||
f"n={self._normalize(self._username)}",
|
||||
f"r={self.client_nonce}",
|
||||
]
|
||||
)
|
||||
logger.debug("client_first_no_header: %s", client_first_no_header)
|
||||
|
||||
client_header = b64encode(
|
||||
f"n,a={self._normalize(self._username)},".encode()
|
||||
).decode()
|
||||
|
||||
auth_msg = ",".join(
|
||||
[
|
||||
client_first_no_header,
|
||||
self.servers_first,
|
||||
f"c={client_header}",
|
||||
f"r={self.server_nonce}",
|
||||
]
|
||||
)
|
||||
logger.debug("auth_msg: %s", auth_msg)
|
||||
|
||||
client_signature = self._hmac(stored_key, auth_msg.encode())
|
||||
logger.debug("client_signature: %s", b64encode(client_signature).decode())
|
||||
|
||||
client_proof = self._xor(client_key, client_signature)
|
||||
logger.debug("client_proof: %s", b64encode(client_proof).decode())
|
||||
|
||||
self.server_auth_var = b64encode(
|
||||
self._hmac(server_key, auth_msg.encode())
|
||||
).decode()
|
||||
logger.debug("server_auth_var: %s", self.server_auth_var)
|
||||
|
||||
msg = ",".join(
|
||||
[
|
||||
f"c={client_header}",
|
||||
f"r={self.server_nonce}",
|
||||
f"p={b64encode(client_proof).decode()}",
|
||||
]
|
||||
)
|
||||
logger.debug("second_message: %s", msg)
|
||||
return msg.encode()
|
||||
|
||||
def _validate_first_reponse(self, servers_first: bytes) -> None:
|
||||
"""Validates first message from the server.
|
||||
|
||||
Extracts the server's salt and iterations from the servers 1st response.
|
||||
First message from the server is in the form:
|
||||
<server_salt>,i=<iterations>
|
||||
"""
|
||||
if not servers_first or not isinstance(servers_first, (bytearray, bytes)):
|
||||
raise InterfaceError(f"Unexpected server message: {repr(servers_first)}")
|
||||
try:
|
||||
servers_first_str = servers_first.decode()
|
||||
self.servers_first = servers_first_str
|
||||
r_server_nonce, s_salt, i_counter = servers_first_str.split(",")
|
||||
except ValueError:
|
||||
raise InterfaceError(
|
||||
f"Unexpected server message: {servers_first_str}"
|
||||
) from None
|
||||
if (
|
||||
not r_server_nonce.startswith("r=")
|
||||
or not s_salt.startswith("s=")
|
||||
or not i_counter.startswith("i=")
|
||||
):
|
||||
raise InterfaceError(
|
||||
f"Incomplete reponse from the server: {servers_first_str}"
|
||||
)
|
||||
if self.client_nonce in r_server_nonce:
|
||||
self.server_nonce = r_server_nonce[2:]
|
||||
logger.debug("server_nonce: %s", self.server_nonce)
|
||||
else:
|
||||
raise InterfaceError(
|
||||
"Unable to authenticate response: response not well formed "
|
||||
f"{servers_first_str}"
|
||||
)
|
||||
self.server_salt = s_salt[2:]
|
||||
logger.debug(
|
||||
"server_salt: %s length: %s",
|
||||
self.server_salt,
|
||||
len(self.server_salt),
|
||||
)
|
||||
try:
|
||||
i_counter = i_counter[2:]
|
||||
logger.debug("iterations: %s", i_counter)
|
||||
self.iterations = int(i_counter)
|
||||
except Exception as err:
|
||||
raise InterfaceError(
|
||||
f"Unable to authenticate: iterations not found {servers_first_str}"
|
||||
) from err
|
||||
|
||||
def auth_continue(self, servers_first_response: bytes) -> bytes:
|
||||
"""return the second message from the client.
|
||||
|
||||
Returns bytes to send to the server as the second message.
|
||||
"""
|
||||
self._validate_first_reponse(servers_first_response)
|
||||
return self._second_message()
|
||||
|
||||
def _validate_second_reponse(self, servers_second: bytearray) -> bool:
|
||||
"""Validates second message from the server.
|
||||
|
||||
The client and the server prove to each other they have the same Auth
|
||||
variable.
|
||||
|
||||
The second message from the server consist of the server's proof:
|
||||
server_proof = HMAC(<server_key>, <auth_msg>)
|
||||
where:
|
||||
<server_key>: hmac(<salted_password>, b"Server Key")
|
||||
<auth_msg>: <client_first_no_header>,<servers_first>,
|
||||
c=<client_header>,r=<server_nonce>
|
||||
|
||||
Our server_proof must be equal to the Auth variable send on this second
|
||||
response.
|
||||
"""
|
||||
if (
|
||||
not servers_second
|
||||
or not isinstance(servers_second, bytearray)
|
||||
or len(servers_second) <= 2
|
||||
or not servers_second.startswith(b"v=")
|
||||
):
|
||||
raise InterfaceError("The server's proof is not well formated")
|
||||
server_var = servers_second[2:].decode()
|
||||
logger.debug("server auth variable: %s", server_var)
|
||||
return self.server_auth_var == server_var
|
||||
|
||||
def auth_finalize(self, servers_second_response: bytearray) -> bool:
|
||||
"""finalize the authentication process.
|
||||
|
||||
Raises InterfaceError if the ervers_second_response is invalid.
|
||||
|
||||
Returns True in successful authentication False otherwise.
|
||||
"""
|
||||
if not self._validate_second_reponse(servers_second_response):
|
||||
raise InterfaceError(
|
||||
"Authentication failed: Unable to proof server identity"
|
||||
)
|
||||
return True
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Plugin official name."""
|
||||
return "authentication_ldap_sasl_client"
|
||||
|
||||
@property
|
||||
def requires_ssl(self) -> bool:
|
||||
"""Signals whether or not SSL is required."""
|
||||
return False
|
||||
|
||||
async def auth_switch_response(
|
||||
self, sock: "MySQLSocket", auth_data: bytes, **kwargs: Any
|
||||
) -> bytes:
|
||||
"""Handles server's `auth switch request` response.
|
||||
|
||||
Args:
|
||||
sock: Pointer to the socket connection.
|
||||
auth_data: Plugin provided data (extracted from a packet
|
||||
representing an `auth switch request` response).
|
||||
kwargs: Custom configuration to be passed to the auth plugin
|
||||
when invoked. The parameters defined here will override the ones
|
||||
defined in the auth plugin itself.
|
||||
|
||||
Returns:
|
||||
packet: Last server's response after back-and-forth
|
||||
communication.
|
||||
"""
|
||||
logger.debug("# auth_data: %s", auth_data)
|
||||
self.krb_service_principal = kwargs.get("krb_service_principal")
|
||||
|
||||
response = self.auth_response(auth_data, **kwargs)
|
||||
if response is None:
|
||||
raise InterfaceError("Got a NULL auth response")
|
||||
|
||||
logger.debug("# request: %s size: %s", response, len(response))
|
||||
await sock.write(response)
|
||||
|
||||
packet = await sock.read()
|
||||
logger.debug("# server response packet: %s", packet)
|
||||
|
||||
if len(packet) >= 6 and packet[5] == 114 and packet[6] == 61: # 'r' and '='
|
||||
# Continue with sasl authentication
|
||||
dec_response = packet[5:]
|
||||
cresponse = self.auth_continue(dec_response)
|
||||
await sock.write(cresponse)
|
||||
packet = await sock.read()
|
||||
if packet[5] == 118 and packet[6] == 61: # 'v' and '='
|
||||
if self.auth_finalize(packet[5:]):
|
||||
# receive packed OK
|
||||
packet = await sock.read()
|
||||
elif auth_data == b"GSSAPI" and packet[4] != ERR_STATUS:
|
||||
rcode_size = 5 # header size for the response status code.
|
||||
logger.debug("# Continue with sasl GSSAPI authentication")
|
||||
logger.debug("# response header: %s", packet[: rcode_size + 1])
|
||||
logger.debug("# response size: %s", len(packet))
|
||||
|
||||
logger.debug("# Negotiate a service request")
|
||||
complete = False
|
||||
tries = 0 # To avoid a infinite loop attempt no more than feedback messages
|
||||
while not complete and tries < 5:
|
||||
logger.debug("%s Attempt %s %s", "-" * 20, tries + 1, "-" * 20)
|
||||
logger.debug("<< server response: %s", packet)
|
||||
logger.debug("# response code: %s", packet[: rcode_size + 1])
|
||||
step, complete = self.auth_continue_krb(packet[rcode_size:])
|
||||
logger.debug(" >> response to server: %s", step)
|
||||
await sock.write(step or b"")
|
||||
packet = await sock.read()
|
||||
tries += 1
|
||||
if not complete:
|
||||
raise InterfaceError(
|
||||
f"Unable to fulfill server request after {tries} "
|
||||
f"attempts. Last server response: {packet}"
|
||||
)
|
||||
logger.debug(
|
||||
" last GSSAPI response from server: %s length: %d",
|
||||
packet,
|
||||
len(packet),
|
||||
)
|
||||
last_step = self.auth_accept_close_handshake(packet[rcode_size:])
|
||||
logger.debug(
|
||||
" >> last response to server: %s length: %d",
|
||||
last_step,
|
||||
len(last_step),
|
||||
)
|
||||
await sock.write(last_step)
|
||||
# Receive final handshake from server
|
||||
packet = await sock.read()
|
||||
logger.debug("<< final handshake from server: %s", packet)
|
||||
|
||||
# receive OK packet from server.
|
||||
packet = await sock.read()
|
||||
logger.debug("<< ok packet from server: %s", packet)
|
||||
|
||||
return bytes(packet)
|
||||
|
||||
|
||||
# pylint: enable=c-extension-no-member,no-member
|
||||
@@ -0,0 +1,234 @@
|
||||
# Copyright (c) 2023, 2024, Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
# mypy: disable-error-code="arg-type,union-attr,call-arg"
|
||||
|
||||
"""OCI Authentication Plugin."""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
from base64 import b64encode
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
from mysql.connector import errors
|
||||
from mysql.connector.logger import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..network import MySQLSocket
|
||||
|
||||
try:
|
||||
from cryptography.exceptions import UnsupportedAlgorithm
|
||||
from cryptography.hazmat.primitives import hashes, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import padding
|
||||
from cryptography.hazmat.primitives.asymmetric.types import PRIVATE_KEY_TYPES
|
||||
except ImportError:
|
||||
raise errors.ProgrammingError("Package 'cryptography' is not installed") from None
|
||||
|
||||
try:
|
||||
from oci import config, exceptions
|
||||
except ImportError:
|
||||
raise errors.ProgrammingError(
|
||||
"Package 'oci' (Oracle Cloud Infrastructure Python SDK) is not installed"
|
||||
) from None
|
||||
|
||||
from . import MySQLAuthPlugin
|
||||
|
||||
AUTHENTICATION_PLUGIN_CLASS = "MySQLOCIAuthPlugin"
|
||||
OCI_SECURITY_TOKEN_MAX_SIZE = 10 * 1024 # In bytes
|
||||
OCI_SECURITY_TOKEN_TOO_LARGE = "Ephemeral security token is too large (10KB max)"
|
||||
OCI_SECURITY_TOKEN_FILE_NOT_AVAILABLE = (
|
||||
"Ephemeral security token file ('security_token_file') could not be read"
|
||||
)
|
||||
OCI_PROFILE_MISSING_PROPERTIES = (
|
||||
"OCI configuration file does not contain a 'fingerprint' or 'key_file' entry"
|
||||
)
|
||||
|
||||
|
||||
class MySQLOCIAuthPlugin(MySQLAuthPlugin):
|
||||
"""Implement the MySQL OCI IAM authentication plugin."""
|
||||
|
||||
context: Any = None
|
||||
oci_config_profile: str = "DEFAULT"
|
||||
oci_config_file: str = config.DEFAULT_LOCATION
|
||||
|
||||
@staticmethod
|
||||
def _prepare_auth_response(signature: bytes, oci_config: Dict[str, Any]) -> str:
|
||||
"""Prepare client's authentication response
|
||||
|
||||
Prepares client's authentication response in JSON format
|
||||
Args:
|
||||
signature (bytes): server's nonce to be signed by client.
|
||||
oci_config (dict): OCI configuration object.
|
||||
|
||||
Returns:
|
||||
str: JSON string with the following format:
|
||||
{"fingerprint": str, "signature": str, "token": base64.base64.base64}
|
||||
|
||||
Raises:
|
||||
ProgrammingError: If the ephemeral security token file can't be open or the
|
||||
token is too large.
|
||||
"""
|
||||
signature_64 = b64encode(signature)
|
||||
auth_response = {
|
||||
"fingerprint": oci_config["fingerprint"],
|
||||
"signature": signature_64.decode(),
|
||||
}
|
||||
|
||||
# The security token, if it exists, should be a JWT (JSON Web Token), consisted
|
||||
# of a base64-encoded header, body, and signature, separated by '.',
|
||||
# e.g. "Base64.Base64.Base64", stored in a file at the path specified by the
|
||||
# security_token_file configuration property
|
||||
if oci_config.get("security_token_file"):
|
||||
try:
|
||||
security_token_file = Path(oci_config["security_token_file"])
|
||||
# Check if token exceeds the maximum size
|
||||
if security_token_file.stat().st_size > OCI_SECURITY_TOKEN_MAX_SIZE:
|
||||
raise errors.ProgrammingError(OCI_SECURITY_TOKEN_TOO_LARGE)
|
||||
auth_response["token"] = security_token_file.read_text(encoding="utf-8")
|
||||
except (OSError, UnicodeError) as err:
|
||||
raise errors.ProgrammingError(
|
||||
OCI_SECURITY_TOKEN_FILE_NOT_AVAILABLE
|
||||
) from err
|
||||
return json.dumps(auth_response, separators=(",", ":"))
|
||||
|
||||
@staticmethod
|
||||
def _get_private_key(key_path: str) -> PRIVATE_KEY_TYPES:
|
||||
"""Get the private_key form the given location"""
|
||||
try:
|
||||
with open(os.path.expanduser(key_path), "rb") as key_file:
|
||||
private_key = serialization.load_pem_private_key(
|
||||
key_file.read(),
|
||||
password=None,
|
||||
)
|
||||
except (TypeError, OSError, ValueError, UnsupportedAlgorithm) as err:
|
||||
raise errors.ProgrammingError(
|
||||
"An error occurred while reading the API_KEY from "
|
||||
f'"{key_path}": {err}'
|
||||
)
|
||||
|
||||
return private_key
|
||||
|
||||
def _get_valid_oci_config(self) -> Dict[str, Any]:
|
||||
"""Get a valid OCI config from the given configuration file path"""
|
||||
error_list = []
|
||||
req_keys = {
|
||||
"fingerprint": (lambda x: len(x) > 32),
|
||||
"key_file": (lambda x: os.path.exists(os.path.expanduser(x))),
|
||||
}
|
||||
|
||||
oci_config: Dict[str, Any] = {}
|
||||
try:
|
||||
# key_file is validated by oci.config if present
|
||||
oci_config = config.from_file(
|
||||
self.oci_config_file or config.DEFAULT_LOCATION,
|
||||
self.oci_config_profile or "DEFAULT",
|
||||
)
|
||||
for req_key, req_value in req_keys.items():
|
||||
try:
|
||||
# Verify parameter in req_key is present and valid
|
||||
if oci_config[req_key] and not req_value(oci_config[req_key]):
|
||||
error_list.append(f'Parameter "{req_key}" is invalid')
|
||||
except KeyError:
|
||||
error_list.append(f"Does not contain parameter {req_key}")
|
||||
except (
|
||||
exceptions.ConfigFileNotFound,
|
||||
exceptions.InvalidConfig,
|
||||
exceptions.InvalidKeyFilePath,
|
||||
exceptions.InvalidPrivateKey,
|
||||
exceptions.ProfileNotFound,
|
||||
) as err:
|
||||
error_list.append(str(err))
|
||||
|
||||
# Raise errors if any
|
||||
if error_list:
|
||||
raise errors.ProgrammingError(
|
||||
f"Invalid oci-config-file: {self.oci_config_file}. "
|
||||
f"Errors found: {error_list}"
|
||||
)
|
||||
|
||||
return oci_config
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Plugin official name."""
|
||||
return "authentication_oci_client"
|
||||
|
||||
@property
|
||||
def requires_ssl(self) -> bool:
|
||||
"""Signals whether or not SSL is required."""
|
||||
return False
|
||||
|
||||
def auth_response(self, auth_data: bytes, **kwargs: Any) -> Optional[bytes]:
|
||||
"""Prepare authentication string for the server."""
|
||||
logger.debug("server nonce: %s, len %d", auth_data, len(auth_data))
|
||||
|
||||
oci_config = self._get_valid_oci_config()
|
||||
|
||||
private_key = self._get_private_key(oci_config["key_file"])
|
||||
signature = private_key.sign(auth_data, padding.PKCS1v15(), hashes.SHA256())
|
||||
|
||||
auth_response = self._prepare_auth_response(signature, oci_config)
|
||||
logger.debug("authentication response: %s", auth_response)
|
||||
return auth_response.encode()
|
||||
|
||||
async def auth_switch_response(
|
||||
self, sock: "MySQLSocket", auth_data: bytes, **kwargs: Any
|
||||
) -> bytes:
|
||||
"""Handles server's `auth switch request` response.
|
||||
|
||||
Args:
|
||||
sock: Pointer to the socket connection.
|
||||
auth_data: Plugin provided data (extracted from a packet
|
||||
representing an `auth switch request` response).
|
||||
kwargs: Custom configuration to be passed to the auth plugin
|
||||
when invoked. The parameters defined here will override the ones
|
||||
defined in the auth plugin itself.
|
||||
|
||||
Returns:
|
||||
packet: Last server's response after back-and-forth
|
||||
communication.
|
||||
"""
|
||||
self.oci_config_file = kwargs.get("oci_config_file", "DEFAULT")
|
||||
self.oci_config_profile = kwargs.get(
|
||||
"oci_config_profile", config.DEFAULT_LOCATION
|
||||
)
|
||||
logger.debug("# oci configuration file path: %s", self.oci_config_file)
|
||||
|
||||
response = self.auth_response(auth_data, **kwargs)
|
||||
if response is None:
|
||||
raise errors.InterfaceError("Got a NULL auth response")
|
||||
|
||||
logger.debug("# request: %s size: %s", response, len(response))
|
||||
await sock.write(response)
|
||||
|
||||
packet = await sock.read()
|
||||
logger.debug("# server response packet: %s", packet)
|
||||
|
||||
return bytes(packet)
|
||||
@@ -0,0 +1,172 @@
|
||||
# Copyright (c) 2024, Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""OpenID Authentication Plugin."""
|
||||
|
||||
import re
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, List
|
||||
|
||||
from mysql.connector import errors, utils
|
||||
from mysql.connector.aio.network import MySQLSocket
|
||||
from mysql.connector.logger import logger
|
||||
|
||||
from . import MySQLAuthPlugin
|
||||
|
||||
AUTHENTICATION_PLUGIN_CLASS = "MySQLOpenIDConnectAuthPlugin"
|
||||
OPENID_TOKEN_MAX_SIZE = 10 * 1024 # In bytes
|
||||
|
||||
|
||||
class MySQLOpenIDConnectAuthPlugin(MySQLAuthPlugin):
|
||||
"""Class implementing the MySQL OpenID Connect Authentication Plugin."""
|
||||
|
||||
_openid_capability_flag: bytes = utils.int1store(1)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Plugin official name."""
|
||||
return "authentication_openid_connect_client"
|
||||
|
||||
@property
|
||||
def requires_ssl(self) -> bool:
|
||||
"""Signals whether or not SSL is required."""
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _validate_openid_token(token: str) -> bool:
|
||||
"""Helper method used to validate OpenID Connect token
|
||||
|
||||
The Token is represented as a JSON Web Token (JWT) consists of a
|
||||
base64-encoded header, body, and signature, separated by '.' e.g.,
|
||||
"Base64url.Base64url.Base64url". The First part of the token contains
|
||||
the header, the second part contains payload and the third part contains
|
||||
signature. These token parts should be Base64 URLSafe i.e., Token cannot
|
||||
contain characters other than a-z, A-Z, 0-9 and special characters '-', '_'.
|
||||
|
||||
Args:
|
||||
token (str): Base64url-encoded OpenID connect token fetched from
|
||||
the file path passed via `openid_token_file` connection
|
||||
argument.
|
||||
|
||||
Returns:
|
||||
bool: Signal indicating whether the token is valid or not.
|
||||
"""
|
||||
header_payload_sig: List[str] = token.split(".")
|
||||
if len(header_payload_sig) != 3:
|
||||
# invalid structure
|
||||
return False
|
||||
urlsafe_pattern = re.compile("^[a-zA-Z0-9-_]*$")
|
||||
return all(
|
||||
(
|
||||
len(token_part) and urlsafe_pattern.search(token_part) is not None
|
||||
for token_part in header_payload_sig
|
||||
)
|
||||
)
|
||||
|
||||
def auth_response(self, auth_data: bytes, **kwargs: Any) -> bytes:
|
||||
"""Prepares authentication string for the server.
|
||||
Args:
|
||||
auth_data: Authorization data.
|
||||
kwargs: Custom configuration to be passed to the auth plugin
|
||||
when invoked.
|
||||
|
||||
Returns:
|
||||
packet: Client's authorization response.
|
||||
The OpenID Connect authorization response follows the pattern :-
|
||||
int<1> capability flag
|
||||
string<lenenc> id token
|
||||
|
||||
Raises:
|
||||
InterfaceError: If the connection is insecure or the OpenID Token is too large,
|
||||
invalid or non-existent.
|
||||
ProgrammingError: If the OpenID Token file could not be read.
|
||||
"""
|
||||
try:
|
||||
# Check if the connection is secure
|
||||
if self.requires_ssl and not self._ssl_enabled:
|
||||
raise errors.InterfaceError(f"{self.name} requires SSL")
|
||||
|
||||
# Validate the file
|
||||
token_file_path: str = kwargs.get("openid_token_file", None)
|
||||
openid_token_file: Path = Path(token_file_path)
|
||||
# Check if token exceeds the maximum size
|
||||
if openid_token_file.stat().st_size > OPENID_TOKEN_MAX_SIZE:
|
||||
raise errors.InterfaceError(
|
||||
"The OpenID Connect token file size is too large (> 10KB)"
|
||||
)
|
||||
openid_token: str = openid_token_file.read_text(encoding="utf-8")
|
||||
openid_token = openid_token.strip()
|
||||
# Validate the JWT Token
|
||||
if not self._validate_openid_token(openid_token):
|
||||
raise errors.InterfaceError("The OpenID Connect Token is invalid")
|
||||
|
||||
# build the auth_response packet
|
||||
auth_response: List[bytes] = [
|
||||
self._openid_capability_flag,
|
||||
utils.lc_int(len(openid_token)),
|
||||
openid_token.encode(),
|
||||
]
|
||||
return b"".join(auth_response)
|
||||
except (SyntaxError, TypeError, OSError, UnicodeError) as err:
|
||||
raise errors.ProgrammingError(
|
||||
"The OpenID Connect Token File (openid_token_file) could not be read"
|
||||
) from err
|
||||
|
||||
async def auth_switch_response(
|
||||
self, sock: MySQLSocket, auth_data: bytes, **kwargs: Any
|
||||
) -> bytes:
|
||||
"""Handles server's `auth switch request` response.
|
||||
|
||||
Args:
|
||||
sock: Pointer to the socket connection.
|
||||
auth_data: Plugin provided data (extracted from a packet
|
||||
representing an `auth switch request` response).
|
||||
kwargs: Custom configuration to be passed to the auth plugin
|
||||
when invoked. The parameters defined here will override the ones
|
||||
defined in the auth plugin itself.
|
||||
|
||||
Returns:
|
||||
packet: Last server's response after back-and-forth
|
||||
communication.
|
||||
|
||||
Raises:
|
||||
InterfaceError: If a NULL auth response is received from auth_response method.
|
||||
"""
|
||||
response = self.auth_response(auth_data, **kwargs)
|
||||
|
||||
if response is None:
|
||||
raise errors.InterfaceError("Got a NULL auth response")
|
||||
|
||||
logger.debug("# request: %s size: %s", response, len(response))
|
||||
await sock.write(response)
|
||||
|
||||
packet = await sock.read()
|
||||
logger.debug("# server response packet: %s", packet)
|
||||
|
||||
return bytes(packet)
|
||||
@@ -0,0 +1,291 @@
|
||||
# Copyright (c) 2023, 2025, Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""WebAuthn Authentication Plugin."""
|
||||
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional
|
||||
|
||||
from mysql.connector import errors, utils
|
||||
|
||||
from ..logger import logger
|
||||
from . import MySQLAuthPlugin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..network import MySQLSocket
|
||||
|
||||
try:
|
||||
from fido2.cbor import dump_bytes as cbor_dump_bytes
|
||||
from fido2.client import Fido2Client, UserInteraction
|
||||
from fido2.hid import CtapHidDevice
|
||||
from fido2.webauthn import PublicKeyCredentialRequestOptions
|
||||
except ImportError as import_err:
|
||||
raise errors.ProgrammingError(
|
||||
"Module fido2 is required for WebAuthn authentication mechanism but was "
|
||||
"not found. Unable to authenticate with the server"
|
||||
) from import_err
|
||||
|
||||
try:
|
||||
from fido2.pcsc import CtapPcscDevice
|
||||
|
||||
CTAP_PCSC_DEVICE_AVAILABLE = True
|
||||
except ModuleNotFoundError:
|
||||
CTAP_PCSC_DEVICE_AVAILABLE = False
|
||||
|
||||
|
||||
AUTHENTICATION_PLUGIN_CLASS = "MySQLWebAuthnAuthPlugin"
|
||||
|
||||
|
||||
class ClientInteraction(UserInteraction):
|
||||
"""Provides user interaction to the Client."""
|
||||
|
||||
def __init__(self, callback: Optional[Callable] = None):
|
||||
self.callback = callback
|
||||
self.msg = (
|
||||
"Please insert FIDO device and perform gesture action for authentication "
|
||||
"to complete."
|
||||
)
|
||||
|
||||
def prompt_up(self) -> None:
|
||||
"""Prompt message for the user interaction with the FIDO device."""
|
||||
if self.callback is None:
|
||||
print(self.msg)
|
||||
else:
|
||||
self.callback(self.msg)
|
||||
|
||||
|
||||
class MySQLWebAuthnAuthPlugin(MySQLAuthPlugin):
|
||||
"""Class implementing the MySQL WebAuthn authentication plugin."""
|
||||
|
||||
client: Optional[Fido2Client] = None
|
||||
callback: Optional[Callable] = None
|
||||
options: dict = {"rpId": None, "challenge": None, "allowCredentials": []}
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Plugin official name."""
|
||||
return "authentication_webauthn_client"
|
||||
|
||||
@property
|
||||
def requires_ssl(self) -> bool:
|
||||
"""Signals whether or not SSL is required."""
|
||||
return False
|
||||
|
||||
def get_assertion_response(
|
||||
self, credential_id: Optional[bytearray] = None
|
||||
) -> bytes:
|
||||
"""Get assertion from authenticator and return the response.
|
||||
|
||||
Args:
|
||||
credential_id (Optional[bytearray]): The credential ID.
|
||||
|
||||
Returns:
|
||||
bytearray: The response packet with the data from the assertion.
|
||||
"""
|
||||
if self.client is None:
|
||||
raise errors.InterfaceError("No WebAuthn client found")
|
||||
|
||||
if credential_id is not None:
|
||||
# If credential_id is not None, it's because the FIDO device does not
|
||||
# support resident keys and the credential_id was requested from the server
|
||||
self.options["allowCredentials"] = [
|
||||
{
|
||||
"id": credential_id,
|
||||
"type": "public-key",
|
||||
}
|
||||
]
|
||||
|
||||
# Get assertion from authenticator
|
||||
assertion = self.client.get_assertion(
|
||||
PublicKeyCredentialRequestOptions.from_dict(self.options)
|
||||
)
|
||||
number_of_assertions = len(assertion.get_assertions())
|
||||
client_data_json = b""
|
||||
|
||||
# Build response packet
|
||||
#
|
||||
# Format:
|
||||
# int<1> 0x02 (2) status tag
|
||||
# int<lenenc> number of assertions length encoded number of assertions
|
||||
# string authenticator data variable length raw binary string
|
||||
# string signed challenge variable length raw binary string
|
||||
# ...
|
||||
# ...
|
||||
# string authenticator data variable length raw binary string
|
||||
# string signed challenge variable length raw binary string
|
||||
# string ClientDataJSON variable length raw binary string
|
||||
packet = utils.lc_int(2)
|
||||
packet += utils.lc_int(number_of_assertions)
|
||||
|
||||
# Add authenticator data and signed challenge for each assertion
|
||||
for i in range(number_of_assertions):
|
||||
assertion_response = assertion.get_response(i)
|
||||
|
||||
# string<lenenc> authenticator_data
|
||||
authenticator_data = cbor_dump_bytes(assertion_response.authenticator_data)
|
||||
|
||||
# string<lenenc> signed_challenge
|
||||
signature = assertion_response.signature
|
||||
|
||||
packet += utils.lc_int(len(authenticator_data))
|
||||
packet += authenticator_data
|
||||
packet += utils.lc_int(len(signature))
|
||||
packet += signature
|
||||
|
||||
# string<lenenc> client_data_json
|
||||
client_data_json = assertion_response.client_data
|
||||
|
||||
packet += utils.lc_int(len(client_data_json))
|
||||
packet += client_data_json
|
||||
|
||||
logger.debug("WebAuthn - payload response packet: %s", packet)
|
||||
return packet
|
||||
|
||||
def auth_response(self, auth_data: bytes, **kwargs: Any) -> Optional[bytes]:
|
||||
"""Find authenticator device and check if supports resident keys.
|
||||
|
||||
It also creates a Fido2Client using the relying party ID from the server.
|
||||
|
||||
Raises:
|
||||
InterfaceError: When the FIDO device is not found.
|
||||
|
||||
Returns:
|
||||
bytes: 2 if the authenticator supports resident keys else 1.
|
||||
"""
|
||||
try:
|
||||
packets, capability = utils.read_int(auth_data, 1)
|
||||
challenge, rp_id = utils.read_lc_string_list(packets)
|
||||
self.options["challenge"] = challenge
|
||||
self.options["rpId"] = rp_id.decode()
|
||||
logger.debug("WebAuthn - capability: %d", capability)
|
||||
logger.debug("WebAuthn - challenge: %s", self.options["challenge"])
|
||||
logger.debug("WebAuthn - relying party id: %s", self.options["rpId"])
|
||||
except ValueError as err:
|
||||
raise errors.InterfaceError(
|
||||
"Unable to parse MySQL WebAuthn authentication data"
|
||||
) from err
|
||||
|
||||
# Locate a device
|
||||
device = next(CtapHidDevice.list_devices(), None)
|
||||
if device is not None:
|
||||
logger.debug("WebAuthn - Use USB HID channel")
|
||||
elif CTAP_PCSC_DEVICE_AVAILABLE:
|
||||
device = next(CtapPcscDevice.list_devices(), None)
|
||||
|
||||
if device is None:
|
||||
raise errors.InterfaceError("No FIDO device found")
|
||||
|
||||
# Set up a FIDO 2 client using the origin relying party id
|
||||
self.client = Fido2Client(
|
||||
device,
|
||||
f"https://{self.options['rpId']}",
|
||||
user_interaction=ClientInteraction(self.callback),
|
||||
)
|
||||
|
||||
if not self.client.info.options.get("rk"):
|
||||
logger.debug("WebAuthn - Authenticator doesn't support resident keys")
|
||||
return b"1"
|
||||
|
||||
logger.debug("WebAuthn - Authenticator with support for resident key found")
|
||||
return b"2"
|
||||
|
||||
async def auth_more_response(
|
||||
self, sock: "MySQLSocket", auth_data: bytes, **kwargs: Any
|
||||
) -> bytes:
|
||||
"""Handles server's `auth more data` response.
|
||||
|
||||
Args:
|
||||
sock: Pointer to the socket connection.
|
||||
auth_data: Authentication method data (from a packet representing
|
||||
an `auth more data` response).
|
||||
kwargs: Custom configuration to be passed to the auth plugin
|
||||
when invoked. The parameters defined here will override the ones
|
||||
defined in the auth plugin itself.
|
||||
|
||||
Returns:
|
||||
packet: Last server's response after back-and-forth
|
||||
communication.
|
||||
"""
|
||||
_, credential_id = utils.read_lc_string(auth_data)
|
||||
|
||||
response = self.get_assertion_response(credential_id)
|
||||
|
||||
logger.debug("WebAuthn - request: %s size: %s", response, len(response))
|
||||
await sock.write(response)
|
||||
|
||||
pkt = bytes(await sock.read())
|
||||
logger.debug("WebAuthn - server response packet: %s", pkt)
|
||||
|
||||
return pkt
|
||||
|
||||
async def auth_switch_response(
|
||||
self, sock: "MySQLSocket", auth_data: bytes, **kwargs: Any
|
||||
) -> bytes:
|
||||
"""Handles server's `auth switch request` response.
|
||||
|
||||
Args:
|
||||
sock: Pointer to the socket connection.
|
||||
auth_data: Plugin provided data (extracted from a packet
|
||||
representing an `auth switch request` response).
|
||||
kwargs: Custom configuration to be passed to the auth plugin
|
||||
when invoked. The parameters defined here will override the ones
|
||||
defined in the auth plugin itself.
|
||||
|
||||
Returns:
|
||||
packet: Last server's response after back-and-forth
|
||||
communication.
|
||||
"""
|
||||
webauth_callback = kwargs.get("webauthn_callback") or kwargs.get(
|
||||
"fido_callback"
|
||||
)
|
||||
self.callback = (
|
||||
utils.import_object(webauth_callback)
|
||||
if isinstance(webauth_callback, str)
|
||||
else webauth_callback
|
||||
)
|
||||
|
||||
response = self.auth_response(auth_data)
|
||||
credential_id = None
|
||||
|
||||
if response == b"1":
|
||||
# Authenticator doesn't support resident keys, request credential_id
|
||||
logger.debug("WebAuthn - request credential_id")
|
||||
await sock.write(utils.lc_int(int(response)))
|
||||
|
||||
# return a packet representing an `auth more data` response
|
||||
return bytes(await sock.read())
|
||||
|
||||
response = self.get_assertion_response(credential_id)
|
||||
|
||||
logger.debug("WebAuthn - request: %s size: %s", response, len(response))
|
||||
await sock.write(response)
|
||||
|
||||
pkt = bytes(await sock.read())
|
||||
logger.debug("WebAuthn - server response packet: %s", pkt)
|
||||
|
||||
return pkt
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user