Añadiendo todos los archivos del proyecto (incluidos secretos y venv)

This commit is contained in:
2026-03-06 18:31:45 -06:00
parent 3a15a3eafa
commit e4d50b6eb5
4965 changed files with 991048 additions and 0 deletions

View File

@@ -0,0 +1,48 @@
# Copyright (c) 2025 Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""ML package for MySQL Connector/Python.
Performs optional dependency checks and exposes ML utilities:
- ML_TASK, MyModel
- MyClassifier, MyRegressor, MyGenericTransformer
- MyAnomalyDetector
"""
from mysql.ai.utils import check_dependencies as _check_dependencies
_check_dependencies(["ML"])
del _check_dependencies
# Sklearn models
from .classifier import MyClassifier
# Minimal interface
from .model import ML_TASK, MyModel
from .outlier import MyAnomalyDetector
from .regressor import MyRegressor
from .transformer import MyGenericTransformer

View File

@@ -0,0 +1,142 @@
# Copyright (c) 2025 Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Base classes for MySQL HeatWave ML estimators for Connector/Python.
Implements a scikit-learn-compatible base estimator wrapping server-side ML.
"""
from typing import Optional, Union
import pandas as pd
from sklearn.base import BaseEstimator
from mysql.connector.abstracts import MySQLConnectionAbstract
from mysql.ai.ml.model import ML_TASK, MyModel
from mysql.ai.utils import copy_dict
class MyBaseMLModel(BaseEstimator):
"""
Base class for MySQL HeatWave machine learning estimators.
Implements the scikit-learn API and core model management logic,
including fit, explain, serialization, and dynamic option handling.
For use as a base class by classifiers, regressors, transformers, and outlier models.
Args:
db_connection (MySQLConnectionAbstract): An active MySQL connector database connection.
task (str): ML task type, e.g. "classification" or "regression".
model_name (str, optional): Custom name for the deployed model.
fit_extra_options (dict, optional): Extra options for fitting.
Attributes:
_model: Underlying database helper for fit/predict/explain.
fit_extra_options: User-provided options for fitting.
"""
def __init__(
self,
db_connection: MySQLConnectionAbstract,
task: Union[str, ML_TASK],
model_name: Optional[str] = None,
fit_extra_options: Optional[dict] = None,
):
"""
Initialize a MyBaseMLModel with connection, task, and option parameters.
Args:
db_connection: Active MySQL connector database connection.
task: String label of ML task (e.g. "classification").
model_name: Optional custom model name.
fit_extra_options: Optional extra fit options.
Raises:
DatabaseError:
If a database connection issue occurs.
If an operational error occurs during execution.
"""
self._model = MyModel(db_connection, task=task, model_name=model_name)
self.fit_extra_options = copy_dict(fit_extra_options)
def fit(
self,
X: pd.DataFrame, # pylint: disable=invalid-name
y: Optional[pd.DataFrame] = None,
) -> "MyBaseMLModel":
"""
Fit the underlying ML model using pandas DataFrames.
Delegates to MyMLModelPandasHelper.fit.
Args:
X: Features DataFrame.
y: (Optional) Target labels DataFrame or Series.
Returns:
self
Raises:
DatabaseError:
If provided options are invalid or unsupported.
If a database connection issue occurs.
If an operational error occurs during execution.
Notes:
Additional temp SQL resources may be created and cleaned up during the operation.
"""
self._model.fit(X, y, self.fit_extra_options)
return self
def _delete_model(self) -> bool:
"""
Deletes the model from the model catalog if present
Raises:
DatabaseError:
If a database connection issue occurs.
If an operational error occurs during execution.
Returns:
Whether the model was deleted
"""
return self._model._delete_model()
def get_model_info(self) -> Optional[dict]:
"""
Checks if the model name is available. Model info will only be present in the
catalog if the model has previously been fitted.
Returns:
True if the model name is not part of the model catalog
Raises:
DatabaseError:
If a database connection issue occurs.
If an operational error occurs during execution.
"""
return self._model.get_model_info()

View File

@@ -0,0 +1,194 @@
# Copyright (c) 2025 Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Classifier utilities for MySQL Connector/Python.
Provides a scikit-learn compatible classifier backed by HeatWave ML.
"""
from typing import Optional, Union
import numpy as np
import pandas as pd
from sklearn.base import ClassifierMixin
from mysql.ai.ml.base import MyBaseMLModel
from mysql.ai.ml.model import ML_TASK
from mysql.ai.utils import copy_dict
from mysql.connector.abstracts import MySQLConnectionAbstract
class MyClassifier(MyBaseMLModel, ClassifierMixin):
"""
MySQL HeatWave scikit-learn compatible classifier estimator.
Provides prediction and probability output from a model deployed in MySQL,
and manages fit, explain, and prediction options as per HeatWave ML interface.
Attributes:
predict_extra_options (dict): Dictionary of optional parameters passed through
to the MySQL backend for prediction and probability inference.
_model (MyModel): Underlying interface for database model operations.
fit_extra_options (dict): See MyBaseMLModel.
Args:
db_connection (MySQLConnectionAbstract): Active MySQL connector DB connection.
model_name (str, optional): Custom name for the model.
fit_extra_options (dict, optional): Extra options for fitting.
explain_extra_options (dict, optional): Extra options for explanations.
predict_extra_options (dict, optional): Extra options for predict/predict_proba.
Methods:
predict(X): Predict class labels.
predict_proba(X): Predict class probabilities.
"""
def __init__(
self,
db_connection: MySQLConnectionAbstract,
model_name: Optional[str] = None,
fit_extra_options: Optional[dict] = None,
explain_extra_options: Optional[dict] = None,
predict_extra_options: Optional[dict] = None,
):
"""
Initialize a MyClassifier.
Args:
db_connection: Active MySQL connector database connection.
model_name: Optional, custom model name.
fit_extra_options: Optional fit options.
explain_extra_options: Optional explain options.
predict_extra_options: Optional predict/predict_proba options.
Raises:
DatabaseError:
If a database connection issue occurs.
If an operational error occurs during execution.
"""
MyBaseMLModel.__init__(
self,
db_connection,
ML_TASK.CLASSIFICATION,
model_name=model_name,
fit_extra_options=fit_extra_options,
)
self.predict_extra_options = copy_dict(predict_extra_options)
self.explain_extra_options = copy_dict(explain_extra_options)
def predict(
self, X: Union[pd.DataFrame, np.ndarray]
) -> np.ndarray: # pylint: disable=invalid-name
"""
Predict class labels for the input features using the MySQL model.
References:
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-predict-table.html
A full list of supported options can be found under "ML_PREDICT_TABLE Options"
Args:
X: Input samples as a numpy array or pandas DataFrame.
Returns:
ndarray: Array of predicted class labels, shape (n_samples,).
Raises:
DatabaseError:
If provided options are invalid or unsupported,
or if the model is not initialized, i.e., fit or import has not
been called
If a database connection issue occurs.
If an operational error occurs during execution.
"""
result = self._model.predict(X, options=self.predict_extra_options)
return result["Prediction"].to_numpy()
def predict_proba(
self, X: Union[pd.DataFrame, np.ndarray]
) -> np.ndarray: # pylint: disable=invalid-name
"""
Predict class probabilities for the input features using the MySQL model.
References:
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-predict-table.html
A full list of supported options can be found under "ML_PREDICT_TABLE Options"
Args:
X: Input samples as a numpy array or pandas DataFrame.
Returns:
ndarray: Array of shape (n_samples, n_classes) with class probabilities.
Raises:
DatabaseError:
If provided options are invalid or unsupported,
or if the model is not initialized, i.e., fit or import has not
been called
If a database connection issue occurs.
If an operational error occurs during execution.
"""
result = self._model.predict(X, options=self.predict_extra_options)
classes = sorted(result["ml_results"].iloc[0]["probabilities"].keys())
return np.stack(
result["ml_results"].map(
lambda ml_result: [
ml_result["probabilities"][class_name] for class_name in classes
]
)
)
def explain_predictions(
self, X: Union[pd.DataFrame, np.ndarray]
) -> pd.DataFrame: # pylint: disable=invalid-name
"""
Explain model predictions using provided data.
References:
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-explain-table.html
A full list of supported options can be found under "ML_EXPLAIN_TABLE Options"
Args:
X: DataFrame for which predictions should be explained.
Returns:
DataFrame containing explanation details (feature attributions, etc.)
Raises:
DatabaseError:
If provided options are invalid or unsupported,
or if the model is not initialized, i.e., fit or import has not
been called
If a database connection issue occurs.
If an operational error occurs during execution.
Notes:
Temporary input/output tables are cleaned up after explanation.
"""
self._model.explain_predictions(X, options=self.explain_extra_options)

View File

@@ -0,0 +1,780 @@
# Copyright (c) 2025 Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""HeatWave ML model utilities for MySQL Connector/Python.
Provides classes to manage training, prediction, scoring, and explanations
via MySQL HeatWave stored procedures.
"""
import copy
import json
from enum import Enum
from typing import Any, Dict, Optional, Union
import numpy as np
import pandas as pd
from mysql.ai.utils import (
VAR_NAME_SPACE,
atomic_transaction,
convert_to_df,
execute_sql,
format_value_sql,
get_random_name,
source_schema,
sql_response_to_df,
sql_table_from_df,
sql_table_to_df,
table_exists,
temporary_sql_tables,
validate_name,
)
from mysql.connector.abstracts import MySQLConnectionAbstract
class ML_TASK(Enum): # pylint: disable=invalid-name
"""Enumeration of supported ML tasks for HeatWave."""
CLASSIFICATION = "classification"
REGRESSION = "regression"
FORECASTING = "forecasting"
ANOMALY_DETECTION = "anomaly_detection"
LOG_ANOMALY_DETECTION = "log_anomaly_detection"
RECOMMENDATION = "recommendation"
TOPIC_MODELING = "topic_modeling"
@staticmethod
def get_task_string(task: Union[str, "ML_TASK"]) -> str:
"""
Return the string representation of a machine learning task.
Args:
task (Union[str, ML_TASK]): The task to convert.
Accepts either a task enum member (ML_TASK) or a string.
Returns:
str: The string value of the ML task.
"""
if isinstance(task, str):
return task
return task.value
class _MyModelCommon:
"""
Common utilities and workflow for MySQL HeatWave ML models.
This class handles model lifecycle steps such as loading, fitting, scoring,
making predictions, and explaining models or predictions. Not intended for
direct instantiation, but as a superclass for heatwave model wrappers.
Attributes:
db_connection: MySQL connector database connection.
task: ML task, e.g., "classification" or "regression".
model_name: Identifier of model in MySQL.
schema_name: Database schema used for operations and temp tables.
"""
def __init__(
self,
db_connection: MySQLConnectionAbstract,
task: Union[str, ML_TASK] = ML_TASK.CLASSIFICATION,
model_name: Optional[str] = None,
):
"""
Instantiate _MyMLModelCommon.
References:
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-train.html
A full list of supported tasks can be found under "Common ML_TRAIN Options"
Args:
db_connection: MySQL database connection.
task: ML task type (default: "classification").
model_name: Name to register the model within MySQL (default: None).
Raises:
ValueError: If the schema name is not valid
DatabaseError:
If a database connection issue occurs.
If an operational error occurs during execution.
Returns:
None
"""
self.db_connection = db_connection
self.task = ML_TASK.get_task_string(task)
self.schema_name = source_schema(db_connection)
with atomic_transaction(self.db_connection) as cursor:
execute_sql(cursor, "CALL sys.ML_CREATE_OR_UPGRADE_CATALOG();")
if model_name is None:
model_name = get_random_name(self._is_model_name_available)
self.model_var = f"{VAR_NAME_SPACE}.{model_name}"
self.model_var_score = f"{self.model_var}.score"
self.model_name = model_name
validate_name(model_name)
with atomic_transaction(self.db_connection) as cursor:
execute_sql(cursor, f"SET @{self.model_var} = %s;", params=(model_name,))
def _delete_model(self) -> bool:
"""
Deletes the model from the model catalog if present
Raises:
DatabaseError:
If a database connection issue occurs.
If an operational error occurs during execution.
Returns:
Whether the model was deleted
"""
current_user = self._get_user()
qualified_model_catalog = f"ML_SCHEMA_{current_user}.MODEL_CATALOG"
delete_model = (
f"DELETE FROM {qualified_model_catalog} "
f"WHERE model_handle = @{self.model_var}"
)
with atomic_transaction(self.db_connection) as cursor:
execute_sql(cursor, delete_model)
return cursor.rowcount > 0
def _get_model_info(self, model_name: str) -> Optional[dict]:
"""
Retrieves the model info from the model_catalog
Args:
model_var: The model alias to retrieve
Returns:
The model info from the model_catalog (None if the model is not present in the catalog)
Raises:
DatabaseError:
If a database connection issue occurs.
If an operational error occurs during execution.
"""
def process_col(elem: Any) -> Any:
if isinstance(elem, str):
try:
elem = json.loads(elem)
except json.JSONDecodeError:
pass
return elem
current_user = self._get_user()
qualified_model_catalog = f"ML_SCHEMA_{current_user}.MODEL_CATALOG"
model_exists = (
f"SELECT * FROM {qualified_model_catalog} WHERE model_handle = %s"
)
with atomic_transaction(self.db_connection) as cursor:
execute_sql(cursor, model_exists, params=(model_name,))
model_info_df = sql_response_to_df(cursor)
if model_info_df.empty:
result = None
else:
unprocessed_result = model_info_df.to_json(orient="records")
unprocessed_result_json = json.loads(unprocessed_result)[0]
result = {
key: process_col(elem)
for key, elem in unprocessed_result_json.items()
}
return result
def get_model_info(self) -> Optional[dict]:
"""
Checks if the model name is available.
Model info is present in the catalog only if the model was previously fitted.
Returns:
True if the model name is not part of the model catalog
Raises:
DatabaseError:
If a database connection issue occurs.
If an operational error occurs during execution.
"""
return self._get_model_info(self.model_name)
def _is_model_name_available(self, model_name: str) -> bool:
"""
Checks if the model name is available
Returns:
True if the model name is not part of the model catalog
Raises:
DatabaseError:
If a database connection issue occurs.
If an operational error occurs during execution.
"""
return self._get_model_info(model_name) is None
def _load_model(self) -> None:
"""
Loads the model specified by `self.model_name` into MySQL.
After loading, the model is ready to handle ML operations.
References:
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-model-load.html
Raises:
DatabaseError:
If the model is not initialized, i.e., fit or import has not been called
If a database connection issue occurs.
If an operational error occurs during execution.
Returns:
None
"""
with atomic_transaction(self.db_connection) as cursor:
load_model_query = f"CALL sys.ML_MODEL_LOAD(@{self.model_var}, NULL);"
execute_sql(cursor, load_model_query)
def _get_user(self) -> str:
"""
Fetch the current database user (without host).
Returns:
The username string associated with the connection.
Raises:
DatabaseError:
If a database connection issue occurs.
If an operational error occurs during execution.
ValueError: If the user name includes unsupported characters
"""
with atomic_transaction(self.db_connection) as cursor:
cursor.execute("SELECT CURRENT_USER()")
current_user = cursor.fetchone()[0].split("@")[0]
return validate_name(current_user)
def explain_model(self) -> dict:
"""
Get model explanations, such as detailed feature importances.
Returns:
dict: Feature importances and model explainability data.
References:
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-model-explanations.html
Raises:
DatabaseError:
If the model is not initialized, i.e., fit or import has not been called
If a database connection issue occurs.
If an operational error occurs during execution.
ValueError:
If the model does not exist in the model catalog.
Should only occur if model was not fitted or was deleted.
"""
self._load_model()
with atomic_transaction(self.db_connection) as cursor:
current_user = self._get_user()
qualified_model_catalog = f"ML_SCHEMA_{current_user}.MODEL_CATALOG"
explain_query = (
f"SELECT model_explanation FROM {qualified_model_catalog} "
f"WHERE model_handle = @{self.model_var}"
)
execute_sql(cursor, explain_query)
df = sql_response_to_df(cursor)
return df.iloc[0, 0]
def _fit(
self,
table_name: str,
target_column_name: Optional[str],
options: Optional[dict],
) -> None:
"""
Fit an ML model using a referenced SQL table and target column.
References:
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-train.html
A full list of supported options can be found under "Common ML_TRAIN Options"
Args:
table_name: Name of the training data table.
target_column_name: Name of the target/label column.
options: Additional fit/config options (may override defaults).
Raises:
DatabaseError:
If provided options are invalid or unsupported.
If a database connection issue occurs.
If an operational error occurs during execution.
ValueError: If the table or target_column name is not valid
Returns:
None
"""
validate_name(table_name)
if target_column_name is not None:
validate_name(target_column_name)
target_col_string = f"'{target_column_name}'"
else:
target_col_string = "NULL"
if options is None:
options = {}
options = copy.deepcopy(options)
options["task"] = self.task
self._delete_model()
with atomic_transaction(self.db_connection) as cursor:
placeholders, parameters = format_value_sql(options)
execute_sql(
cursor,
(
"CALL sys.ML_TRAIN("
f"'{self.schema_name}.{table_name}', "
f"{target_col_string}, "
f"{placeholders}, "
f"@{self.model_var}"
")"
),
params=parameters,
)
def _predict(
self, table_name: str, output_table_name: str, options: Optional[dict]
) -> None:
"""
Predict on a given data table and write results to an output table.
References:
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-predict-table.html
A full list of supported options can be found under "ML_PREDICT_TABLE Options"
Args:
table_name: Name of the SQL table with input data.
output_table_name: Name for the SQL output table to contain predictions.
options: Optional prediction options.
Returns:
None
Raises:
DatabaseError:
If provided options are invalid or unsupported,
or if the model is not initialized, i.e., fit or import has not
been called
If a database connection issue occurs.
If an operational error occurs during execution.
ValueError: If the table or output_table name is not valid
"""
validate_name(table_name)
validate_name(output_table_name)
self._load_model()
with atomic_transaction(self.db_connection) as cursor:
placeholders, parameters = format_value_sql(options)
execute_sql(
cursor,
(
"CALL sys.ML_PREDICT_TABLE("
f"'{self.schema_name}.{table_name}', "
f"@{self.model_var}, "
f"'{self.schema_name}.{output_table_name}', "
f"{placeholders}"
")"
),
params=parameters,
)
def _score(
self,
table_name: str,
target_column_name: str,
metric: str,
options: Optional[dict],
) -> float:
"""
Evaluate model performance with a scoring metric.
References:
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-score.html
A full list of supported options can be found under
"Options for Recommendation Models" and
"Options for Anomaly Detection Models"
Args:
table_name: Table with features and ground truth.
target_column_name: Column of true target labels.
metric: String name of the metric to compute.
options: Optional dictionary of further scoring options.
Returns:
float: Computed score from the ML system.
Raises:
DatabaseError:
If provided options are invalid or unsupported,
or if the model is not initialized, i.e., fit or import has not
been called
If a database connection issue occurs.
If an operational error occurs during execution.
ValueError: If the table or target_column name or metric is not valid
"""
validate_name(table_name)
validate_name(target_column_name)
validate_name(metric)
self._load_model()
with atomic_transaction(self.db_connection) as cursor:
placeholders, parameters = format_value_sql(options)
execute_sql(
cursor,
(
"CALL sys.ML_SCORE("
f"'{self.schema_name}.{table_name}', "
f"'{target_column_name}', "
f"@{self.model_var}, "
"%s, "
f"@{self.model_var_score}, "
f"{placeholders}"
")"
),
params=[metric, *parameters],
)
execute_sql(cursor, f"SELECT @{self.model_var_score}")
df = sql_response_to_df(cursor)
return df.iloc[0, 0]
def _explain_predictions(
self, table_name: str, output_table_name: str, options: Optional[dict]
) -> pd.DataFrame:
"""
Produce explanations for model predictions on provided data.
References:
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-explain-table.html
A full list of supported options can be found under "ML_EXPLAIN_TABLE Options"
Args:
table_name: Name of the SQL table with input data.
output_table_name: Name for the SQL table to store explanations.
options: Optional dictionary (default:
{"prediction_explainer": "permutation_importance"}).
Returns:
DataFrame: Prediction explanations from the output SQL table.
Raises:
DatabaseError:
If provided options are invalid or unsupported,
or if the model is not initialized, i.e., fit or import has not
been called
If a database connection issue occurs.
If an operational error occurs during execution.
ValueError: If the table or output_table name is not valid
"""
validate_name(table_name)
validate_name(output_table_name)
if options is None:
options = {"prediction_explainer": "permutation_importance"}
self._load_model()
with atomic_transaction(self.db_connection) as cursor:
placeholders, parameters = format_value_sql(options)
execute_sql(
cursor,
(
"CALL sys.ML_EXPLAIN_TABLE("
f"'{self.schema_name}.{table_name}', "
f"@{self.model_var}, "
f"'{self.schema_name}.{output_table_name}', "
f"{placeholders}"
")"
),
params=parameters,
)
execute_sql(cursor, f"SELECT * FROM {self.schema_name}.{output_table_name}")
df = sql_response_to_df(cursor)
return df
class MyModel(_MyModelCommon):
"""
Convenience class for managing the ML workflow using pandas DataFrames.
Methods convert in-memory DataFrames into temp SQL tables before delegating to the
_MyMLModelCommon routines, and automatically clean up temp resources.
"""
def fit(
self,
X: Union[pd.DataFrame, np.ndarray], # pylint: disable=invalid-name
y: Optional[Union[pd.DataFrame, np.ndarray]],
options: Optional[dict] = None,
) -> None:
"""
Fit a model using DataFrame inputs.
If an 'id' column is defined in either dataframe, it will be used as the primary key.
References:
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-train.html
A full list of supported options can be found under "Common ML_TRAIN Options"
Args:
X: Features DataFrame.
y: (Optional) Target labels DataFrame or Series. If None, only X is used.
options: Additional options to pass to training.
Returns:
None
Raises:
DatabaseError:
If provided options are invalid or unsupported.
If a database connection issue occurs.
If an operational error occurs during execution.
Notes:
Combines X and y as necessary. Creates a temporary table in the schema for training,
and deletes it afterward.
"""
X, y = convert_to_df(X), convert_to_df(y)
with (
atomic_transaction(self.db_connection) as cursor,
temporary_sql_tables(self.db_connection) as temporary_tables,
):
if y is not None:
if isinstance(y, pd.DataFrame):
# keep column name if it exists
target_column_name = y.columns[0]
else:
target_column_name = get_random_name(
lambda name: name not in X.columns
)
if target_column_name in X.columns:
raise ValueError(
f"Target column y with name {target_column_name} already present "
"in feature dataframe X"
)
df_combined = X.copy()
df_combined[target_column_name] = y
final_df = df_combined
else:
target_column_name = None
final_df = X
_, table_name = sql_table_from_df(cursor, self.schema_name, final_df)
temporary_tables.append((self.schema_name, table_name))
self._fit(table_name, target_column_name, options)
def predict(
self,
X: Union[pd.DataFrame, np.ndarray], # pylint: disable=invalid-name
options: Optional[dict] = None,
) -> pd.DataFrame:
"""
Generate model predictions using DataFrame input.
If an 'id' column is defined in either dataframe, it will be used as the primary key.
References:
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-predict-table.html
A full list of supported options can be found under "ML_PREDICT_TABLE Options"
Args:
X: DataFrame containing prediction features (no labels).
options: Additional prediction settings.
Returns:
DataFrame with prediction results as returned by HeatWave.
Raises:
DatabaseError:
If provided options are invalid or unsupported,
or if the model is not initialized, i.e., fit or import has not
been called
If a database connection issue occurs.
If an operational error occurs during execution.
Notes:
Temporary SQL tables are created and deleted for input/output.
"""
X = convert_to_df(X)
with (
atomic_transaction(self.db_connection) as cursor,
temporary_sql_tables(self.db_connection) as temporary_tables,
):
_, table_name = sql_table_from_df(cursor, self.schema_name, X)
temporary_tables.append((self.schema_name, table_name))
output_table_name = get_random_name(
lambda table_name: not table_exists(
cursor, self.schema_name, table_name
)
)
temporary_tables.append((self.schema_name, output_table_name))
self._predict(table_name, output_table_name, options)
predictions = sql_table_to_df(cursor, self.schema_name, output_table_name)
# ml_results is text but known to always follow JSON format
predictions["ml_results"] = predictions["ml_results"].map(json.loads)
return predictions
def score(
self,
X: Union[pd.DataFrame, np.ndarray], # pylint: disable=invalid-name
y: Union[pd.DataFrame, np.ndarray],
metric: str,
options: Optional[dict] = None,
) -> float:
"""
Score the model using X/y data and a selected metric.
If an 'id' column is defined in either dataframe, it will be used as the primary key.
References:
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-score.html
A full list of supported options can be found under
"Options for Recommendation Models" and
"Options for Anomaly Detection Models"
Args:
X: DataFrame of features.
y: DataFrame or Series of labels.
metric: Metric name (e.g., "balanced_accuracy").
options: Optional ml scoring options.
Raises:
DatabaseError:
If provided options are invalid or unsupported,
or if the model is not initialized, i.e., fit or import has not
been called
If a database connection issue occurs.
If an operational error occurs during execution.
Returns:
float: Computed score.
"""
X, y = convert_to_df(X), convert_to_df(y)
with (
atomic_transaction(self.db_connection) as cursor,
temporary_sql_tables(self.db_connection) as temporary_tables,
):
target_column_name = get_random_name(lambda name: name not in X.columns)
df_combined = X.copy()
df_combined[target_column_name] = y
final_df = df_combined
_, table_name = sql_table_from_df(cursor, self.schema_name, final_df)
temporary_tables.append((self.schema_name, table_name))
score = self._score(table_name, target_column_name, metric, options)
return score
def explain_predictions(
self,
X: Union[pd.DataFrame, np.ndarray], # pylint: disable=invalid-name
options: Dict = None,
) -> pd.DataFrame:
"""
Explain model predictions using provided data.
If an 'id' column is defined in either dataframe, it will be used as the primary key.
References:
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-explain-table.html
A full list of supported options can be found under
"ML_EXPLAIN_TABLE Options"
Args:
X: DataFrame for which predictions should be explained.
options: Optional dictionary of explainability options.
Returns:
DataFrame containing explanation details (feature attributions, etc.)
Raises:
DatabaseError:
If provided options are invalid or unsupported, or if the model is not initialized,
i.e., fit or import has not been called
If a database connection issue occurs.
If an operational error occurs during execution.
Notes:
Temporary input/output tables are cleaned up after explanation.
"""
X = convert_to_df(X)
with (
atomic_transaction(self.db_connection) as cursor,
temporary_sql_tables(self.db_connection) as temporary_tables,
):
_, table_name = sql_table_from_df(cursor, self.schema_name, X)
temporary_tables.append((self.schema_name, table_name))
output_table_name = get_random_name(
lambda table_name: not table_exists(
cursor, self.schema_name, table_name
)
)
temporary_tables.append((self.schema_name, output_table_name))
explanations = self._explain_predictions(
table_name, output_table_name, options
)
return explanations

View File

@@ -0,0 +1,221 @@
# Copyright (c) 2025 Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Outlier/anomaly detection utilities for MySQL Connector/Python.
Provides a scikit-learn compatible wrapper using HeatWave to score anomalies.
"""
from typing import Optional, Union
import numpy as np
import pandas as pd
from sklearn.base import OutlierMixin
from mysql.ai.ml.base import MyBaseMLModel
from mysql.ai.ml.model import ML_TASK
from mysql.ai.utils import copy_dict
from mysql.connector.abstracts import MySQLConnectionAbstract
EPS = 1e-5
def _get_logits(prob: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
"""
Compute logit (logodds) for a probability, clipping to avoid numerical overflow.
Args:
prob: Scalar or array of probability values in (0,1).
Returns:
logit-transformed probabilities.
"""
result = np.clip(prob, EPS, 1 - EPS)
return np.log(result / (1 - result))
class MyAnomalyDetector(MyBaseMLModel, OutlierMixin):
"""
MySQL HeatWave scikit-learn compatible anomaly/outlier detector.
Flags samples as outliers when the probability of being an anomaly
exceeds a user-tunable threshold.
Includes helpers to obtain decision scores and anomaly probabilities
for ranking.
Args:
db_connection (MySQLConnectionAbstract): Active MySQL DB connection.
model_name (str, optional): Custom model name in the database.
fit_extra_options (dict, optional): Extra options for fitting.
score_extra_options (dict, optional): Extra options for scoring/prediction.
Attributes:
boundary: Decision threshold boundary in logit space. Derived from
trained model's catalog info
Methods:
predict(X): Predict outlier/inlier labels.
score_samples(X): Compute anomaly (normal class) logit scores.
decision_function(X): Compute signed score above/below threshold for ranking.
"""
def __init__(
self,
db_connection: MySQLConnectionAbstract,
model_name: Optional[str] = None,
fit_extra_options: Optional[dict] = None,
score_extra_options: Optional[dict] = None,
):
"""
Initialize an anomaly detector instance with threshold and extra options.
Args:
db_connection: Active MySQL DB connection.
model_name: Optional model name in DB.
fit_extra_options: Optional extra fit options.
score_extra_options: Optional extra scoring options.
Raises:
ValueError: If outlier_threshold is not in (0,1).
DatabaseError:
If a database connection issue occurs.
If an operational error occurs during execution.
"""
MyBaseMLModel.__init__(
self,
db_connection,
ML_TASK.ANOMALY_DETECTION,
model_name=model_name,
fit_extra_options=fit_extra_options,
)
self.score_extra_options = copy_dict(score_extra_options)
self.boundary: Optional[float] = None
def predict(
self,
X: Union[pd.DataFrame, np.ndarray], # pylint: disable=invalid-name
) -> np.ndarray:
"""
Predict outlier/inlier binary labels for input samples.
Args:
X: Samples to predict on.
Returns:
ndarray: Values are -1 for outliers, +1 for inliers, as per scikit-learn convention.
Raises:
DatabaseError:
If provided options are invalid or unsupported,
or if the model is not initialized, i.e., fit or import has not
been called
If a database connection issue occurs.
If an operational error occurs during execution.
DatabaseError:
If provided options are invalid or unsupported,
or if the model is not initialized, i.e., fit or import has not
been called
If a database connection issue occurs.
If an operational error occurs during execution.
"""
return np.where(self.decision_function(X) < 0.0, -1, 1)
def decision_function(
self,
X: Union[pd.DataFrame, np.ndarray], # pylint: disable=invalid-name
) -> np.ndarray:
"""
Compute signed distance to the outlier threshold.
Args:
X: Samples to predict on.
Returns:
ndarray: Score > 0 means inlier, < 0 means outlier; |value| gives margin.
Raises:
DatabaseError:
If provided options are invalid or unsupported,
or if the model is not initialized, i.e., fit or import has not
been called
If a database connection issue occurs.
If an operational error occurs during execution.
ValueError:
If the provided model info does not provide threshold
"""
sample_scores = self.score_samples(X)
if self.boundary is None:
model_info = self.get_model_info()
if model_info is None:
raise ValueError("Model does not exist in catalog.")
threshold = model_info["model_metadata"]["training_params"].get(
"anomaly_detection_threshold", None
)
if threshold is None:
raise ValueError(
"Trained model is outdated and does not support threshold. "
"Try retraining or using an existing, trained model with MyModel."
)
# scikit-learn uses large positive values as inlier
# and negative as outlier, so we need to flip our threshold
self.boundary = _get_logits(1.0 - threshold)
return sample_scores - self.boundary
def score_samples(
self,
X: Union[pd.DataFrame, np.ndarray], # pylint: disable=invalid-name
) -> np.ndarray:
"""
Compute normal probability logit score for each sample.
Used for ranking, thresholding.
Args:
X: Samples to score.
Returns:
ndarray: Logit scores based on "normal" class probability.
Raises:
DatabaseError:
If provided options are invalid or unsupported,
or if the model is not initialized, i.e., fit or import has not
been called
If a database connection issue occurs.
If an operational error occurs during execution.
"""
result = self._model.predict(X, options=self.score_extra_options)
return _get_logits(
result["ml_results"]
.apply(lambda x: x["probabilities"]["normal"])
.to_numpy()
)

View File

@@ -0,0 +1,154 @@
# Copyright (c) 2025 Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Regressor utilities for MySQL Connector/Python.
Provides a scikit-learn compatible regressor backed by HeatWave ML.
"""
from typing import Optional, Union
import numpy as np
import pandas as pd
from sklearn.base import RegressorMixin
from mysql.ai.ml.base import MyBaseMLModel
from mysql.ai.ml.model import ML_TASK
from mysql.ai.utils import copy_dict
from mysql.connector.abstracts import MySQLConnectionAbstract
class MyRegressor(MyBaseMLModel, RegressorMixin):
"""
MySQL HeatWave scikit-learn compatible regressor estimator.
Provides prediction output from a regression model deployed in MySQL,
and manages fit, explain, and prediction options as per HeatWave ML interface.
Attributes:
predict_extra_options (dict): Optional parameter dict passed to the backend for prediction.
_model (MyModel): Underlying interface for database model operations.
fit_extra_options (dict): See MyBaseMLModel.
explain_extra_options (dict): See MyBaseMLModel.
Args:
db_connection (MySQLConnectionAbstract): Active MySQL connector DB connection.
model_name (str, optional): Custom name for the model.
fit_extra_options (dict, optional): Extra options for fitting.
explain_extra_options (dict, optional): Extra options for explanations.
predict_extra_options (dict, optional): Extra options for predictions.
Methods:
predict(X): Predict regression target.
"""
def __init__(
self,
db_connection: MySQLConnectionAbstract,
model_name: Optional[str] = None,
fit_extra_options: Optional[dict] = None,
explain_extra_options: Optional[dict] = None,
predict_extra_options: Optional[dict] = None,
):
"""
Initialize a MyRegressor.
Args:
db_connection: Active MySQL connector database connection.
model_name: Optional, custom model name.
fit_extra_options: Optional fit options.
explain_extra_options: Optional explain options.
predict_extra_options: Optional prediction options.
Raises:
DatabaseError:
If a database connection issue occurs.
If an operational error occurs during execution.
"""
MyBaseMLModel.__init__(
self,
db_connection,
ML_TASK.REGRESSION,
model_name=model_name,
fit_extra_options=fit_extra_options,
)
self.predict_extra_options = copy_dict(predict_extra_options)
self.explain_extra_options = copy_dict(explain_extra_options)
def predict(
self, X: Union[pd.DataFrame, np.ndarray]
) -> np.ndarray: # pylint: disable=invalid-name
"""
Predict a continuous target for the input features using the MySQL model.
Args:
X: Input samples as a numpy array or pandas DataFrame.
Returns:
ndarray: Array of predicted target values, shape (n_samples,).
Raises:
DatabaseError:
If provided options are invalid or unsupported,
or if the model is not initialized, i.e., fit or import has not
been called
If a database connection issue occurs.
If an operational error occurs during execution.
"""
result = self._model.predict(X, options=self.predict_extra_options)
return result["Prediction"].to_numpy()
def explain_predictions(
self, X: Union[pd.DataFrame, np.ndarray]
) -> pd.DataFrame: # pylint: disable=invalid-name
"""
Explain model predictions using provided data.
References:
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-explain-table.html
A full list of supported options can be found under "ML_EXPLAIN_TABLE Options"
Args:
X: DataFrame for which predictions should be explained.
Returns:
DataFrame containing explanation details (feature attributions, etc.)
Raises:
DatabaseError:
If provided options are invalid or unsupported,
or if the model is not initialized, i.e., fit or import has not
been called
If a database connection issue occurs.
If an operational error occurs during execution.
Notes:
Temporary input/output tables are cleaned up after explanation.
"""
self._model.explain_predictions(X, options=self.explain_extra_options)

View File

@@ -0,0 +1,164 @@
# Copyright (c) 2025 Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Generic transformer utilities for MySQL Connector/Python.
Provides a scikit-learn compatible Transformer using HeatWave for fit/transform
and scoring operations.
"""
from typing import Optional, Union
import numpy as np
import pandas as pd
from sklearn.base import TransformerMixin
from mysql.ai.ml.base import MyBaseMLModel
from mysql.ai.ml.model import ML_TASK
from mysql.ai.utils import copy_dict
from mysql.connector.abstracts import MySQLConnectionAbstract
class MyGenericTransformer(MyBaseMLModel, TransformerMixin):
"""
MySQL HeatWave scikit-learn compatible generic transformer.
Can be used as the transformation step in an sklearn pipeline. Implements fit, transform,
explain, and scoring capability, passing options for server-side transform logic.
Args:
db_connection (MySQLConnectionAbstract): Active MySQL connector database connection.
task (str): ML task type for transformer (default: "classification").
score_metric (str): Scoring metric to request from backend (default: "balanced_accuracy").
model_name (str, optional): Custom name for the deployed model.
fit_extra_options (dict, optional): Extra fit options.
transform_extra_options (dict, optional): Extra options for transformations.
score_extra_options (dict, optional): Extra options for scoring.
Attributes:
score_metric (str): Name of the backend metric to use for scoring
(e.g. "balanced_accuracy").
score_extra_options (dict): Dictionary of optional scoring parameters;
passed to backend score.
transform_extra_options (dict): Dictionary of inference (/predict)
parameters for the backend.
fit_extra_options (dict): See MyBaseMLModel.
_model (MyModel): Underlying interface for database model operations.
Methods:
fit(X, y): Fit the underlying model using the provided features/targets.
transform(X): Transform features using the backend model.
score(X, y): Score data using backend metric and options.
"""
def __init__(
self,
db_connection: MySQLConnectionAbstract,
task: Union[str, ML_TASK] = ML_TASK.CLASSIFICATION,
score_metric: str = "balanced_accuracy",
model_name: Optional[str] = None,
fit_extra_options: Optional[dict] = None,
transform_extra_options: Optional[dict] = None,
score_extra_options: Optional[dict] = None,
):
"""
Initialize transformer with required and optional arguments.
Args:
db_connection: Active MySQL backend database connection.
task: ML task type for transformer.
score_metric: Requested backend scoring metric.
model_name: Optional model name for storage.
fit_extra_options: Optional extra options for fitting.
transform_extra_options: Optional extra options for transformation/inference.
score_extra_options: Optional extra scoring options.
Raises:
DatabaseError:
If a database connection issue occurs.
If an operational error occurs during execution.
"""
MyBaseMLModel.__init__(
self,
db_connection,
task,
model_name=model_name,
fit_extra_options=fit_extra_options,
)
self.score_metric = score_metric
self.score_extra_options = copy_dict(score_extra_options)
self.transform_extra_options = copy_dict(transform_extra_options)
def transform(
self, X: pd.DataFrame
) -> pd.DataFrame: # pylint: disable=invalid-name
"""
Transform input data to model predictions using the underlying helper.
Args:
X: DataFrame of features to predict/transform.
Returns:
pd.DataFrame: Results of transformation as returned by backend.
Raises:
DatabaseError:
If provided options are invalid or unsupported,
or if the model is not initialized, i.e., fit or import has not
been called
If a database connection issue occurs.
If an operational error occurs during execution.
"""
return self._model.predict(X, options=self.transform_extra_options)
def score(
self,
X: Union[pd.DataFrame, np.ndarray], # pylint: disable=invalid-name
y: Union[pd.DataFrame, np.ndarray],
) -> float:
"""
Score the transformed data using the backend scoring interface.
Args:
X: Transformed features.
y: Target labels or data for scoring.
Returns:
float: Score based on backend metric.
Raises:
DatabaseError:
If provided options are invalid or unsupported,
or if the model is not initialized, i.e., fit or import has not
been called
If a database connection issue occurs.
If an operational error occurs during execution.
"""
return self._model.score(
X, y, self.score_metric, options=self.score_extra_options
)