Añadiendo todos los archivos del proyecto (incluidos secretos y venv)

This commit is contained in:
2026-03-06 18:31:45 -06:00
parent 3a15a3eafa
commit e4d50b6eb5
4965 changed files with 991048 additions and 0 deletions

View File

@@ -0,0 +1,27 @@
# Copyright (c) 2025 Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA

View File

@@ -0,0 +1,43 @@
# Copyright (c) 2025 Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""GenAI package for MySQL Connector/Python.
Performs optional dependency checks and exposes public classes:
- MyEmbeddings
- MyLLM
- MyVectorStore
"""
from mysql.ai.utils import check_dependencies as _check_dependencies
_check_dependencies(["GENAI"])
del _check_dependencies
from .embedding import MyEmbeddings
from .generation import MyLLM
from .vector_store import MyVectorStore

View File

@@ -0,0 +1,197 @@
# Copyright (c) 2025 Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Embeddings integration utilities for MySQL Connector/Python.
Provides MyEmbeddings class to generate embeddings via MySQL HeatWave
using ML_EMBED_TABLE and ML_EMBED_ROW.
"""
from typing import Dict, List, Optional
import pandas as pd
from langchain_core.embeddings import Embeddings
from pydantic import PrivateAttr
from mysql.ai.utils import (
atomic_transaction,
execute_sql,
format_value_sql,
source_schema,
sql_table_from_df,
sql_table_to_df,
temporary_sql_tables,
)
from mysql.connector.abstracts import MySQLConnectionAbstract
class MyEmbeddings(Embeddings):
"""
Embedding generator class that uses a MySQL database to compute embeddings for input text.
This class batches input text into temporary SQL tables, invokes MySQL's ML_EMBED_TABLE
to generate embeddings, and retrieves the results as lists of floats.
Attributes:
_db_connection (MySQLConnectionAbstract): MySQL connection used for all database operations.
schema_name (str): Name of the database schema to use.
options_placeholder (str): SQL-ready placeholder string for ML_EMBED_TABLE options.
options_params (dict): Dictionary of concrete option values to be passed as SQL parameters.
"""
_db_connection: MySQLConnectionAbstract = PrivateAttr()
def __init__(
self, db_connection: MySQLConnectionAbstract, options: Optional[Dict] = None
):
"""
Initialize MyEmbeddings with a database connection and optional embedding parameters.
References:
https://dev.mysql.com/doc/heatwave/en/mys-hwgenai-ml-embed-row.html
A full list of supported options can be found under "options"
NOTE: The supported "options" are the intersection of the options provided in
https://dev.mysql.com/doc/heatwave/en/mys-hwgenai-ml-embed-row.html
https://dev.mysql.com/doc/heatwave/en/mys-hwgenai-ml-embed-table.html
Args:
db_connection: Active MySQL connector database connection.
options: Optional dictionary of options for embedding operations.
Raises:
ValueError: If the schema name is not valid
DatabaseError:
If a database connection issue occurs.
If an operational error occurs during execution.
"""
super().__init__()
self._db_connection = db_connection
self.schema_name = source_schema(db_connection)
options = options or {}
self.options_placeholder, self.options_params = format_value_sql(options)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""
Generate embeddings for a list of input texts using the MySQL ML embedding procedure.
References:
https://dev.mysql.com/doc/heatwave/en/mys-hwgenai-ml-embed-table.html
Args:
texts: List of input strings to embed.
Returns:
List of lists of floats, with each inner list containing the embedding for a text.
Raises:
DatabaseError:
If provided options are invalid or unsupported.
If a database connection issue occurs.
If an operational error occurs during execution.
ValueError:
If one or more text entries were unable to be embedded.
Implementation notes:
- Creates a temporary table to pass input text to the MySQL embedding service.
- Adds a primary key to ensure results preserve input order.
- Calls ML_EMBED_TABLE and fetches the resulting embeddings.
- Deletes the temporary table after use to avoid polluting the database.
- Embedding vectors are extracted from the "embeddings" column of the result table.
"""
if not texts:
return []
df = pd.DataFrame({"id": range(len(texts)), "text": texts})
with (
atomic_transaction(self._db_connection) as cursor,
temporary_sql_tables(self._db_connection) as temporary_tables,
):
qualified_table_name, table_name = sql_table_from_df(
cursor, self.schema_name, df
)
temporary_tables.append((self.schema_name, table_name))
# ML_EMBED_TABLE expects input/output columns and options as parameters
embed_query = (
"CALL sys.ML_EMBED_TABLE("
f"'{qualified_table_name}.text', "
f"'{qualified_table_name}.embeddings', "
f"{self.options_placeholder}"
")"
)
execute_sql(cursor, embed_query, params=self.options_params)
# Read back all columns, including "embeddings"
df_embeddings = sql_table_to_df(cursor, self.schema_name, table_name)
if df_embeddings["embeddings"].isnull().any() or any(
e is None for e in df_embeddings["embeddings"]
):
raise ValueError(
"Failure to generate embeddings for one or more text entry."
)
# Convert fetched embeddings to lists of floats
embeddings = df_embeddings["embeddings"].tolist()
embeddings = [list(e) for e in embeddings]
return embeddings
def embed_query(self, text: str) -> List[float]:
"""
Generate an embedding for a single text string.
References:
https://dev.mysql.com/doc/heatwave/en/mys-hwgenai-ml-embed-row.html
Args:
text: The input string to embed.
Returns:
List of floats representing the embedding vector.
Raises:
DatabaseError:
If provided options are invalid or unsupported.
If a database connection issue occurs.
If an operational error occurs during execution.
Example:
>>> MyEmbeddings(db_conn).embed_query("Hello world")
[0.1, 0.2, ...]
"""
with atomic_transaction(self._db_connection) as cursor:
execute_sql(
cursor,
f'SELECT sys.ML_EMBED_ROW("%s", {self.options_placeholder})',
params=(text, *self.options_params),
)
return list(cursor.fetchone()[0])

View File

@@ -0,0 +1,162 @@
# Copyright (c) 2025 Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""GenAI LLM integration utilities for MySQL Connector/Python.
Provides MyLLM wrapper that issues ML_GENERATE calls via SQL.
"""
import json
from typing import Any, List, Optional
try:
from langchain_core.language_models.llms import LLM
except ImportError:
from langchain.llms.base import LLM
from pydantic import PrivateAttr
from mysql.ai.utils import atomic_transaction, execute_sql, format_value_sql
from mysql.connector.abstracts import MySQLConnectionAbstract
class MyLLM(LLM):
"""
Custom Large Language Model (LLM) interface for MySQL HeatWave.
This class wraps the generation functionality provided by HeatWave LLMs,
exposing an interface compatible with common LLM APIs for text generation.
It provides full support for generative queries and limited support for
agentic queries.
Attributes:
_db_connection (MySQLConnectionAbstract):
Underlying MySQL connector database connection.
"""
_db_connection: MySQLConnectionAbstract = PrivateAttr()
class Config:
"""
Pydantic config for the model.
By default, LangChain (through Pydantic BaseModel) does not allow
setting or storing undeclared attributes such as _db_connection.
Setting extra = "allow" makes it possible to store extra attributes
on the class instance, which is required for MyLLM.
"""
extra = "allow"
def __init__(self, db_connection: MySQLConnectionAbstract):
"""
Initialize the MyLLM instance with an active MySQL database connection.
Args:
db_connection: A MySQL connection object used to run LLM queries.
Notes:
The db_connection is stored as a private attribute via object.__setattr__,
which is compatible with Pydantic models.
"""
super().__init__()
self._db_connection = db_connection
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> str:
"""
Generate a text completion from the LLM for a given input prompt.
References:
https://dev.mysql.com/doc/heatwave/en/mys-hwgenai-ml-generate.html
A full list of supported options (specified by kwargs) can be found under "options"
Args:
prompt: The input prompt string for the language model.
stop: Optional list of stop strings to support agentic and chain-of-thought
reasoning workflows.
**kwargs: Additional keyword arguments providing generation options to
the LLM (these are serialized to JSON and passed to the HeatWave syscall).
Returns:
The generated model output as a string.
(The actual completion does NOT include the input prompt.)
Raises:
DatabaseError:
If provided options are invalid or unsupported.
If a database connection issue occurs.
If an operational error occurs during execution.
Implementation Notes:
- Serializes kwargs into a SQL-compatible JSON string.
- Calls the LLM stored procedure using a database cursor context.
- Uses `sys.ML_GENERATE` on the server to produce the model output.
- Expects the server response to be a JSON object with a 'text' key.
"""
options = kwargs.copy()
if stop is not None:
options["stop_sequences"] = stop
options_placeholder, options_params = format_value_sql(options)
with atomic_transaction(self._db_connection) as cursor:
# The prompt is passed as a parameterized argument (avoids SQL injection).
generate_query = f"""SELECT sys.ML_GENERATE("%s", {options_placeholder});"""
execute_sql(cursor, generate_query, params=(prompt, *options_params))
# Expect a JSON-encoded result from MySQL; parse to extract the output.
llm_response = json.loads(cursor.fetchone()[0])["text"]
return llm_response
@property
def _identifying_params(self) -> dict:
"""
Return a dictionary of params that uniquely identify this LLM instance.
Returns:
dict: Dictionary of identifier parameters (should include
model_name for tracing/caching).
"""
return {
"model_name": "mysql_heatwave_llm",
}
@property
def _llm_type(self) -> str:
"""
Get the type name of this LLM implementation.
Returns:
A string identifying the LLM provider (used for logging or metrics).
"""
return "mysql_heatwave_llm"

View File

@@ -0,0 +1,520 @@
# Copyright (c) 2025 Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""MySQL-backed vector store for embeddings and semantic document retrieval.
Provides a VectorStore implementation persisting documents, metadata, and
embeddings in MySQL, plus similarity search utilities.
"""
import json
from typing import Any, Iterable, List, Optional, Sequence, Union
import pandas as pd
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore
from pydantic import PrivateAttr
from mysql.ai.genai.embedding import MyEmbeddings
from mysql.ai.utils import (
VAR_NAME_SPACE,
atomic_transaction,
delete_sql_table,
execute_sql,
extend_sql_table,
format_value_sql,
get_random_name,
is_table_empty,
source_schema,
table_exists,
)
from mysql.connector.abstracts import MySQLConnectionAbstract
BASIC_EMBEDDING_QUERY = "Hello world!"
EMBEDDING_SOURCE = "external_source"
VAR_EMBEDDING = f"{VAR_NAME_SPACE}.embedding"
VAR_CONTEXT = f"{VAR_NAME_SPACE}.context"
VAR_CONTEXT_MAP = f"{VAR_NAME_SPACE}.context_map"
VAR_RETRIEVAL_INFO = f"{VAR_NAME_SPACE}.retrieval_info"
VAR_OPTIONS = f"{VAR_NAME_SPACE}.options"
ID_SPACE = "internal_ai_id_"
class MyVectorStore(VectorStore):
"""
MySQL-backed vector store for handling embeddings and semantic document retrieval.
Supports adding, deleting, and searching high-dimensional vector representations
of documents using efficient storage and HeatWave ML similarity search procedures.
Supports use as a context manager: when used in a `with` statement, all backing
tables/data are deleted automatically when the block exits (even on exception).
Attributes:
db_connection (MySQLConnectionAbstract): Active MySQL database connection.
embedder (Embeddings): Embeddings generator for computing vector representations.
schema_name (str): SQL schema for table storage.
table_name (Optional[str]): Name of the active table backing the store
(or None until created).
embedding_dimension (int): Size of embedding vectors stored.
next_id (int): Internal counter for unique document ID generation.
"""
_db_connection: MySQLConnectionAbstract = PrivateAttr()
_embedder: Embeddings = PrivateAttr()
_schema_name: str = PrivateAttr()
_table_name: Optional[str] = PrivateAttr()
_embedding_dimension: int = PrivateAttr()
_next_id: int = PrivateAttr()
def __init__(
self,
db_connection: MySQLConnectionAbstract,
embedder: Optional[Embeddings] = None,
) -> None:
"""
Initialize a MyVectorStore with a database connection and embedding generator.
Args:
db_connection: MySQL database connection for all vector operations.
embedder: Embeddings generator used for creating and querying embeddings.
Raises:
ValueError: If the schema name is not valid
DatabaseError:
If a database connection issue occurs.
If an operational error occurs during execution.
"""
super().__init__()
self._next_id = 0
self._schema_name = source_schema(db_connection)
self._embedder = embedder or MyEmbeddings(db_connection)
self._db_connection = db_connection
self._table_name: Optional[str] = None
# Embedding dimension determined using an example call.
# Assumes embeddings have fixed length.
self._embedding_dimension = len(
self._embedder.embed_query(BASIC_EMBEDDING_QUERY)
)
def _get_ids(self, num_ids: int) -> list[str]:
"""
Generate a batch of unique internal document IDs for vector storage.
Args:
num_ids: Number of IDs to create.
Returns:
List of sequentially numbered internal string IDs.
"""
ids = [
f"internal_ai_id_{i}" for i in range(self._next_id, self._next_id + num_ids)
]
self._next_id += num_ids
return ids
def _make_vector_store(self) -> None:
"""
Create a backing SQL table for storing vectors if not already created.
Returns:
None
Raises:
DatabaseError:
If a database connection issue occurs.
If an operational error occurs during execution.
Notes:
The table name is randomized to avoid collisions.
Schema includes content, metadata, and embedding vector.
"""
if self._table_name is None:
with atomic_transaction(self._db_connection) as cursor:
table_name = get_random_name(
lambda table_name: not table_exists(
cursor, self._schema_name, table_name
)
)
create_table_stmt = f"""
CREATE TABLE {self._schema_name}.{table_name} (
`id` VARCHAR(128) NOT NULL,
`content` TEXT,
`metadata` JSON DEFAULT NULL,
`embed` vector(%s),
PRIMARY KEY (`id`)
) ENGINE=InnoDB;
"""
execute_sql(
cursor, create_table_stmt, params=(self._embedding_dimension,)
)
self._table_name = table_name
def delete(self, ids: Optional[Sequence[str]] = None, **_: Any) -> None:
"""
Delete documents by ID. Optionally deletes the vector table if empty after deletions.
Args:
ids: Optional sequence of document IDs to delete. If None, no action is taken.
Returns:
None
Raises:
DatabaseError:
If a database connection issue occurs.
If an operational error occurs during execution.
Notes:
If the backing table is empty after deletions, the table is dropped and
table_name is set to None.
"""
with atomic_transaction(self._db_connection) as cursor:
if ids:
for _id in ids:
execute_sql(
cursor,
f"DELETE FROM {self._schema_name}.{self._table_name} WHERE id = %s",
params=(_id,),
)
if is_table_empty(cursor, self._schema_name, self._table_name):
self.delete_all()
def delete_all(self) -> None:
"""
Delete and drop the entire vector store table.
Returns:
None
"""
if self._table_name is not None:
with atomic_transaction(self._db_connection) as cursor:
delete_sql_table(cursor, self._schema_name, self._table_name)
self._table_name = None
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[list[dict]] = None,
ids: Optional[List[str]] = None,
**_: dict,
) -> List[str]:
"""
Add a batch of text strings and corresponding metadata to the vector store.
Args:
texts: List of strings to embed and store.
metadatas: Optional list of metadata dicts (one per text).
ids: Optional custom document IDs.
Returns:
List of document IDs corresponding to the added texts.
Raises:
DatabaseError:
If a database connection issue occurs.
If an operational error occurs during execution.
Notes:
If metadatas is None, an empty dict is assigned to each document.
"""
texts = list(texts)
documents = [
Document(page_content=text, metadata=meta)
for text, meta in zip(texts, metadatas or [{}] * len(texts))
]
return self.add_documents(documents, ids=ids)
@classmethod
def from_texts(
cls,
texts: Iterable[str],
embedder: Embeddings,
metadatas: Optional[list[dict]] = None,
db_connection: MySQLConnectionAbstract = None,
) -> VectorStore:
"""
Construct and populate a MyVectorStore instance from raw texts and metadata.
Args:
texts: List of strings to vectorize and store.
embedder: Embeddings generator to use.
metadatas: Optional list of metadata dicts per text.
db_connection: Active MySQL connection.
Returns:
Instance of MyVectorStore containing the added texts.
Raises:
ValueError: If db_connection is not provided.
DatabaseError:
If a database connection issue occurs.
If an operational error occurs during execution.
"""
if db_connection is None:
raise ValueError(
"db_connection must be specified to create a MyVectorStore object"
)
texts = list(texts)
instance = cls(db_connection=db_connection, embedder=embedder)
instance.add_texts(texts, metadatas=metadatas)
return instance
def add_documents(
self, documents: list[Document], ids: list[str] = None
) -> list[str]:
"""
Embed and store Document objects as high-dimensional vectors with metadata.
Args:
documents: List of Document objects (each with 'page_content' and 'metadata').
ids: Optional list of explicit document IDs. Must match the length of documents.
Returns:
List of document IDs stored.
Raises:
ValueError: If provided IDs do not match the number of documents.
DatabaseError:
If a database connection issue occurs.
If an operational error occurs during execution.
Notes:
Automatically creates the backing table if it does not exist.
"""
if ids and len(ids) != len(documents):
msg = (
"ids must be the same length as documents. "
f"Got {len(ids)} ids and {len(documents)} documents."
)
raise ValueError(msg)
if len(documents) > 0:
self._make_vector_store()
else:
return []
if ids is None:
ids = self._get_ids(len(documents))
content = [doc.page_content for doc in documents]
vectors = self._embedder.embed_documents(content)
df = pd.DataFrame()
df["id"] = ids
df["content"] = content
df["embed"] = vectors
df["metadata"] = [doc.metadata for doc in documents]
with atomic_transaction(self._db_connection) as cursor:
extend_sql_table(
cursor,
self._schema_name,
self._table_name,
df,
col_name_to_placeholder_string={"embed": "string_to_vector(%s)"},
)
return ids
def similarity_search(
self,
query: str,
k: int = 3,
**kwargs: Any,
) -> list[Document]:
"""
Search for and return the most similar documents in the store to the given query.
Args:
query: String query to embed and use for similarity search.
k: Number of top documents to return.
kwargs: options to pass to ML_SIMILARITY_SEARCH. Currently supports
distance_metric, max_distance, percentage_distance, and segment_overlap
Returns:
List of Document objects, ordered from most to least similar.
Raises:
DatabaseError:
If provided kwargs are invalid or unsupported.
If a database connection issue occurs.
If an operational error occurs during execution.
Implementation Notes:
- Calls ML similarity search within MySQL using stored procedures.
- Retrieves IDs, content, and metadata for search matches.
- Parsing and retrieval for context results are handled via intermediate JSONs.
"""
if self._table_name is None:
return []
embedding = self._embedder.embed_query(query)
with atomic_transaction(self._db_connection) as cursor:
# Set the embedding variable for the similarity search SP
execute_sql(
cursor,
f"SET @{VAR_EMBEDDING} = string_to_vector(%s)",
params=[str(embedding)],
)
distance_metric = kwargs.get("distance_metric", "COSINE")
retrieval_options = {
"max_distance": kwargs.get("max_distance", 0.6),
"percentage_distance": kwargs.get("percentage_distance", 20.0),
"segment_overlap": kwargs.get("segment_overlap", 0),
}
retrieval_options_placeholder, retrieval_options_params = format_value_sql(
retrieval_options
)
similarity_search_query = f"""
CALL sys.ML_SIMILARITY_SEARCH(
@{VAR_EMBEDDING},
JSON_ARRAY(
'{self._schema_name}.{self._table_name}'
),
JSON_OBJECT(
"segment", "content",
"segment_embedding", "embed",
"document_name", "id"
),
{k},
%s,
NULL,
NULL,
{retrieval_options_placeholder},
@{VAR_CONTEXT},
@{VAR_CONTEXT_MAP},
@{VAR_RETRIEVAL_INFO}
)
"""
execute_sql(
cursor,
similarity_search_query,
params=[distance_metric, *retrieval_options_params],
)
execute_sql(cursor, f"SELECT @{VAR_CONTEXT_MAP}")
results = []
context_maps = json.loads(cursor.fetchone()[0])
for context in context_maps:
execute_sql(
cursor,
(
"SELECT id, content, metadata "
f"FROM {self._schema_name}.{self._table_name} "
"WHERE id = %s"
),
params=(context["document_name"],),
)
doc_id, content, metadata = cursor.fetchone()
doc_args = {
"id": doc_id,
"page_content": content,
}
if metadata is not None:
doc_args["metadata"] = json.loads(metadata)
doc = Document(**doc_args)
results.append(doc)
return results
def __enter__(self) -> "VectorStore":
"""
Enter the runtime context related to this vector store instance.
Returns:
The current MyVectorStore object, allowing use within a `with` statement block.
Usage Notes:
- Intended for use in a `with` statement to ensure automatic
cleanup of resources.
- No special initialization occurs during context entry, but enables
proper context-managed lifecycle.
Example:
with MyVectorStore(db_connection, embedder) as vectorstore:
vectorstore.add_texts([...])
# Vector store is active within this block.
# All storage and resources are now cleaned up.
"""
return self
def __exit__(
self,
exc_type: Union[type, None],
exc_val: Union[BaseException, None],
exc_tb: Union[object, None],
) -> None:
"""
Exit the runtime context for the vector store, ensuring all storage
resources are cleaned up.
Args:
exc_type: The exception type, if any exception occurred in the context block.
exc_val: The exception value, if any exception occurred in the context block.
exc_tb: The traceback object, if any exception occurred in the context block.
Returns:
None: Indicates that exceptions are never suppressed; they will propagate as normal.
Implementation Notes:
- Automatically deletes all vector store data and backing tables via `delete_all()`
upon exiting the context.
- This cleanup occurs whether the block exits normally or due to an exception.
- Does not suppress exceptions; errors in the context block will continue to propagate.
- Use when the vector store lifecycle is intended to be temporary or scoped.
Example:
with MyVectorStore(db_connection, embedder) as vectorstore:
vectorstore.add_texts([...])
# Vector store is active within this block.
# All storage and resources are now cleaned up.
"""
self.delete_all()
# No return, so exceptions are never suppressed

View File

@@ -0,0 +1,48 @@
# Copyright (c) 2025 Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""ML package for MySQL Connector/Python.
Performs optional dependency checks and exposes ML utilities:
- ML_TASK, MyModel
- MyClassifier, MyRegressor, MyGenericTransformer
- MyAnomalyDetector
"""
from mysql.ai.utils import check_dependencies as _check_dependencies
_check_dependencies(["ML"])
del _check_dependencies
# Sklearn models
from .classifier import MyClassifier
# Minimal interface
from .model import ML_TASK, MyModel
from .outlier import MyAnomalyDetector
from .regressor import MyRegressor
from .transformer import MyGenericTransformer

View File

@@ -0,0 +1,142 @@
# Copyright (c) 2025 Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Base classes for MySQL HeatWave ML estimators for Connector/Python.
Implements a scikit-learn-compatible base estimator wrapping server-side ML.
"""
from typing import Optional, Union
import pandas as pd
from sklearn.base import BaseEstimator
from mysql.connector.abstracts import MySQLConnectionAbstract
from mysql.ai.ml.model import ML_TASK, MyModel
from mysql.ai.utils import copy_dict
class MyBaseMLModel(BaseEstimator):
"""
Base class for MySQL HeatWave machine learning estimators.
Implements the scikit-learn API and core model management logic,
including fit, explain, serialization, and dynamic option handling.
For use as a base class by classifiers, regressors, transformers, and outlier models.
Args:
db_connection (MySQLConnectionAbstract): An active MySQL connector database connection.
task (str): ML task type, e.g. "classification" or "regression".
model_name (str, optional): Custom name for the deployed model.
fit_extra_options (dict, optional): Extra options for fitting.
Attributes:
_model: Underlying database helper for fit/predict/explain.
fit_extra_options: User-provided options for fitting.
"""
def __init__(
self,
db_connection: MySQLConnectionAbstract,
task: Union[str, ML_TASK],
model_name: Optional[str] = None,
fit_extra_options: Optional[dict] = None,
):
"""
Initialize a MyBaseMLModel with connection, task, and option parameters.
Args:
db_connection: Active MySQL connector database connection.
task: String label of ML task (e.g. "classification").
model_name: Optional custom model name.
fit_extra_options: Optional extra fit options.
Raises:
DatabaseError:
If a database connection issue occurs.
If an operational error occurs during execution.
"""
self._model = MyModel(db_connection, task=task, model_name=model_name)
self.fit_extra_options = copy_dict(fit_extra_options)
def fit(
self,
X: pd.DataFrame, # pylint: disable=invalid-name
y: Optional[pd.DataFrame] = None,
) -> "MyBaseMLModel":
"""
Fit the underlying ML model using pandas DataFrames.
Delegates to MyMLModelPandasHelper.fit.
Args:
X: Features DataFrame.
y: (Optional) Target labels DataFrame or Series.
Returns:
self
Raises:
DatabaseError:
If provided options are invalid or unsupported.
If a database connection issue occurs.
If an operational error occurs during execution.
Notes:
Additional temp SQL resources may be created and cleaned up during the operation.
"""
self._model.fit(X, y, self.fit_extra_options)
return self
def _delete_model(self) -> bool:
"""
Deletes the model from the model catalog if present
Raises:
DatabaseError:
If a database connection issue occurs.
If an operational error occurs during execution.
Returns:
Whether the model was deleted
"""
return self._model._delete_model()
def get_model_info(self) -> Optional[dict]:
"""
Checks if the model name is available. Model info will only be present in the
catalog if the model has previously been fitted.
Returns:
True if the model name is not part of the model catalog
Raises:
DatabaseError:
If a database connection issue occurs.
If an operational error occurs during execution.
"""
return self._model.get_model_info()

View File

@@ -0,0 +1,194 @@
# Copyright (c) 2025 Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Classifier utilities for MySQL Connector/Python.
Provides a scikit-learn compatible classifier backed by HeatWave ML.
"""
from typing import Optional, Union
import numpy as np
import pandas as pd
from sklearn.base import ClassifierMixin
from mysql.ai.ml.base import MyBaseMLModel
from mysql.ai.ml.model import ML_TASK
from mysql.ai.utils import copy_dict
from mysql.connector.abstracts import MySQLConnectionAbstract
class MyClassifier(MyBaseMLModel, ClassifierMixin):
"""
MySQL HeatWave scikit-learn compatible classifier estimator.
Provides prediction and probability output from a model deployed in MySQL,
and manages fit, explain, and prediction options as per HeatWave ML interface.
Attributes:
predict_extra_options (dict): Dictionary of optional parameters passed through
to the MySQL backend for prediction and probability inference.
_model (MyModel): Underlying interface for database model operations.
fit_extra_options (dict): See MyBaseMLModel.
Args:
db_connection (MySQLConnectionAbstract): Active MySQL connector DB connection.
model_name (str, optional): Custom name for the model.
fit_extra_options (dict, optional): Extra options for fitting.
explain_extra_options (dict, optional): Extra options for explanations.
predict_extra_options (dict, optional): Extra options for predict/predict_proba.
Methods:
predict(X): Predict class labels.
predict_proba(X): Predict class probabilities.
"""
def __init__(
self,
db_connection: MySQLConnectionAbstract,
model_name: Optional[str] = None,
fit_extra_options: Optional[dict] = None,
explain_extra_options: Optional[dict] = None,
predict_extra_options: Optional[dict] = None,
):
"""
Initialize a MyClassifier.
Args:
db_connection: Active MySQL connector database connection.
model_name: Optional, custom model name.
fit_extra_options: Optional fit options.
explain_extra_options: Optional explain options.
predict_extra_options: Optional predict/predict_proba options.
Raises:
DatabaseError:
If a database connection issue occurs.
If an operational error occurs during execution.
"""
MyBaseMLModel.__init__(
self,
db_connection,
ML_TASK.CLASSIFICATION,
model_name=model_name,
fit_extra_options=fit_extra_options,
)
self.predict_extra_options = copy_dict(predict_extra_options)
self.explain_extra_options = copy_dict(explain_extra_options)
def predict(
self, X: Union[pd.DataFrame, np.ndarray]
) -> np.ndarray: # pylint: disable=invalid-name
"""
Predict class labels for the input features using the MySQL model.
References:
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-predict-table.html
A full list of supported options can be found under "ML_PREDICT_TABLE Options"
Args:
X: Input samples as a numpy array or pandas DataFrame.
Returns:
ndarray: Array of predicted class labels, shape (n_samples,).
Raises:
DatabaseError:
If provided options are invalid or unsupported,
or if the model is not initialized, i.e., fit or import has not
been called
If a database connection issue occurs.
If an operational error occurs during execution.
"""
result = self._model.predict(X, options=self.predict_extra_options)
return result["Prediction"].to_numpy()
def predict_proba(
self, X: Union[pd.DataFrame, np.ndarray]
) -> np.ndarray: # pylint: disable=invalid-name
"""
Predict class probabilities for the input features using the MySQL model.
References:
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-predict-table.html
A full list of supported options can be found under "ML_PREDICT_TABLE Options"
Args:
X: Input samples as a numpy array or pandas DataFrame.
Returns:
ndarray: Array of shape (n_samples, n_classes) with class probabilities.
Raises:
DatabaseError:
If provided options are invalid or unsupported,
or if the model is not initialized, i.e., fit or import has not
been called
If a database connection issue occurs.
If an operational error occurs during execution.
"""
result = self._model.predict(X, options=self.predict_extra_options)
classes = sorted(result["ml_results"].iloc[0]["probabilities"].keys())
return np.stack(
result["ml_results"].map(
lambda ml_result: [
ml_result["probabilities"][class_name] for class_name in classes
]
)
)
def explain_predictions(
self, X: Union[pd.DataFrame, np.ndarray]
) -> pd.DataFrame: # pylint: disable=invalid-name
"""
Explain model predictions using provided data.
References:
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-explain-table.html
A full list of supported options can be found under "ML_EXPLAIN_TABLE Options"
Args:
X: DataFrame for which predictions should be explained.
Returns:
DataFrame containing explanation details (feature attributions, etc.)
Raises:
DatabaseError:
If provided options are invalid or unsupported,
or if the model is not initialized, i.e., fit or import has not
been called
If a database connection issue occurs.
If an operational error occurs during execution.
Notes:
Temporary input/output tables are cleaned up after explanation.
"""
self._model.explain_predictions(X, options=self.explain_extra_options)

View File

@@ -0,0 +1,780 @@
# Copyright (c) 2025 Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""HeatWave ML model utilities for MySQL Connector/Python.
Provides classes to manage training, prediction, scoring, and explanations
via MySQL HeatWave stored procedures.
"""
import copy
import json
from enum import Enum
from typing import Any, Dict, Optional, Union
import numpy as np
import pandas as pd
from mysql.ai.utils import (
VAR_NAME_SPACE,
atomic_transaction,
convert_to_df,
execute_sql,
format_value_sql,
get_random_name,
source_schema,
sql_response_to_df,
sql_table_from_df,
sql_table_to_df,
table_exists,
temporary_sql_tables,
validate_name,
)
from mysql.connector.abstracts import MySQLConnectionAbstract
class ML_TASK(Enum): # pylint: disable=invalid-name
"""Enumeration of supported ML tasks for HeatWave."""
CLASSIFICATION = "classification"
REGRESSION = "regression"
FORECASTING = "forecasting"
ANOMALY_DETECTION = "anomaly_detection"
LOG_ANOMALY_DETECTION = "log_anomaly_detection"
RECOMMENDATION = "recommendation"
TOPIC_MODELING = "topic_modeling"
@staticmethod
def get_task_string(task: Union[str, "ML_TASK"]) -> str:
"""
Return the string representation of a machine learning task.
Args:
task (Union[str, ML_TASK]): The task to convert.
Accepts either a task enum member (ML_TASK) or a string.
Returns:
str: The string value of the ML task.
"""
if isinstance(task, str):
return task
return task.value
class _MyModelCommon:
"""
Common utilities and workflow for MySQL HeatWave ML models.
This class handles model lifecycle steps such as loading, fitting, scoring,
making predictions, and explaining models or predictions. Not intended for
direct instantiation, but as a superclass for heatwave model wrappers.
Attributes:
db_connection: MySQL connector database connection.
task: ML task, e.g., "classification" or "regression".
model_name: Identifier of model in MySQL.
schema_name: Database schema used for operations and temp tables.
"""
def __init__(
self,
db_connection: MySQLConnectionAbstract,
task: Union[str, ML_TASK] = ML_TASK.CLASSIFICATION,
model_name: Optional[str] = None,
):
"""
Instantiate _MyMLModelCommon.
References:
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-train.html
A full list of supported tasks can be found under "Common ML_TRAIN Options"
Args:
db_connection: MySQL database connection.
task: ML task type (default: "classification").
model_name: Name to register the model within MySQL (default: None).
Raises:
ValueError: If the schema name is not valid
DatabaseError:
If a database connection issue occurs.
If an operational error occurs during execution.
Returns:
None
"""
self.db_connection = db_connection
self.task = ML_TASK.get_task_string(task)
self.schema_name = source_schema(db_connection)
with atomic_transaction(self.db_connection) as cursor:
execute_sql(cursor, "CALL sys.ML_CREATE_OR_UPGRADE_CATALOG();")
if model_name is None:
model_name = get_random_name(self._is_model_name_available)
self.model_var = f"{VAR_NAME_SPACE}.{model_name}"
self.model_var_score = f"{self.model_var}.score"
self.model_name = model_name
validate_name(model_name)
with atomic_transaction(self.db_connection) as cursor:
execute_sql(cursor, f"SET @{self.model_var} = %s;", params=(model_name,))
def _delete_model(self) -> bool:
"""
Deletes the model from the model catalog if present
Raises:
DatabaseError:
If a database connection issue occurs.
If an operational error occurs during execution.
Returns:
Whether the model was deleted
"""
current_user = self._get_user()
qualified_model_catalog = f"ML_SCHEMA_{current_user}.MODEL_CATALOG"
delete_model = (
f"DELETE FROM {qualified_model_catalog} "
f"WHERE model_handle = @{self.model_var}"
)
with atomic_transaction(self.db_connection) as cursor:
execute_sql(cursor, delete_model)
return cursor.rowcount > 0
def _get_model_info(self, model_name: str) -> Optional[dict]:
"""
Retrieves the model info from the model_catalog
Args:
model_var: The model alias to retrieve
Returns:
The model info from the model_catalog (None if the model is not present in the catalog)
Raises:
DatabaseError:
If a database connection issue occurs.
If an operational error occurs during execution.
"""
def process_col(elem: Any) -> Any:
if isinstance(elem, str):
try:
elem = json.loads(elem)
except json.JSONDecodeError:
pass
return elem
current_user = self._get_user()
qualified_model_catalog = f"ML_SCHEMA_{current_user}.MODEL_CATALOG"
model_exists = (
f"SELECT * FROM {qualified_model_catalog} WHERE model_handle = %s"
)
with atomic_transaction(self.db_connection) as cursor:
execute_sql(cursor, model_exists, params=(model_name,))
model_info_df = sql_response_to_df(cursor)
if model_info_df.empty:
result = None
else:
unprocessed_result = model_info_df.to_json(orient="records")
unprocessed_result_json = json.loads(unprocessed_result)[0]
result = {
key: process_col(elem)
for key, elem in unprocessed_result_json.items()
}
return result
def get_model_info(self) -> Optional[dict]:
"""
Checks if the model name is available.
Model info is present in the catalog only if the model was previously fitted.
Returns:
True if the model name is not part of the model catalog
Raises:
DatabaseError:
If a database connection issue occurs.
If an operational error occurs during execution.
"""
return self._get_model_info(self.model_name)
def _is_model_name_available(self, model_name: str) -> bool:
"""
Checks if the model name is available
Returns:
True if the model name is not part of the model catalog
Raises:
DatabaseError:
If a database connection issue occurs.
If an operational error occurs during execution.
"""
return self._get_model_info(model_name) is None
def _load_model(self) -> None:
"""
Loads the model specified by `self.model_name` into MySQL.
After loading, the model is ready to handle ML operations.
References:
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-model-load.html
Raises:
DatabaseError:
If the model is not initialized, i.e., fit or import has not been called
If a database connection issue occurs.
If an operational error occurs during execution.
Returns:
None
"""
with atomic_transaction(self.db_connection) as cursor:
load_model_query = f"CALL sys.ML_MODEL_LOAD(@{self.model_var}, NULL);"
execute_sql(cursor, load_model_query)
def _get_user(self) -> str:
"""
Fetch the current database user (without host).
Returns:
The username string associated with the connection.
Raises:
DatabaseError:
If a database connection issue occurs.
If an operational error occurs during execution.
ValueError: If the user name includes unsupported characters
"""
with atomic_transaction(self.db_connection) as cursor:
cursor.execute("SELECT CURRENT_USER()")
current_user = cursor.fetchone()[0].split("@")[0]
return validate_name(current_user)
def explain_model(self) -> dict:
"""
Get model explanations, such as detailed feature importances.
Returns:
dict: Feature importances and model explainability data.
References:
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-model-explanations.html
Raises:
DatabaseError:
If the model is not initialized, i.e., fit or import has not been called
If a database connection issue occurs.
If an operational error occurs during execution.
ValueError:
If the model does not exist in the model catalog.
Should only occur if model was not fitted or was deleted.
"""
self._load_model()
with atomic_transaction(self.db_connection) as cursor:
current_user = self._get_user()
qualified_model_catalog = f"ML_SCHEMA_{current_user}.MODEL_CATALOG"
explain_query = (
f"SELECT model_explanation FROM {qualified_model_catalog} "
f"WHERE model_handle = @{self.model_var}"
)
execute_sql(cursor, explain_query)
df = sql_response_to_df(cursor)
return df.iloc[0, 0]
def _fit(
self,
table_name: str,
target_column_name: Optional[str],
options: Optional[dict],
) -> None:
"""
Fit an ML model using a referenced SQL table and target column.
References:
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-train.html
A full list of supported options can be found under "Common ML_TRAIN Options"
Args:
table_name: Name of the training data table.
target_column_name: Name of the target/label column.
options: Additional fit/config options (may override defaults).
Raises:
DatabaseError:
If provided options are invalid or unsupported.
If a database connection issue occurs.
If an operational error occurs during execution.
ValueError: If the table or target_column name is not valid
Returns:
None
"""
validate_name(table_name)
if target_column_name is not None:
validate_name(target_column_name)
target_col_string = f"'{target_column_name}'"
else:
target_col_string = "NULL"
if options is None:
options = {}
options = copy.deepcopy(options)
options["task"] = self.task
self._delete_model()
with atomic_transaction(self.db_connection) as cursor:
placeholders, parameters = format_value_sql(options)
execute_sql(
cursor,
(
"CALL sys.ML_TRAIN("
f"'{self.schema_name}.{table_name}', "
f"{target_col_string}, "
f"{placeholders}, "
f"@{self.model_var}"
")"
),
params=parameters,
)
def _predict(
self, table_name: str, output_table_name: str, options: Optional[dict]
) -> None:
"""
Predict on a given data table and write results to an output table.
References:
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-predict-table.html
A full list of supported options can be found under "ML_PREDICT_TABLE Options"
Args:
table_name: Name of the SQL table with input data.
output_table_name: Name for the SQL output table to contain predictions.
options: Optional prediction options.
Returns:
None
Raises:
DatabaseError:
If provided options are invalid or unsupported,
or if the model is not initialized, i.e., fit or import has not
been called
If a database connection issue occurs.
If an operational error occurs during execution.
ValueError: If the table or output_table name is not valid
"""
validate_name(table_name)
validate_name(output_table_name)
self._load_model()
with atomic_transaction(self.db_connection) as cursor:
placeholders, parameters = format_value_sql(options)
execute_sql(
cursor,
(
"CALL sys.ML_PREDICT_TABLE("
f"'{self.schema_name}.{table_name}', "
f"@{self.model_var}, "
f"'{self.schema_name}.{output_table_name}', "
f"{placeholders}"
")"
),
params=parameters,
)
def _score(
self,
table_name: str,
target_column_name: str,
metric: str,
options: Optional[dict],
) -> float:
"""
Evaluate model performance with a scoring metric.
References:
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-score.html
A full list of supported options can be found under
"Options for Recommendation Models" and
"Options for Anomaly Detection Models"
Args:
table_name: Table with features and ground truth.
target_column_name: Column of true target labels.
metric: String name of the metric to compute.
options: Optional dictionary of further scoring options.
Returns:
float: Computed score from the ML system.
Raises:
DatabaseError:
If provided options are invalid or unsupported,
or if the model is not initialized, i.e., fit or import has not
been called
If a database connection issue occurs.
If an operational error occurs during execution.
ValueError: If the table or target_column name or metric is not valid
"""
validate_name(table_name)
validate_name(target_column_name)
validate_name(metric)
self._load_model()
with atomic_transaction(self.db_connection) as cursor:
placeholders, parameters = format_value_sql(options)
execute_sql(
cursor,
(
"CALL sys.ML_SCORE("
f"'{self.schema_name}.{table_name}', "
f"'{target_column_name}', "
f"@{self.model_var}, "
"%s, "
f"@{self.model_var_score}, "
f"{placeholders}"
")"
),
params=[metric, *parameters],
)
execute_sql(cursor, f"SELECT @{self.model_var_score}")
df = sql_response_to_df(cursor)
return df.iloc[0, 0]
def _explain_predictions(
self, table_name: str, output_table_name: str, options: Optional[dict]
) -> pd.DataFrame:
"""
Produce explanations for model predictions on provided data.
References:
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-explain-table.html
A full list of supported options can be found under "ML_EXPLAIN_TABLE Options"
Args:
table_name: Name of the SQL table with input data.
output_table_name: Name for the SQL table to store explanations.
options: Optional dictionary (default:
{"prediction_explainer": "permutation_importance"}).
Returns:
DataFrame: Prediction explanations from the output SQL table.
Raises:
DatabaseError:
If provided options are invalid or unsupported,
or if the model is not initialized, i.e., fit or import has not
been called
If a database connection issue occurs.
If an operational error occurs during execution.
ValueError: If the table or output_table name is not valid
"""
validate_name(table_name)
validate_name(output_table_name)
if options is None:
options = {"prediction_explainer": "permutation_importance"}
self._load_model()
with atomic_transaction(self.db_connection) as cursor:
placeholders, parameters = format_value_sql(options)
execute_sql(
cursor,
(
"CALL sys.ML_EXPLAIN_TABLE("
f"'{self.schema_name}.{table_name}', "
f"@{self.model_var}, "
f"'{self.schema_name}.{output_table_name}', "
f"{placeholders}"
")"
),
params=parameters,
)
execute_sql(cursor, f"SELECT * FROM {self.schema_name}.{output_table_name}")
df = sql_response_to_df(cursor)
return df
class MyModel(_MyModelCommon):
"""
Convenience class for managing the ML workflow using pandas DataFrames.
Methods convert in-memory DataFrames into temp SQL tables before delegating to the
_MyMLModelCommon routines, and automatically clean up temp resources.
"""
def fit(
self,
X: Union[pd.DataFrame, np.ndarray], # pylint: disable=invalid-name
y: Optional[Union[pd.DataFrame, np.ndarray]],
options: Optional[dict] = None,
) -> None:
"""
Fit a model using DataFrame inputs.
If an 'id' column is defined in either dataframe, it will be used as the primary key.
References:
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-train.html
A full list of supported options can be found under "Common ML_TRAIN Options"
Args:
X: Features DataFrame.
y: (Optional) Target labels DataFrame or Series. If None, only X is used.
options: Additional options to pass to training.
Returns:
None
Raises:
DatabaseError:
If provided options are invalid or unsupported.
If a database connection issue occurs.
If an operational error occurs during execution.
Notes:
Combines X and y as necessary. Creates a temporary table in the schema for training,
and deletes it afterward.
"""
X, y = convert_to_df(X), convert_to_df(y)
with (
atomic_transaction(self.db_connection) as cursor,
temporary_sql_tables(self.db_connection) as temporary_tables,
):
if y is not None:
if isinstance(y, pd.DataFrame):
# keep column name if it exists
target_column_name = y.columns[0]
else:
target_column_name = get_random_name(
lambda name: name not in X.columns
)
if target_column_name in X.columns:
raise ValueError(
f"Target column y with name {target_column_name} already present "
"in feature dataframe X"
)
df_combined = X.copy()
df_combined[target_column_name] = y
final_df = df_combined
else:
target_column_name = None
final_df = X
_, table_name = sql_table_from_df(cursor, self.schema_name, final_df)
temporary_tables.append((self.schema_name, table_name))
self._fit(table_name, target_column_name, options)
def predict(
self,
X: Union[pd.DataFrame, np.ndarray], # pylint: disable=invalid-name
options: Optional[dict] = None,
) -> pd.DataFrame:
"""
Generate model predictions using DataFrame input.
If an 'id' column is defined in either dataframe, it will be used as the primary key.
References:
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-predict-table.html
A full list of supported options can be found under "ML_PREDICT_TABLE Options"
Args:
X: DataFrame containing prediction features (no labels).
options: Additional prediction settings.
Returns:
DataFrame with prediction results as returned by HeatWave.
Raises:
DatabaseError:
If provided options are invalid or unsupported,
or if the model is not initialized, i.e., fit or import has not
been called
If a database connection issue occurs.
If an operational error occurs during execution.
Notes:
Temporary SQL tables are created and deleted for input/output.
"""
X = convert_to_df(X)
with (
atomic_transaction(self.db_connection) as cursor,
temporary_sql_tables(self.db_connection) as temporary_tables,
):
_, table_name = sql_table_from_df(cursor, self.schema_name, X)
temporary_tables.append((self.schema_name, table_name))
output_table_name = get_random_name(
lambda table_name: not table_exists(
cursor, self.schema_name, table_name
)
)
temporary_tables.append((self.schema_name, output_table_name))
self._predict(table_name, output_table_name, options)
predictions = sql_table_to_df(cursor, self.schema_name, output_table_name)
# ml_results is text but known to always follow JSON format
predictions["ml_results"] = predictions["ml_results"].map(json.loads)
return predictions
def score(
self,
X: Union[pd.DataFrame, np.ndarray], # pylint: disable=invalid-name
y: Union[pd.DataFrame, np.ndarray],
metric: str,
options: Optional[dict] = None,
) -> float:
"""
Score the model using X/y data and a selected metric.
If an 'id' column is defined in either dataframe, it will be used as the primary key.
References:
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-score.html
A full list of supported options can be found under
"Options for Recommendation Models" and
"Options for Anomaly Detection Models"
Args:
X: DataFrame of features.
y: DataFrame or Series of labels.
metric: Metric name (e.g., "balanced_accuracy").
options: Optional ml scoring options.
Raises:
DatabaseError:
If provided options are invalid or unsupported,
or if the model is not initialized, i.e., fit or import has not
been called
If a database connection issue occurs.
If an operational error occurs during execution.
Returns:
float: Computed score.
"""
X, y = convert_to_df(X), convert_to_df(y)
with (
atomic_transaction(self.db_connection) as cursor,
temporary_sql_tables(self.db_connection) as temporary_tables,
):
target_column_name = get_random_name(lambda name: name not in X.columns)
df_combined = X.copy()
df_combined[target_column_name] = y
final_df = df_combined
_, table_name = sql_table_from_df(cursor, self.schema_name, final_df)
temporary_tables.append((self.schema_name, table_name))
score = self._score(table_name, target_column_name, metric, options)
return score
def explain_predictions(
self,
X: Union[pd.DataFrame, np.ndarray], # pylint: disable=invalid-name
options: Dict = None,
) -> pd.DataFrame:
"""
Explain model predictions using provided data.
If an 'id' column is defined in either dataframe, it will be used as the primary key.
References:
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-explain-table.html
A full list of supported options can be found under
"ML_EXPLAIN_TABLE Options"
Args:
X: DataFrame for which predictions should be explained.
options: Optional dictionary of explainability options.
Returns:
DataFrame containing explanation details (feature attributions, etc.)
Raises:
DatabaseError:
If provided options are invalid or unsupported, or if the model is not initialized,
i.e., fit or import has not been called
If a database connection issue occurs.
If an operational error occurs during execution.
Notes:
Temporary input/output tables are cleaned up after explanation.
"""
X = convert_to_df(X)
with (
atomic_transaction(self.db_connection) as cursor,
temporary_sql_tables(self.db_connection) as temporary_tables,
):
_, table_name = sql_table_from_df(cursor, self.schema_name, X)
temporary_tables.append((self.schema_name, table_name))
output_table_name = get_random_name(
lambda table_name: not table_exists(
cursor, self.schema_name, table_name
)
)
temporary_tables.append((self.schema_name, output_table_name))
explanations = self._explain_predictions(
table_name, output_table_name, options
)
return explanations

View File

@@ -0,0 +1,221 @@
# Copyright (c) 2025 Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Outlier/anomaly detection utilities for MySQL Connector/Python.
Provides a scikit-learn compatible wrapper using HeatWave to score anomalies.
"""
from typing import Optional, Union
import numpy as np
import pandas as pd
from sklearn.base import OutlierMixin
from mysql.ai.ml.base import MyBaseMLModel
from mysql.ai.ml.model import ML_TASK
from mysql.ai.utils import copy_dict
from mysql.connector.abstracts import MySQLConnectionAbstract
EPS = 1e-5
def _get_logits(prob: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
"""
Compute logit (logodds) for a probability, clipping to avoid numerical overflow.
Args:
prob: Scalar or array of probability values in (0,1).
Returns:
logit-transformed probabilities.
"""
result = np.clip(prob, EPS, 1 - EPS)
return np.log(result / (1 - result))
class MyAnomalyDetector(MyBaseMLModel, OutlierMixin):
"""
MySQL HeatWave scikit-learn compatible anomaly/outlier detector.
Flags samples as outliers when the probability of being an anomaly
exceeds a user-tunable threshold.
Includes helpers to obtain decision scores and anomaly probabilities
for ranking.
Args:
db_connection (MySQLConnectionAbstract): Active MySQL DB connection.
model_name (str, optional): Custom model name in the database.
fit_extra_options (dict, optional): Extra options for fitting.
score_extra_options (dict, optional): Extra options for scoring/prediction.
Attributes:
boundary: Decision threshold boundary in logit space. Derived from
trained model's catalog info
Methods:
predict(X): Predict outlier/inlier labels.
score_samples(X): Compute anomaly (normal class) logit scores.
decision_function(X): Compute signed score above/below threshold for ranking.
"""
def __init__(
self,
db_connection: MySQLConnectionAbstract,
model_name: Optional[str] = None,
fit_extra_options: Optional[dict] = None,
score_extra_options: Optional[dict] = None,
):
"""
Initialize an anomaly detector instance with threshold and extra options.
Args:
db_connection: Active MySQL DB connection.
model_name: Optional model name in DB.
fit_extra_options: Optional extra fit options.
score_extra_options: Optional extra scoring options.
Raises:
ValueError: If outlier_threshold is not in (0,1).
DatabaseError:
If a database connection issue occurs.
If an operational error occurs during execution.
"""
MyBaseMLModel.__init__(
self,
db_connection,
ML_TASK.ANOMALY_DETECTION,
model_name=model_name,
fit_extra_options=fit_extra_options,
)
self.score_extra_options = copy_dict(score_extra_options)
self.boundary: Optional[float] = None
def predict(
self,
X: Union[pd.DataFrame, np.ndarray], # pylint: disable=invalid-name
) -> np.ndarray:
"""
Predict outlier/inlier binary labels for input samples.
Args:
X: Samples to predict on.
Returns:
ndarray: Values are -1 for outliers, +1 for inliers, as per scikit-learn convention.
Raises:
DatabaseError:
If provided options are invalid or unsupported,
or if the model is not initialized, i.e., fit or import has not
been called
If a database connection issue occurs.
If an operational error occurs during execution.
DatabaseError:
If provided options are invalid or unsupported,
or if the model is not initialized, i.e., fit or import has not
been called
If a database connection issue occurs.
If an operational error occurs during execution.
"""
return np.where(self.decision_function(X) < 0.0, -1, 1)
def decision_function(
self,
X: Union[pd.DataFrame, np.ndarray], # pylint: disable=invalid-name
) -> np.ndarray:
"""
Compute signed distance to the outlier threshold.
Args:
X: Samples to predict on.
Returns:
ndarray: Score > 0 means inlier, < 0 means outlier; |value| gives margin.
Raises:
DatabaseError:
If provided options are invalid or unsupported,
or if the model is not initialized, i.e., fit or import has not
been called
If a database connection issue occurs.
If an operational error occurs during execution.
ValueError:
If the provided model info does not provide threshold
"""
sample_scores = self.score_samples(X)
if self.boundary is None:
model_info = self.get_model_info()
if model_info is None:
raise ValueError("Model does not exist in catalog.")
threshold = model_info["model_metadata"]["training_params"].get(
"anomaly_detection_threshold", None
)
if threshold is None:
raise ValueError(
"Trained model is outdated and does not support threshold. "
"Try retraining or using an existing, trained model with MyModel."
)
# scikit-learn uses large positive values as inlier
# and negative as outlier, so we need to flip our threshold
self.boundary = _get_logits(1.0 - threshold)
return sample_scores - self.boundary
def score_samples(
self,
X: Union[pd.DataFrame, np.ndarray], # pylint: disable=invalid-name
) -> np.ndarray:
"""
Compute normal probability logit score for each sample.
Used for ranking, thresholding.
Args:
X: Samples to score.
Returns:
ndarray: Logit scores based on "normal" class probability.
Raises:
DatabaseError:
If provided options are invalid or unsupported,
or if the model is not initialized, i.e., fit or import has not
been called
If a database connection issue occurs.
If an operational error occurs during execution.
"""
result = self._model.predict(X, options=self.score_extra_options)
return _get_logits(
result["ml_results"]
.apply(lambda x: x["probabilities"]["normal"])
.to_numpy()
)

View File

@@ -0,0 +1,154 @@
# Copyright (c) 2025 Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Regressor utilities for MySQL Connector/Python.
Provides a scikit-learn compatible regressor backed by HeatWave ML.
"""
from typing import Optional, Union
import numpy as np
import pandas as pd
from sklearn.base import RegressorMixin
from mysql.ai.ml.base import MyBaseMLModel
from mysql.ai.ml.model import ML_TASK
from mysql.ai.utils import copy_dict
from mysql.connector.abstracts import MySQLConnectionAbstract
class MyRegressor(MyBaseMLModel, RegressorMixin):
"""
MySQL HeatWave scikit-learn compatible regressor estimator.
Provides prediction output from a regression model deployed in MySQL,
and manages fit, explain, and prediction options as per HeatWave ML interface.
Attributes:
predict_extra_options (dict): Optional parameter dict passed to the backend for prediction.
_model (MyModel): Underlying interface for database model operations.
fit_extra_options (dict): See MyBaseMLModel.
explain_extra_options (dict): See MyBaseMLModel.
Args:
db_connection (MySQLConnectionAbstract): Active MySQL connector DB connection.
model_name (str, optional): Custom name for the model.
fit_extra_options (dict, optional): Extra options for fitting.
explain_extra_options (dict, optional): Extra options for explanations.
predict_extra_options (dict, optional): Extra options for predictions.
Methods:
predict(X): Predict regression target.
"""
def __init__(
self,
db_connection: MySQLConnectionAbstract,
model_name: Optional[str] = None,
fit_extra_options: Optional[dict] = None,
explain_extra_options: Optional[dict] = None,
predict_extra_options: Optional[dict] = None,
):
"""
Initialize a MyRegressor.
Args:
db_connection: Active MySQL connector database connection.
model_name: Optional, custom model name.
fit_extra_options: Optional fit options.
explain_extra_options: Optional explain options.
predict_extra_options: Optional prediction options.
Raises:
DatabaseError:
If a database connection issue occurs.
If an operational error occurs during execution.
"""
MyBaseMLModel.__init__(
self,
db_connection,
ML_TASK.REGRESSION,
model_name=model_name,
fit_extra_options=fit_extra_options,
)
self.predict_extra_options = copy_dict(predict_extra_options)
self.explain_extra_options = copy_dict(explain_extra_options)
def predict(
self, X: Union[pd.DataFrame, np.ndarray]
) -> np.ndarray: # pylint: disable=invalid-name
"""
Predict a continuous target for the input features using the MySQL model.
Args:
X: Input samples as a numpy array or pandas DataFrame.
Returns:
ndarray: Array of predicted target values, shape (n_samples,).
Raises:
DatabaseError:
If provided options are invalid or unsupported,
or if the model is not initialized, i.e., fit or import has not
been called
If a database connection issue occurs.
If an operational error occurs during execution.
"""
result = self._model.predict(X, options=self.predict_extra_options)
return result["Prediction"].to_numpy()
def explain_predictions(
self, X: Union[pd.DataFrame, np.ndarray]
) -> pd.DataFrame: # pylint: disable=invalid-name
"""
Explain model predictions using provided data.
References:
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-explain-table.html
A full list of supported options can be found under "ML_EXPLAIN_TABLE Options"
Args:
X: DataFrame for which predictions should be explained.
Returns:
DataFrame containing explanation details (feature attributions, etc.)
Raises:
DatabaseError:
If provided options are invalid or unsupported,
or if the model is not initialized, i.e., fit or import has not
been called
If a database connection issue occurs.
If an operational error occurs during execution.
Notes:
Temporary input/output tables are cleaned up after explanation.
"""
self._model.explain_predictions(X, options=self.explain_extra_options)

View File

@@ -0,0 +1,164 @@
# Copyright (c) 2025 Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Generic transformer utilities for MySQL Connector/Python.
Provides a scikit-learn compatible Transformer using HeatWave for fit/transform
and scoring operations.
"""
from typing import Optional, Union
import numpy as np
import pandas as pd
from sklearn.base import TransformerMixin
from mysql.ai.ml.base import MyBaseMLModel
from mysql.ai.ml.model import ML_TASK
from mysql.ai.utils import copy_dict
from mysql.connector.abstracts import MySQLConnectionAbstract
class MyGenericTransformer(MyBaseMLModel, TransformerMixin):
"""
MySQL HeatWave scikit-learn compatible generic transformer.
Can be used as the transformation step in an sklearn pipeline. Implements fit, transform,
explain, and scoring capability, passing options for server-side transform logic.
Args:
db_connection (MySQLConnectionAbstract): Active MySQL connector database connection.
task (str): ML task type for transformer (default: "classification").
score_metric (str): Scoring metric to request from backend (default: "balanced_accuracy").
model_name (str, optional): Custom name for the deployed model.
fit_extra_options (dict, optional): Extra fit options.
transform_extra_options (dict, optional): Extra options for transformations.
score_extra_options (dict, optional): Extra options for scoring.
Attributes:
score_metric (str): Name of the backend metric to use for scoring
(e.g. "balanced_accuracy").
score_extra_options (dict): Dictionary of optional scoring parameters;
passed to backend score.
transform_extra_options (dict): Dictionary of inference (/predict)
parameters for the backend.
fit_extra_options (dict): See MyBaseMLModel.
_model (MyModel): Underlying interface for database model operations.
Methods:
fit(X, y): Fit the underlying model using the provided features/targets.
transform(X): Transform features using the backend model.
score(X, y): Score data using backend metric and options.
"""
def __init__(
self,
db_connection: MySQLConnectionAbstract,
task: Union[str, ML_TASK] = ML_TASK.CLASSIFICATION,
score_metric: str = "balanced_accuracy",
model_name: Optional[str] = None,
fit_extra_options: Optional[dict] = None,
transform_extra_options: Optional[dict] = None,
score_extra_options: Optional[dict] = None,
):
"""
Initialize transformer with required and optional arguments.
Args:
db_connection: Active MySQL backend database connection.
task: ML task type for transformer.
score_metric: Requested backend scoring metric.
model_name: Optional model name for storage.
fit_extra_options: Optional extra options for fitting.
transform_extra_options: Optional extra options for transformation/inference.
score_extra_options: Optional extra scoring options.
Raises:
DatabaseError:
If a database connection issue occurs.
If an operational error occurs during execution.
"""
MyBaseMLModel.__init__(
self,
db_connection,
task,
model_name=model_name,
fit_extra_options=fit_extra_options,
)
self.score_metric = score_metric
self.score_extra_options = copy_dict(score_extra_options)
self.transform_extra_options = copy_dict(transform_extra_options)
def transform(
self, X: pd.DataFrame
) -> pd.DataFrame: # pylint: disable=invalid-name
"""
Transform input data to model predictions using the underlying helper.
Args:
X: DataFrame of features to predict/transform.
Returns:
pd.DataFrame: Results of transformation as returned by backend.
Raises:
DatabaseError:
If provided options are invalid or unsupported,
or if the model is not initialized, i.e., fit or import has not
been called
If a database connection issue occurs.
If an operational error occurs during execution.
"""
return self._model.predict(X, options=self.transform_extra_options)
def score(
self,
X: Union[pd.DataFrame, np.ndarray], # pylint: disable=invalid-name
y: Union[pd.DataFrame, np.ndarray],
) -> float:
"""
Score the transformed data using the backend scoring interface.
Args:
X: Transformed features.
y: Target labels or data for scoring.
Returns:
float: Score based on backend metric.
Raises:
DatabaseError:
If provided options are invalid or unsupported,
or if the model is not initialized, i.e., fit or import has not
been called
If a database connection issue occurs.
If an operational error occurs during execution.
"""
return self._model.score(
X, y, self.score_metric, options=self.score_extra_options
)

View File

@@ -0,0 +1,44 @@
# Copyright (c) 2025 Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Utilities for AI-related helpers in MySQL Connector/Python.
This package exposes:
- check_dependencies(): runtime dependency guard for optional AI features
- atomic_transaction(): context manager ensuring atomic DB transactions
- utils: general-purpose helpers used by AI integrations
Importing this package validates base dependencies required for AI utilities.
"""
from .dependencies import check_dependencies
check_dependencies(["BASE"])
from .atomic_cursor import atomic_transaction
from .utils import *

View File

@@ -0,0 +1,94 @@
# Copyright (c) 2025 Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Atomic transaction context manager utilities for MySQL Connector/Python.
Provides context manager atomic_transaction() that ensures commit on success
and rollback on error without obscuring the original exception.
"""
from contextlib import contextmanager
from typing import Iterator
from mysql.connector.abstracts import MySQLConnectionAbstract
from mysql.connector.cursor import MySQLCursorAbstract
@contextmanager
def atomic_transaction(
conn: MySQLConnectionAbstract,
) -> Iterator[MySQLCursorAbstract]:
"""
Context manager that wraps a MySQL database cursor and ensures transaction
rollback in case of exception.
NOTE: DDL statements such as CREATE TABLE cause implicit commits. These cannot
be managed by a cursor object. Changes made at or before a DDL statement will
be committed and not rolled back. Callers are responsible for any cleanup of
this type.
This class acts as a robust, PEP 343-compliant context manager for handling
database cursor operations on a MySQL connection. It ensures that all operations
executed within the context block are part of the same transaction, and
automatically calls `connection.rollback()` if an exception occurs, helping
to maintain database integrity. On normal completion (no exception), it simply
closes the cursor after use. Exceptions are always propagated to the caller.
Args:
conn: A MySQLConnectionAbstract instance.
"""
old_autocommit = conn.autocommit
cursor = conn.cursor()
exception_raised = False
try:
if old_autocommit:
conn.autocommit = False
yield cursor # provide cursor to block
conn.commit()
except Exception: # pylint: disable=broad-exception-caught
exception_raised = True
try:
conn.rollback()
except Exception: # pylint: disable=broad-exception-caught
# Don't obscure original exception
pass
# Raise original exception
raise
finally:
conn.autocommit = old_autocommit
try:
cursor.close()
except Exception: # pylint: disable=broad-exception-caught
# don't obscure original exception if exists
if not exception_raised:
raise

View File

@@ -0,0 +1,88 @@
# Copyright (c) 2025 Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Dependency checking utilities for AI features in MySQL Connector/Python.
Provides check_dependencies() to assert required optional packages are present
with acceptable minimum versions at runtime.
"""
import importlib.metadata
from typing import List
def check_dependencies(tasks: List[str]) -> None:
"""
Check required runtime dependencies and minimum versions; raise an error
if any are missing or version-incompatible.
This verifies the presence and minimum version of essential Python packages.
Missing or insufficient versions cause an ImportError listing the packages
and a suggested install command.
Args:
tasks (List[str]): Task types to check requirements for.
Raises:
ImportError: If any required dependencies are missing or below the
minimum version.
"""
task_set = set(tasks)
task_set.add("BASE")
# Requirements: (import_name, min_version)
task_to_requirement = {
"BASE": [("pandas", "1.5.0")],
"GENAI": [
("langchain", "0.1.11"),
("langchain_core", "0.1.11"),
("pydantic", "1.10.0"),
],
"ML": [("scikit-learn", "1.3.0")],
}
requirements = []
for task in task_set:
requirements.extend(task_to_requirement[task])
requirements_set = set(requirements)
problems = []
for name, min_version in requirements_set:
try:
installed_version = importlib.metadata.version(name)
# Version comparison uses simple string comparison to avoid extra
# dependencies. This is valid for the dependencies defined above;
# reconsider if adding packages with version schemes that do not
# compare correctly as strings.
error = installed_version < min_version
except importlib.metadata.PackageNotFoundError:
error = True
if error:
problems.append(f"{name} v{min_version} (or later)")
if problems:
raise ImportError("Please install " + ", ".join(problems) + ".")

View File

@@ -0,0 +1,573 @@
# Copyright (c) 2025 Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""General utilities for AI features in MySQL Connector/Python.
Includes helpers for:
- defensive dict copying
- temporary table lifecycle management
- SQL execution and result conversions
- DataFrame to/from SQL table utilities
- schema/table/column name validation
- array-like to DataFrame conversion
"""
import copy
import json
import random
import re
import string
from contextlib import contextmanager
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union
import numpy as np
import pandas as pd
from mysql.ai.utils.atomic_cursor import atomic_transaction
from mysql.connector.abstracts import MySQLConnectionAbstract
from mysql.connector.cursor import MySQLCursorAbstract
from mysql.connector.types import ParamsSequenceOrDictType
VAR_NAME_SPACE = "mysql_ai"
RANDOM_TABLE_NAME_LENGTH = 32
PD_TO_SQL_DTYPE_MAPPING = {
"int64": "BIGINT",
"float64": "DOUBLE",
"object": "LONGTEXT",
"bool": "BOOLEAN",
"datetime64[ns]": "DATETIME",
}
DEFAULT_SCHEMA = "mysql_ai"
# Misc Utilities
def copy_dict(options: Optional[dict]) -> dict:
"""
Make a defensive copy of a dictionary, or return an empty dict if None.
Args:
options: param dict or None
Returns:
dict
"""
if options is None:
return {}
return copy.deepcopy(options)
@contextmanager
def temporary_sql_tables(
db_connection: MySQLConnectionAbstract,
) -> Iterator[list[tuple[str, str]]]:
"""
Context manager to track and automatically clean up temporary SQL tables.
Args:
db_connection: Database connection object used to create and delete tables.
Returns:
None
Raises:
DatabaseError:
If a database connection issue occurs.
If an operational error occurs during execution.
Yields:
temporary_tables: List of (schema_name, table_name) tuples created during the
context. All tables in this list are deleted on context exit.
"""
temporary_tables: List[Tuple[str, str]] = []
try:
yield temporary_tables
finally:
with atomic_transaction(db_connection) as cursor:
for schema_name, table_name in temporary_tables:
delete_sql_table(cursor, schema_name, table_name)
def execute_sql(
cursor: MySQLCursorAbstract, query: str, params: ParamsSequenceOrDictType = None
) -> None:
"""
Execute an SQL query with optional parameters using the given cursor.
Args:
cursor: MySQLCursorAbstract object to execute the query.
query: SQL query string to execute.
params: Optional sequence or dict providing parameters for the query.
Raises:
DatabaseError:
If the provided SQL query/params are invalid
If the query is valid but the sql raises as an exception
If a database connection issue occurs.
If an operational error occurs during execution.
Returns:
None
"""
cursor.execute(query, params or ())
def _get_name() -> str:
"""
Generate a random uppercase string of fixed length for table names.
Returns:
Random string of length RANDOM_TABLE_NAME_LENGTH.
"""
char_set = string.ascii_uppercase
return "".join(random.choices(char_set, k=RANDOM_TABLE_NAME_LENGTH))
def get_random_name(condition: Callable[[str], bool], max_calls: int = 100) -> str:
"""
Generate a random string name that satisfies a given condition.
Args:
condition: Callable that takes a generated name and returns True if it is valid.
max_calls: Maximum number of attempts before giving up (default 100).
Returns:
A random string that fulfills the provided condition.
Raises:
RuntimeError: If the maximum number of attempts is reached without success.
"""
for _ in range(max_calls):
if condition(name := _get_name()):
return name
# condition never met
raise RuntimeError("Reached max tries without successfully finding a unique name")
# Format conversions
def format_value_sql(value: Any) -> Tuple[str, List[Any]]:
"""
Convert a Python value into its SQL-compatible string representation and parameters.
Args:
value: The value to format.
Returns:
Tuple containing:
- A string for substitution into a SQL query.
- A list of parameters to be bound into the query.
"""
if isinstance(value, (dict, list)):
if len(value) == 0:
return "%s", [None]
return "CAST(%s as JSON)", [json.dumps(value)]
return "%s", [value]
def sql_response_to_df(cursor: MySQLCursorAbstract) -> pd.DataFrame:
"""
Convert the results of a cursor's last executed query to a pandas DataFrame.
Args:
cursor: MySQLCursorAbstract with a completed query.
Returns:
DataFrame with data from the cursor.
Raises:
DatabaseError:
If a database connection issue occurs.
If an operational error occurs during execution.
If a compatible SELECT query wasn't the last statement ran
"""
def _json_processor(elem: Optional[str]) -> Optional[dict]:
return json.loads(elem) if elem is not None else None
def _default_processor(elem: Any) -> Any:
return elem
idx_to_processor = {}
for idx, col in enumerate(cursor.description):
if col[1] == 245:
# 245 is the MySQL type code for JSON
idx_to_processor[idx] = _json_processor
else:
idx_to_processor[idx] = _default_processor
rows = cursor.fetchall()
# Process results
processed_rows = []
for row in rows:
processed_row = list(row)
for idx, elem in enumerate(row):
processed_row[idx] = idx_to_processor[idx](elem)
processed_rows.append(processed_row)
return pd.DataFrame(processed_rows, columns=cursor.column_names)
def sql_table_to_df(
cursor: MySQLCursorAbstract, schema_name: str, table_name: str
) -> pd.DataFrame:
"""
Load the entire contents of a SQL table into a pandas DataFrame.
Args:
cursor: MySQLCursorAbstract to execute the query.
schema_name: Name of the schema containing the table.
table_name: Name of the table to fetch.
Returns:
DataFrame containing all rows from the specified table.
Raises:
DatabaseError:
If the table does not exist
If a database connection issue occurs.
If an operational error occurs during execution.
ValueError: If the schema or table name is not valid
"""
validate_name(schema_name)
validate_name(table_name)
execute_sql(cursor, f"SELECT * FROM {schema_name}.{table_name}")
return sql_response_to_df(cursor)
# Table operations
def table_exists(
cursor: MySQLCursorAbstract, schema_name: str, table_name: str
) -> bool:
"""
Check whether a table exists in a specific schema.
Args:
cursor: MySQLCursorAbstract object to execute the query.
schema_name: Name of the database schema.
table_name: Name of the table.
Returns:
True if the table exists, False otherwise.
Raises:
DatabaseError:
If a database connection issue occurs.
If an operational error occurs during execution.
ValueError: If the schema or table name is not valid
"""
validate_name(schema_name)
validate_name(table_name)
cursor.execute(
"""
SELECT 1
FROM information_schema.tables
WHERE table_schema = %s AND table_name = %s
LIMIT 1
""",
(schema_name, table_name),
)
return cursor.fetchone() is not None
def delete_sql_table(
cursor: MySQLCursorAbstract, schema_name: str, table_name: str
) -> None:
"""
Drop a table from the SQL database if it exists.
Args:
cursor: MySQLCursorAbstract to execute the drop command.
schema_name: Name of the schema.
table_name: Name of the table to delete.
Returns:
None
Raises:
DatabaseError:
If a database connection issue occurs.
If an operational error occurs during execution.
ValueError: If the schema or table name is not valid
"""
validate_name(schema_name)
validate_name(table_name)
execute_sql(cursor, f"DROP TABLE IF EXISTS {schema_name}.{table_name}")
def extend_sql_table(
cursor: MySQLCursorAbstract,
schema_name: str,
table_name: str,
df: pd.DataFrame,
col_name_to_placeholder_string: Dict[str, str] = None,
) -> None:
"""
Insert all rows from a pandas DataFrame into an existing SQL table.
Args:
cursor: MySQLCursorAbstract for execution.
schema_name: Name of the database schema.
table_name: Table to insert new rows into.
df: DataFrame containing the rows to insert.
col_name_to_placeholder_string:
Optional mapping of column names to custom SQL value/placeholder
strings.
Returns:
None
Raises:
DatabaseError:
If the rows could not be inserted into the table, e.g., a type or shape issue
If a database connection issue occurs.
If an operational error occurs during execution.
ValueError: If the schema or table name is not valid
"""
if col_name_to_placeholder_string is None:
col_name_to_placeholder_string = {}
validate_name(schema_name)
validate_name(table_name)
for col in df.columns:
validate_name(str(col))
qualified_table_name = f"{schema_name}.{table_name}"
# Iterate over all rows in the DataFrame to build insert statements row by row
for row in df.values:
placeholders, params = [], []
for elem, col in zip(row, df.columns):
elem = elem.item() if hasattr(elem, "item") else elem
if col in col_name_to_placeholder_string:
elem_placeholder, elem_params = col_name_to_placeholder_string[col], [
str(elem)
]
else:
elem_placeholder, elem_params = format_value_sql(elem)
placeholders.append(elem_placeholder)
params.extend(elem_params)
cols_sql = ", ".join([str(col) for col in df.columns])
placeholders_sql = ", ".join(placeholders)
insert_sql = (
f"INSERT INTO {qualified_table_name} "
f"({cols_sql}) VALUES ({placeholders_sql})"
)
execute_sql(cursor, insert_sql, params=params)
def sql_table_from_df(
cursor: MySQLCursorAbstract, schema_name: str, df: pd.DataFrame
) -> Tuple[str, str]:
"""
Create a new SQL table with a random name, and populate it with data from a DataFrame.
If an 'id' column is defined in the dataframe, it will be used as the primary key.
Args:
cursor: MySQLCursorAbstract for executing SQL.
schema_name: Schema in which to create the table.
df: DataFrame containing the data to be inserted.
Returns:
Tuple (qualified_table_name, table_name): The schema-qualified and
unqualified table names.
Raises:
RuntimeError: If a random available table name could not be found.
ValueError: If any schema, table, or a column name is invalid.
DatabaseError:
If a database connection issue occurs.
If an operational error occurs during execution.
"""
table_name = get_random_name(
lambda table_name: not table_exists(cursor, schema_name, table_name)
)
qualified_table_name = f"{schema_name}.{table_name}"
validate_name(schema_name)
validate_name(table_name)
for col in df.columns:
validate_name(str(col))
columns_sql = []
for col, dtype in df.dtypes.items():
# Map pandas dtype to SQL type, fallback is VARCHAR
sql_type = PD_TO_SQL_DTYPE_MAPPING.get(str(dtype), "LONGTEXT")
validate_name(str(col))
columns_sql.append(f"{col} {sql_type}")
columns_str = ", ".join(columns_sql)
has_id_col = any(col.lower() == "id" for col in df.columns)
if has_id_col:
columns_str += ", PRIMARY KEY (id)"
# Create table with generated columns
create_table_sql = f"CREATE TABLE {qualified_table_name} ({columns_str})"
execute_sql(cursor, create_table_sql)
try:
# Insert provided data into new table
extend_sql_table(cursor, schema_name, table_name, df)
except Exception: # pylint: disable=broad-exception-caught
# Delete table before we lose access to it
delete_sql_table(cursor, schema_name, table_name)
raise
return qualified_table_name, table_name
def validate_name(name: str) -> str:
"""
Validate that the string is a legal SQL identifier (letters, digits, underscores).
Args:
name: Name (schema, table, or column) to validate.
Returns:
The validated name.
Raises:
ValueError: If the name does not meet format requirements.
"""
# Accepts only letters, digits, and underscores; change as needed
if not (isinstance(name, str) and re.match(r"^[A-Za-z0-9_]+$", name)):
raise ValueError(f"Unsupported name format {name}")
return name
def source_schema(db_connection: MySQLConnectionAbstract) -> str:
"""
Retrieve the name of the currently selected schema, or set and ensure the default schema.
Args:
db_connection: MySQL connector database connection object.
Returns:
Name of the schema (database in use).
Raises:
ValueError: If the schema name is not valid
DatabaseError:
If a database connection issue occurs.
If an operational error occurs during execution.
"""
schema = db_connection.database
if schema is None:
schema = DEFAULT_SCHEMA
with atomic_transaction(db_connection) as cursor:
create_database_stmt = f"CREATE DATABASE IF NOT EXISTS {schema}"
execute_sql(cursor, create_database_stmt)
validate_name(schema)
return schema
def is_table_empty(
cursor: MySQLCursorAbstract, schema_name: str, table_name: str
) -> bool:
"""
Determine if a given SQL table is empty.
Args:
cursor: MySQLCursorAbstract with access to the database.
schema_name: Name of the schema containing the table.
table_name: Name of the table to check.
Returns:
True if the table has no rows, False otherwise.
Raises:
DatabaseError:
If the table does not exist
If a database connection issue occurs.
If an operational error occurs during execution.
ValueError: If the schema or table name is not valid
"""
validate_name(schema_name)
validate_name(table_name)
cursor.execute(f"SELECT 1 FROM {schema_name}.{table_name} LIMIT 1")
return cursor.fetchone() is None
def convert_to_df(
arr: Optional[Union[pd.DataFrame, pd.Series, np.ndarray]],
col_prefix: str = "feature",
) -> Optional[pd.DataFrame]:
"""
Convert input data to a pandas DataFrame if necessary.
Args:
arr: Input data as a pandas DataFrame, NumPy ndarray, pandas Series, or None.
Returns:
If the input is None, returns None.
Otherwise, returns a DataFrame backed by the same underlying data whenever
possible (except in cases where pandas or NumPy must copy, such as for
certain views or non-contiguous arrays).
Notes:
- If an ndarray is passed, column names will be integer indices (0, 1, ...).
- If a DataFrame is passed, column names and indices are preserved.
- The returned DataFrame is a shallow copy and shares data with the original
input when possible; however, copies may still occur for certain input
types or memory layouts.
"""
if arr is None:
return None
if isinstance(arr, pd.DataFrame):
return pd.DataFrame(arr)
if isinstance(arr, pd.Series):
return arr.to_frame()
if arr.ndim == 1:
arr = arr.reshape(-1, 1)
col_names = [f"{col_prefix}_{idx}" for idx in range(arr.shape[1])]
return pd.DataFrame(arr, columns=col_names, copy=False)