Skip to content
Merged
248 changes: 248 additions & 0 deletions google/cloud/aiplatform/metadata/schema/google/artifact_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,3 +268,251 @@ def __init__(
metadata=extended_metadata,
state=state,
)


class ClassificationMetrics(base_artifact.BaseArtifactSchema):
"""A Google artifact representing evaluation Classification Metrics."""

schema_title = "google.ClassificationMetrics"

def __init__(
self,
*,
au_prc: Optional[float] = None,
au_roc: Optional[float] = None,
log_loss: Optional[float] = None,
artifact_id: Optional[str] = None,
uri: Optional[str] = None,
display_name: Optional[str] = None,
schema_version: Optional[str] = None,
description: Optional[str] = None,
metadata: Optional[Dict] = None,
state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE,
):
"""Args:
au_prc (float):
Optional. The Area Under Precision-Recall Curve metric.
Micro-averaged for the overall evaluation.
au_roc (float):
Optional. The Area Under Receiver Operating Characteristic curve metric.
Micro-averaged for the overall evaluation.
log_loss (float):
Optional. The Log Loss metric.
artifact_id (str):
Optional. The <resource_id> portion of the Artifact name with
the format. This is globally unique in a metadataStore:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/artifacts/<resource_id>.
uri (str):
Optional. The uniform resource identifier of the artifact file. May be empty if there is no actual
artifact file.
display_name (str):
Optional. The user-defined name of the Artifact.
schema_version (str):
Optional. schema_version specifies the version used by the Artifact.
If not set, defaults to use the latest version.
description (str):
Optional. Describes the purpose of the Artifact to be created.
metadata (Dict):
Optional. Contains the metadata information that will be stored in the Artifact.
state (google.cloud.gapic.types.Artifact.State):
Optional. The state of this Artifact. This is a
property of the Artifact, and does not imply or
capture any ongoing process. This property is
managed by clients (such as Vertex AI
Pipelines), and the system does not prescribe or
check the validity of state transitions.
"""
extended_metadata = copy.deepcopy(metadata) if metadata else {}
if au_prc:
extended_metadata["auPrc"] = au_prc
if au_roc:
extended_metadata["auRoc"] = au_roc
if log_loss:
extended_metadata["logLoss"] = log_loss

super(ClassificationMetrics, self).__init__(
uri=uri,
artifact_id=artifact_id,
display_name=display_name,
schema_version=schema_version,
description=description,
metadata=extended_metadata,
state=state,
)


class RegressionMetrics(base_artifact.BaseArtifactSchema):
"""A Google artifact representing evaluation Regression Metrics."""

schema_title = "google.RegressionMetrics"

def __init__(
self,
*,
root_mean_squared_error: Optional[float] = None,
mean_absolute_error: Optional[float] = None,
mean_absolute_percentage_error: Optional[float] = None,
r_squared: Optional[float] = None,
root_mean_squared_log_error: Optional[float] = None,
artifact_id: Optional[str] = None,
uri: Optional[str] = None,
display_name: Optional[str] = None,
schema_version: Optional[str] = None,
description: Optional[str] = None,
metadata: Optional[Dict] = None,
state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE,
):
"""Args:
root_mean_squared_error (float):
Optional. Root Mean Squared Error (RMSE).
mean_absolute_error (float):
Optional. Mean Absolute Error (MAE).
mean_absolute_percentage_error (float):
Optional. Mean absolute percentage error.
r_squared (float):
Optional. Coefficient of determination as Pearson correlation coefficient.
root_mean_squared_log_error (float):
Optional. Root mean squared log error.
artifact_id (str):
Optional. The <resource_id> portion of the Artifact name with
the format. This is globally unique in a metadataStore:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/artifacts/<resource_id>.
uri (str):
Optional. The uniform resource identifier of the artifact file. May be empty if there is no actual
artifact file.
display_name (str):
Optional. The user-defined name of the Artifact.
schema_version (str):
Optional. schema_version specifies the version used by the Artifact.
If not set, defaults to use the latest version.
description (str):
Optional. Describes the purpose of the Artifact to be created.
metadata (Dict):
Optional. Contains the metadata information that will be stored in the Artifact.
state (google.cloud.gapic.types.Artifact.State):
Optional. The state of this Artifact. This is a
property of the Artifact, and does not imply or
capture any ongoing process. This property is
managed by clients (such as Vertex AI
Pipelines), and the system does not prescribe or
check the validity of state transitions.
"""
extended_metadata = copy.deepcopy(metadata) if metadata else {}
if root_mean_squared_error:
extended_metadata["rootMeanSquaredError"] = root_mean_squared_error
if mean_absolute_error:
extended_metadata["meanAbsoluteError"] = mean_absolute_error
if mean_absolute_percentage_error:
extended_metadata["meanAbsolutePercentageError"] = mean_absolute_percentage_error
if r_squared:
extended_metadata["rSquared"] = r_squared
if root_mean_squared_log_error:
extended_metadata["rootMeanSquaredLogError"] = root_mean_squared_log_error

super(RegressionMetrics, self).__init__(
uri=uri,
artifact_id=artifact_id,
display_name=display_name,
schema_version=schema_version,
description=description,
metadata=extended_metadata,
state=state,
)


class ForecastingMetrics(base_artifact.BaseArtifactSchema):
"""A Google artifact representing evaluation Forecasting Metrics."""

schema_title = "google.ForecastingMetrics"

def __init__(
self,
*,
root_mean_squared_error: Optional[float] = None,
mean_absolute_error: Optional[float] = None,
mean_absolute_percentage_error: Optional[float] = None,
r_squared: Optional[float] = None,
root_mean_squared_log_error: Optional[float] = None,
weighted_absolute_percentage_error: Optional[float] = None,
root_mean_squared_percentage_error: Optional[float] = None,
symmetric_mean_absolute_percentage_error: Optional[float] = None,
artifact_id: Optional[str] = None,
uri: Optional[str] = None,
display_name: Optional[str] = None,
schema_version: Optional[str] = None,
description: Optional[str] = None,
metadata: Optional[Dict] = None,
state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE,
):
"""Args:
root_mean_squared_error (float):
Optional. Root Mean Squared Error (RMSE).
mean_absolute_error (float):
Optional. Mean Absolute Error (MAE).
mean_absolute_percentage_error (float):
Optional. Mean absolute percentage error.
r_squared (float):
Optional. Coefficient of determination as Pearson correlation coefficient.
root_mean_squared_log_error (float):
Optional. Root mean squared log error.
weighted_absolute_percentage_error (float):
Optional. Weighted Absolute Percentage Error.
Does not use weights, this is just what the metric is called.
Undefined if actual values sum to zero.
Will be very large if actual values sum to a very small number.
root_mean_squared_percentage_error (float):
Optional. Root Mean Square Percentage Error. Square root of MSPE.
Undefined/imaginary when MSPE is negative.
symmetric_mean_absolute_percentage_error (float):
Optional. Symmetric Mean Absolute Percentage Error.
artifact_id (str):
Optional. The <resource_id> portion of the Artifact name with
the format. This is globally unique in a metadataStore:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/artifacts/<resource_id>.
uri (str):
Optional. The uniform resource identifier of the artifact file. May be empty if there is no actual
artifact file.
display_name (str):
Optional. The user-defined name of the Artifact.
schema_version (str):
Optional. schema_version specifies the version used by the Artifact.
If not set, defaults to use the latest version.
description (str):
Optional. Describes the purpose of the Artifact to be created.
metadata (Dict):
Optional. Contains the metadata information that will be stored in the Artifact.
state (google.cloud.gapic.types.Artifact.State):
Optional. The state of this Artifact. This is a
property of the Artifact, and does not imply or
capture any ongoing process. This property is
managed by clients (such as Vertex AI
Pipelines), and the system does not prescribe or
check the validity of state transitions.
"""
extended_metadata = copy.deepcopy(metadata) if metadata else {}
if root_mean_squared_error:
extended_metadata["rootMeanSquaredError"] = root_mean_squared_error
if mean_absolute_error:
extended_metadata["meanAbsoluteError"] = mean_absolute_error
if mean_absolute_percentage_error:
extended_metadata["meanAbsolutePercentageError"] = mean_absolute_percentage_error
if r_squared:
extended_metadata["rSquared"] = r_squared
if root_mean_squared_log_error:
extended_metadata["rootMeanSquaredLogError"] = root_mean_squared_log_error
if weighted_absolute_percentage_error:
extended_metadata["weightedAbsolutePercentageError"] = weighted_absolute_percentage_error
if root_mean_squared_percentage_error:
extended_metadata["rootMeanSquaredPercentageError"] = root_mean_squared_percentage_error
if symmetric_mean_absolute_percentage_error:
extended_metadata["symmetricMeanAbsolutePercentageError"] = symmetric_mean_absolute_percentage_error

super(ForecastingMetrics, self).__init__(
uri=uri,
artifact_id=artifact_id,
display_name=display_name,
schema_version=schema_version,
description=description,
metadata=extended_metadata,
state=state,
)
138 changes: 138 additions & 0 deletions tests/unit/aiplatform/test_metadata_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,144 @@ def test_unmanaged_container_model_constructor_parameters_are_set_correctly(self
assert artifact.schema_version == _TEST_SCHEMA_VERSION


def test_classification_metrics_title_is_set_correctly(self):
artifact = google_artifact_schema.ClassificationMetrics()
assert artifact.schema_title == "google.ClassificationMetrics"

def test_classification_metrics_constructor_parameters_are_set_correctly(self):
au_prc = 1.0
au_roc = 2.0
log_loss = 0.5

artifact = google_artifact_schema.ClassificationMetrics(
au_prc=au_prc,
au_roc=au_roc,
log_loss=log_loss,
artifact_id=_TEST_ARTIFACT_ID,
uri=_TEST_URI,
display_name=_TEST_DISPLAY_NAME,
schema_version=_TEST_SCHEMA_VERSION,
description=_TEST_DESCRIPTION,
metadata=_TEST_UPDATED_METADATA,
)
expected_metadata = {
"test-param1": 2.0,
"test-param2": "test-value-1",
"test-param3": False,
"auPrc": 1.0,
"auRoc": 2.0,
"logLoss": 0.5,
}

assert artifact.artifact_id == _TEST_ARTIFACT_ID
assert artifact.uri == _TEST_URI
assert artifact.display_name == _TEST_DISPLAY_NAME
assert artifact.description == _TEST_DESCRIPTION
assert json.dumps(artifact.metadata, sort_keys=True) == json.dumps(
expected_metadata, sort_keys=True
)
assert artifact.schema_version == _TEST_SCHEMA_VERSION


def test_regression_metrics_title_is_set_correctly(self):
artifact = google_artifact_schema.RegressionMetrics()
assert artifact.schema_title == "google.RegressionMetrics"

def test_regression_metrics_constructor_parameters_are_set_correctly(self):
root_mean_squared_error = 1.0
mean_absolute_error = 2.0
mean_absolute_percentage_error = 0.2
r_squared = 0.5
root_mean_squared_log_error = 0.9

artifact = google_artifact_schema.RegressionMetrics(
root_mean_squared_error=root_mean_squared_error,
mean_absolute_error=mean_absolute_error,
mean_absolute_percentage_error=mean_absolute_percentage_error,
r_squared=r_squared,
root_mean_squared_log_error=root_mean_squared_log_error,
artifact_id=_TEST_ARTIFACT_ID,
uri=_TEST_URI,
display_name=_TEST_DISPLAY_NAME,
schema_version=_TEST_SCHEMA_VERSION,
description=_TEST_DESCRIPTION,
metadata=_TEST_UPDATED_METADATA,
)
expected_metadata = {
"test-param1": 2.0,
"test-param2": "test-value-1",
"test-param3": False,
"rootMeanSquaredError": 1.0,
"meanAbsoluteError": 2.0,
"meanAbsolutePercentageError": 0.2,
"rSquared": 0.5,
"rootMeanSquaredLogError": 0.9,
}

assert artifact.artifact_id == _TEST_ARTIFACT_ID
assert artifact.uri == _TEST_URI
assert artifact.display_name == _TEST_DISPLAY_NAME
assert artifact.description == _TEST_DESCRIPTION
assert json.dumps(artifact.metadata, sort_keys=True) == json.dumps(
expected_metadata, sort_keys=True
)
assert artifact.schema_version == _TEST_SCHEMA_VERSION


def test_forecasting_metrics_title_is_set_correctly(self):
artifact = google_artifact_schema.ForecastingMetrics()
assert artifact.schema_title == "google.ForecastingMetrics"

def test_forecasting_metrics_constructor_parameters_are_set_correctly(self):
root_mean_squared_error = 1.0
mean_absolute_error = 2.0
mean_absolute_percentage_error = 0.2
r_squared = 0.5
root_mean_squared_log_error = 0.9
weighted_absolute_percentage_error = 4.0
root_mean_squared_percentage_error = 0.7
symmetric_mean_absolute_percentage_error = 0.8

artifact = google_artifact_schema.UnmanagedContainerModel(
Comment thread
jaycee-li marked this conversation as resolved.
Outdated
root_mean_squared_error=root_mean_squared_error,
mean_absolute_error=mean_absolute_error,
mean_absolute_percentage_error=mean_absolute_percentage_error,
r_squared=r_squared,
root_mean_squared_log_error=root_mean_squared_log_error,
weighted_absolute_percentage_error=weighted_absolute_percentage_error,
root_mean_squared_percentage_error=root_mean_squared_percentage_error,
symmetric_mean_absolute_percentage_error=symmetric_mean_absolute_percentage_error,
artifact_id=_TEST_ARTIFACT_ID,
uri=_TEST_URI,
display_name=_TEST_DISPLAY_NAME,
schema_version=_TEST_SCHEMA_VERSION,
description=_TEST_DESCRIPTION,
metadata=_TEST_UPDATED_METADATA,
)
expected_metadata = {
"test-param1": 2.0,
"test-param2": "test-value-1",
"test-param3": False,
"rootMeanSquaredError": 1.0,
"meanAbsoluteError": 2.0,
"meanAbsolutePercentageError": 0.2,
"rSquared": 0.5,
"rootMeanSquaredLogError": 0.9,
"weightedAbsolutePercentageError": 4.0,
"rootMeanSquaredPercentageError": 0.7,
"symmetricMeanAbsolutePercentageError": 0.8,
}

assert artifact.artifact_id == _TEST_ARTIFACT_ID
assert artifact.uri == _TEST_URI
assert artifact.display_name == _TEST_DISPLAY_NAME
assert artifact.description == _TEST_DESCRIPTION
assert json.dumps(artifact.metadata, sort_keys=True) == json.dumps(
expected_metadata, sort_keys=True
)
assert artifact.schema_version == _TEST_SCHEMA_VERSION


@pytest.mark.usefixtures("google_auth_mock")
class TestMetadataSystemArtifactSchema:
def setup_method(self):
Expand Down