# Copyright 2021 The Kubeflow Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Classes for input/output Artifacts in KFP SDK.
These are only compatible with v2 Pipelines.
"""
from typing import Dict, List, Optional, Type
_GCS_LOCAL_MOUNT_PREFIX = '/gcs/'
_MINIO_LOCAL_MOUNT_PREFIX = '/minio/'
_S3_LOCAL_MOUNT_PREFIX = '/s3/'
[docs]class Artifact:
"""Represents a generic machine learning artifact.
This class and all artifact classes store the name, uri, and metadata for a machine learning artifact. Use this artifact type when an artifact does not fit into another more specific artifact type (e.g., ``Model``, ``Dataset``).
Args:
name: Name of the artifact.
uri: The artifact's location on disk or cloud storage.
metadata: Arbitrary key-value pairs about the artifact.
Example:
::
from kfp import dsl
from kfp.dsl import Output, Artifact, Input
@dsl.component
def create_artifact(
data: str,
output_artifact: Output[Artifact],
):
with open(output_artifact.path, 'w') as f:
f.write(data)
@dsl.component
def use_artifact(input_artifact: Input[Artifact]):
with open(input_artifact.path) as input_file:
artifact_contents = input_file.read()
print(artifact_contents)
@dsl.pipeline(name='my-pipeline', pipeline_root='gs://my/storage')
def my_pipeline():
create_task = create_artifact(data='my data')
use_artifact(input_artifact=create_task.outputs['output_artifact'])
Note: Other artifacts are used similarly to the usage of ``Artifact`` in the example above (within ``Input[]`` and ``Output[]``).
"""
schema_title = 'system.Artifact'
schema_version = '0.0.1'
def __init__(self,
name: Optional[str] = None,
uri: Optional[str] = None,
metadata: Optional[Dict] = None) -> None:
"""Initializes the Artifact with the given name, URI and metadata."""
self.uri = uri or ''
self.name = name or ''
self.metadata = metadata or {}
@property
def path(self) -> str:
return self._get_path()
@path.setter
def path(self, path: str) -> None:
self._set_path(path)
def _get_path(self) -> Optional[str]:
if self.uri.startswith('gs://'):
return _GCS_LOCAL_MOUNT_PREFIX + self.uri[len('gs://'):]
elif self.uri.startswith('minio://'):
return _MINIO_LOCAL_MOUNT_PREFIX + self.uri[len('minio://'):]
elif self.uri.startswith('s3://'):
return _S3_LOCAL_MOUNT_PREFIX + self.uri[len('s3://'):]
return None
def _set_path(self, path: str) -> None:
if path.startswith(_GCS_LOCAL_MOUNT_PREFIX):
path = 'gs://' + path[len(_GCS_LOCAL_MOUNT_PREFIX):]
elif path.startswith(_MINIO_LOCAL_MOUNT_PREFIX):
path = 'minio://' + path[len(_MINIO_LOCAL_MOUNT_PREFIX):]
elif path.startswith(_S3_LOCAL_MOUNT_PREFIX):
path = 's3://' + path[len(_S3_LOCAL_MOUNT_PREFIX):]
self.uri = path
[docs]class Model(Artifact):
"""An artifact representing a machine learning model.
Args:
name: Name of the model.
uri: The model's location on disk or cloud storage.
metadata: Arbitrary key-value pairs about the model.
"""
schema_title = 'system.Model'
def __init__(self,
name: Optional[str] = None,
uri: Optional[str] = None,
metadata: Optional[Dict] = None) -> None:
super().__init__(uri=uri, name=name, metadata=metadata)
@property
def framework(self) -> str:
return self._get_framework()
def _get_framework(self) -> str:
return self.metadata.get('framework', '')
@framework.setter
def framework(self, framework: str) -> None:
self._set_framework(framework)
def _set_framework(self, framework: str) -> None:
self.metadata['framework'] = framework
[docs]class Dataset(Artifact):
"""An artifact representing a machine learning dataset.
Args:
name: Name of the dataset.
uri: The dataset's location on disk or cloud storage.
metadata: Arbitrary key-value pairs about the dataset.
"""
schema_title = 'system.Dataset'
def __init__(self,
name: Optional[str] = None,
uri: Optional[str] = None,
metadata: Optional[Dict] = None) -> None:
super().__init__(uri=uri, name=name, metadata=metadata)
[docs]class Metrics(Artifact):
"""An artifact for storing key-value scalar metrics.
Args:
name: Name of the metrics artifact.
uri: The metrics artifact's location on disk or cloud storage.
metadata: Key-value scalar metrics.
"""
schema_title = 'system.Metrics'
def __init__(self,
name: Optional[str] = None,
uri: Optional[str] = None,
metadata: Optional[Dict] = None) -> None:
super().__init__(uri=uri, name=name, metadata=metadata)
[docs] def log_metric(self, metric: str, value: float) -> None:
"""Sets a custom scalar metric in the artifact's metadata.
Args:
metric: The metric key.
value: The metric value.
"""
self.metadata[metric] = value
[docs]class ClassificationMetrics(Artifact):
"""An artifact for storing classification metrics.
Args:
name: Name of the metrics artifact.
uri: The metrics artifact's location on disk or cloud storage.
metadata: The key-value scalar metrics.
"""
schema_title = 'system.ClassificationMetrics'
def __init__(self,
name: Optional[str] = None,
uri: Optional[str] = None,
metadata: Optional[Dict] = None):
super().__init__(uri=uri, name=name, metadata=metadata)
[docs] def log_roc_data_point(self, fpr: float, tpr: float,
threshold: float) -> None:
"""Logs a single data point in the ROC curve to metadata.
Args:
fpr: False positive rate value of the data point.
tpr: True positive rate value of the data point.
threshold: Threshold value for the data point.
"""
roc_reading = {
'confidenceThreshold': threshold,
'recall': tpr,
'falsePositiveRate': fpr
}
if 'confidenceMetrics' not in self.metadata.keys():
self.metadata['confidenceMetrics'] = []
self.metadata['confidenceMetrics'].append(roc_reading)
[docs] def log_roc_curve(self, fpr: List[float], tpr: List[float],
threshold: List[float]) -> None:
"""Logs an ROC curve to metadata.
Args:
fpr: List of false positive rate values.
tpr: List of true positive rate values.
threshold: List of threshold values.
Raises:
ValueError: If the lists ``fpr``, ``tpr`` and ``threshold`` are not the same length.
"""
if len(fpr) != len(tpr) or len(fpr) != len(threshold) or len(
tpr) != len(threshold):
raise ValueError(
f'Length of fpr, tpr and threshold must be the same. Got lengths {len(fpr)}, {len(tpr)} and {len(threshold)} respectively.'
)
for i in range(len(fpr)):
self.log_roc_data_point(
fpr=fpr[i], tpr=tpr[i], threshold=threshold[i])
[docs] def set_confusion_matrix_categories(self, categories: List[str]) -> None:
"""Stores confusion matrix categories to metadata.
Args:
categories: List of strings specifying the categories.
"""
self._categories = []
annotation_specs = []
for category in categories:
annotation_spec = {'displayName': category}
self._categories.append(category)
annotation_specs.append(annotation_spec)
self._matrix = []
for row in range(len(self._categories)):
self._matrix.append({'row': [0] * len(self._categories)})
self._confusion_matrix = {
'annotationSpecs': annotation_specs,
'rows': self._matrix
}
self.metadata['confusionMatrix'] = self._confusion_matrix
[docs] def log_confusion_matrix_row(self, row_category: str,
row: List[float]) -> None:
"""Logs a confusion matrix row to metadata.
Args:
row_category: Category to which the row belongs.
row: List of integers specifying the values for the row.
Raises:
ValueError: If ``row_category`` is not in the list of categories
set in ``set_categories`` call.
"""
if row_category not in self._categories:
raise ValueError(
f'Invalid category: {row_category} passed. Expected one of: {self._categories}'
)
if len(row) != len(self._categories):
raise ValueError(
f'Invalid row. Expected size: {len(self._categories)} got: {len(row)}'
)
self._matrix[self._categories.index(row_category)] = {'row': row}
self.metadata['confusionMatrix'] = self._confusion_matrix
[docs] def log_confusion_matrix_cell(self, row_category: str, col_category: str,
value: int) -> None:
"""Logs a cell in the confusion matrix to metadata.
Args:
row_category: String representing the name of the row category.
col_category: String representing the name of the column category.
value: Value of the cell.
Raises:
ValueError: If ``row_category`` or ``col_category`` is not in the list of
categories set in ``set_categories``.
"""
if row_category not in self._categories:
raise ValueError(
f'Invalid category: {row_category} passed. Expected one of: {self._categories}'
)
if col_category not in self._categories:
raise ValueError(
f'Invalid category: {row_category} passed. Expected one of: {self._categories}'
)
self._matrix[self._categories.index(row_category)]['row'][
self._categories.index(col_category)] = value
self.metadata['confusionMatrix'] = self._confusion_matrix
[docs] def log_confusion_matrix(self, categories: List[str],
matrix: List[List[int]]) -> None:
"""Logs a confusion matrix to metadata.
Args:
categories: List of the category names.
matrix: Complete confusion matrix.
Raises:
ValueError: If the length of ``categories`` does not match number of rows or columns of ``matrix``.
"""
self.set_confusion_matrix_categories(categories)
if len(matrix) != len(categories):
raise ValueError(
f'Invalid matrix: {matrix} passed for categories: {categories}')
for index in range(len(categories)):
if len(matrix[index]) != len(categories):
raise ValueError(
f'Invalid matrix: {matrix} passed for categories: {categories}'
)
self.log_confusion_matrix_row(categories[index], matrix[index])
self.metadata['confusionMatrix'] = self._confusion_matrix
[docs]class SlicedClassificationMetrics(Artifact):
"""An artifact for storing sliced classification metrics.
Similar to ``ClassificationMetrics``, tasks using this class are
expected to use log methods of the class to log metrics with the
difference being each log method takes a slice to associate the
``ClassificationMetrics``.
Args:
name: Name of the metrics artifact.
uri: The metrics artifact's location on disk or cloud storage.
metadata: Arbitrary key-value pairs about the metrics artifact.
"""
schema_title = 'system.SlicedClassificationMetrics'
def __init__(self,
name: Optional[str] = None,
uri: Optional[str] = None,
metadata: Optional[Dict] = None) -> None:
super().__init__(uri=uri, name=name, metadata=metadata)
def _upsert_classification_metrics_for_slice(self, slice: str) -> None:
"""Upserts the classification metrics instance for a slice."""
if slice not in self._sliced_metrics:
self._sliced_metrics[slice] = ClassificationMetrics()
def _update_metadata(self, slice: str) -> None:
"""Updates metadata to adhere to the metrics schema."""
self.metadata = {'evaluationSlices': []}
for slice in self._sliced_metrics.keys():
slice_metrics = {
'slice':
slice,
'sliceClassificationMetrics':
self._sliced_metrics[slice].metadata
}
self.metadata['evaluationSlices'].append(slice_metrics)
[docs] def log_roc_reading(self, slice: str, threshold: float, tpr: float,
fpr: float) -> None:
"""Logs a single data point in the ROC curve of a slice to metadata.
Args:
slice: String representing slice label.
threshold: Thresold value for the data point.
tpr: True positive rate value of the data point.
fpr: False positive rate value of the data point.
"""
self._upsert_classification_metrics_for_slice(slice)
self._sliced_metrics[slice].log_roc_reading(threshold, tpr, fpr)
self._update_metadata(slice)
[docs] def load_roc_readings(self, slice: str,
readings: List[List[float]]) -> None:
"""Bulk loads ROC curve readings for a slice.
Args:
slice: String representing slice label.
readings: A 2-dimensional list providing ROC curve data points. The expected order of the data points is: threshold, true positive rate, false positive rate.
"""
self._upsert_classification_metrics_for_slice(slice)
self._sliced_metrics[slice].load_roc_readings(readings)
self._update_metadata(slice)
[docs] def set_confusion_matrix_categories(self, slice: str,
categories: List[str]) -> None:
"""Logs confusion matrix categories for a slice to metadata.
Categories are stored in the internal ``metrics_utils.ConfusionMatrix``
instance of the slice.
Args:
slice: String representing slice label.
categories: List of strings specifying the categories.
"""
self._upsert_classification_metrics_for_slice(slice)
self._sliced_metrics[slice].set_confusion_matrix_categories(categories)
self._update_metadata(slice)
[docs] def log_confusion_matrix_row(self, slice: str, row_category: str,
row: List[int]) -> None:
"""Logs a confusion matrix row for a slice to metadata.
Row is updated on the internal ``metrics_utils.ConfusionMatrix``
instance of the slice.
Args:
slice: String representing slice label.
row_category: Category to which the row belongs.
row: List of integers specifying the values for the row.
"""
self._upsert_classification_metrics_for_slice(slice)
self._sliced_metrics[slice].log_confusion_matrix_row(row_category, row)
self._update_metadata(slice)
[docs] def log_confusion_matrix_cell(self, slice: str, row_category: str,
col_category: str, value: int) -> None:
"""Logs a confusion matrix cell for a slice to metadata.
Cell is updated on the internal ``metrics_utils.ConfusionMatrix``
instance of the slice.
Args:
slice: String representing slice label.
row_category: String representing the name of the row category.
col_category: String representing the name of the column category.
value: Value of the cell.
"""
self._upsert_classification_metrics_for_slice(slice)
self._sliced_metrics[slice].log_confusion_matrix_cell(
row_category, col_category, value)
self._update_metadata(slice)
[docs] def load_confusion_matrix(self, slice: str, categories: List[str],
matrix: List[List[int]]) -> None:
"""Bulk loads the whole confusion matrix for a slice.
Args:
slice: String representing slice label.
categories: List of the category names.
matrix: Complete confusion matrix.
"""
self._upsert_classification_metrics_for_slice(slice)
self._sliced_metrics[slice].log_confusion_matrix_cell(
categories, matrix)
self._update_metadata(slice)
[docs]class HTML(Artifact):
"""An artifact representing an HTML file.
Args:
name: Name of the HTML file.
uri: The HTML file's location on disk or cloud storage.
metadata: Arbitrary key-value pairs about the HTML file.
"""
schema_title = 'system.HTML'
def __init__(self,
name: Optional[str] = None,
uri: Optional[str] = None,
metadata: Optional[Dict] = None) -> None:
super().__init__(uri=uri, name=name, metadata=metadata)
[docs]class Markdown(Artifact):
"""An artifact representing a markdown file.
Args:
name: Name of the markdown file.
uri: The markdown file's location on disk or cloud storage.
metadata: Arbitrary key-value pairs about the markdown file.
"""
schema_title = 'system.Markdown'
def __init__(self,
name: Optional[str] = None,
uri: Optional[str] = None,
metadata: Optional[Dict] = None):
super().__init__(uri=uri, name=name, metadata=metadata)
_SCHEMA_TITLE_TO_TYPE: Dict[str, Type[Artifact]] = {
x.schema_title: x for x in [
Artifact,
Model,
Dataset,
Metrics,
ClassificationMetrics,
SlicedClassificationMetrics,
HTML,
Markdown,
]
}