Source code for kfp.components.types.artifact_types

# Copyright 2021 The Kubeflow Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Classes for input/output Artifacts in KFP SDK.

These are only compatible with v2 Pipelines.
"""

from typing import Dict, List, Optional, Type

_GCS_LOCAL_MOUNT_PREFIX = '/gcs/'
_MINIO_LOCAL_MOUNT_PREFIX = '/minio/'
_S3_LOCAL_MOUNT_PREFIX = '/s3/'


[docs]class Artifact: """Represents a generic machine learning artifact. This class and all artifact classes store the name, uri, and metadata for a machine learning artifact. Use this artifact type when an artifact does not fit into another more specific artifact type (e.g., ``Model``, ``Dataset``). Args: name: Name of the artifact. uri: The artifact's location on disk or cloud storage. metadata: Arbitrary key-value pairs about the artifact. Example: :: from kfp import dsl from kfp.dsl import Output, Artifact, Input @dsl.component def create_artifact( data: str, output_artifact: Output[Artifact], ): with open(output_artifact.path, 'w') as f: f.write(data) @dsl.component def use_artifact(input_artifact: Input[Artifact]): with open(input_artifact.path) as input_file: artifact_contents = input_file.read() print(artifact_contents) @dsl.pipeline(name='my-pipeline', pipeline_root='gs://my/storage') def my_pipeline(): create_task = create_artifact(data='my data') use_artifact(input_artifact=create_task.outputs['output_artifact']) Note: Other artifacts are used similarly to the usage of ``Artifact`` in the example above (within ``Input[]`` and ``Output[]``). """ schema_title = 'system.Artifact' schema_version = '0.0.1' def __init__(self, name: Optional[str] = None, uri: Optional[str] = None, metadata: Optional[Dict] = None) -> None: """Initializes the Artifact with the given name, URI and metadata.""" self.uri = uri or '' self.name = name or '' self.metadata = metadata or {} @property def path(self) -> str: return self._get_path() @path.setter def path(self, path: str) -> None: self._set_path(path) def _get_path(self) -> Optional[str]: if self.uri.startswith('gs://'): return _GCS_LOCAL_MOUNT_PREFIX + self.uri[len('gs://'):] elif self.uri.startswith('minio://'): return _MINIO_LOCAL_MOUNT_PREFIX + self.uri[len('minio://'):] elif self.uri.startswith('s3://'): return _S3_LOCAL_MOUNT_PREFIX + self.uri[len('s3://'):] return None def _set_path(self, path: str) -> None: if path.startswith(_GCS_LOCAL_MOUNT_PREFIX): path = 'gs://' + path[len(_GCS_LOCAL_MOUNT_PREFIX):] elif path.startswith(_MINIO_LOCAL_MOUNT_PREFIX): path = 'minio://' + path[len(_MINIO_LOCAL_MOUNT_PREFIX):] elif path.startswith(_S3_LOCAL_MOUNT_PREFIX): path = 's3://' + path[len(_S3_LOCAL_MOUNT_PREFIX):] self.uri = path
[docs]class Model(Artifact): """An artifact representing a machine learning model. Args: name: Name of the model. uri: The model's location on disk or cloud storage. metadata: Arbitrary key-value pairs about the model. """ schema_title = 'system.Model' def __init__(self, name: Optional[str] = None, uri: Optional[str] = None, metadata: Optional[Dict] = None) -> None: super().__init__(uri=uri, name=name, metadata=metadata) @property def framework(self) -> str: return self._get_framework() def _get_framework(self) -> str: return self.metadata.get('framework', '') @framework.setter def framework(self, framework: str) -> None: self._set_framework(framework) def _set_framework(self, framework: str) -> None: self.metadata['framework'] = framework
[docs]class Dataset(Artifact): """An artifact representing a machine learning dataset. Args: name: Name of the dataset. uri: The dataset's location on disk or cloud storage. metadata: Arbitrary key-value pairs about the dataset. """ schema_title = 'system.Dataset' def __init__(self, name: Optional[str] = None, uri: Optional[str] = None, metadata: Optional[Dict] = None) -> None: super().__init__(uri=uri, name=name, metadata=metadata)
[docs]class Metrics(Artifact): """An artifact for storing key-value scalar metrics. Args: name: Name of the metrics artifact. uri: The metrics artifact's location on disk or cloud storage. metadata: Key-value scalar metrics. """ schema_title = 'system.Metrics' def __init__(self, name: Optional[str] = None, uri: Optional[str] = None, metadata: Optional[Dict] = None) -> None: super().__init__(uri=uri, name=name, metadata=metadata)
[docs] def log_metric(self, metric: str, value: float) -> None: """Sets a custom scalar metric in the artifact's metadata. Args: metric: The metric key. value: The metric value. """ self.metadata[metric] = value
[docs]class ClassificationMetrics(Artifact): """An artifact for storing classification metrics. Args: name: Name of the metrics artifact. uri: The metrics artifact's location on disk or cloud storage. metadata: The key-value scalar metrics. """ schema_title = 'system.ClassificationMetrics' def __init__(self, name: Optional[str] = None, uri: Optional[str] = None, metadata: Optional[Dict] = None): super().__init__(uri=uri, name=name, metadata=metadata)
[docs] def log_roc_data_point(self, fpr: float, tpr: float, threshold: float) -> None: """Logs a single data point in the ROC curve to metadata. Args: fpr: False positive rate value of the data point. tpr: True positive rate value of the data point. threshold: Threshold value for the data point. """ roc_reading = { 'confidenceThreshold': threshold, 'recall': tpr, 'falsePositiveRate': fpr } if 'confidenceMetrics' not in self.metadata.keys(): self.metadata['confidenceMetrics'] = [] self.metadata['confidenceMetrics'].append(roc_reading)
[docs] def log_roc_curve(self, fpr: List[float], tpr: List[float], threshold: List[float]) -> None: """Logs an ROC curve to metadata. Args: fpr: List of false positive rate values. tpr: List of true positive rate values. threshold: List of threshold values. Raises: ValueError: If the lists ``fpr``, ``tpr`` and ``threshold`` are not the same length. """ if len(fpr) != len(tpr) or len(fpr) != len(threshold) or len( tpr) != len(threshold): raise ValueError( f'Length of fpr, tpr and threshold must be the same. Got lengths {len(fpr)}, {len(tpr)} and {len(threshold)} respectively.' ) for i in range(len(fpr)): self.log_roc_data_point( fpr=fpr[i], tpr=tpr[i], threshold=threshold[i])
[docs] def set_confusion_matrix_categories(self, categories: List[str]) -> None: """Stores confusion matrix categories to metadata. Args: categories: List of strings specifying the categories. """ self._categories = [] annotation_specs = [] for category in categories: annotation_spec = {'displayName': category} self._categories.append(category) annotation_specs.append(annotation_spec) self._matrix = [] for row in range(len(self._categories)): self._matrix.append({'row': [0] * len(self._categories)}) self._confusion_matrix = { 'annotationSpecs': annotation_specs, 'rows': self._matrix } self.metadata['confusionMatrix'] = self._confusion_matrix
[docs] def log_confusion_matrix_row(self, row_category: str, row: List[float]) -> None: """Logs a confusion matrix row to metadata. Args: row_category: Category to which the row belongs. row: List of integers specifying the values for the row. Raises: ValueError: If ``row_category`` is not in the list of categories set in ``set_categories`` call. """ if row_category not in self._categories: raise ValueError( f'Invalid category: {row_category} passed. Expected one of: {self._categories}' ) if len(row) != len(self._categories): raise ValueError( f'Invalid row. Expected size: {len(self._categories)} got: {len(row)}' ) self._matrix[self._categories.index(row_category)] = {'row': row} self.metadata['confusionMatrix'] = self._confusion_matrix
[docs] def log_confusion_matrix_cell(self, row_category: str, col_category: str, value: int) -> None: """Logs a cell in the confusion matrix to metadata. Args: row_category: String representing the name of the row category. col_category: String representing the name of the column category. value: Value of the cell. Raises: ValueError: If ``row_category`` or ``col_category`` is not in the list of categories set in ``set_categories``. """ if row_category not in self._categories: raise ValueError( f'Invalid category: {row_category} passed. Expected one of: {self._categories}' ) if col_category not in self._categories: raise ValueError( f'Invalid category: {row_category} passed. Expected one of: {self._categories}' ) self._matrix[self._categories.index(row_category)]['row'][ self._categories.index(col_category)] = value self.metadata['confusionMatrix'] = self._confusion_matrix
[docs] def log_confusion_matrix(self, categories: List[str], matrix: List[List[int]]) -> None: """Logs a confusion matrix to metadata. Args: categories: List of the category names. matrix: Complete confusion matrix. Raises: ValueError: If the length of ``categories`` does not match number of rows or columns of ``matrix``. """ self.set_confusion_matrix_categories(categories) if len(matrix) != len(categories): raise ValueError( f'Invalid matrix: {matrix} passed for categories: {categories}') for index in range(len(categories)): if len(matrix[index]) != len(categories): raise ValueError( f'Invalid matrix: {matrix} passed for categories: {categories}' ) self.log_confusion_matrix_row(categories[index], matrix[index]) self.metadata['confusionMatrix'] = self._confusion_matrix
[docs]class SlicedClassificationMetrics(Artifact): """An artifact for storing sliced classification metrics. Similar to ``ClassificationMetrics``, tasks using this class are expected to use log methods of the class to log metrics with the difference being each log method takes a slice to associate the ``ClassificationMetrics``. Args: name: Name of the metrics artifact. uri: The metrics artifact's location on disk or cloud storage. metadata: Arbitrary key-value pairs about the metrics artifact. """ schema_title = 'system.SlicedClassificationMetrics' def __init__(self, name: Optional[str] = None, uri: Optional[str] = None, metadata: Optional[Dict] = None) -> None: super().__init__(uri=uri, name=name, metadata=metadata) def _upsert_classification_metrics_for_slice(self, slice: str) -> None: """Upserts the classification metrics instance for a slice.""" if slice not in self._sliced_metrics: self._sliced_metrics[slice] = ClassificationMetrics() def _update_metadata(self, slice: str) -> None: """Updates metadata to adhere to the metrics schema.""" self.metadata = {'evaluationSlices': []} for slice in self._sliced_metrics.keys(): slice_metrics = { 'slice': slice, 'sliceClassificationMetrics': self._sliced_metrics[slice].metadata } self.metadata['evaluationSlices'].append(slice_metrics)
[docs] def log_roc_reading(self, slice: str, threshold: float, tpr: float, fpr: float) -> None: """Logs a single data point in the ROC curve of a slice to metadata. Args: slice: String representing slice label. threshold: Thresold value for the data point. tpr: True positive rate value of the data point. fpr: False positive rate value of the data point. """ self._upsert_classification_metrics_for_slice(slice) self._sliced_metrics[slice].log_roc_reading(threshold, tpr, fpr) self._update_metadata(slice)
[docs] def load_roc_readings(self, slice: str, readings: List[List[float]]) -> None: """Bulk loads ROC curve readings for a slice. Args: slice: String representing slice label. readings: A 2-dimensional list providing ROC curve data points. The expected order of the data points is: threshold, true positive rate, false positive rate. """ self._upsert_classification_metrics_for_slice(slice) self._sliced_metrics[slice].load_roc_readings(readings) self._update_metadata(slice)
[docs] def set_confusion_matrix_categories(self, slice: str, categories: List[str]) -> None: """Logs confusion matrix categories for a slice to metadata. Categories are stored in the internal ``metrics_utils.ConfusionMatrix`` instance of the slice. Args: slice: String representing slice label. categories: List of strings specifying the categories. """ self._upsert_classification_metrics_for_slice(slice) self._sliced_metrics[slice].set_confusion_matrix_categories(categories) self._update_metadata(slice)
[docs] def log_confusion_matrix_row(self, slice: str, row_category: str, row: List[int]) -> None: """Logs a confusion matrix row for a slice to metadata. Row is updated on the internal ``metrics_utils.ConfusionMatrix`` instance of the slice. Args: slice: String representing slice label. row_category: Category to which the row belongs. row: List of integers specifying the values for the row. """ self._upsert_classification_metrics_for_slice(slice) self._sliced_metrics[slice].log_confusion_matrix_row(row_category, row) self._update_metadata(slice)
[docs] def log_confusion_matrix_cell(self, slice: str, row_category: str, col_category: str, value: int) -> None: """Logs a confusion matrix cell for a slice to metadata. Cell is updated on the internal ``metrics_utils.ConfusionMatrix`` instance of the slice. Args: slice: String representing slice label. row_category: String representing the name of the row category. col_category: String representing the name of the column category. value: Value of the cell. """ self._upsert_classification_metrics_for_slice(slice) self._sliced_metrics[slice].log_confusion_matrix_cell( row_category, col_category, value) self._update_metadata(slice)
[docs] def load_confusion_matrix(self, slice: str, categories: List[str], matrix: List[List[int]]) -> None: """Bulk loads the whole confusion matrix for a slice. Args: slice: String representing slice label. categories: List of the category names. matrix: Complete confusion matrix. """ self._upsert_classification_metrics_for_slice(slice) self._sliced_metrics[slice].log_confusion_matrix_cell( categories, matrix) self._update_metadata(slice)
[docs]class HTML(Artifact): """An artifact representing an HTML file. Args: name: Name of the HTML file. uri: The HTML file's location on disk or cloud storage. metadata: Arbitrary key-value pairs about the HTML file. """ schema_title = 'system.HTML' def __init__(self, name: Optional[str] = None, uri: Optional[str] = None, metadata: Optional[Dict] = None) -> None: super().__init__(uri=uri, name=name, metadata=metadata)
[docs]class Markdown(Artifact): """An artifact representing a markdown file. Args: name: Name of the markdown file. uri: The markdown file's location on disk or cloud storage. metadata: Arbitrary key-value pairs about the markdown file. """ schema_title = 'system.Markdown' def __init__(self, name: Optional[str] = None, uri: Optional[str] = None, metadata: Optional[Dict] = None): super().__init__(uri=uri, name=name, metadata=metadata)
_SCHEMA_TITLE_TO_TYPE: Dict[str, Type[Artifact]] = { x.schema_title: x for x in [ Artifact, Model, Dataset, Metrics, ClassificationMetrics, SlicedClassificationMetrics, HTML, Markdown, ] }