Añadiendo todos los archivos del proyecto (incluidos secretos y venv)
This commit is contained in:
48
venv/lib/python3.12/site-packages/mysql/ai/ml/__init__.py
Normal file
48
venv/lib/python3.12/site-packages/mysql/ai/ml/__init__.py
Normal file
@@ -0,0 +1,48 @@
|
||||
# Copyright (c) 2025 Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""ML package for MySQL Connector/Python.
|
||||
|
||||
Performs optional dependency checks and exposes ML utilities:
|
||||
- ML_TASK, MyModel
|
||||
- MyClassifier, MyRegressor, MyGenericTransformer
|
||||
- MyAnomalyDetector
|
||||
"""
|
||||
from mysql.ai.utils import check_dependencies as _check_dependencies
|
||||
|
||||
_check_dependencies(["ML"])
|
||||
del _check_dependencies
|
||||
|
||||
# Sklearn models
|
||||
from .classifier import MyClassifier
|
||||
|
||||
# Minimal interface
|
||||
from .model import ML_TASK, MyModel
|
||||
from .outlier import MyAnomalyDetector
|
||||
from .regressor import MyRegressor
|
||||
from .transformer import MyGenericTransformer
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
142
venv/lib/python3.12/site-packages/mysql/ai/ml/base.py
Normal file
142
venv/lib/python3.12/site-packages/mysql/ai/ml/base.py
Normal file
@@ -0,0 +1,142 @@
|
||||
# Copyright (c) 2025 Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""Base classes for MySQL HeatWave ML estimators for Connector/Python.
|
||||
|
||||
Implements a scikit-learn-compatible base estimator wrapping server-side ML.
|
||||
"""
|
||||
from typing import Optional, Union
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from sklearn.base import BaseEstimator
|
||||
|
||||
from mysql.connector.abstracts import MySQLConnectionAbstract
|
||||
from mysql.ai.ml.model import ML_TASK, MyModel
|
||||
from mysql.ai.utils import copy_dict
|
||||
|
||||
|
||||
class MyBaseMLModel(BaseEstimator):
|
||||
"""
|
||||
Base class for MySQL HeatWave machine learning estimators.
|
||||
|
||||
Implements the scikit-learn API and core model management logic,
|
||||
including fit, explain, serialization, and dynamic option handling.
|
||||
For use as a base class by classifiers, regressors, transformers, and outlier models.
|
||||
|
||||
Args:
|
||||
db_connection (MySQLConnectionAbstract): An active MySQL connector database connection.
|
||||
task (str): ML task type, e.g. "classification" or "regression".
|
||||
model_name (str, optional): Custom name for the deployed model.
|
||||
fit_extra_options (dict, optional): Extra options for fitting.
|
||||
|
||||
Attributes:
|
||||
_model: Underlying database helper for fit/predict/explain.
|
||||
fit_extra_options: User-provided options for fitting.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_connection: MySQLConnectionAbstract,
|
||||
task: Union[str, ML_TASK],
|
||||
model_name: Optional[str] = None,
|
||||
fit_extra_options: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Initialize a MyBaseMLModel with connection, task, and option parameters.
|
||||
|
||||
Args:
|
||||
db_connection: Active MySQL connector database connection.
|
||||
task: String label of ML task (e.g. "classification").
|
||||
model_name: Optional custom model name.
|
||||
fit_extra_options: Optional extra fit options.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
self._model = MyModel(db_connection, task=task, model_name=model_name)
|
||||
self.fit_extra_options = copy_dict(fit_extra_options)
|
||||
|
||||
def fit(
|
||||
self,
|
||||
X: pd.DataFrame, # pylint: disable=invalid-name
|
||||
y: Optional[pd.DataFrame] = None,
|
||||
) -> "MyBaseMLModel":
|
||||
"""
|
||||
Fit the underlying ML model using pandas DataFrames.
|
||||
Delegates to MyMLModelPandasHelper.fit.
|
||||
|
||||
Args:
|
||||
X: Features DataFrame.
|
||||
y: (Optional) Target labels DataFrame or Series.
|
||||
|
||||
Returns:
|
||||
self
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported.
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Notes:
|
||||
Additional temp SQL resources may be created and cleaned up during the operation.
|
||||
"""
|
||||
self._model.fit(X, y, self.fit_extra_options)
|
||||
return self
|
||||
|
||||
def _delete_model(self) -> bool:
|
||||
"""
|
||||
Deletes the model from the model catalog if present
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Returns:
|
||||
Whether the model was deleted
|
||||
"""
|
||||
return self._model._delete_model()
|
||||
|
||||
def get_model_info(self) -> Optional[dict]:
|
||||
"""
|
||||
Checks if the model name is available. Model info will only be present in the
|
||||
catalog if the model has previously been fitted.
|
||||
|
||||
Returns:
|
||||
True if the model name is not part of the model catalog
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
return self._model.get_model_info()
|
||||
194
venv/lib/python3.12/site-packages/mysql/ai/ml/classifier.py
Normal file
194
venv/lib/python3.12/site-packages/mysql/ai/ml/classifier.py
Normal file
@@ -0,0 +1,194 @@
|
||||
# Copyright (c) 2025 Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""Classifier utilities for MySQL Connector/Python.
|
||||
|
||||
Provides a scikit-learn compatible classifier backed by HeatWave ML.
|
||||
"""
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.base import ClassifierMixin
|
||||
|
||||
from mysql.ai.ml.base import MyBaseMLModel
|
||||
from mysql.ai.ml.model import ML_TASK
|
||||
from mysql.ai.utils import copy_dict
|
||||
|
||||
from mysql.connector.abstracts import MySQLConnectionAbstract
|
||||
|
||||
|
||||
class MyClassifier(MyBaseMLModel, ClassifierMixin):
|
||||
"""
|
||||
MySQL HeatWave scikit-learn compatible classifier estimator.
|
||||
|
||||
Provides prediction and probability output from a model deployed in MySQL,
|
||||
and manages fit, explain, and prediction options as per HeatWave ML interface.
|
||||
|
||||
Attributes:
|
||||
predict_extra_options (dict): Dictionary of optional parameters passed through
|
||||
to the MySQL backend for prediction and probability inference.
|
||||
_model (MyModel): Underlying interface for database model operations.
|
||||
fit_extra_options (dict): See MyBaseMLModel.
|
||||
|
||||
Args:
|
||||
db_connection (MySQLConnectionAbstract): Active MySQL connector DB connection.
|
||||
model_name (str, optional): Custom name for the model.
|
||||
fit_extra_options (dict, optional): Extra options for fitting.
|
||||
explain_extra_options (dict, optional): Extra options for explanations.
|
||||
predict_extra_options (dict, optional): Extra options for predict/predict_proba.
|
||||
|
||||
Methods:
|
||||
predict(X): Predict class labels.
|
||||
predict_proba(X): Predict class probabilities.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_connection: MySQLConnectionAbstract,
|
||||
model_name: Optional[str] = None,
|
||||
fit_extra_options: Optional[dict] = None,
|
||||
explain_extra_options: Optional[dict] = None,
|
||||
predict_extra_options: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Initialize a MyClassifier.
|
||||
|
||||
Args:
|
||||
db_connection: Active MySQL connector database connection.
|
||||
model_name: Optional, custom model name.
|
||||
fit_extra_options: Optional fit options.
|
||||
explain_extra_options: Optional explain options.
|
||||
predict_extra_options: Optional predict/predict_proba options.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
MyBaseMLModel.__init__(
|
||||
self,
|
||||
db_connection,
|
||||
ML_TASK.CLASSIFICATION,
|
||||
model_name=model_name,
|
||||
fit_extra_options=fit_extra_options,
|
||||
)
|
||||
self.predict_extra_options = copy_dict(predict_extra_options)
|
||||
self.explain_extra_options = copy_dict(explain_extra_options)
|
||||
|
||||
def predict(
|
||||
self, X: Union[pd.DataFrame, np.ndarray]
|
||||
) -> np.ndarray: # pylint: disable=invalid-name
|
||||
"""
|
||||
Predict class labels for the input features using the MySQL model.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-predict-table.html
|
||||
A full list of supported options can be found under "ML_PREDICT_TABLE Options"
|
||||
|
||||
Args:
|
||||
X: Input samples as a numpy array or pandas DataFrame.
|
||||
|
||||
Returns:
|
||||
ndarray: Array of predicted class labels, shape (n_samples,).
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
result = self._model.predict(X, options=self.predict_extra_options)
|
||||
return result["Prediction"].to_numpy()
|
||||
|
||||
def predict_proba(
|
||||
self, X: Union[pd.DataFrame, np.ndarray]
|
||||
) -> np.ndarray: # pylint: disable=invalid-name
|
||||
"""
|
||||
Predict class probabilities for the input features using the MySQL model.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-predict-table.html
|
||||
A full list of supported options can be found under "ML_PREDICT_TABLE Options"
|
||||
|
||||
Args:
|
||||
X: Input samples as a numpy array or pandas DataFrame.
|
||||
|
||||
Returns:
|
||||
ndarray: Array of shape (n_samples, n_classes) with class probabilities.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
result = self._model.predict(X, options=self.predict_extra_options)
|
||||
|
||||
classes = sorted(result["ml_results"].iloc[0]["probabilities"].keys())
|
||||
|
||||
return np.stack(
|
||||
result["ml_results"].map(
|
||||
lambda ml_result: [
|
||||
ml_result["probabilities"][class_name] for class_name in classes
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
def explain_predictions(
|
||||
self, X: Union[pd.DataFrame, np.ndarray]
|
||||
) -> pd.DataFrame: # pylint: disable=invalid-name
|
||||
"""
|
||||
Explain model predictions using provided data.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-explain-table.html
|
||||
A full list of supported options can be found under "ML_EXPLAIN_TABLE Options"
|
||||
|
||||
Args:
|
||||
X: DataFrame for which predictions should be explained.
|
||||
|
||||
Returns:
|
||||
DataFrame containing explanation details (feature attributions, etc.)
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Notes:
|
||||
Temporary input/output tables are cleaned up after explanation.
|
||||
"""
|
||||
self._model.explain_predictions(X, options=self.explain_extra_options)
|
||||
780
venv/lib/python3.12/site-packages/mysql/ai/ml/model.py
Normal file
780
venv/lib/python3.12/site-packages/mysql/ai/ml/model.py
Normal file
@@ -0,0 +1,780 @@
|
||||
# Copyright (c) 2025 Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
"""HeatWave ML model utilities for MySQL Connector/Python.
|
||||
|
||||
Provides classes to manage training, prediction, scoring, and explanations
|
||||
via MySQL HeatWave stored procedures.
|
||||
"""
|
||||
import copy
|
||||
import json
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from mysql.ai.utils import (
|
||||
VAR_NAME_SPACE,
|
||||
atomic_transaction,
|
||||
convert_to_df,
|
||||
execute_sql,
|
||||
format_value_sql,
|
||||
get_random_name,
|
||||
source_schema,
|
||||
sql_response_to_df,
|
||||
sql_table_from_df,
|
||||
sql_table_to_df,
|
||||
table_exists,
|
||||
temporary_sql_tables,
|
||||
validate_name,
|
||||
)
|
||||
from mysql.connector.abstracts import MySQLConnectionAbstract
|
||||
|
||||
|
||||
class ML_TASK(Enum): # pylint: disable=invalid-name
|
||||
"""Enumeration of supported ML tasks for HeatWave."""
|
||||
|
||||
CLASSIFICATION = "classification"
|
||||
REGRESSION = "regression"
|
||||
FORECASTING = "forecasting"
|
||||
ANOMALY_DETECTION = "anomaly_detection"
|
||||
LOG_ANOMALY_DETECTION = "log_anomaly_detection"
|
||||
RECOMMENDATION = "recommendation"
|
||||
TOPIC_MODELING = "topic_modeling"
|
||||
|
||||
@staticmethod
|
||||
def get_task_string(task: Union[str, "ML_TASK"]) -> str:
|
||||
"""
|
||||
Return the string representation of a machine learning task.
|
||||
|
||||
Args:
|
||||
task (Union[str, ML_TASK]): The task to convert.
|
||||
Accepts either a task enum member (ML_TASK) or a string.
|
||||
|
||||
Returns:
|
||||
str: The string value of the ML task.
|
||||
"""
|
||||
|
||||
if isinstance(task, str):
|
||||
return task
|
||||
|
||||
return task.value
|
||||
|
||||
|
||||
class _MyModelCommon:
|
||||
"""
|
||||
Common utilities and workflow for MySQL HeatWave ML models.
|
||||
|
||||
This class handles model lifecycle steps such as loading, fitting, scoring,
|
||||
making predictions, and explaining models or predictions. Not intended for
|
||||
direct instantiation, but as a superclass for heatwave model wrappers.
|
||||
|
||||
Attributes:
|
||||
db_connection: MySQL connector database connection.
|
||||
task: ML task, e.g., "classification" or "regression".
|
||||
model_name: Identifier of model in MySQL.
|
||||
schema_name: Database schema used for operations and temp tables.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_connection: MySQLConnectionAbstract,
|
||||
task: Union[str, ML_TASK] = ML_TASK.CLASSIFICATION,
|
||||
model_name: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Instantiate _MyMLModelCommon.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-train.html
|
||||
A full list of supported tasks can be found under "Common ML_TRAIN Options"
|
||||
|
||||
Args:
|
||||
db_connection: MySQL database connection.
|
||||
task: ML task type (default: "classification").
|
||||
model_name: Name to register the model within MySQL (default: None).
|
||||
|
||||
Raises:
|
||||
ValueError: If the schema name is not valid
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self.db_connection = db_connection
|
||||
self.task = ML_TASK.get_task_string(task)
|
||||
self.schema_name = source_schema(db_connection)
|
||||
|
||||
with atomic_transaction(self.db_connection) as cursor:
|
||||
execute_sql(cursor, "CALL sys.ML_CREATE_OR_UPGRADE_CATALOG();")
|
||||
|
||||
if model_name is None:
|
||||
model_name = get_random_name(self._is_model_name_available)
|
||||
|
||||
self.model_var = f"{VAR_NAME_SPACE}.{model_name}"
|
||||
self.model_var_score = f"{self.model_var}.score"
|
||||
|
||||
self.model_name = model_name
|
||||
validate_name(model_name)
|
||||
|
||||
with atomic_transaction(self.db_connection) as cursor:
|
||||
execute_sql(cursor, f"SET @{self.model_var} = %s;", params=(model_name,))
|
||||
|
||||
def _delete_model(self) -> bool:
|
||||
"""
|
||||
Deletes the model from the model catalog if present
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Returns:
|
||||
Whether the model was deleted
|
||||
"""
|
||||
current_user = self._get_user()
|
||||
|
||||
qualified_model_catalog = f"ML_SCHEMA_{current_user}.MODEL_CATALOG"
|
||||
delete_model = (
|
||||
f"DELETE FROM {qualified_model_catalog} "
|
||||
f"WHERE model_handle = @{self.model_var}"
|
||||
)
|
||||
|
||||
with atomic_transaction(self.db_connection) as cursor:
|
||||
execute_sql(cursor, delete_model)
|
||||
return cursor.rowcount > 0
|
||||
|
||||
def _get_model_info(self, model_name: str) -> Optional[dict]:
|
||||
"""
|
||||
Retrieves the model info from the model_catalog
|
||||
|
||||
Args:
|
||||
model_var: The model alias to retrieve
|
||||
|
||||
Returns:
|
||||
The model info from the model_catalog (None if the model is not present in the catalog)
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
|
||||
def process_col(elem: Any) -> Any:
|
||||
if isinstance(elem, str):
|
||||
try:
|
||||
elem = json.loads(elem)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return elem
|
||||
|
||||
current_user = self._get_user()
|
||||
|
||||
qualified_model_catalog = f"ML_SCHEMA_{current_user}.MODEL_CATALOG"
|
||||
model_exists = (
|
||||
f"SELECT * FROM {qualified_model_catalog} WHERE model_handle = %s"
|
||||
)
|
||||
|
||||
with atomic_transaction(self.db_connection) as cursor:
|
||||
execute_sql(cursor, model_exists, params=(model_name,))
|
||||
model_info_df = sql_response_to_df(cursor)
|
||||
|
||||
if model_info_df.empty:
|
||||
result = None
|
||||
else:
|
||||
unprocessed_result = model_info_df.to_json(orient="records")
|
||||
unprocessed_result_json = json.loads(unprocessed_result)[0]
|
||||
result = {
|
||||
key: process_col(elem)
|
||||
for key, elem in unprocessed_result_json.items()
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def get_model_info(self) -> Optional[dict]:
|
||||
"""
|
||||
Checks if the model name is available.
|
||||
Model info is present in the catalog only if the model was previously fitted.
|
||||
|
||||
Returns:
|
||||
True if the model name is not part of the model catalog
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
return self._get_model_info(self.model_name)
|
||||
|
||||
def _is_model_name_available(self, model_name: str) -> bool:
|
||||
"""
|
||||
Checks if the model name is available
|
||||
|
||||
Returns:
|
||||
True if the model name is not part of the model catalog
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
return self._get_model_info(model_name) is None
|
||||
|
||||
def _load_model(self) -> None:
|
||||
"""
|
||||
Loads the model specified by `self.model_name` into MySQL.
|
||||
After loading, the model is ready to handle ML operations.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-model-load.html
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If the model is not initialized, i.e., fit or import has not been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
with atomic_transaction(self.db_connection) as cursor:
|
||||
load_model_query = f"CALL sys.ML_MODEL_LOAD(@{self.model_var}, NULL);"
|
||||
execute_sql(cursor, load_model_query)
|
||||
|
||||
def _get_user(self) -> str:
|
||||
"""
|
||||
Fetch the current database user (without host).
|
||||
|
||||
Returns:
|
||||
The username string associated with the connection.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
ValueError: If the user name includes unsupported characters
|
||||
"""
|
||||
with atomic_transaction(self.db_connection) as cursor:
|
||||
cursor.execute("SELECT CURRENT_USER()")
|
||||
current_user = cursor.fetchone()[0].split("@")[0]
|
||||
|
||||
return validate_name(current_user)
|
||||
|
||||
def explain_model(self) -> dict:
|
||||
"""
|
||||
Get model explanations, such as detailed feature importances.
|
||||
|
||||
Returns:
|
||||
dict: Feature importances and model explainability data.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-model-explanations.html
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If the model is not initialized, i.e., fit or import has not been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
ValueError:
|
||||
If the model does not exist in the model catalog.
|
||||
Should only occur if model was not fitted or was deleted.
|
||||
"""
|
||||
self._load_model()
|
||||
with atomic_transaction(self.db_connection) as cursor:
|
||||
current_user = self._get_user()
|
||||
|
||||
qualified_model_catalog = f"ML_SCHEMA_{current_user}.MODEL_CATALOG"
|
||||
explain_query = (
|
||||
f"SELECT model_explanation FROM {qualified_model_catalog} "
|
||||
f"WHERE model_handle = @{self.model_var}"
|
||||
)
|
||||
|
||||
execute_sql(cursor, explain_query)
|
||||
df = sql_response_to_df(cursor)
|
||||
|
||||
return df.iloc[0, 0]
|
||||
|
||||
def _fit(
|
||||
self,
|
||||
table_name: str,
|
||||
target_column_name: Optional[str],
|
||||
options: Optional[dict],
|
||||
) -> None:
|
||||
"""
|
||||
Fit an ML model using a referenced SQL table and target column.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-train.html
|
||||
A full list of supported options can be found under "Common ML_TRAIN Options"
|
||||
|
||||
Args:
|
||||
table_name: Name of the training data table.
|
||||
target_column_name: Name of the target/label column.
|
||||
options: Additional fit/config options (may override defaults).
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported.
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
ValueError: If the table or target_column name is not valid
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
validate_name(table_name)
|
||||
if target_column_name is not None:
|
||||
validate_name(target_column_name)
|
||||
target_col_string = f"'{target_column_name}'"
|
||||
else:
|
||||
target_col_string = "NULL"
|
||||
|
||||
if options is None:
|
||||
options = {}
|
||||
options = copy.deepcopy(options)
|
||||
options["task"] = self.task
|
||||
|
||||
self._delete_model()
|
||||
|
||||
with atomic_transaction(self.db_connection) as cursor:
|
||||
placeholders, parameters = format_value_sql(options)
|
||||
execute_sql(
|
||||
cursor,
|
||||
(
|
||||
"CALL sys.ML_TRAIN("
|
||||
f"'{self.schema_name}.{table_name}', "
|
||||
f"{target_col_string}, "
|
||||
f"{placeholders}, "
|
||||
f"@{self.model_var}"
|
||||
")"
|
||||
),
|
||||
params=parameters,
|
||||
)
|
||||
|
||||
def _predict(
|
||||
self, table_name: str, output_table_name: str, options: Optional[dict]
|
||||
) -> None:
|
||||
"""
|
||||
Predict on a given data table and write results to an output table.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-predict-table.html
|
||||
A full list of supported options can be found under "ML_PREDICT_TABLE Options"
|
||||
|
||||
Args:
|
||||
table_name: Name of the SQL table with input data.
|
||||
output_table_name: Name for the SQL output table to contain predictions.
|
||||
options: Optional prediction options.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
ValueError: If the table or output_table name is not valid
|
||||
"""
|
||||
validate_name(table_name)
|
||||
validate_name(output_table_name)
|
||||
|
||||
self._load_model()
|
||||
with atomic_transaction(self.db_connection) as cursor:
|
||||
placeholders, parameters = format_value_sql(options)
|
||||
execute_sql(
|
||||
cursor,
|
||||
(
|
||||
"CALL sys.ML_PREDICT_TABLE("
|
||||
f"'{self.schema_name}.{table_name}', "
|
||||
f"@{self.model_var}, "
|
||||
f"'{self.schema_name}.{output_table_name}', "
|
||||
f"{placeholders}"
|
||||
")"
|
||||
),
|
||||
params=parameters,
|
||||
)
|
||||
|
||||
def _score(
|
||||
self,
|
||||
table_name: str,
|
||||
target_column_name: str,
|
||||
metric: str,
|
||||
options: Optional[dict],
|
||||
) -> float:
|
||||
"""
|
||||
Evaluate model performance with a scoring metric.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-score.html
|
||||
A full list of supported options can be found under
|
||||
"Options for Recommendation Models" and
|
||||
"Options for Anomaly Detection Models"
|
||||
|
||||
Args:
|
||||
table_name: Table with features and ground truth.
|
||||
target_column_name: Column of true target labels.
|
||||
metric: String name of the metric to compute.
|
||||
options: Optional dictionary of further scoring options.
|
||||
|
||||
Returns:
|
||||
float: Computed score from the ML system.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
ValueError: If the table or target_column name or metric is not valid
|
||||
"""
|
||||
validate_name(table_name)
|
||||
validate_name(target_column_name)
|
||||
validate_name(metric)
|
||||
|
||||
self._load_model()
|
||||
with atomic_transaction(self.db_connection) as cursor:
|
||||
placeholders, parameters = format_value_sql(options)
|
||||
execute_sql(
|
||||
cursor,
|
||||
(
|
||||
"CALL sys.ML_SCORE("
|
||||
f"'{self.schema_name}.{table_name}', "
|
||||
f"'{target_column_name}', "
|
||||
f"@{self.model_var}, "
|
||||
"%s, "
|
||||
f"@{self.model_var_score}, "
|
||||
f"{placeholders}"
|
||||
")"
|
||||
),
|
||||
params=[metric, *parameters],
|
||||
)
|
||||
execute_sql(cursor, f"SELECT @{self.model_var_score}")
|
||||
df = sql_response_to_df(cursor)
|
||||
|
||||
return df.iloc[0, 0]
|
||||
|
||||
def _explain_predictions(
|
||||
self, table_name: str, output_table_name: str, options: Optional[dict]
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Produce explanations for model predictions on provided data.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-explain-table.html
|
||||
A full list of supported options can be found under "ML_EXPLAIN_TABLE Options"
|
||||
|
||||
Args:
|
||||
table_name: Name of the SQL table with input data.
|
||||
output_table_name: Name for the SQL table to store explanations.
|
||||
options: Optional dictionary (default:
|
||||
{"prediction_explainer": "permutation_importance"}).
|
||||
|
||||
Returns:
|
||||
DataFrame: Prediction explanations from the output SQL table.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
ValueError: If the table or output_table name is not valid
|
||||
"""
|
||||
validate_name(table_name)
|
||||
validate_name(output_table_name)
|
||||
|
||||
if options is None:
|
||||
options = {"prediction_explainer": "permutation_importance"}
|
||||
|
||||
self._load_model()
|
||||
|
||||
with atomic_transaction(self.db_connection) as cursor:
|
||||
placeholders, parameters = format_value_sql(options)
|
||||
execute_sql(
|
||||
cursor,
|
||||
(
|
||||
"CALL sys.ML_EXPLAIN_TABLE("
|
||||
f"'{self.schema_name}.{table_name}', "
|
||||
f"@{self.model_var}, "
|
||||
f"'{self.schema_name}.{output_table_name}', "
|
||||
f"{placeholders}"
|
||||
")"
|
||||
),
|
||||
params=parameters,
|
||||
)
|
||||
execute_sql(cursor, f"SELECT * FROM {self.schema_name}.{output_table_name}")
|
||||
df = sql_response_to_df(cursor)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
class MyModel(_MyModelCommon):
|
||||
"""
|
||||
Convenience class for managing the ML workflow using pandas DataFrames.
|
||||
|
||||
Methods convert in-memory DataFrames into temp SQL tables before delegating to the
|
||||
_MyMLModelCommon routines, and automatically clean up temp resources.
|
||||
"""
|
||||
|
||||
def fit(
|
||||
self,
|
||||
X: Union[pd.DataFrame, np.ndarray], # pylint: disable=invalid-name
|
||||
y: Optional[Union[pd.DataFrame, np.ndarray]],
|
||||
options: Optional[dict] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Fit a model using DataFrame inputs.
|
||||
|
||||
If an 'id' column is defined in either dataframe, it will be used as the primary key.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-train.html
|
||||
A full list of supported options can be found under "Common ML_TRAIN Options"
|
||||
|
||||
Args:
|
||||
X: Features DataFrame.
|
||||
y: (Optional) Target labels DataFrame or Series. If None, only X is used.
|
||||
options: Additional options to pass to training.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported.
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Notes:
|
||||
Combines X and y as necessary. Creates a temporary table in the schema for training,
|
||||
and deletes it afterward.
|
||||
"""
|
||||
X, y = convert_to_df(X), convert_to_df(y)
|
||||
|
||||
with (
|
||||
atomic_transaction(self.db_connection) as cursor,
|
||||
temporary_sql_tables(self.db_connection) as temporary_tables,
|
||||
):
|
||||
if y is not None:
|
||||
if isinstance(y, pd.DataFrame):
|
||||
# keep column name if it exists
|
||||
target_column_name = y.columns[0]
|
||||
else:
|
||||
target_column_name = get_random_name(
|
||||
lambda name: name not in X.columns
|
||||
)
|
||||
|
||||
if target_column_name in X.columns:
|
||||
raise ValueError(
|
||||
f"Target column y with name {target_column_name} already present "
|
||||
"in feature dataframe X"
|
||||
)
|
||||
|
||||
df_combined = X.copy()
|
||||
df_combined[target_column_name] = y
|
||||
final_df = df_combined
|
||||
else:
|
||||
target_column_name = None
|
||||
final_df = X
|
||||
|
||||
_, table_name = sql_table_from_df(cursor, self.schema_name, final_df)
|
||||
temporary_tables.append((self.schema_name, table_name))
|
||||
|
||||
self._fit(table_name, target_column_name, options)
|
||||
|
||||
def predict(
|
||||
self,
|
||||
X: Union[pd.DataFrame, np.ndarray], # pylint: disable=invalid-name
|
||||
options: Optional[dict] = None,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Generate model predictions using DataFrame input.
|
||||
|
||||
If an 'id' column is defined in either dataframe, it will be used as the primary key.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-predict-table.html
|
||||
A full list of supported options can be found under "ML_PREDICT_TABLE Options"
|
||||
|
||||
Args:
|
||||
X: DataFrame containing prediction features (no labels).
|
||||
options: Additional prediction settings.
|
||||
|
||||
Returns:
|
||||
DataFrame with prediction results as returned by HeatWave.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Notes:
|
||||
Temporary SQL tables are created and deleted for input/output.
|
||||
"""
|
||||
X = convert_to_df(X)
|
||||
|
||||
with (
|
||||
atomic_transaction(self.db_connection) as cursor,
|
||||
temporary_sql_tables(self.db_connection) as temporary_tables,
|
||||
):
|
||||
_, table_name = sql_table_from_df(cursor, self.schema_name, X)
|
||||
temporary_tables.append((self.schema_name, table_name))
|
||||
|
||||
output_table_name = get_random_name(
|
||||
lambda table_name: not table_exists(
|
||||
cursor, self.schema_name, table_name
|
||||
)
|
||||
)
|
||||
temporary_tables.append((self.schema_name, output_table_name))
|
||||
|
||||
self._predict(table_name, output_table_name, options)
|
||||
predictions = sql_table_to_df(cursor, self.schema_name, output_table_name)
|
||||
|
||||
# ml_results is text but known to always follow JSON format
|
||||
predictions["ml_results"] = predictions["ml_results"].map(json.loads)
|
||||
|
||||
return predictions
|
||||
|
||||
def score(
|
||||
self,
|
||||
X: Union[pd.DataFrame, np.ndarray], # pylint: disable=invalid-name
|
||||
y: Union[pd.DataFrame, np.ndarray],
|
||||
metric: str,
|
||||
options: Optional[dict] = None,
|
||||
) -> float:
|
||||
"""
|
||||
Score the model using X/y data and a selected metric.
|
||||
|
||||
If an 'id' column is defined in either dataframe, it will be used as the primary key.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-score.html
|
||||
A full list of supported options can be found under
|
||||
"Options for Recommendation Models" and
|
||||
"Options for Anomaly Detection Models"
|
||||
|
||||
Args:
|
||||
X: DataFrame of features.
|
||||
y: DataFrame or Series of labels.
|
||||
metric: Metric name (e.g., "balanced_accuracy").
|
||||
options: Optional ml scoring options.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Returns:
|
||||
float: Computed score.
|
||||
"""
|
||||
X, y = convert_to_df(X), convert_to_df(y)
|
||||
|
||||
with (
|
||||
atomic_transaction(self.db_connection) as cursor,
|
||||
temporary_sql_tables(self.db_connection) as temporary_tables,
|
||||
):
|
||||
target_column_name = get_random_name(lambda name: name not in X.columns)
|
||||
df_combined = X.copy()
|
||||
df_combined[target_column_name] = y
|
||||
final_df = df_combined
|
||||
|
||||
_, table_name = sql_table_from_df(cursor, self.schema_name, final_df)
|
||||
temporary_tables.append((self.schema_name, table_name))
|
||||
|
||||
score = self._score(table_name, target_column_name, metric, options)
|
||||
|
||||
return score
|
||||
|
||||
def explain_predictions(
|
||||
self,
|
||||
X: Union[pd.DataFrame, np.ndarray], # pylint: disable=invalid-name
|
||||
options: Dict = None,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Explain model predictions using provided data.
|
||||
|
||||
If an 'id' column is defined in either dataframe, it will be used as the primary key.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-explain-table.html
|
||||
A full list of supported options can be found under
|
||||
"ML_EXPLAIN_TABLE Options"
|
||||
|
||||
Args:
|
||||
X: DataFrame for which predictions should be explained.
|
||||
options: Optional dictionary of explainability options.
|
||||
|
||||
Returns:
|
||||
DataFrame containing explanation details (feature attributions, etc.)
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported, or if the model is not initialized,
|
||||
i.e., fit or import has not been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Notes:
|
||||
Temporary input/output tables are cleaned up after explanation.
|
||||
"""
|
||||
X = convert_to_df(X)
|
||||
|
||||
with (
|
||||
atomic_transaction(self.db_connection) as cursor,
|
||||
temporary_sql_tables(self.db_connection) as temporary_tables,
|
||||
):
|
||||
|
||||
_, table_name = sql_table_from_df(cursor, self.schema_name, X)
|
||||
temporary_tables.append((self.schema_name, table_name))
|
||||
|
||||
output_table_name = get_random_name(
|
||||
lambda table_name: not table_exists(
|
||||
cursor, self.schema_name, table_name
|
||||
)
|
||||
)
|
||||
temporary_tables.append((self.schema_name, output_table_name))
|
||||
|
||||
explanations = self._explain_predictions(
|
||||
table_name, output_table_name, options
|
||||
)
|
||||
|
||||
return explanations
|
||||
221
venv/lib/python3.12/site-packages/mysql/ai/ml/outlier.py
Normal file
221
venv/lib/python3.12/site-packages/mysql/ai/ml/outlier.py
Normal file
@@ -0,0 +1,221 @@
|
||||
# Copyright (c) 2025 Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""Outlier/anomaly detection utilities for MySQL Connector/Python.
|
||||
|
||||
Provides a scikit-learn compatible wrapper using HeatWave to score anomalies.
|
||||
"""
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.base import OutlierMixin
|
||||
|
||||
from mysql.ai.ml.base import MyBaseMLModel
|
||||
from mysql.ai.ml.model import ML_TASK
|
||||
from mysql.ai.utils import copy_dict
|
||||
|
||||
from mysql.connector.abstracts import MySQLConnectionAbstract
|
||||
|
||||
EPS = 1e-5
|
||||
|
||||
|
||||
def _get_logits(prob: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
|
||||
"""
|
||||
Compute logit (logodds) for a probability, clipping to avoid numerical overflow.
|
||||
|
||||
Args:
|
||||
prob: Scalar or array of probability values in (0,1).
|
||||
|
||||
Returns:
|
||||
logit-transformed probabilities.
|
||||
"""
|
||||
result = np.clip(prob, EPS, 1 - EPS)
|
||||
return np.log(result / (1 - result))
|
||||
|
||||
|
||||
class MyAnomalyDetector(MyBaseMLModel, OutlierMixin):
|
||||
"""
|
||||
MySQL HeatWave scikit-learn compatible anomaly/outlier detector.
|
||||
|
||||
Flags samples as outliers when the probability of being an anomaly
|
||||
exceeds a user-tunable threshold.
|
||||
Includes helpers to obtain decision scores and anomaly probabilities
|
||||
for ranking.
|
||||
|
||||
Args:
|
||||
db_connection (MySQLConnectionAbstract): Active MySQL DB connection.
|
||||
model_name (str, optional): Custom model name in the database.
|
||||
fit_extra_options (dict, optional): Extra options for fitting.
|
||||
score_extra_options (dict, optional): Extra options for scoring/prediction.
|
||||
|
||||
Attributes:
|
||||
boundary: Decision threshold boundary in logit space. Derived from
|
||||
trained model's catalog info
|
||||
|
||||
Methods:
|
||||
predict(X): Predict outlier/inlier labels.
|
||||
score_samples(X): Compute anomaly (normal class) logit scores.
|
||||
decision_function(X): Compute signed score above/below threshold for ranking.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_connection: MySQLConnectionAbstract,
|
||||
model_name: Optional[str] = None,
|
||||
fit_extra_options: Optional[dict] = None,
|
||||
score_extra_options: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Initialize an anomaly detector instance with threshold and extra options.
|
||||
|
||||
Args:
|
||||
db_connection: Active MySQL DB connection.
|
||||
model_name: Optional model name in DB.
|
||||
fit_extra_options: Optional extra fit options.
|
||||
score_extra_options: Optional extra scoring options.
|
||||
|
||||
Raises:
|
||||
ValueError: If outlier_threshold is not in (0,1).
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
MyBaseMLModel.__init__(
|
||||
self,
|
||||
db_connection,
|
||||
ML_TASK.ANOMALY_DETECTION,
|
||||
model_name=model_name,
|
||||
fit_extra_options=fit_extra_options,
|
||||
)
|
||||
self.score_extra_options = copy_dict(score_extra_options)
|
||||
self.boundary: Optional[float] = None
|
||||
|
||||
def predict(
|
||||
self,
|
||||
X: Union[pd.DataFrame, np.ndarray], # pylint: disable=invalid-name
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Predict outlier/inlier binary labels for input samples.
|
||||
|
||||
Args:
|
||||
X: Samples to predict on.
|
||||
|
||||
Returns:
|
||||
ndarray: Values are -1 for outliers, +1 for inliers, as per scikit-learn convention.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
return np.where(self.decision_function(X) < 0.0, -1, 1)
|
||||
|
||||
def decision_function(
|
||||
self,
|
||||
X: Union[pd.DataFrame, np.ndarray], # pylint: disable=invalid-name
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute signed distance to the outlier threshold.
|
||||
|
||||
Args:
|
||||
X: Samples to predict on.
|
||||
|
||||
Returns:
|
||||
ndarray: Score > 0 means inlier, < 0 means outlier; |value| gives margin.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
ValueError:
|
||||
If the provided model info does not provide threshold
|
||||
"""
|
||||
sample_scores = self.score_samples(X)
|
||||
|
||||
if self.boundary is None:
|
||||
model_info = self.get_model_info()
|
||||
if model_info is None:
|
||||
raise ValueError("Model does not exist in catalog.")
|
||||
|
||||
threshold = model_info["model_metadata"]["training_params"].get(
|
||||
"anomaly_detection_threshold", None
|
||||
)
|
||||
if threshold is None:
|
||||
raise ValueError(
|
||||
"Trained model is outdated and does not support threshold. "
|
||||
"Try retraining or using an existing, trained model with MyModel."
|
||||
)
|
||||
|
||||
# scikit-learn uses large positive values as inlier
|
||||
# and negative as outlier, so we need to flip our threshold
|
||||
self.boundary = _get_logits(1.0 - threshold)
|
||||
|
||||
return sample_scores - self.boundary
|
||||
|
||||
def score_samples(
|
||||
self,
|
||||
X: Union[pd.DataFrame, np.ndarray], # pylint: disable=invalid-name
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute normal probability logit score for each sample.
|
||||
Used for ranking, thresholding.
|
||||
|
||||
Args:
|
||||
X: Samples to score.
|
||||
|
||||
Returns:
|
||||
ndarray: Logit scores based on "normal" class probability.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
result = self._model.predict(X, options=self.score_extra_options)
|
||||
|
||||
return _get_logits(
|
||||
result["ml_results"]
|
||||
.apply(lambda x: x["probabilities"]["normal"])
|
||||
.to_numpy()
|
||||
)
|
||||
154
venv/lib/python3.12/site-packages/mysql/ai/ml/regressor.py
Normal file
154
venv/lib/python3.12/site-packages/mysql/ai/ml/regressor.py
Normal file
@@ -0,0 +1,154 @@
|
||||
# Copyright (c) 2025 Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""Regressor utilities for MySQL Connector/Python.
|
||||
|
||||
Provides a scikit-learn compatible regressor backed by HeatWave ML.
|
||||
"""
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.base import RegressorMixin
|
||||
|
||||
from mysql.ai.ml.base import MyBaseMLModel
|
||||
from mysql.ai.ml.model import ML_TASK
|
||||
from mysql.ai.utils import copy_dict
|
||||
|
||||
from mysql.connector.abstracts import MySQLConnectionAbstract
|
||||
|
||||
|
||||
class MyRegressor(MyBaseMLModel, RegressorMixin):
|
||||
"""
|
||||
MySQL HeatWave scikit-learn compatible regressor estimator.
|
||||
|
||||
Provides prediction output from a regression model deployed in MySQL,
|
||||
and manages fit, explain, and prediction options as per HeatWave ML interface.
|
||||
|
||||
Attributes:
|
||||
predict_extra_options (dict): Optional parameter dict passed to the backend for prediction.
|
||||
_model (MyModel): Underlying interface for database model operations.
|
||||
fit_extra_options (dict): See MyBaseMLModel.
|
||||
explain_extra_options (dict): See MyBaseMLModel.
|
||||
|
||||
Args:
|
||||
db_connection (MySQLConnectionAbstract): Active MySQL connector DB connection.
|
||||
model_name (str, optional): Custom name for the model.
|
||||
fit_extra_options (dict, optional): Extra options for fitting.
|
||||
explain_extra_options (dict, optional): Extra options for explanations.
|
||||
predict_extra_options (dict, optional): Extra options for predictions.
|
||||
|
||||
Methods:
|
||||
predict(X): Predict regression target.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_connection: MySQLConnectionAbstract,
|
||||
model_name: Optional[str] = None,
|
||||
fit_extra_options: Optional[dict] = None,
|
||||
explain_extra_options: Optional[dict] = None,
|
||||
predict_extra_options: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Initialize a MyRegressor.
|
||||
|
||||
Args:
|
||||
db_connection: Active MySQL connector database connection.
|
||||
model_name: Optional, custom model name.
|
||||
fit_extra_options: Optional fit options.
|
||||
explain_extra_options: Optional explain options.
|
||||
predict_extra_options: Optional prediction options.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
MyBaseMLModel.__init__(
|
||||
self,
|
||||
db_connection,
|
||||
ML_TASK.REGRESSION,
|
||||
model_name=model_name,
|
||||
fit_extra_options=fit_extra_options,
|
||||
)
|
||||
|
||||
self.predict_extra_options = copy_dict(predict_extra_options)
|
||||
self.explain_extra_options = copy_dict(explain_extra_options)
|
||||
|
||||
def predict(
|
||||
self, X: Union[pd.DataFrame, np.ndarray]
|
||||
) -> np.ndarray: # pylint: disable=invalid-name
|
||||
"""
|
||||
Predict a continuous target for the input features using the MySQL model.
|
||||
|
||||
Args:
|
||||
X: Input samples as a numpy array or pandas DataFrame.
|
||||
|
||||
Returns:
|
||||
ndarray: Array of predicted target values, shape (n_samples,).
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
result = self._model.predict(X, options=self.predict_extra_options)
|
||||
return result["Prediction"].to_numpy()
|
||||
|
||||
def explain_predictions(
|
||||
self, X: Union[pd.DataFrame, np.ndarray]
|
||||
) -> pd.DataFrame: # pylint: disable=invalid-name
|
||||
"""
|
||||
Explain model predictions using provided data.
|
||||
|
||||
References:
|
||||
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-explain-table.html
|
||||
A full list of supported options can be found under "ML_EXPLAIN_TABLE Options"
|
||||
|
||||
Args:
|
||||
X: DataFrame for which predictions should be explained.
|
||||
|
||||
Returns:
|
||||
DataFrame containing explanation details (feature attributions, etc.)
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
|
||||
Notes:
|
||||
Temporary input/output tables are cleaned up after explanation.
|
||||
"""
|
||||
self._model.explain_predictions(X, options=self.explain_extra_options)
|
||||
164
venv/lib/python3.12/site-packages/mysql/ai/ml/transformer.py
Normal file
164
venv/lib/python3.12/site-packages/mysql/ai/ml/transformer.py
Normal file
@@ -0,0 +1,164 @@
|
||||
# Copyright (c) 2025 Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
"""Generic transformer utilities for MySQL Connector/Python.
|
||||
|
||||
Provides a scikit-learn compatible Transformer using HeatWave for fit/transform
|
||||
and scoring operations.
|
||||
"""
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.base import TransformerMixin
|
||||
|
||||
from mysql.ai.ml.base import MyBaseMLModel
|
||||
from mysql.ai.ml.model import ML_TASK
|
||||
from mysql.ai.utils import copy_dict
|
||||
from mysql.connector.abstracts import MySQLConnectionAbstract
|
||||
|
||||
|
||||
class MyGenericTransformer(MyBaseMLModel, TransformerMixin):
|
||||
"""
|
||||
MySQL HeatWave scikit-learn compatible generic transformer.
|
||||
|
||||
Can be used as the transformation step in an sklearn pipeline. Implements fit, transform,
|
||||
explain, and scoring capability, passing options for server-side transform logic.
|
||||
|
||||
Args:
|
||||
db_connection (MySQLConnectionAbstract): Active MySQL connector database connection.
|
||||
task (str): ML task type for transformer (default: "classification").
|
||||
score_metric (str): Scoring metric to request from backend (default: "balanced_accuracy").
|
||||
model_name (str, optional): Custom name for the deployed model.
|
||||
fit_extra_options (dict, optional): Extra fit options.
|
||||
transform_extra_options (dict, optional): Extra options for transformations.
|
||||
score_extra_options (dict, optional): Extra options for scoring.
|
||||
|
||||
Attributes:
|
||||
score_metric (str): Name of the backend metric to use for scoring
|
||||
(e.g. "balanced_accuracy").
|
||||
score_extra_options (dict): Dictionary of optional scoring parameters;
|
||||
passed to backend score.
|
||||
transform_extra_options (dict): Dictionary of inference (/predict)
|
||||
parameters for the backend.
|
||||
fit_extra_options (dict): See MyBaseMLModel.
|
||||
_model (MyModel): Underlying interface for database model operations.
|
||||
|
||||
Methods:
|
||||
fit(X, y): Fit the underlying model using the provided features/targets.
|
||||
transform(X): Transform features using the backend model.
|
||||
score(X, y): Score data using backend metric and options.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_connection: MySQLConnectionAbstract,
|
||||
task: Union[str, ML_TASK] = ML_TASK.CLASSIFICATION,
|
||||
score_metric: str = "balanced_accuracy",
|
||||
model_name: Optional[str] = None,
|
||||
fit_extra_options: Optional[dict] = None,
|
||||
transform_extra_options: Optional[dict] = None,
|
||||
score_extra_options: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Initialize transformer with required and optional arguments.
|
||||
|
||||
Args:
|
||||
db_connection: Active MySQL backend database connection.
|
||||
task: ML task type for transformer.
|
||||
score_metric: Requested backend scoring metric.
|
||||
model_name: Optional model name for storage.
|
||||
fit_extra_options: Optional extra options for fitting.
|
||||
transform_extra_options: Optional extra options for transformation/inference.
|
||||
score_extra_options: Optional extra scoring options.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
MyBaseMLModel.__init__(
|
||||
self,
|
||||
db_connection,
|
||||
task,
|
||||
model_name=model_name,
|
||||
fit_extra_options=fit_extra_options,
|
||||
)
|
||||
|
||||
self.score_metric = score_metric
|
||||
self.score_extra_options = copy_dict(score_extra_options)
|
||||
|
||||
self.transform_extra_options = copy_dict(transform_extra_options)
|
||||
|
||||
def transform(
|
||||
self, X: pd.DataFrame
|
||||
) -> pd.DataFrame: # pylint: disable=invalid-name
|
||||
"""
|
||||
Transform input data to model predictions using the underlying helper.
|
||||
|
||||
Args:
|
||||
X: DataFrame of features to predict/transform.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: Results of transformation as returned by backend.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
return self._model.predict(X, options=self.transform_extra_options)
|
||||
|
||||
def score(
|
||||
self,
|
||||
X: Union[pd.DataFrame, np.ndarray], # pylint: disable=invalid-name
|
||||
y: Union[pd.DataFrame, np.ndarray],
|
||||
) -> float:
|
||||
"""
|
||||
Score the transformed data using the backend scoring interface.
|
||||
|
||||
Args:
|
||||
X: Transformed features.
|
||||
y: Target labels or data for scoring.
|
||||
|
||||
Returns:
|
||||
float: Score based on backend metric.
|
||||
|
||||
Raises:
|
||||
DatabaseError:
|
||||
If provided options are invalid or unsupported,
|
||||
or if the model is not initialized, i.e., fit or import has not
|
||||
been called
|
||||
If a database connection issue occurs.
|
||||
If an operational error occurs during execution.
|
||||
"""
|
||||
return self._model.score(
|
||||
X, y, self.score_metric, options=self.score_extra_options
|
||||
)
|
||||
Reference in New Issue
Block a user