Añadiendo todos los archivos del proyecto (incluidos secretos y venv)
This commit is contained in:
27
venv/lib/python3.12/site-packages/mysql/ai/__init__.py
Normal file
27
venv/lib/python3.12/site-packages/mysql/ai/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
# Copyright (c) 2025 Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
Binary file not shown.
43
venv/lib/python3.12/site-packages/mysql/ai/genai/__init__.py
Normal file
43
venv/lib/python3.12/site-packages/mysql/ai/genai/__init__.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# Copyright (c) 2025 Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""GenAI package for MySQL Connector/Python.
|
||||
|
||||
Performs optional dependency checks and exposes public classes:
|
||||
- MyEmbeddings
|
||||
- MyLLM
|
||||
- MyVectorStore
|
||||
"""
|
||||
from mysql.ai.utils import check_dependencies as _check_dependencies
|
||||
|
||||
_check_dependencies(["GENAI"])
|
||||
del _check_dependencies
|
||||
|
||||
from .embedding import MyEmbeddings
|
||||
from .generation import MyLLM
|
||||
from .vector_store import MyVectorStore
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
197
venv/lib/python3.12/site-packages/mysql/ai/genai/embedding.py
Normal file
197
venv/lib/python3.12/site-packages/mysql/ai/genai/embedding.py
Normal file
@@ -0,0 +1,197 @@
|
||||
# Copyright (c) 2025 Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""Embeddings integration utilities for MySQL Connector/Python.
|
||||
|
||||
Provides MyEmbeddings class to generate embeddings via MySQL HeatWave
|
||||
using ML_EMBED_TABLE and ML_EMBED_ROW.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from mysql.ai.utils import (
|
||||
atomic_transaction,
|
||||
execute_sql,
|
||||
format_value_sql,
|
||||
source_schema,
|
||||
sql_table_from_df,
|
||||
sql_table_to_df,
|
||||
temporary_sql_tables,
|
||||
)
|
||||
from mysql.connector.abstracts import MySQLConnectionAbstract
|
||||
|
||||
|
||||
class MyEmbeddings(Embeddings):
|
||||
"""
|
||||
Embedding generator class that uses a MySQL database to compute embeddings for input text.
|
||||
|
||||
This class batches input text into temporary SQL tables, invokes MySQL's ML_EMBED_TABLE
|
||||
to generate embeddings, and retrieves the results as lists of floats.
|
||||
|
||||
Attributes:
|
||||
_db_connection (MySQLConnectionAbstract): MySQL connection used for all database operations.
|
||||
schema_name (str): Name of the database schema to use.
|
||||
options_placeholder (str): SQL-ready placeholder string for ML_EMBED_TABLE options.
|
||||
options_params (dict): Dictionary of concrete option values to be passed as SQL parameters.
|
||||
"""
|
||||
|
||||
_db_connection: MySQLConnectionAbstract = PrivateAttr()
|
||||
|
||||
def __init__(
|
||||
self, db_connection: MySQLConnectionAbstract, options: Optional[Dict] = None
|
||||
):
|
||||
"""
|
||||
Initialize MyEmbeddings with a database connection and optional embedding parameters.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwgenai-ml-embed-row.html
|
||||
A full list of supported options can be found under "options"
|
||||
|
||||
NOTE: The supported "options" are the intersection of the options provided in
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwgenai-ml-embed-row.html
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwgenai-ml-embed-table.html
|
||||
|
||||
Args:
|
||||
db_connection: Active MySQL connector database connection.
|
||||
options: Optional dictionary of options for embedding operations.
|
||||
|
||||
Raises:
|
||||
ValueError: If the schema name is not valid
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
super().__init__()
|
||||
self._db_connection = db_connection
|
||||
self.schema_name = source_schema(db_connection)
|
||||
options = options or {}
|
||||
self.options_placeholder, self.options_params = format_value_sql(options)
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""
|
||||
Generate embeddings for a list of input texts using the MySQL ML embedding procedure.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwgenai-ml-embed-table.html
|
||||
|
||||
Args:
|
||||
texts: List of input strings to embed.
|
||||
|
||||
Returns:
|
||||
List of lists of floats, with each inner list containing the embedding for a text.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported.
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
ValueError:
|
||||
If one or more text entries were unable to be embedded.
|
||||
|
||||
Implementation notes:
|
||||
- Creates a temporary table to pass input text to the MySQL embedding service.
|
||||
- Adds a primary key to ensure results preserve input order.
|
||||
- Calls ML_EMBED_TABLE and fetches the resulting embeddings.
|
||||
- Deletes the temporary table after use to avoid polluting the database.
|
||||
- Embedding vectors are extracted from the "embeddings" column of the result table.
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
df = pd.DataFrame({"id": range(len(texts)), "text": texts})
|
||||
|
||||
with (
|
||||
atomic_transaction(self._db_connection) as cursor,
|
||||
temporary_sql_tables(self._db_connection) as temporary_tables,
|
||||
):
|
||||
qualified_table_name, table_name = sql_table_from_df(
|
||||
cursor, self.schema_name, df
|
||||
)
|
||||
temporary_tables.append((self.schema_name, table_name))
|
||||
|
||||
# ML_EMBED_TABLE expects input/output columns and options as parameters
|
||||
embed_query = (
|
||||
"CALL sys.ML_EMBED_TABLE("
|
||||
f"'{qualified_table_name}.text', "
|
||||
f"'{qualified_table_name}.embeddings', "
|
||||
f"{self.options_placeholder}"
|
||||
")"
|
||||
)
|
||||
execute_sql(cursor, embed_query, params=self.options_params)
|
||||
|
||||
# Read back all columns, including "embeddings"
|
||||
df_embeddings = sql_table_to_df(cursor, self.schema_name, table_name)
|
||||
|
||||
if df_embeddings["embeddings"].isnull().any() or any(
|
||||
e is None for e in df_embeddings["embeddings"]
|
||||
):
|
||||
raise ValueError(
|
||||
"Failure to generate embeddings for one or more text entry."
|
||||
)
|
||||
|
||||
# Convert fetched embeddings to lists of floats
|
||||
embeddings = df_embeddings["embeddings"].tolist()
|
||||
embeddings = [list(e) for e in embeddings]
|
||||
|
||||
return embeddings
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""
|
||||
Generate an embedding for a single text string.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwgenai-ml-embed-row.html
|
||||
|
||||
Args:
|
||||
text: The input string to embed.
|
||||
|
||||
Returns:
|
||||
List of floats representing the embedding vector.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported.
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Example:
|
||||
>>> MyEmbeddings(db_conn).embed_query("Hello world")
|
||||
[0.1, 0.2, ...]
|
||||
"""
|
||||
with atomic_transaction(self._db_connection) as cursor:
|
||||
execute_sql(
|
||||
cursor,
|
||||
f'SELECT sys.ML_EMBED_ROW("%s", {self.options_placeholder})',
|
||||
params=(text, *self.options_params),
|
||||
)
|
||||
return list(cursor.fetchone()[0])
|
||||
162
venv/lib/python3.12/site-packages/mysql/ai/genai/generation.py
Normal file
162
venv/lib/python3.12/site-packages/mysql/ai/genai/generation.py
Normal file
@@ -0,0 +1,162 @@
|
||||
# Copyright (c) 2025 Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""GenAI LLM integration utilities for MySQL Connector/Python.
|
||||
|
||||
Provides MyLLM wrapper that issues ML_GENERATE calls via SQL.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
from typing import Any, List, Optional
|
||||
|
||||
try:
|
||||
from langchain_core.language_models.llms import LLM
|
||||
except ImportError:
|
||||
from langchain.llms.base import LLM
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from mysql.ai.utils import atomic_transaction, execute_sql, format_value_sql
|
||||
from mysql.connector.abstracts import MySQLConnectionAbstract
|
||||
|
||||
|
||||
class MyLLM(LLM):
|
||||
"""
|
||||
Custom Large Language Model (LLM) interface for MySQL HeatWave.
|
||||
|
||||
This class wraps the generation functionality provided by HeatWave LLMs,
|
||||
exposing an interface compatible with common LLM APIs for text generation.
|
||||
It provides full support for generative queries and limited support for
|
||||
agentic queries.
|
||||
|
||||
Attributes:
|
||||
_db_connection (MySQLConnectionAbstract):
|
||||
Underlying MySQL connector database connection.
|
||||
"""
|
||||
|
||||
_db_connection: MySQLConnectionAbstract = PrivateAttr()
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Pydantic config for the model.
|
||||
|
||||
By default, LangChain (through Pydantic BaseModel) does not allow
|
||||
setting or storing undeclared attributes such as _db_connection.
|
||||
Setting extra = "allow" makes it possible to store extra attributes
|
||||
on the class instance, which is required for MyLLM.
|
||||
"""
|
||||
|
||||
extra = "allow"
|
||||
|
||||
def __init__(self, db_connection: MySQLConnectionAbstract):
|
||||
"""
|
||||
Initialize the MyLLM instance with an active MySQL database connection.
|
||||
|
||||
Args:
|
||||
db_connection: A MySQL connection object used to run LLM queries.
|
||||
|
||||
Notes:
|
||||
The db_connection is stored as a private attribute via object.__setattr__,
|
||||
which is compatible with Pydantic models.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self._db_connection = db_connection
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a text completion from the LLM for a given input prompt.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwgenai-ml-generate.html
|
||||
A full list of supported options (specified by kwargs) can be found under "options"
|
||||
|
||||
Args:
|
||||
prompt: The input prompt string for the language model.
|
||||
stop: Optional list of stop strings to support agentic and chain-of-thought
|
||||
reasoning workflows.
|
||||
**kwargs: Additional keyword arguments providing generation options to
|
||||
the LLM (these are serialized to JSON and passed to the HeatWave syscall).
|
||||
|
||||
Returns:
|
||||
The generated model output as a string.
|
||||
(The actual completion does NOT include the input prompt.)
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported.
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Implementation Notes:
|
||||
- Serializes kwargs into a SQL-compatible JSON string.
|
||||
- Calls the LLM stored procedure using a database cursor context.
|
||||
- Uses `sys.ML_GENERATE` on the server to produce the model output.
|
||||
- Expects the server response to be a JSON object with a 'text' key.
|
||||
"""
|
||||
options = kwargs.copy()
|
||||
if stop is not None:
|
||||
options["stop_sequences"] = stop
|
||||
|
||||
options_placeholder, options_params = format_value_sql(options)
|
||||
with atomic_transaction(self._db_connection) as cursor:
|
||||
# The prompt is passed as a parameterized argument (avoids SQL injection).
|
||||
generate_query = f"""SELECT sys.ML_GENERATE("%s", {options_placeholder});"""
|
||||
execute_sql(cursor, generate_query, params=(prompt, *options_params))
|
||||
# Expect a JSON-encoded result from MySQL; parse to extract the output.
|
||||
llm_response = json.loads(cursor.fetchone()[0])["text"]
|
||||
|
||||
return llm_response
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> dict:
|
||||
"""
|
||||
Return a dictionary of params that uniquely identify this LLM instance.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary of identifier parameters (should include
|
||||
model_name for tracing/caching).
|
||||
"""
|
||||
return {
|
||||
"model_name": "mysql_heatwave_llm",
|
||||
}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""
|
||||
Get the type name of this LLM implementation.
|
||||
|
||||
Returns:
|
||||
A string identifying the LLM provider (used for logging or metrics).
|
||||
"""
|
||||
return "mysql_heatwave_llm"
|
||||
520
venv/lib/python3.12/site-packages/mysql/ai/genai/vector_store.py
Normal file
520
venv/lib/python3.12/site-packages/mysql/ai/genai/vector_store.py
Normal file
@@ -0,0 +1,520 @@
|
||||
# Copyright (c) 2025 Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""MySQL-backed vector store for embeddings and semantic document retrieval.
|
||||
|
||||
Provides a VectorStore implementation persisting documents, metadata, and
|
||||
embeddings in MySQL, plus similarity search utilities.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
from typing import Any, Iterable, List, Optional, Sequence, Union
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from mysql.ai.genai.embedding import MyEmbeddings
|
||||
from mysql.ai.utils import (
|
||||
VAR_NAME_SPACE,
|
||||
atomic_transaction,
|
||||
delete_sql_table,
|
||||
execute_sql,
|
||||
extend_sql_table,
|
||||
format_value_sql,
|
||||
get_random_name,
|
||||
is_table_empty,
|
||||
source_schema,
|
||||
table_exists,
|
||||
)
|
||||
from mysql.connector.abstracts import MySQLConnectionAbstract
|
||||
|
||||
BASIC_EMBEDDING_QUERY = "Hello world!"
|
||||
EMBEDDING_SOURCE = "external_source"
|
||||
|
||||
VAR_EMBEDDING = f"{VAR_NAME_SPACE}.embedding"
|
||||
VAR_CONTEXT = f"{VAR_NAME_SPACE}.context"
|
||||
VAR_CONTEXT_MAP = f"{VAR_NAME_SPACE}.context_map"
|
||||
VAR_RETRIEVAL_INFO = f"{VAR_NAME_SPACE}.retrieval_info"
|
||||
VAR_OPTIONS = f"{VAR_NAME_SPACE}.options"
|
||||
|
||||
ID_SPACE = "internal_ai_id_"
|
||||
|
||||
|
||||
class MyVectorStore(VectorStore):
|
||||
"""
|
||||
MySQL-backed vector store for handling embeddings and semantic document retrieval.
|
||||
|
||||
Supports adding, deleting, and searching high-dimensional vector representations
|
||||
of documents using efficient storage and HeatWave ML similarity search procedures.
|
||||
|
||||
Supports use as a context manager: when used in a `with` statement, all backing
|
||||
tables/data are deleted automatically when the block exits (even on exception).
|
||||
|
||||
Attributes:
|
||||
db_connection (MySQLConnectionAbstract): Active MySQL database connection.
|
||||
embedder (Embeddings): Embeddings generator for computing vector representations.
|
||||
schema_name (str): SQL schema for table storage.
|
||||
table_name (Optional[str]): Name of the active table backing the store
|
||||
(or None until created).
|
||||
embedding_dimension (int): Size of embedding vectors stored.
|
||||
next_id (int): Internal counter for unique document ID generation.
|
||||
"""
|
||||
|
||||
_db_connection: MySQLConnectionAbstract = PrivateAttr()
|
||||
_embedder: Embeddings = PrivateAttr()
|
||||
_schema_name: str = PrivateAttr()
|
||||
_table_name: Optional[str] = PrivateAttr()
|
||||
_embedding_dimension: int = PrivateAttr()
|
||||
_next_id: int = PrivateAttr()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_connection: MySQLConnectionAbstract,
|
||||
embedder: Optional[Embeddings] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize a MyVectorStore with a database connection and embedding generator.
|
||||
|
||||
Args:
|
||||
db_connection: MySQL database connection for all vector operations.
|
||||
embedder: Embeddings generator used for creating and querying embeddings.
|
||||
|
||||
Raises:
|
||||
ValueError: If the schema name is not valid
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
super().__init__()
|
||||
self._next_id = 0
|
||||
|
||||
self._schema_name = source_schema(db_connection)
|
||||
self._embedder = embedder or MyEmbeddings(db_connection)
|
||||
self._db_connection = db_connection
|
||||
self._table_name: Optional[str] = None
|
||||
|
||||
# Embedding dimension determined using an example call.
|
||||
# Assumes embeddings have fixed length.
|
||||
self._embedding_dimension = len(
|
||||
self._embedder.embed_query(BASIC_EMBEDDING_QUERY)
|
||||
)
|
||||
|
||||
def _get_ids(self, num_ids: int) -> list[str]:
|
||||
"""
|
||||
Generate a batch of unique internal document IDs for vector storage.
|
||||
|
||||
Args:
|
||||
num_ids: Number of IDs to create.
|
||||
|
||||
Returns:
|
||||
List of sequentially numbered internal string IDs.
|
||||
"""
|
||||
ids = [
|
||||
f"internal_ai_id_{i}" for i in range(self._next_id, self._next_id + num_ids)
|
||||
]
|
||||
self._next_id += num_ids
|
||||
return ids
|
||||
|
||||
def _make_vector_store(self) -> None:
|
||||
"""
|
||||
Create a backing SQL table for storing vectors if not already created.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Notes:
|
||||
The table name is randomized to avoid collisions.
|
||||
Schema includes content, metadata, and embedding vector.
|
||||
"""
|
||||
if self._table_name is None:
|
||||
with atomic_transaction(self._db_connection) as cursor:
|
||||
table_name = get_random_name(
|
||||
lambda table_name: not table_exists(
|
||||
cursor, self._schema_name, table_name
|
||||
)
|
||||
)
|
||||
|
||||
create_table_stmt = f"""
|
||||
CREATE TABLE {self._schema_name}.{table_name} (
|
||||
`id` VARCHAR(128) NOT NULL,
|
||||
`content` TEXT,
|
||||
`metadata` JSON DEFAULT NULL,
|
||||
`embed` vector(%s),
|
||||
PRIMARY KEY (`id`)
|
||||
) ENGINE=InnoDB;
|
||||
"""
|
||||
execute_sql(
|
||||
cursor, create_table_stmt, params=(self._embedding_dimension,)
|
||||
)
|
||||
|
||||
self._table_name = table_name
|
||||
|
||||
def delete(self, ids: Optional[Sequence[str]] = None, **_: Any) -> None:
|
||||
"""
|
||||
Delete documents by ID. Optionally deletes the vector table if empty after deletions.
|
||||
|
||||
Args:
|
||||
ids: Optional sequence of document IDs to delete. If None, no action is taken.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Notes:
|
||||
If the backing table is empty after deletions, the table is dropped and
|
||||
table_name is set to None.
|
||||
"""
|
||||
with atomic_transaction(self._db_connection) as cursor:
|
||||
if ids:
|
||||
for _id in ids:
|
||||
execute_sql(
|
||||
cursor,
|
||||
f"DELETE FROM {self._schema_name}.{self._table_name} WHERE id = %s",
|
||||
params=(_id,),
|
||||
)
|
||||
|
||||
if is_table_empty(cursor, self._schema_name, self._table_name):
|
||||
self.delete_all()
|
||||
|
||||
def delete_all(self) -> None:
|
||||
"""
|
||||
Delete and drop the entire vector store table.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if self._table_name is not None:
|
||||
with atomic_transaction(self._db_connection) as cursor:
|
||||
delete_sql_table(cursor, self._schema_name, self._table_name)
|
||||
self._table_name = None
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[list[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
**_: dict,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Add a batch of text strings and corresponding metadata to the vector store.
|
||||
|
||||
Args:
|
||||
texts: List of strings to embed and store.
|
||||
metadatas: Optional list of metadata dicts (one per text).
|
||||
ids: Optional custom document IDs.
|
||||
|
||||
Returns:
|
||||
List of document IDs corresponding to the added texts.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Notes:
|
||||
If metadatas is None, an empty dict is assigned to each document.
|
||||
"""
|
||||
texts = list(texts)
|
||||
|
||||
documents = [
|
||||
Document(page_content=text, metadata=meta)
|
||||
for text, meta in zip(texts, metadatas or [{}] * len(texts))
|
||||
]
|
||||
return self.add_documents(documents, ids=ids)
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls,
|
||||
texts: Iterable[str],
|
||||
embedder: Embeddings,
|
||||
metadatas: Optional[list[dict]] = None,
|
||||
db_connection: MySQLConnectionAbstract = None,
|
||||
) -> VectorStore:
|
||||
"""
|
||||
Construct and populate a MyVectorStore instance from raw texts and metadata.
|
||||
|
||||
Args:
|
||||
texts: List of strings to vectorize and store.
|
||||
embedder: Embeddings generator to use.
|
||||
metadatas: Optional list of metadata dicts per text.
|
||||
db_connection: Active MySQL connection.
|
||||
|
||||
Returns:
|
||||
Instance of MyVectorStore containing the added texts.
|
||||
|
||||
Raises:
|
||||
ValueError: If db_connection is not provided.
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
if db_connection is None:
|
||||
raise ValueError(
|
||||
"db_connection must be specified to create a MyVectorStore object"
|
||||
)
|
||||
|
||||
texts = list(texts)
|
||||
|
||||
instance = cls(db_connection=db_connection, embedder=embedder)
|
||||
instance.add_texts(texts, metadatas=metadatas)
|
||||
|
||||
return instance
|
||||
|
||||
def add_documents(
|
||||
self, documents: list[Document], ids: list[str] = None
|
||||
) -> list[str]:
|
||||
"""
|
||||
Embed and store Document objects as high-dimensional vectors with metadata.
|
||||
|
||||
Args:
|
||||
documents: List of Document objects (each with 'page_content' and 'metadata').
|
||||
ids: Optional list of explicit document IDs. Must match the length of documents.
|
||||
|
||||
Returns:
|
||||
List of document IDs stored.
|
||||
|
||||
Raises:
|
||||
ValueError: If provided IDs do not match the number of documents.
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Notes:
|
||||
Automatically creates the backing table if it does not exist.
|
||||
"""
|
||||
if ids and len(ids) != len(documents):
|
||||
msg = (
|
||||
"ids must be the same length as documents. "
|
||||
f"Got {len(ids)} ids and {len(documents)} documents."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
if len(documents) > 0:
|
||||
self._make_vector_store()
|
||||
else:
|
||||
return []
|
||||
|
||||
if ids is None:
|
||||
ids = self._get_ids(len(documents))
|
||||
|
||||
content = [doc.page_content for doc in documents]
|
||||
vectors = self._embedder.embed_documents(content)
|
||||
|
||||
df = pd.DataFrame()
|
||||
df["id"] = ids
|
||||
df["content"] = content
|
||||
df["embed"] = vectors
|
||||
df["metadata"] = [doc.metadata for doc in documents]
|
||||
|
||||
with atomic_transaction(self._db_connection) as cursor:
|
||||
extend_sql_table(
|
||||
cursor,
|
||||
self._schema_name,
|
||||
self._table_name,
|
||||
df,
|
||||
col_name_to_placeholder_string={"embed": "string_to_vector(%s)"},
|
||||
)
|
||||
|
||||
return ids
|
||||
|
||||
def similarity_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 3,
|
||||
**kwargs: Any,
|
||||
) -> list[Document]:
|
||||
"""
|
||||
Search for and return the most similar documents in the store to the given query.
|
||||
|
||||
Args:
|
||||
query: String query to embed and use for similarity search.
|
||||
k: Number of top documents to return.
|
||||
kwargs: options to pass to ML_SIMILARITY_SEARCH. Currently supports
|
||||
distance_metric, max_distance, percentage_distance, and segment_overlap
|
||||
|
||||
Returns:
|
||||
List of Document objects, ordered from most to least similar.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided kwargs are invalid or unsupported.
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Implementation Notes:
|
||||
- Calls ML similarity search within MySQL using stored procedures.
|
||||
- Retrieves IDs, content, and metadata for search matches.
|
||||
- Parsing and retrieval for context results are handled via intermediate JSONs.
|
||||
"""
|
||||
if self._table_name is None:
|
||||
return []
|
||||
|
||||
embedding = self._embedder.embed_query(query)
|
||||
|
||||
with atomic_transaction(self._db_connection) as cursor:
|
||||
# Set the embedding variable for the similarity search SP
|
||||
execute_sql(
|
||||
cursor,
|
||||
f"SET @{VAR_EMBEDDING} = string_to_vector(%s)",
|
||||
params=[str(embedding)],
|
||||
)
|
||||
|
||||
distance_metric = kwargs.get("distance_metric", "COSINE")
|
||||
retrieval_options = {
|
||||
"max_distance": kwargs.get("max_distance", 0.6),
|
||||
"percentage_distance": kwargs.get("percentage_distance", 20.0),
|
||||
"segment_overlap": kwargs.get("segment_overlap", 0),
|
||||
}
|
||||
|
||||
retrieval_options_placeholder, retrieval_options_params = format_value_sql(
|
||||
retrieval_options
|
||||
)
|
||||
similarity_search_query = f"""
|
||||
CALL sys.ML_SIMILARITY_SEARCH(
|
||||
@{VAR_EMBEDDING},
|
||||
JSON_ARRAY(
|
||||
'{self._schema_name}.{self._table_name}'
|
||||
),
|
||||
JSON_OBJECT(
|
||||
"segment", "content",
|
||||
"segment_embedding", "embed",
|
||||
"document_name", "id"
|
||||
),
|
||||
{k},
|
||||
%s,
|
||||
NULL,
|
||||
NULL,
|
||||
{retrieval_options_placeholder},
|
||||
@{VAR_CONTEXT},
|
||||
@{VAR_CONTEXT_MAP},
|
||||
@{VAR_RETRIEVAL_INFO}
|
||||
)
|
||||
"""
|
||||
|
||||
execute_sql(
|
||||
cursor,
|
||||
similarity_search_query,
|
||||
params=[distance_metric, *retrieval_options_params],
|
||||
)
|
||||
execute_sql(cursor, f"SELECT @{VAR_CONTEXT_MAP}")
|
||||
|
||||
results = []
|
||||
|
||||
context_maps = json.loads(cursor.fetchone()[0])
|
||||
for context in context_maps:
|
||||
execute_sql(
|
||||
cursor,
|
||||
(
|
||||
"SELECT id, content, metadata "
|
||||
f"FROM {self._schema_name}.{self._table_name} "
|
||||
"WHERE id = %s"
|
||||
),
|
||||
params=(context["document_name"],),
|
||||
)
|
||||
doc_id, content, metadata = cursor.fetchone()
|
||||
|
||||
doc_args = {
|
||||
"id": doc_id,
|
||||
"page_content": content,
|
||||
}
|
||||
if metadata is not None:
|
||||
doc_args["metadata"] = json.loads(metadata)
|
||||
|
||||
doc = Document(**doc_args)
|
||||
results.append(doc)
|
||||
|
||||
return results
|
||||
|
||||
def __enter__(self) -> "VectorStore":
|
||||
"""
|
||||
Enter the runtime context related to this vector store instance.
|
||||
|
||||
Returns:
|
||||
The current MyVectorStore object, allowing use within a `with` statement block.
|
||||
|
||||
Usage Notes:
|
||||
- Intended for use in a `with` statement to ensure automatic
|
||||
cleanup of resources.
|
||||
- No special initialization occurs during context entry, but enables
|
||||
proper context-managed lifecycle.
|
||||
|
||||
Example:
|
||||
with MyVectorStore(db_connection, embedder) as vectorstore:
|
||||
vectorstore.add_texts([...])
|
||||
# Vector store is active within this block.
|
||||
# All storage and resources are now cleaned up.
|
||||
"""
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Union[type, None],
|
||||
exc_val: Union[BaseException, None],
|
||||
exc_tb: Union[object, None],
|
||||
) -> None:
|
||||
"""
|
||||
Exit the runtime context for the vector store, ensuring all storage
|
||||
resources are cleaned up.
|
||||
|
||||
Args:
|
||||
exc_type: The exception type, if any exception occurred in the context block.
|
||||
exc_val: The exception value, if any exception occurred in the context block.
|
||||
exc_tb: The traceback object, if any exception occurred in the context block.
|
||||
|
||||
Returns:
|
||||
None: Indicates that exceptions are never suppressed; they will propagate as normal.
|
||||
|
||||
Implementation Notes:
|
||||
- Automatically deletes all vector store data and backing tables via `delete_all()`
|
||||
upon exiting the context.
|
||||
- This cleanup occurs whether the block exits normally or due to an exception.
|
||||
- Does not suppress exceptions; errors in the context block will continue to propagate.
|
||||
- Use when the vector store lifecycle is intended to be temporary or scoped.
|
||||
|
||||
Example:
|
||||
with MyVectorStore(db_connection, embedder) as vectorstore:
|
||||
vectorstore.add_texts([...])
|
||||
# Vector store is active within this block.
|
||||
# All storage and resources are now cleaned up.
|
||||
"""
|
||||
self.delete_all()
|
||||
# No return, so exceptions are never suppressed
|
||||
48
venv/lib/python3.12/site-packages/mysql/ai/ml/__init__.py
Normal file
48
venv/lib/python3.12/site-packages/mysql/ai/ml/__init__.py
Normal file
@@ -0,0 +1,48 @@
|
||||
# Copyright (c) 2025 Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""ML package for MySQL Connector/Python.
|
||||
|
||||
Performs optional dependency checks and exposes ML utilities:
|
||||
- ML_TASK, MyModel
|
||||
- MyClassifier, MyRegressor, MyGenericTransformer
|
||||
- MyAnomalyDetector
|
||||
"""
|
||||
from mysql.ai.utils import check_dependencies as _check_dependencies
|
||||
|
||||
_check_dependencies(["ML"])
|
||||
del _check_dependencies
|
||||
|
||||
# Sklearn models
|
||||
from .classifier import MyClassifier
|
||||
|
||||
# Minimal interface
|
||||
from .model import ML_TASK, MyModel
|
||||
from .outlier import MyAnomalyDetector
|
||||
from .regressor import MyRegressor
|
||||
from .transformer import MyGenericTransformer
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
142
venv/lib/python3.12/site-packages/mysql/ai/ml/base.py
Normal file
142
venv/lib/python3.12/site-packages/mysql/ai/ml/base.py
Normal file
@@ -0,0 +1,142 @@
|
||||
# Copyright (c) 2025 Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""Base classes for MySQL HeatWave ML estimators for Connector/Python.
|
||||
|
||||
Implements a scikit-learn-compatible base estimator wrapping server-side ML.
|
||||
"""
|
||||
from typing import Optional, Union
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from sklearn.base import BaseEstimator
|
||||
|
||||
from mysql.connector.abstracts import MySQLConnectionAbstract
|
||||
from mysql.ai.ml.model import ML_TASK, MyModel
|
||||
from mysql.ai.utils import copy_dict
|
||||
|
||||
|
||||
class MyBaseMLModel(BaseEstimator):
|
||||
"""
|
||||
Base class for MySQL HeatWave machine learning estimators.
|
||||
|
||||
Implements the scikit-learn API and core model management logic,
|
||||
including fit, explain, serialization, and dynamic option handling.
|
||||
For use as a base class by classifiers, regressors, transformers, and outlier models.
|
||||
|
||||
Args:
|
||||
db_connection (MySQLConnectionAbstract): An active MySQL connector database connection.
|
||||
task (str): ML task type, e.g. "classification" or "regression".
|
||||
model_name (str, optional): Custom name for the deployed model.
|
||||
fit_extra_options (dict, optional): Extra options for fitting.
|
||||
|
||||
Attributes:
|
||||
_model: Underlying database helper for fit/predict/explain.
|
||||
fit_extra_options: User-provided options for fitting.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_connection: MySQLConnectionAbstract,
|
||||
task: Union[str, ML_TASK],
|
||||
model_name: Optional[str] = None,
|
||||
fit_extra_options: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Initialize a MyBaseMLModel with connection, task, and option parameters.
|
||||
|
||||
Args:
|
||||
db_connection: Active MySQL connector database connection.
|
||||
task: String label of ML task (e.g. "classification").
|
||||
model_name: Optional custom model name.
|
||||
fit_extra_options: Optional extra fit options.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
self._model = MyModel(db_connection, task=task, model_name=model_name)
|
||||
self.fit_extra_options = copy_dict(fit_extra_options)
|
||||
|
||||
def fit(
|
||||
self,
|
||||
X: pd.DataFrame, # pylint: disable=invalid-name
|
||||
y: Optional[pd.DataFrame] = None,
|
||||
) -> "MyBaseMLModel":
|
||||
"""
|
||||
Fit the underlying ML model using pandas DataFrames.
|
||||
Delegates to MyMLModelPandasHelper.fit.
|
||||
|
||||
Args:
|
||||
X: Features DataFrame.
|
||||
y: (Optional) Target labels DataFrame or Series.
|
||||
|
||||
Returns:
|
||||
self
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported.
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Notes:
|
||||
Additional temp SQL resources may be created and cleaned up during the operation.
|
||||
"""
|
||||
self._model.fit(X, y, self.fit_extra_options)
|
||||
return self
|
||||
|
||||
def _delete_model(self) -> bool:
|
||||
"""
|
||||
Deletes the model from the model catalog if present
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Returns:
|
||||
Whether the model was deleted
|
||||
"""
|
||||
return self._model._delete_model()
|
||||
|
||||
def get_model_info(self) -> Optional[dict]:
|
||||
"""
|
||||
Checks if the model name is available. Model info will only be present in the
|
||||
catalog if the model has previously been fitted.
|
||||
|
||||
Returns:
|
||||
True if the model name is not part of the model catalog
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
return self._model.get_model_info()
|
||||
194
venv/lib/python3.12/site-packages/mysql/ai/ml/classifier.py
Normal file
194
venv/lib/python3.12/site-packages/mysql/ai/ml/classifier.py
Normal file
@@ -0,0 +1,194 @@
|
||||
# Copyright (c) 2025 Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""Classifier utilities for MySQL Connector/Python.
|
||||
|
||||
Provides a scikit-learn compatible classifier backed by HeatWave ML.
|
||||
"""
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.base import ClassifierMixin
|
||||
|
||||
from mysql.ai.ml.base import MyBaseMLModel
|
||||
from mysql.ai.ml.model import ML_TASK
|
||||
from mysql.ai.utils import copy_dict
|
||||
|
||||
from mysql.connector.abstracts import MySQLConnectionAbstract
|
||||
|
||||
|
||||
class MyClassifier(MyBaseMLModel, ClassifierMixin):
|
||||
"""
|
||||
MySQL HeatWave scikit-learn compatible classifier estimator.
|
||||
|
||||
Provides prediction and probability output from a model deployed in MySQL,
|
||||
and manages fit, explain, and prediction options as per HeatWave ML interface.
|
||||
|
||||
Attributes:
|
||||
predict_extra_options (dict): Dictionary of optional parameters passed through
|
||||
to the MySQL backend for prediction and probability inference.
|
||||
_model (MyModel): Underlying interface for database model operations.
|
||||
fit_extra_options (dict): See MyBaseMLModel.
|
||||
|
||||
Args:
|
||||
db_connection (MySQLConnectionAbstract): Active MySQL connector DB connection.
|
||||
model_name (str, optional): Custom name for the model.
|
||||
fit_extra_options (dict, optional): Extra options for fitting.
|
||||
explain_extra_options (dict, optional): Extra options for explanations.
|
||||
predict_extra_options (dict, optional): Extra options for predict/predict_proba.
|
||||
|
||||
Methods:
|
||||
predict(X): Predict class labels.
|
||||
predict_proba(X): Predict class probabilities.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_connection: MySQLConnectionAbstract,
|
||||
model_name: Optional[str] = None,
|
||||
fit_extra_options: Optional[dict] = None,
|
||||
explain_extra_options: Optional[dict] = None,
|
||||
predict_extra_options: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Initialize a MyClassifier.
|
||||
|
||||
Args:
|
||||
db_connection: Active MySQL connector database connection.
|
||||
model_name: Optional, custom model name.
|
||||
fit_extra_options: Optional fit options.
|
||||
explain_extra_options: Optional explain options.
|
||||
predict_extra_options: Optional predict/predict_proba options.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
MyBaseMLModel.__init__(
|
||||
self,
|
||||
db_connection,
|
||||
ML_TASK.CLASSIFICATION,
|
||||
model_name=model_name,
|
||||
fit_extra_options=fit_extra_options,
|
||||
)
|
||||
self.predict_extra_options = copy_dict(predict_extra_options)
|
||||
self.explain_extra_options = copy_dict(explain_extra_options)
|
||||
|
||||
def predict(
|
||||
self, X: Union[pd.DataFrame, np.ndarray]
|
||||
) -> np.ndarray: # pylint: disable=invalid-name
|
||||
"""
|
||||
Predict class labels for the input features using the MySQL model.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-predict-table.html
|
||||
A full list of supported options can be found under "ML_PREDICT_TABLE Options"
|
||||
|
||||
Args:
|
||||
X: Input samples as a numpy array or pandas DataFrame.
|
||||
|
||||
Returns:
|
||||
ndarray: Array of predicted class labels, shape (n_samples,).
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
result = self._model.predict(X, options=self.predict_extra_options)
|
||||
return result["Prediction"].to_numpy()
|
||||
|
||||
def predict_proba(
|
||||
self, X: Union[pd.DataFrame, np.ndarray]
|
||||
) -> np.ndarray: # pylint: disable=invalid-name
|
||||
"""
|
||||
Predict class probabilities for the input features using the MySQL model.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-predict-table.html
|
||||
A full list of supported options can be found under "ML_PREDICT_TABLE Options"
|
||||
|
||||
Args:
|
||||
X: Input samples as a numpy array or pandas DataFrame.
|
||||
|
||||
Returns:
|
||||
ndarray: Array of shape (n_samples, n_classes) with class probabilities.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
result = self._model.predict(X, options=self.predict_extra_options)
|
||||
|
||||
classes = sorted(result["ml_results"].iloc[0]["probabilities"].keys())
|
||||
|
||||
return np.stack(
|
||||
result["ml_results"].map(
|
||||
lambda ml_result: [
|
||||
ml_result["probabilities"][class_name] for class_name in classes
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
def explain_predictions(
|
||||
self, X: Union[pd.DataFrame, np.ndarray]
|
||||
) -> pd.DataFrame: # pylint: disable=invalid-name
|
||||
"""
|
||||
Explain model predictions using provided data.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-explain-table.html
|
||||
A full list of supported options can be found under "ML_EXPLAIN_TABLE Options"
|
||||
|
||||
Args:
|
||||
X: DataFrame for which predictions should be explained.
|
||||
|
||||
Returns:
|
||||
DataFrame containing explanation details (feature attributions, etc.)
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Notes:
|
||||
Temporary input/output tables are cleaned up after explanation.
|
||||
"""
|
||||
self._model.explain_predictions(X, options=self.explain_extra_options)
|
||||
780
venv/lib/python3.12/site-packages/mysql/ai/ml/model.py
Normal file
780
venv/lib/python3.12/site-packages/mysql/ai/ml/model.py
Normal file
@@ -0,0 +1,780 @@
|
||||
# Copyright (c) 2025 Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
"""HeatWave ML model utilities for MySQL Connector/Python.
|
||||
|
||||
Provides classes to manage training, prediction, scoring, and explanations
|
||||
via MySQL HeatWave stored procedures.
|
||||
"""
|
||||
import copy
|
||||
import json
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from mysql.ai.utils import (
|
||||
VAR_NAME_SPACE,
|
||||
atomic_transaction,
|
||||
convert_to_df,
|
||||
execute_sql,
|
||||
format_value_sql,
|
||||
get_random_name,
|
||||
source_schema,
|
||||
sql_response_to_df,
|
||||
sql_table_from_df,
|
||||
sql_table_to_df,
|
||||
table_exists,
|
||||
temporary_sql_tables,
|
||||
validate_name,
|
||||
)
|
||||
from mysql.connector.abstracts import MySQLConnectionAbstract
|
||||
|
||||
|
||||
class ML_TASK(Enum): # pylint: disable=invalid-name
|
||||
"""Enumeration of supported ML tasks for HeatWave."""
|
||||
|
||||
CLASSIFICATION = "classification"
|
||||
REGRESSION = "regression"
|
||||
FORECASTING = "forecasting"
|
||||
ANOMALY_DETECTION = "anomaly_detection"
|
||||
LOG_ANOMALY_DETECTION = "log_anomaly_detection"
|
||||
RECOMMENDATION = "recommendation"
|
||||
TOPIC_MODELING = "topic_modeling"
|
||||
|
||||
@staticmethod
|
||||
def get_task_string(task: Union[str, "ML_TASK"]) -> str:
|
||||
"""
|
||||
Return the string representation of a machine learning task.
|
||||
|
||||
Args:
|
||||
task (Union[str, ML_TASK]): The task to convert.
|
||||
Accepts either a task enum member (ML_TASK) or a string.
|
||||
|
||||
Returns:
|
||||
str: The string value of the ML task.
|
||||
"""
|
||||
|
||||
if isinstance(task, str):
|
||||
return task
|
||||
|
||||
return task.value
|
||||
|
||||
|
||||
class _MyModelCommon:
|
||||
"""
|
||||
Common utilities and workflow for MySQL HeatWave ML models.
|
||||
|
||||
This class handles model lifecycle steps such as loading, fitting, scoring,
|
||||
making predictions, and explaining models or predictions. Not intended for
|
||||
direct instantiation, but as a superclass for heatwave model wrappers.
|
||||
|
||||
Attributes:
|
||||
db_connection: MySQL connector database connection.
|
||||
task: ML task, e.g., "classification" or "regression".
|
||||
model_name: Identifier of model in MySQL.
|
||||
schema_name: Database schema used for operations and temp tables.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_connection: MySQLConnectionAbstract,
|
||||
task: Union[str, ML_TASK] = ML_TASK.CLASSIFICATION,
|
||||
model_name: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Instantiate _MyMLModelCommon.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-train.html
|
||||
A full list of supported tasks can be found under "Common ML_TRAIN Options"
|
||||
|
||||
Args:
|
||||
db_connection: MySQL database connection.
|
||||
task: ML task type (default: "classification").
|
||||
model_name: Name to register the model within MySQL (default: None).
|
||||
|
||||
Raises:
|
||||
ValueError: If the schema name is not valid
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self.db_connection = db_connection
|
||||
self.task = ML_TASK.get_task_string(task)
|
||||
self.schema_name = source_schema(db_connection)
|
||||
|
||||
with atomic_transaction(self.db_connection) as cursor:
|
||||
execute_sql(cursor, "CALL sys.ML_CREATE_OR_UPGRADE_CATALOG();")
|
||||
|
||||
if model_name is None:
|
||||
model_name = get_random_name(self._is_model_name_available)
|
||||
|
||||
self.model_var = f"{VAR_NAME_SPACE}.{model_name}"
|
||||
self.model_var_score = f"{self.model_var}.score"
|
||||
|
||||
self.model_name = model_name
|
||||
validate_name(model_name)
|
||||
|
||||
with atomic_transaction(self.db_connection) as cursor:
|
||||
execute_sql(cursor, f"SET @{self.model_var} = %s;", params=(model_name,))
|
||||
|
||||
def _delete_model(self) -> bool:
|
||||
"""
|
||||
Deletes the model from the model catalog if present
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Returns:
|
||||
Whether the model was deleted
|
||||
"""
|
||||
current_user = self._get_user()
|
||||
|
||||
qualified_model_catalog = f"ML_SCHEMA_{current_user}.MODEL_CATALOG"
|
||||
delete_model = (
|
||||
f"DELETE FROM {qualified_model_catalog} "
|
||||
f"WHERE model_handle = @{self.model_var}"
|
||||
)
|
||||
|
||||
with atomic_transaction(self.db_connection) as cursor:
|
||||
execute_sql(cursor, delete_model)
|
||||
return cursor.rowcount > 0
|
||||
|
||||
def _get_model_info(self, model_name: str) -> Optional[dict]:
|
||||
"""
|
||||
Retrieves the model info from the model_catalog
|
||||
|
||||
Args:
|
||||
model_var: The model alias to retrieve
|
||||
|
||||
Returns:
|
||||
The model info from the model_catalog (None if the model is not present in the catalog)
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
|
||||
def process_col(elem: Any) -> Any:
|
||||
if isinstance(elem, str):
|
||||
try:
|
||||
elem = json.loads(elem)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return elem
|
||||
|
||||
current_user = self._get_user()
|
||||
|
||||
qualified_model_catalog = f"ML_SCHEMA_{current_user}.MODEL_CATALOG"
|
||||
model_exists = (
|
||||
f"SELECT * FROM {qualified_model_catalog} WHERE model_handle = %s"
|
||||
)
|
||||
|
||||
with atomic_transaction(self.db_connection) as cursor:
|
||||
execute_sql(cursor, model_exists, params=(model_name,))
|
||||
model_info_df = sql_response_to_df(cursor)
|
||||
|
||||
if model_info_df.empty:
|
||||
result = None
|
||||
else:
|
||||
unprocessed_result = model_info_df.to_json(orient="records")
|
||||
unprocessed_result_json = json.loads(unprocessed_result)[0]
|
||||
result = {
|
||||
key: process_col(elem)
|
||||
for key, elem in unprocessed_result_json.items()
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def get_model_info(self) -> Optional[dict]:
|
||||
"""
|
||||
Checks if the model name is available.
|
||||
Model info is present in the catalog only if the model was previously fitted.
|
||||
|
||||
Returns:
|
||||
True if the model name is not part of the model catalog
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
return self._get_model_info(self.model_name)
|
||||
|
||||
def _is_model_name_available(self, model_name: str) -> bool:
|
||||
"""
|
||||
Checks if the model name is available
|
||||
|
||||
Returns:
|
||||
True if the model name is not part of the model catalog
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
return self._get_model_info(model_name) is None
|
||||
|
||||
def _load_model(self) -> None:
|
||||
"""
|
||||
Loads the model specified by `self.model_name` into MySQL.
|
||||
After loading, the model is ready to handle ML operations.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-model-load.html
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If the model is not initialized, i.e., fit or import has not been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
with atomic_transaction(self.db_connection) as cursor:
|
||||
load_model_query = f"CALL sys.ML_MODEL_LOAD(@{self.model_var}, NULL);"
|
||||
execute_sql(cursor, load_model_query)
|
||||
|
||||
def _get_user(self) -> str:
|
||||
"""
|
||||
Fetch the current database user (without host).
|
||||
|
||||
Returns:
|
||||
The username string associated with the connection.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
ValueError: If the user name includes unsupported characters
|
||||
"""
|
||||
with atomic_transaction(self.db_connection) as cursor:
|
||||
cursor.execute("SELECT CURRENT_USER()")
|
||||
current_user = cursor.fetchone()[0].split("@")[0]
|
||||
|
||||
return validate_name(current_user)
|
||||
|
||||
def explain_model(self) -> dict:
|
||||
"""
|
||||
Get model explanations, such as detailed feature importances.
|
||||
|
||||
Returns:
|
||||
dict: Feature importances and model explainability data.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-model-explanations.html
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If the model is not initialized, i.e., fit or import has not been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
ValueError:
|
||||
If the model does not exist in the model catalog.
|
||||
Should only occur if model was not fitted or was deleted.
|
||||
"""
|
||||
self._load_model()
|
||||
with atomic_transaction(self.db_connection) as cursor:
|
||||
current_user = self._get_user()
|
||||
|
||||
qualified_model_catalog = f"ML_SCHEMA_{current_user}.MODEL_CATALOG"
|
||||
explain_query = (
|
||||
f"SELECT model_explanation FROM {qualified_model_catalog} "
|
||||
f"WHERE model_handle = @{self.model_var}"
|
||||
)
|
||||
|
||||
execute_sql(cursor, explain_query)
|
||||
df = sql_response_to_df(cursor)
|
||||
|
||||
return df.iloc[0, 0]
|
||||
|
||||
def _fit(
|
||||
self,
|
||||
table_name: str,
|
||||
target_column_name: Optional[str],
|
||||
options: Optional[dict],
|
||||
) -> None:
|
||||
"""
|
||||
Fit an ML model using a referenced SQL table and target column.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-train.html
|
||||
A full list of supported options can be found under "Common ML_TRAIN Options"
|
||||
|
||||
Args:
|
||||
table_name: Name of the training data table.
|
||||
target_column_name: Name of the target/label column.
|
||||
options: Additional fit/config options (may override defaults).
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported.
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
ValueError: If the table or target_column name is not valid
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
validate_name(table_name)
|
||||
if target_column_name is not None:
|
||||
validate_name(target_column_name)
|
||||
target_col_string = f"'{target_column_name}'"
|
||||
else:
|
||||
target_col_string = "NULL"
|
||||
|
||||
if options is None:
|
||||
options = {}
|
||||
options = copy.deepcopy(options)
|
||||
options["task"] = self.task
|
||||
|
||||
self._delete_model()
|
||||
|
||||
with atomic_transaction(self.db_connection) as cursor:
|
||||
placeholders, parameters = format_value_sql(options)
|
||||
execute_sql(
|
||||
cursor,
|
||||
(
|
||||
"CALL sys.ML_TRAIN("
|
||||
f"'{self.schema_name}.{table_name}', "
|
||||
f"{target_col_string}, "
|
||||
f"{placeholders}, "
|
||||
f"@{self.model_var}"
|
||||
")"
|
||||
),
|
||||
params=parameters,
|
||||
)
|
||||
|
||||
def _predict(
|
||||
self, table_name: str, output_table_name: str, options: Optional[dict]
|
||||
) -> None:
|
||||
"""
|
||||
Predict on a given data table and write results to an output table.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-predict-table.html
|
||||
A full list of supported options can be found under "ML_PREDICT_TABLE Options"
|
||||
|
||||
Args:
|
||||
table_name: Name of the SQL table with input data.
|
||||
output_table_name: Name for the SQL output table to contain predictions.
|
||||
options: Optional prediction options.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
ValueError: If the table or output_table name is not valid
|
||||
"""
|
||||
validate_name(table_name)
|
||||
validate_name(output_table_name)
|
||||
|
||||
self._load_model()
|
||||
with atomic_transaction(self.db_connection) as cursor:
|
||||
placeholders, parameters = format_value_sql(options)
|
||||
execute_sql(
|
||||
cursor,
|
||||
(
|
||||
"CALL sys.ML_PREDICT_TABLE("
|
||||
f"'{self.schema_name}.{table_name}', "
|
||||
f"@{self.model_var}, "
|
||||
f"'{self.schema_name}.{output_table_name}', "
|
||||
f"{placeholders}"
|
||||
")"
|
||||
),
|
||||
params=parameters,
|
||||
)
|
||||
|
||||
def _score(
|
||||
self,
|
||||
table_name: str,
|
||||
target_column_name: str,
|
||||
metric: str,
|
||||
options: Optional[dict],
|
||||
) -> float:
|
||||
"""
|
||||
Evaluate model performance with a scoring metric.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-score.html
|
||||
A full list of supported options can be found under
|
||||
"Options for Recommendation Models" and
|
||||
"Options for Anomaly Detection Models"
|
||||
|
||||
Args:
|
||||
table_name: Table with features and ground truth.
|
||||
target_column_name: Column of true target labels.
|
||||
metric: String name of the metric to compute.
|
||||
options: Optional dictionary of further scoring options.
|
||||
|
||||
Returns:
|
||||
float: Computed score from the ML system.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
ValueError: If the table or target_column name or metric is not valid
|
||||
"""
|
||||
validate_name(table_name)
|
||||
validate_name(target_column_name)
|
||||
validate_name(metric)
|
||||
|
||||
self._load_model()
|
||||
with atomic_transaction(self.db_connection) as cursor:
|
||||
placeholders, parameters = format_value_sql(options)
|
||||
execute_sql(
|
||||
cursor,
|
||||
(
|
||||
"CALL sys.ML_SCORE("
|
||||
f"'{self.schema_name}.{table_name}', "
|
||||
f"'{target_column_name}', "
|
||||
f"@{self.model_var}, "
|
||||
"%s, "
|
||||
f"@{self.model_var_score}, "
|
||||
f"{placeholders}"
|
||||
")"
|
||||
),
|
||||
params=[metric, *parameters],
|
||||
)
|
||||
execute_sql(cursor, f"SELECT @{self.model_var_score}")
|
||||
df = sql_response_to_df(cursor)
|
||||
|
||||
return df.iloc[0, 0]
|
||||
|
||||
def _explain_predictions(
|
||||
self, table_name: str, output_table_name: str, options: Optional[dict]
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Produce explanations for model predictions on provided data.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-explain-table.html
|
||||
A full list of supported options can be found under "ML_EXPLAIN_TABLE Options"
|
||||
|
||||
Args:
|
||||
table_name: Name of the SQL table with input data.
|
||||
output_table_name: Name for the SQL table to store explanations.
|
||||
options: Optional dictionary (default:
|
||||
{"prediction_explainer": "permutation_importance"}).
|
||||
|
||||
Returns:
|
||||
DataFrame: Prediction explanations from the output SQL table.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
ValueError: If the table or output_table name is not valid
|
||||
"""
|
||||
validate_name(table_name)
|
||||
validate_name(output_table_name)
|
||||
|
||||
if options is None:
|
||||
options = {"prediction_explainer": "permutation_importance"}
|
||||
|
||||
self._load_model()
|
||||
|
||||
with atomic_transaction(self.db_connection) as cursor:
|
||||
placeholders, parameters = format_value_sql(options)
|
||||
execute_sql(
|
||||
cursor,
|
||||
(
|
||||
"CALL sys.ML_EXPLAIN_TABLE("
|
||||
f"'{self.schema_name}.{table_name}', "
|
||||
f"@{self.model_var}, "
|
||||
f"'{self.schema_name}.{output_table_name}', "
|
||||
f"{placeholders}"
|
||||
")"
|
||||
),
|
||||
params=parameters,
|
||||
)
|
||||
execute_sql(cursor, f"SELECT * FROM {self.schema_name}.{output_table_name}")
|
||||
df = sql_response_to_df(cursor)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
class MyModel(_MyModelCommon):
|
||||
"""
|
||||
Convenience class for managing the ML workflow using pandas DataFrames.
|
||||
|
||||
Methods convert in-memory DataFrames into temp SQL tables before delegating to the
|
||||
_MyMLModelCommon routines, and automatically clean up temp resources.
|
||||
"""
|
||||
|
||||
def fit(
|
||||
self,
|
||||
X: Union[pd.DataFrame, np.ndarray], # pylint: disable=invalid-name
|
||||
y: Optional[Union[pd.DataFrame, np.ndarray]],
|
||||
options: Optional[dict] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Fit a model using DataFrame inputs.
|
||||
|
||||
If an 'id' column is defined in either dataframe, it will be used as the primary key.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-train.html
|
||||
A full list of supported options can be found under "Common ML_TRAIN Options"
|
||||
|
||||
Args:
|
||||
X: Features DataFrame.
|
||||
y: (Optional) Target labels DataFrame or Series. If None, only X is used.
|
||||
options: Additional options to pass to training.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported.
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Notes:
|
||||
Combines X and y as necessary. Creates a temporary table in the schema for training,
|
||||
and deletes it afterward.
|
||||
"""
|
||||
X, y = convert_to_df(X), convert_to_df(y)
|
||||
|
||||
with (
|
||||
atomic_transaction(self.db_connection) as cursor,
|
||||
temporary_sql_tables(self.db_connection) as temporary_tables,
|
||||
):
|
||||
if y is not None:
|
||||
if isinstance(y, pd.DataFrame):
|
||||
# keep column name if it exists
|
||||
target_column_name = y.columns[0]
|
||||
else:
|
||||
target_column_name = get_random_name(
|
||||
lambda name: name not in X.columns
|
||||
)
|
||||
|
||||
if target_column_name in X.columns:
|
||||
raise ValueError(
|
||||
f"Target column y with name {target_column_name} already present "
|
||||
"in feature dataframe X"
|
||||
)
|
||||
|
||||
df_combined = X.copy()
|
||||
df_combined[target_column_name] = y
|
||||
final_df = df_combined
|
||||
else:
|
||||
target_column_name = None
|
||||
final_df = X
|
||||
|
||||
_, table_name = sql_table_from_df(cursor, self.schema_name, final_df)
|
||||
temporary_tables.append((self.schema_name, table_name))
|
||||
|
||||
self._fit(table_name, target_column_name, options)
|
||||
|
||||
def predict(
|
||||
self,
|
||||
X: Union[pd.DataFrame, np.ndarray], # pylint: disable=invalid-name
|
||||
options: Optional[dict] = None,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Generate model predictions using DataFrame input.
|
||||
|
||||
If an 'id' column is defined in either dataframe, it will be used as the primary key.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-predict-table.html
|
||||
A full list of supported options can be found under "ML_PREDICT_TABLE Options"
|
||||
|
||||
Args:
|
||||
X: DataFrame containing prediction features (no labels).
|
||||
options: Additional prediction settings.
|
||||
|
||||
Returns:
|
||||
DataFrame with prediction results as returned by HeatWave.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Notes:
|
||||
Temporary SQL tables are created and deleted for input/output.
|
||||
"""
|
||||
X = convert_to_df(X)
|
||||
|
||||
with (
|
||||
atomic_transaction(self.db_connection) as cursor,
|
||||
temporary_sql_tables(self.db_connection) as temporary_tables,
|
||||
):
|
||||
_, table_name = sql_table_from_df(cursor, self.schema_name, X)
|
||||
temporary_tables.append((self.schema_name, table_name))
|
||||
|
||||
output_table_name = get_random_name(
|
||||
lambda table_name: not table_exists(
|
||||
cursor, self.schema_name, table_name
|
||||
)
|
||||
)
|
||||
temporary_tables.append((self.schema_name, output_table_name))
|
||||
|
||||
self._predict(table_name, output_table_name, options)
|
||||
predictions = sql_table_to_df(cursor, self.schema_name, output_table_name)
|
||||
|
||||
# ml_results is text but known to always follow JSON format
|
||||
predictions["ml_results"] = predictions["ml_results"].map(json.loads)
|
||||
|
||||
return predictions
|
||||
|
||||
def score(
|
||||
self,
|
||||
X: Union[pd.DataFrame, np.ndarray], # pylint: disable=invalid-name
|
||||
y: Union[pd.DataFrame, np.ndarray],
|
||||
metric: str,
|
||||
options: Optional[dict] = None,
|
||||
) -> float:
|
||||
"""
|
||||
Score the model using X/y data and a selected metric.
|
||||
|
||||
If an 'id' column is defined in either dataframe, it will be used as the primary key.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-score.html
|
||||
A full list of supported options can be found under
|
||||
"Options for Recommendation Models" and
|
||||
"Options for Anomaly Detection Models"
|
||||
|
||||
Args:
|
||||
X: DataFrame of features.
|
||||
y: DataFrame or Series of labels.
|
||||
metric: Metric name (e.g., "balanced_accuracy").
|
||||
options: Optional ml scoring options.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Returns:
|
||||
float: Computed score.
|
||||
"""
|
||||
X, y = convert_to_df(X), convert_to_df(y)
|
||||
|
||||
with (
|
||||
atomic_transaction(self.db_connection) as cursor,
|
||||
temporary_sql_tables(self.db_connection) as temporary_tables,
|
||||
):
|
||||
target_column_name = get_random_name(lambda name: name not in X.columns)
|
||||
df_combined = X.copy()
|
||||
df_combined[target_column_name] = y
|
||||
final_df = df_combined
|
||||
|
||||
_, table_name = sql_table_from_df(cursor, self.schema_name, final_df)
|
||||
temporary_tables.append((self.schema_name, table_name))
|
||||
|
||||
score = self._score(table_name, target_column_name, metric, options)
|
||||
|
||||
return score
|
||||
|
||||
def explain_predictions(
|
||||
self,
|
||||
X: Union[pd.DataFrame, np.ndarray], # pylint: disable=invalid-name
|
||||
options: Dict = None,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Explain model predictions using provided data.
|
||||
|
||||
If an 'id' column is defined in either dataframe, it will be used as the primary key.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-explain-table.html
|
||||
A full list of supported options can be found under
|
||||
"ML_EXPLAIN_TABLE Options"
|
||||
|
||||
Args:
|
||||
X: DataFrame for which predictions should be explained.
|
||||
options: Optional dictionary of explainability options.
|
||||
|
||||
Returns:
|
||||
DataFrame containing explanation details (feature attributions, etc.)
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported, or if the model is not initialized,
|
||||
i.e., fit or import has not been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Notes:
|
||||
Temporary input/output tables are cleaned up after explanation.
|
||||
"""
|
||||
X = convert_to_df(X)
|
||||
|
||||
with (
|
||||
atomic_transaction(self.db_connection) as cursor,
|
||||
temporary_sql_tables(self.db_connection) as temporary_tables,
|
||||
):
|
||||
|
||||
_, table_name = sql_table_from_df(cursor, self.schema_name, X)
|
||||
temporary_tables.append((self.schema_name, table_name))
|
||||
|
||||
output_table_name = get_random_name(
|
||||
lambda table_name: not table_exists(
|
||||
cursor, self.schema_name, table_name
|
||||
)
|
||||
)
|
||||
temporary_tables.append((self.schema_name, output_table_name))
|
||||
|
||||
explanations = self._explain_predictions(
|
||||
table_name, output_table_name, options
|
||||
)
|
||||
|
||||
return explanations
|
||||
221
venv/lib/python3.12/site-packages/mysql/ai/ml/outlier.py
Normal file
221
venv/lib/python3.12/site-packages/mysql/ai/ml/outlier.py
Normal file
@@ -0,0 +1,221 @@
|
||||
# Copyright (c) 2025 Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""Outlier/anomaly detection utilities for MySQL Connector/Python.
|
||||
|
||||
Provides a scikit-learn compatible wrapper using HeatWave to score anomalies.
|
||||
"""
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.base import OutlierMixin
|
||||
|
||||
from mysql.ai.ml.base import MyBaseMLModel
|
||||
from mysql.ai.ml.model import ML_TASK
|
||||
from mysql.ai.utils import copy_dict
|
||||
|
||||
from mysql.connector.abstracts import MySQLConnectionAbstract
|
||||
|
||||
EPS = 1e-5
|
||||
|
||||
|
||||
def _get_logits(prob: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
|
||||
"""
|
||||
Compute logit (logodds) for a probability, clipping to avoid numerical overflow.
|
||||
|
||||
Args:
|
||||
prob: Scalar or array of probability values in (0,1).
|
||||
|
||||
Returns:
|
||||
logit-transformed probabilities.
|
||||
"""
|
||||
result = np.clip(prob, EPS, 1 - EPS)
|
||||
return np.log(result / (1 - result))
|
||||
|
||||
|
||||
class MyAnomalyDetector(MyBaseMLModel, OutlierMixin):
|
||||
"""
|
||||
MySQL HeatWave scikit-learn compatible anomaly/outlier detector.
|
||||
|
||||
Flags samples as outliers when the probability of being an anomaly
|
||||
exceeds a user-tunable threshold.
|
||||
Includes helpers to obtain decision scores and anomaly probabilities
|
||||
for ranking.
|
||||
|
||||
Args:
|
||||
db_connection (MySQLConnectionAbstract): Active MySQL DB connection.
|
||||
model_name (str, optional): Custom model name in the database.
|
||||
fit_extra_options (dict, optional): Extra options for fitting.
|
||||
score_extra_options (dict, optional): Extra options for scoring/prediction.
|
||||
|
||||
Attributes:
|
||||
boundary: Decision threshold boundary in logit space. Derived from
|
||||
trained model's catalog info
|
||||
|
||||
Methods:
|
||||
predict(X): Predict outlier/inlier labels.
|
||||
score_samples(X): Compute anomaly (normal class) logit scores.
|
||||
decision_function(X): Compute signed score above/below threshold for ranking.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_connection: MySQLConnectionAbstract,
|
||||
model_name: Optional[str] = None,
|
||||
fit_extra_options: Optional[dict] = None,
|
||||
score_extra_options: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Initialize an anomaly detector instance with threshold and extra options.
|
||||
|
||||
Args:
|
||||
db_connection: Active MySQL DB connection.
|
||||
model_name: Optional model name in DB.
|
||||
fit_extra_options: Optional extra fit options.
|
||||
score_extra_options: Optional extra scoring options.
|
||||
|
||||
Raises:
|
||||
ValueError: If outlier_threshold is not in (0,1).
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
MyBaseMLModel.__init__(
|
||||
self,
|
||||
db_connection,
|
||||
ML_TASK.ANOMALY_DETECTION,
|
||||
model_name=model_name,
|
||||
fit_extra_options=fit_extra_options,
|
||||
)
|
||||
self.score_extra_options = copy_dict(score_extra_options)
|
||||
self.boundary: Optional[float] = None
|
||||
|
||||
def predict(
|
||||
self,
|
||||
X: Union[pd.DataFrame, np.ndarray], # pylint: disable=invalid-name
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Predict outlier/inlier binary labels for input samples.
|
||||
|
||||
Args:
|
||||
X: Samples to predict on.
|
||||
|
||||
Returns:
|
||||
ndarray: Values are -1 for outliers, +1 for inliers, as per scikit-learn convention.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
return np.where(self.decision_function(X) < 0.0, -1, 1)
|
||||
|
||||
def decision_function(
|
||||
self,
|
||||
X: Union[pd.DataFrame, np.ndarray], # pylint: disable=invalid-name
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute signed distance to the outlier threshold.
|
||||
|
||||
Args:
|
||||
X: Samples to predict on.
|
||||
|
||||
Returns:
|
||||
ndarray: Score > 0 means inlier, < 0 means outlier; |value| gives margin.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
ValueError:
|
||||
If the provided model info does not provide threshold
|
||||
"""
|
||||
sample_scores = self.score_samples(X)
|
||||
|
||||
if self.boundary is None:
|
||||
model_info = self.get_model_info()
|
||||
if model_info is None:
|
||||
raise ValueError("Model does not exist in catalog.")
|
||||
|
||||
threshold = model_info["model_metadata"]["training_params"].get(
|
||||
"anomaly_detection_threshold", None
|
||||
)
|
||||
if threshold is None:
|
||||
raise ValueError(
|
||||
"Trained model is outdated and does not support threshold. "
|
||||
"Try retraining or using an existing, trained model with MyModel."
|
||||
)
|
||||
|
||||
# scikit-learn uses large positive values as inlier
|
||||
# and negative as outlier, so we need to flip our threshold
|
||||
self.boundary = _get_logits(1.0 - threshold)
|
||||
|
||||
return sample_scores - self.boundary
|
||||
|
||||
def score_samples(
|
||||
self,
|
||||
X: Union[pd.DataFrame, np.ndarray], # pylint: disable=invalid-name
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute normal probability logit score for each sample.
|
||||
Used for ranking, thresholding.
|
||||
|
||||
Args:
|
||||
X: Samples to score.
|
||||
|
||||
Returns:
|
||||
ndarray: Logit scores based on "normal" class probability.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
result = self._model.predict(X, options=self.score_extra_options)
|
||||
|
||||
return _get_logits(
|
||||
result["ml_results"]
|
||||
.apply(lambda x: x["probabilities"]["normal"])
|
||||
.to_numpy()
|
||||
)
|
||||
154
venv/lib/python3.12/site-packages/mysql/ai/ml/regressor.py
Normal file
154
venv/lib/python3.12/site-packages/mysql/ai/ml/regressor.py
Normal file
@@ -0,0 +1,154 @@
|
||||
# Copyright (c) 2025 Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""Regressor utilities for MySQL Connector/Python.
|
||||
|
||||
Provides a scikit-learn compatible regressor backed by HeatWave ML.
|
||||
"""
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.base import RegressorMixin
|
||||
|
||||
from mysql.ai.ml.base import MyBaseMLModel
|
||||
from mysql.ai.ml.model import ML_TASK
|
||||
from mysql.ai.utils import copy_dict
|
||||
|
||||
from mysql.connector.abstracts import MySQLConnectionAbstract
|
||||
|
||||
|
||||
class MyRegressor(MyBaseMLModel, RegressorMixin):
|
||||
"""
|
||||
MySQL HeatWave scikit-learn compatible regressor estimator.
|
||||
|
||||
Provides prediction output from a regression model deployed in MySQL,
|
||||
and manages fit, explain, and prediction options as per HeatWave ML interface.
|
||||
|
||||
Attributes:
|
||||
predict_extra_options (dict): Optional parameter dict passed to the backend for prediction.
|
||||
_model (MyModel): Underlying interface for database model operations.
|
||||
fit_extra_options (dict): See MyBaseMLModel.
|
||||
explain_extra_options (dict): See MyBaseMLModel.
|
||||
|
||||
Args:
|
||||
db_connection (MySQLConnectionAbstract): Active MySQL connector DB connection.
|
||||
model_name (str, optional): Custom name for the model.
|
||||
fit_extra_options (dict, optional): Extra options for fitting.
|
||||
explain_extra_options (dict, optional): Extra options for explanations.
|
||||
predict_extra_options (dict, optional): Extra options for predictions.
|
||||
|
||||
Methods:
|
||||
predict(X): Predict regression target.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_connection: MySQLConnectionAbstract,
|
||||
model_name: Optional[str] = None,
|
||||
fit_extra_options: Optional[dict] = None,
|
||||
explain_extra_options: Optional[dict] = None,
|
||||
predict_extra_options: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Initialize a MyRegressor.
|
||||
|
||||
Args:
|
||||
db_connection: Active MySQL connector database connection.
|
||||
model_name: Optional, custom model name.
|
||||
fit_extra_options: Optional fit options.
|
||||
explain_extra_options: Optional explain options.
|
||||
predict_extra_options: Optional prediction options.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
MyBaseMLModel.__init__(
|
||||
self,
|
||||
db_connection,
|
||||
ML_TASK.REGRESSION,
|
||||
model_name=model_name,
|
||||
fit_extra_options=fit_extra_options,
|
||||
)
|
||||
|
||||
self.predict_extra_options = copy_dict(predict_extra_options)
|
||||
self.explain_extra_options = copy_dict(explain_extra_options)
|
||||
|
||||
def predict(
|
||||
self, X: Union[pd.DataFrame, np.ndarray]
|
||||
) -> np.ndarray: # pylint: disable=invalid-name
|
||||
"""
|
||||
Predict a continuous target for the input features using the MySQL model.
|
||||
|
||||
Args:
|
||||
X: Input samples as a numpy array or pandas DataFrame.
|
||||
|
||||
Returns:
|
||||
ndarray: Array of predicted target values, shape (n_samples,).
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
result = self._model.predict(X, options=self.predict_extra_options)
|
||||
return result["Prediction"].to_numpy()
|
||||
|
||||
def explain_predictions(
|
||||
self, X: Union[pd.DataFrame, np.ndarray]
|
||||
) -> pd.DataFrame: # pylint: disable=invalid-name
|
||||
"""
|
||||
Explain model predictions using provided data.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-explain-table.html
|
||||
A full list of supported options can be found under "ML_EXPLAIN_TABLE Options"
|
||||
|
||||
Args:
|
||||
X: DataFrame for which predictions should be explained.
|
||||
|
||||
Returns:
|
||||
DataFrame containing explanation details (feature attributions, etc.)
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Notes:
|
||||
Temporary input/output tables are cleaned up after explanation.
|
||||
"""
|
||||
self._model.explain_predictions(X, options=self.explain_extra_options)
|
||||
164
venv/lib/python3.12/site-packages/mysql/ai/ml/transformer.py
Normal file
164
venv/lib/python3.12/site-packages/mysql/ai/ml/transformer.py
Normal file
@@ -0,0 +1,164 @@
|
||||
# Copyright (c) 2025 Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
"""Generic transformer utilities for MySQL Connector/Python.
|
||||
|
||||
Provides a scikit-learn compatible Transformer using HeatWave for fit/transform
|
||||
and scoring operations.
|
||||
"""
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.base import TransformerMixin
|
||||
|
||||
from mysql.ai.ml.base import MyBaseMLModel
|
||||
from mysql.ai.ml.model import ML_TASK
|
||||
from mysql.ai.utils import copy_dict
|
||||
from mysql.connector.abstracts import MySQLConnectionAbstract
|
||||
|
||||
|
||||
class MyGenericTransformer(MyBaseMLModel, TransformerMixin):
|
||||
"""
|
||||
MySQL HeatWave scikit-learn compatible generic transformer.
|
||||
|
||||
Can be used as the transformation step in an sklearn pipeline. Implements fit, transform,
|
||||
explain, and scoring capability, passing options for server-side transform logic.
|
||||
|
||||
Args:
|
||||
db_connection (MySQLConnectionAbstract): Active MySQL connector database connection.
|
||||
task (str): ML task type for transformer (default: "classification").
|
||||
score_metric (str): Scoring metric to request from backend (default: "balanced_accuracy").
|
||||
model_name (str, optional): Custom name for the deployed model.
|
||||
fit_extra_options (dict, optional): Extra fit options.
|
||||
transform_extra_options (dict, optional): Extra options for transformations.
|
||||
score_extra_options (dict, optional): Extra options for scoring.
|
||||
|
||||
Attributes:
|
||||
score_metric (str): Name of the backend metric to use for scoring
|
||||
(e.g. "balanced_accuracy").
|
||||
score_extra_options (dict): Dictionary of optional scoring parameters;
|
||||
passed to backend score.
|
||||
transform_extra_options (dict): Dictionary of inference (/predict)
|
||||
parameters for the backend.
|
||||
fit_extra_options (dict): See MyBaseMLModel.
|
||||
_model (MyModel): Underlying interface for database model operations.
|
||||
|
||||
Methods:
|
||||
fit(X, y): Fit the underlying model using the provided features/targets.
|
||||
transform(X): Transform features using the backend model.
|
||||
score(X, y): Score data using backend metric and options.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_connection: MySQLConnectionAbstract,
|
||||
task: Union[str, ML_TASK] = ML_TASK.CLASSIFICATION,
|
||||
score_metric: str = "balanced_accuracy",
|
||||
model_name: Optional[str] = None,
|
||||
fit_extra_options: Optional[dict] = None,
|
||||
transform_extra_options: Optional[dict] = None,
|
||||
score_extra_options: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Initialize transformer with required and optional arguments.
|
||||
|
||||
Args:
|
||||
db_connection: Active MySQL backend database connection.
|
||||
task: ML task type for transformer.
|
||||
score_metric: Requested backend scoring metric.
|
||||
model_name: Optional model name for storage.
|
||||
fit_extra_options: Optional extra options for fitting.
|
||||
transform_extra_options: Optional extra options for transformation/inference.
|
||||
score_extra_options: Optional extra scoring options.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
MyBaseMLModel.__init__(
|
||||
self,
|
||||
db_connection,
|
||||
task,
|
||||
model_name=model_name,
|
||||
fit_extra_options=fit_extra_options,
|
||||
)
|
||||
|
||||
self.score_metric = score_metric
|
||||
self.score_extra_options = copy_dict(score_extra_options)
|
||||
|
||||
self.transform_extra_options = copy_dict(transform_extra_options)
|
||||
|
||||
def transform(
|
||||
self, X: pd.DataFrame
|
||||
) -> pd.DataFrame: # pylint: disable=invalid-name
|
||||
"""
|
||||
Transform input data to model predictions using the underlying helper.
|
||||
|
||||
Args:
|
||||
X: DataFrame of features to predict/transform.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: Results of transformation as returned by backend.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
return self._model.predict(X, options=self.transform_extra_options)
|
||||
|
||||
def score(
|
||||
self,
|
||||
X: Union[pd.DataFrame, np.ndarray], # pylint: disable=invalid-name
|
||||
y: Union[pd.DataFrame, np.ndarray],
|
||||
) -> float:
|
||||
"""
|
||||
Score the transformed data using the backend scoring interface.
|
||||
|
||||
Args:
|
||||
X: Transformed features.
|
||||
y: Target labels or data for scoring.
|
||||
|
||||
Returns:
|
||||
float: Score based on backend metric.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
return self._model.score(
|
||||
X, y, self.score_metric, options=self.score_extra_options
|
||||
)
|
||||
44
venv/lib/python3.12/site-packages/mysql/ai/utils/__init__.py
Normal file
44
venv/lib/python3.12/site-packages/mysql/ai/utils/__init__.py
Normal file
@@ -0,0 +1,44 @@
|
||||
# Copyright (c) 2025 Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""Utilities for AI-related helpers in MySQL Connector/Python.
|
||||
|
||||
This package exposes:
|
||||
- check_dependencies(): runtime dependency guard for optional AI features
|
||||
- atomic_transaction(): context manager ensuring atomic DB transactions
|
||||
- utils: general-purpose helpers used by AI integrations
|
||||
|
||||
Importing this package validates base dependencies required for AI utilities.
|
||||
"""
|
||||
|
||||
from .dependencies import check_dependencies
|
||||
|
||||
check_dependencies(["BASE"])
|
||||
|
||||
from .atomic_cursor import atomic_transaction
|
||||
from .utils import *
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,94 @@
|
||||
# Copyright (c) 2025 Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""Atomic transaction context manager utilities for MySQL Connector/Python.
|
||||
|
||||
Provides context manager atomic_transaction() that ensures commit on success
|
||||
and rollback on error without obscuring the original exception.
|
||||
"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Iterator
|
||||
|
||||
from mysql.connector.abstracts import MySQLConnectionAbstract
|
||||
from mysql.connector.cursor import MySQLCursorAbstract
|
||||
|
||||
|
||||
@contextmanager
|
||||
def atomic_transaction(
|
||||
conn: MySQLConnectionAbstract,
|
||||
) -> Iterator[MySQLCursorAbstract]:
|
||||
"""
|
||||
Context manager that wraps a MySQL database cursor and ensures transaction
|
||||
rollback in case of exception.
|
||||
|
||||
NOTE: DDL statements such as CREATE TABLE cause implicit commits. These cannot
|
||||
be managed by a cursor object. Changes made at or before a DDL statement will
|
||||
be committed and not rolled back. Callers are responsible for any cleanup of
|
||||
this type.
|
||||
|
||||
This class acts as a robust, PEP 343-compliant context manager for handling
|
||||
database cursor operations on a MySQL connection. It ensures that all operations
|
||||
executed within the context block are part of the same transaction, and
|
||||
automatically calls `connection.rollback()` if an exception occurs, helping
|
||||
to maintain database integrity. On normal completion (no exception), it simply
|
||||
closes the cursor after use. Exceptions are always propagated to the caller.
|
||||
|
||||
Args:
|
||||
conn: A MySQLConnectionAbstract instance.
|
||||
"""
|
||||
old_autocommit = conn.autocommit
|
||||
cursor = conn.cursor()
|
||||
|
||||
exception_raised = False
|
||||
try:
|
||||
if old_autocommit:
|
||||
conn.autocommit = False
|
||||
|
||||
yield cursor # provide cursor to block
|
||||
|
||||
conn.commit()
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
exception_raised = True
|
||||
try:
|
||||
conn.rollback()
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
# Don't obscure original exception
|
||||
pass
|
||||
|
||||
# Raise original exception
|
||||
raise
|
||||
finally:
|
||||
conn.autocommit = old_autocommit
|
||||
|
||||
try:
|
||||
cursor.close()
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
# don't obscure original exception if exists
|
||||
if not exception_raised:
|
||||
raise
|
||||
@@ -0,0 +1,88 @@
|
||||
# Copyright (c) 2025 Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""Dependency checking utilities for AI features in MySQL Connector/Python.
|
||||
|
||||
Provides check_dependencies() to assert required optional packages are present
|
||||
with acceptable minimum versions at runtime.
|
||||
"""
|
||||
|
||||
import importlib.metadata
|
||||
|
||||
from typing import List
|
||||
|
||||
|
||||
def check_dependencies(tasks: List[str]) -> None:
|
||||
"""
|
||||
Check required runtime dependencies and minimum versions; raise an error
|
||||
if any are missing or version-incompatible.
|
||||
|
||||
This verifies the presence and minimum version of essential Python packages.
|
||||
Missing or insufficient versions cause an ImportError listing the packages
|
||||
and a suggested install command.
|
||||
|
||||
Args:
|
||||
tasks (List[str]): Task types to check requirements for.
|
||||
|
||||
Raises:
|
||||
ImportError: If any required dependencies are missing or below the
|
||||
minimum version.
|
||||
"""
|
||||
task_set = set(tasks)
|
||||
task_set.add("BASE")
|
||||
|
||||
# Requirements: (import_name, min_version)
|
||||
task_to_requirement = {
|
||||
"BASE": [("pandas", "1.5.0")],
|
||||
"GENAI": [
|
||||
("langchain", "0.1.11"),
|
||||
("langchain_core", "0.1.11"),
|
||||
("pydantic", "1.10.0"),
|
||||
],
|
||||
"ML": [("scikit-learn", "1.3.0")],
|
||||
}
|
||||
requirements = []
|
||||
for task in task_set:
|
||||
requirements.extend(task_to_requirement[task])
|
||||
requirements_set = set(requirements)
|
||||
|
||||
problems = []
|
||||
for name, min_version in requirements_set:
|
||||
try:
|
||||
installed_version = importlib.metadata.version(name)
|
||||
# Version comparison uses simple string comparison to avoid extra
|
||||
# dependencies. This is valid for the dependencies defined above;
|
||||
# reconsider if adding packages with version schemes that do not
|
||||
# compare correctly as strings.
|
||||
error = installed_version < min_version
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
error = True
|
||||
if error:
|
||||
problems.append(f"{name} v{min_version} (or later)")
|
||||
if problems:
|
||||
raise ImportError("Please install " + ", ".join(problems) + ".")
|
||||
573
venv/lib/python3.12/site-packages/mysql/ai/utils/utils.py
Normal file
573
venv/lib/python3.12/site-packages/mysql/ai/utils/utils.py
Normal file
@@ -0,0 +1,573 @@
|
||||
# Copyright (c) 2025 Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
"""General utilities for AI features in MySQL Connector/Python.
|
||||
|
||||
Includes helpers for:
|
||||
- defensive dict copying
|
||||
- temporary table lifecycle management
|
||||
- SQL execution and result conversions
|
||||
- DataFrame to/from SQL table utilities
|
||||
- schema/table/column name validation
|
||||
- array-like to DataFrame conversion
|
||||
"""
|
||||
|
||||
import copy
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
import string
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from mysql.ai.utils.atomic_cursor import atomic_transaction
|
||||
|
||||
from mysql.connector.abstracts import MySQLConnectionAbstract
|
||||
from mysql.connector.cursor import MySQLCursorAbstract
|
||||
from mysql.connector.types import ParamsSequenceOrDictType
|
||||
|
||||
VAR_NAME_SPACE = "mysql_ai"
|
||||
RANDOM_TABLE_NAME_LENGTH = 32
|
||||
|
||||
PD_TO_SQL_DTYPE_MAPPING = {
|
||||
"int64": "BIGINT",
|
||||
"float64": "DOUBLE",
|
||||
"object": "LONGTEXT",
|
||||
"bool": "BOOLEAN",
|
||||
"datetime64[ns]": "DATETIME",
|
||||
}
|
||||
|
||||
DEFAULT_SCHEMA = "mysql_ai"
|
||||
|
||||
# Misc Utilities
|
||||
|
||||
|
||||
def copy_dict(options: Optional[dict]) -> dict:
|
||||
"""
|
||||
Make a defensive copy of a dictionary, or return an empty dict if None.
|
||||
|
||||
Args:
|
||||
options: param dict or None
|
||||
|
||||
Returns:
|
||||
dict
|
||||
"""
|
||||
if options is None:
|
||||
return {}
|
||||
|
||||
return copy.deepcopy(options)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def temporary_sql_tables(
|
||||
db_connection: MySQLConnectionAbstract,
|
||||
) -> Iterator[list[tuple[str, str]]]:
|
||||
"""
|
||||
Context manager to track and automatically clean up temporary SQL tables.
|
||||
|
||||
Args:
|
||||
db_connection: Database connection object used to create and delete tables.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Yields:
|
||||
temporary_tables: List of (schema_name, table_name) tuples created during the
|
||||
context. All tables in this list are deleted on context exit.
|
||||
"""
|
||||
temporary_tables: List[Tuple[str, str]] = []
|
||||
try:
|
||||
yield temporary_tables
|
||||
finally:
|
||||
with atomic_transaction(db_connection) as cursor:
|
||||
for schema_name, table_name in temporary_tables:
|
||||
delete_sql_table(cursor, schema_name, table_name)
|
||||
|
||||
|
||||
def execute_sql(
|
||||
cursor: MySQLCursorAbstract, query: str, params: ParamsSequenceOrDictType = None
|
||||
) -> None:
|
||||
"""
|
||||
Execute an SQL query with optional parameters using the given cursor.
|
||||
|
||||
Args:
|
||||
cursor: MySQLCursorAbstract object to execute the query.
|
||||
query: SQL query string to execute.
|
||||
params: Optional sequence or dict providing parameters for the query.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If the provided SQL query/params are invalid
|
||||
If the query is valid but the sql raises as an exception
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
cursor.execute(query, params or ())
|
||||
|
||||
|
||||
def _get_name() -> str:
|
||||
"""
|
||||
Generate a random uppercase string of fixed length for table names.
|
||||
|
||||
Returns:
|
||||
Random string of length RANDOM_TABLE_NAME_LENGTH.
|
||||
"""
|
||||
char_set = string.ascii_uppercase
|
||||
return "".join(random.choices(char_set, k=RANDOM_TABLE_NAME_LENGTH))
|
||||
|
||||
|
||||
def get_random_name(condition: Callable[[str], bool], max_calls: int = 100) -> str:
|
||||
"""
|
||||
Generate a random string name that satisfies a given condition.
|
||||
|
||||
Args:
|
||||
condition: Callable that takes a generated name and returns True if it is valid.
|
||||
max_calls: Maximum number of attempts before giving up (default 100).
|
||||
|
||||
Returns:
|
||||
A random string that fulfills the provided condition.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the maximum number of attempts is reached without success.
|
||||
"""
|
||||
for _ in range(max_calls):
|
||||
if condition(name := _get_name()):
|
||||
return name
|
||||
# condition never met
|
||||
raise RuntimeError("Reached max tries without successfully finding a unique name")
|
||||
|
||||
|
||||
# Format conversions
|
||||
|
||||
|
||||
def format_value_sql(value: Any) -> Tuple[str, List[Any]]:
|
||||
"""
|
||||
Convert a Python value into its SQL-compatible string representation and parameters.
|
||||
|
||||
Args:
|
||||
value: The value to format.
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- A string for substitution into a SQL query.
|
||||
- A list of parameters to be bound into the query.
|
||||
"""
|
||||
if isinstance(value, (dict, list)):
|
||||
if len(value) == 0:
|
||||
return "%s", [None]
|
||||
return "CAST(%s as JSON)", [json.dumps(value)]
|
||||
return "%s", [value]
|
||||
|
||||
|
||||
def sql_response_to_df(cursor: MySQLCursorAbstract) -> pd.DataFrame:
|
||||
"""
|
||||
Convert the results of a cursor's last executed query to a pandas DataFrame.
|
||||
|
||||
Args:
|
||||
cursor: MySQLCursorAbstract with a completed query.
|
||||
|
||||
Returns:
|
||||
DataFrame with data from the cursor.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
If a compatible SELECT query wasn't the last statement ran
|
||||
"""
|
||||
|
||||
def _json_processor(elem: Optional[str]) -> Optional[dict]:
|
||||
return json.loads(elem) if elem is not None else None
|
||||
|
||||
def _default_processor(elem: Any) -> Any:
|
||||
return elem
|
||||
|
||||
idx_to_processor = {}
|
||||
for idx, col in enumerate(cursor.description):
|
||||
if col[1] == 245:
|
||||
# 245 is the MySQL type code for JSON
|
||||
idx_to_processor[idx] = _json_processor
|
||||
else:
|
||||
idx_to_processor[idx] = _default_processor
|
||||
|
||||
rows = cursor.fetchall()
|
||||
|
||||
# Process results
|
||||
processed_rows = []
|
||||
for row in rows:
|
||||
processed_row = list(row)
|
||||
|
||||
for idx, elem in enumerate(row):
|
||||
processed_row[idx] = idx_to_processor[idx](elem)
|
||||
|
||||
processed_rows.append(processed_row)
|
||||
|
||||
return pd.DataFrame(processed_rows, columns=cursor.column_names)
|
||||
|
||||
|
||||
def sql_table_to_df(
|
||||
cursor: MySQLCursorAbstract, schema_name: str, table_name: str
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Load the entire contents of a SQL table into a pandas DataFrame.
|
||||
|
||||
Args:
|
||||
cursor: MySQLCursorAbstract to execute the query.
|
||||
schema_name: Name of the schema containing the table.
|
||||
table_name: Name of the table to fetch.
|
||||
|
||||
Returns:
|
||||
DataFrame containing all rows from the specified table.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If the table does not exist
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
ValueError: If the schema or table name is not valid
|
||||
"""
|
||||
validate_name(schema_name)
|
||||
validate_name(table_name)
|
||||
|
||||
execute_sql(cursor, f"SELECT * FROM {schema_name}.{table_name}")
|
||||
return sql_response_to_df(cursor)
|
||||
|
||||
|
||||
# Table operations
|
||||
|
||||
|
||||
def table_exists(
|
||||
cursor: MySQLCursorAbstract, schema_name: str, table_name: str
|
||||
) -> bool:
|
||||
"""
|
||||
Check whether a table exists in a specific schema.
|
||||
|
||||
Args:
|
||||
cursor: MySQLCursorAbstract object to execute the query.
|
||||
schema_name: Name of the database schema.
|
||||
table_name: Name of the table.
|
||||
|
||||
Returns:
|
||||
True if the table exists, False otherwise.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
ValueError: If the schema or table name is not valid
|
||||
"""
|
||||
validate_name(schema_name)
|
||||
validate_name(table_name)
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT 1
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = %s AND table_name = %s
|
||||
LIMIT 1
|
||||
""",
|
||||
(schema_name, table_name),
|
||||
)
|
||||
return cursor.fetchone() is not None
|
||||
|
||||
|
||||
def delete_sql_table(
|
||||
cursor: MySQLCursorAbstract, schema_name: str, table_name: str
|
||||
) -> None:
|
||||
"""
|
||||
Drop a table from the SQL database if it exists.
|
||||
|
||||
Args:
|
||||
cursor: MySQLCursorAbstract to execute the drop command.
|
||||
schema_name: Name of the schema.
|
||||
table_name: Name of the table to delete.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
ValueError: If the schema or table name is not valid
|
||||
"""
|
||||
validate_name(schema_name)
|
||||
validate_name(table_name)
|
||||
|
||||
execute_sql(cursor, f"DROP TABLE IF EXISTS {schema_name}.{table_name}")
|
||||
|
||||
|
||||
def extend_sql_table(
|
||||
cursor: MySQLCursorAbstract,
|
||||
schema_name: str,
|
||||
table_name: str,
|
||||
df: pd.DataFrame,
|
||||
col_name_to_placeholder_string: Dict[str, str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Insert all rows from a pandas DataFrame into an existing SQL table.
|
||||
|
||||
Args:
|
||||
cursor: MySQLCursorAbstract for execution.
|
||||
schema_name: Name of the database schema.
|
||||
table_name: Table to insert new rows into.
|
||||
df: DataFrame containing the rows to insert.
|
||||
col_name_to_placeholder_string:
|
||||
Optional mapping of column names to custom SQL value/placeholder
|
||||
strings.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If the rows could not be inserted into the table, e.g., a type or shape issue
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
ValueError: If the schema or table name is not valid
|
||||
"""
|
||||
if col_name_to_placeholder_string is None:
|
||||
col_name_to_placeholder_string = {}
|
||||
|
||||
validate_name(schema_name)
|
||||
validate_name(table_name)
|
||||
for col in df.columns:
|
||||
validate_name(str(col))
|
||||
|
||||
qualified_table_name = f"{schema_name}.{table_name}"
|
||||
|
||||
# Iterate over all rows in the DataFrame to build insert statements row by row
|
||||
for row in df.values:
|
||||
placeholders, params = [], []
|
||||
for elem, col in zip(row, df.columns):
|
||||
elem = elem.item() if hasattr(elem, "item") else elem
|
||||
|
||||
if col in col_name_to_placeholder_string:
|
||||
elem_placeholder, elem_params = col_name_to_placeholder_string[col], [
|
||||
str(elem)
|
||||
]
|
||||
else:
|
||||
elem_placeholder, elem_params = format_value_sql(elem)
|
||||
|
||||
placeholders.append(elem_placeholder)
|
||||
params.extend(elem_params)
|
||||
|
||||
cols_sql = ", ".join([str(col) for col in df.columns])
|
||||
placeholders_sql = ", ".join(placeholders)
|
||||
insert_sql = (
|
||||
f"INSERT INTO {qualified_table_name} "
|
||||
f"({cols_sql}) VALUES ({placeholders_sql})"
|
||||
)
|
||||
execute_sql(cursor, insert_sql, params=params)
|
||||
|
||||
|
||||
def sql_table_from_df(
|
||||
cursor: MySQLCursorAbstract, schema_name: str, df: pd.DataFrame
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Create a new SQL table with a random name, and populate it with data from a DataFrame.
|
||||
|
||||
If an 'id' column is defined in the dataframe, it will be used as the primary key.
|
||||
|
||||
Args:
|
||||
cursor: MySQLCursorAbstract for executing SQL.
|
||||
schema_name: Schema in which to create the table.
|
||||
df: DataFrame containing the data to be inserted.
|
||||
|
||||
Returns:
|
||||
Tuple (qualified_table_name, table_name): The schema-qualified and
|
||||
unqualified table names.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If a random available table name could not be found.
|
||||
ValueError: If any schema, table, or a column name is invalid.
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
table_name = get_random_name(
|
||||
lambda table_name: not table_exists(cursor, schema_name, table_name)
|
||||
)
|
||||
qualified_table_name = f"{schema_name}.{table_name}"
|
||||
|
||||
validate_name(schema_name)
|
||||
validate_name(table_name)
|
||||
for col in df.columns:
|
||||
validate_name(str(col))
|
||||
|
||||
columns_sql = []
|
||||
for col, dtype in df.dtypes.items():
|
||||
# Map pandas dtype to SQL type, fallback is VARCHAR
|
||||
sql_type = PD_TO_SQL_DTYPE_MAPPING.get(str(dtype), "LONGTEXT")
|
||||
validate_name(str(col))
|
||||
columns_sql.append(f"{col} {sql_type}")
|
||||
|
||||
columns_str = ", ".join(columns_sql)
|
||||
|
||||
has_id_col = any(col.lower() == "id" for col in df.columns)
|
||||
if has_id_col:
|
||||
columns_str += ", PRIMARY KEY (id)"
|
||||
|
||||
# Create table with generated columns
|
||||
create_table_sql = f"CREATE TABLE {qualified_table_name} ({columns_str})"
|
||||
execute_sql(cursor, create_table_sql)
|
||||
|
||||
try:
|
||||
# Insert provided data into new table
|
||||
extend_sql_table(cursor, schema_name, table_name, df)
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
# Delete table before we lose access to it
|
||||
delete_sql_table(cursor, schema_name, table_name)
|
||||
raise
|
||||
return qualified_table_name, table_name
|
||||
|
||||
|
||||
def validate_name(name: str) -> str:
|
||||
"""
|
||||
Validate that the string is a legal SQL identifier (letters, digits, underscores).
|
||||
|
||||
Args:
|
||||
name: Name (schema, table, or column) to validate.
|
||||
|
||||
Returns:
|
||||
The validated name.
|
||||
|
||||
Raises:
|
||||
ValueError: If the name does not meet format requirements.
|
||||
"""
|
||||
# Accepts only letters, digits, and underscores; change as needed
|
||||
if not (isinstance(name, str) and re.match(r"^[A-Za-z0-9_]+$", name)):
|
||||
raise ValueError(f"Unsupported name format {name}")
|
||||
|
||||
return name
|
||||
|
||||
|
||||
def source_schema(db_connection: MySQLConnectionAbstract) -> str:
|
||||
"""
|
||||
Retrieve the name of the currently selected schema, or set and ensure the default schema.
|
||||
|
||||
Args:
|
||||
db_connection: MySQL connector database connection object.
|
||||
|
||||
Returns:
|
||||
Name of the schema (database in use).
|
||||
|
||||
Raises:
|
||||
ValueError: If the schema name is not valid
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
schema = db_connection.database
|
||||
if schema is None:
|
||||
schema = DEFAULT_SCHEMA
|
||||
|
||||
with atomic_transaction(db_connection) as cursor:
|
||||
create_database_stmt = f"CREATE DATABASE IF NOT EXISTS {schema}"
|
||||
execute_sql(cursor, create_database_stmt)
|
||||
|
||||
validate_name(schema)
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def is_table_empty(
|
||||
cursor: MySQLCursorAbstract, schema_name: str, table_name: str
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if a given SQL table is empty.
|
||||
|
||||
Args:
|
||||
cursor: MySQLCursorAbstract with access to the database.
|
||||
schema_name: Name of the schema containing the table.
|
||||
table_name: Name of the table to check.
|
||||
|
||||
Returns:
|
||||
True if the table has no rows, False otherwise.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If the table does not exist
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
ValueError: If the schema or table name is not valid
|
||||
"""
|
||||
validate_name(schema_name)
|
||||
validate_name(table_name)
|
||||
|
||||
cursor.execute(f"SELECT 1 FROM {schema_name}.{table_name} LIMIT 1")
|
||||
return cursor.fetchone() is None
|
||||
|
||||
|
||||
def convert_to_df(
|
||||
arr: Optional[Union[pd.DataFrame, pd.Series, np.ndarray]],
|
||||
col_prefix: str = "feature",
|
||||
) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
Convert input data to a pandas DataFrame if necessary.
|
||||
|
||||
Args:
|
||||
arr: Input data as a pandas DataFrame, NumPy ndarray, pandas Series, or None.
|
||||
|
||||
Returns:
|
||||
If the input is None, returns None.
|
||||
Otherwise, returns a DataFrame backed by the same underlying data whenever
|
||||
possible (except in cases where pandas or NumPy must copy, such as for
|
||||
certain views or non-contiguous arrays).
|
||||
|
||||
Notes:
|
||||
- If an ndarray is passed, column names will be integer indices (0, 1, ...).
|
||||
- If a DataFrame is passed, column names and indices are preserved.
|
||||
- The returned DataFrame is a shallow copy and shares data with the original
|
||||
input when possible; however, copies may still occur for certain input
|
||||
types or memory layouts.
|
||||
"""
|
||||
if arr is None:
|
||||
return None
|
||||
|
||||
if isinstance(arr, pd.DataFrame):
|
||||
return pd.DataFrame(arr)
|
||||
if isinstance(arr, pd.Series):
|
||||
return arr.to_frame()
|
||||
|
||||
if arr.ndim == 1:
|
||||
arr = arr.reshape(-1, 1)
|
||||
col_names = [f"{col_prefix}_{idx}" for idx in range(arr.shape[1])]
|
||||
|
||||
return pd.DataFrame(arr, columns=col_names, copy=False)
|
||||
Reference in New Issue
Block a user