Source code for kfp.dsl.for_loop

# Copyright 2021 The Kubeflow Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Classes and methods that supports argument for ParallelFor."""

import re
from typing import Any, Dict, List, Optional, Union

from kfp.dsl import pipeline_channel

ItemList = List[Union[int, float, str, Dict[str, Any]]]

LOOP_ITEM_NAME_BASE = 'loop-item'
LOOP_ITEM_PARAM_NAME_BASE = 'loop-item-param'


def _get_loop_item_type(type_name: str) -> Optional[str]:
    """Extracts the loop item type.

    This method is used for extract the item type from a collection type.
    For example:

        List[str] -> str
        typing.List[int] -> int
        typing.Sequence[str] -> str
        List -> None
        str -> None

    Args:
        type_name: The collection type name, like `List`, Sequence`, etc.

    Returns:
        The collection item type or None if no match found.
    """
    match = re.match('(typing\.)?(?:\w+)(?:\[(?P<item_type>.+)\])', type_name)
    return match['item_type'].lstrip().rstrip() if match else None


def _get_subvar_type(type_name: str) -> Optional[str]:
    """Extracts the subvar type.

    This method is used for extract the value type from a dictionary type.
    For example:

        Dict[str, int] -> int
        typing.Mapping[str, float] -> float

    Args:
        type_name: The dictionary type.

    Returns:
        The dictionary value type or None if no match found.
    """
    match = re.match(
        '(typing\.)?(?:\w+)(?:\[\s*(?:\w+)\s*,\s*(?P<value_type>.+)\])',
        type_name)
    return match['value_type'].lstrip().rstrip() if match else None


def _get_first_element_type(item_list: ItemList) -> str:
    """Returns the type of the first element of ItemList.

    Args:
        item_list: List of items to loop over. If a list of dicts then, all dicts must have the same keys.
    Returns:
        A string representing the type of the first element (e.g., "int", "Dict[str, int]").
    """
    first_element = item_list[0]
    if isinstance(first_element, dict):
        key_type = type(list(
            first_element.keys())[0]).__name__  # Get type of first key
        value_type = type(list(
            first_element.values())[0]).__name__  # Get type of first value
        return f'Dict[{key_type}, {value_type}]'
    else:
        return type(first_element).__name__


def _make_name(code: str) -> str:
    """Makes a name for a loop argument from a unique code."""
    return f'{LOOP_ITEM_PARAM_NAME_BASE}-{code}'


class LoopParameterArgument(pipeline_channel.PipelineParameterChannel):
    """Represents the parameter arguments that are looped over in a ParallelFor
    loop.

    The class shouldn't be instantiated by the end user, rather it is
    created automatically by a ParallelFor ops group.

    To create a LoopParameterArgument instance, use one of its factory methods::

        LoopParameterArgument.from_pipeline_channel(...)
        LoopParameterArgument.from_raw_items(...)


    Attributes:
        items_or_pipeline_channel: The raw items or the PipelineParameterChannel object
        this LoopParameterArgument is associated to.
    """

    def __init__(
        self,
        items: Union[ItemList, pipeline_channel.PipelineParameterChannel],
        name_code: Optional[str] = None,
        name_override: Optional[str] = None,
        **kwargs,
    ):
        """Initializes a LoopParameterArgument object.

        Args:
            items: List of items to loop over.  If a list of dicts then, all
                dicts must have the same keys and every key must be a legal
                Python variable name.
            name_code: A unique code used to identify these loop arguments.
                Should match the code for the ParallelFor ops_group which created
                these LoopArguments. This prevents parameter name collisions.
            name_override: The override name for PipelineParameterChannel.
            **kwargs: Any other keyword arguments passed down to PipelineParameterChannel.
        """
        if (name_code is None) == (name_override is None):
            raise ValueError(
                'Expect one and only one of `name_code` and `name_override` to '
                'be specified.')

        if name_override is None:
            super().__init__(name=_make_name(name_code), **kwargs)
        else:
            super().__init__(name=name_override, **kwargs)

        if not isinstance(
                items,
            (list, tuple, pipeline_channel.PipelineParameterChannel)):
            raise TypeError(
                f'Expected list, tuple, or PipelineParameterChannel, got {items}.'
            )

        if isinstance(items, tuple):
            items = list(items)

        self.items_or_pipeline_channel = items
        self.is_with_items_loop_argument = not isinstance(
            items, pipeline_channel.PipelineParameterChannel)
        self._referenced_subvars: Dict[str, LoopArgumentVariable] = {}

        if isinstance(items, list) and isinstance(items[0], dict):
            subvar_names = set(items[0].keys())
            # then this block creates loop_arg.variable_a and loop_arg.variable_b
            for subvar_name in subvar_names:
                loop_arg_var = LoopArgumentVariable(
                    loop_argument=self,
                    subvar_name=subvar_name,
                )
                self._referenced_subvars[subvar_name] = loop_arg_var
                setattr(self, subvar_name, loop_arg_var)

    def __getattr__(self, name: str):
        # this is being overridden so that we can access subvariables of the
        # LoopArgument (i.e.: item.a) without knowing the subvariable names ahead
        # of time.

        return self._referenced_subvars.setdefault(
            name, LoopArgumentVariable(
                loop_argument=self,
                subvar_name=name,
            ))

    @classmethod
    def from_pipeline_channel(
        cls,
        channel: pipeline_channel.PipelineParameterChannel,
    ) -> 'LoopParameterArgument':
        """Creates a LoopParameterArgument object from a
        PipelineParameterChannel object.

        Provide a flexible default channel_type ('String') if extraction
        from PipelineParameterChannel is unsuccessful. This maintains
        compilation progress in cases of unknown or missing type
        information.
        """
        return LoopParameterArgument(
            items=channel,
            name_override=channel.name + '-' + LOOP_ITEM_NAME_BASE,
            task_name=channel.task_name,
            channel_type=_get_loop_item_type(channel.channel_type) or 'String',
        )

    @classmethod
    def from_raw_items(
        cls,
        raw_items: ItemList,
        name_code: str,
    ) -> 'LoopParameterArgument':
        """Creates a LoopParameterArgument object from raw item list."""
        if len(raw_items) == 0:
            raise ValueError('Got an empty item list for loop argument.')

        return LoopParameterArgument(
            items=raw_items,
            name_code=name_code,
            channel_type=_get_first_element_type(raw_items),
        )


class LoopArtifactArgument(pipeline_channel.PipelineArtifactChannel):
    """Represents the artifact arguments that are looped over in a ParallelFor
    loop.

    The class shouldn't be instantiated by the end user, rather it is
    created automatically by a ParallelFor ops group.

    To create a LoopArtifactArgument instance, use the factory method::

        LoopArtifactArgument.from_pipeline_channel(...)


    Attributes:
        pipeline_channel: The PipelineArtifactChannel object this
            LoopArtifactArgument is associated to.
    """

    def __init__(
        self,
        items: pipeline_channel.PipelineArtifactChannel,
        name_code: Optional[str] = None,
        name_override: Optional[str] = None,
        **kwargs,
    ):
        """Initializes a LoopArtifactArgument object.

        Args:
            items: The PipelineArtifactChannel object this LoopArtifactArgument is
                associated to.
            name_code: A unique code used to identify these loop arguments.
                Should match the code for the ParallelFor ops_group which created
                these LoopArtifactArguments. This prevents parameter name collisions.
            name_override: The override name for PipelineArtifactChannel.
            **kwargs: Any other keyword arguments passed down to PipelineArtifactChannel.
        """
        if (name_code is None) == (name_override is None):
            raise ValueError(
                'Expect one and only one of `name_code` and `name_override` to '
                'be specified.')

        # We don't support nested lists so `is_artifact_list` is always False.
        if name_override is None:
            super().__init__(
                name=_make_name(name_code), is_artifact_list=False, **kwargs)
        else:
            super().__init__(
                name=name_override, is_artifact_list=False, **kwargs)

        self.items_or_pipeline_channel = items
        self.is_with_items_loop_argument = not isinstance(
            items, pipeline_channel.PipelineArtifactChannel)

    @classmethod
    def from_pipeline_channel(
        cls,
        channel: pipeline_channel.PipelineArtifactChannel,
    ) -> 'LoopArtifactArgument':
        """Creates a LoopArtifactArgument object from a PipelineArtifactChannel
        object."""
        if not channel.is_artifact_list:
            raise ValueError(
                'Cannot iterate over a single Artifact using `dsl.ParallelFor`. Expected a list of Artifacts as argument to `items`.'
            )
        return LoopArtifactArgument(
            items=channel,
            name_override=channel.name + '-' + LOOP_ITEM_NAME_BASE,
            task_name=channel.task_name,
            channel_type=channel.channel_type,
        )

    # TODO: support artifact constants here.


class LoopArgumentVariable(pipeline_channel.PipelineParameterChannel):
    """Represents a subvariable for a loop argument.

    This is used for cases where we're looping over maps, each of which contains
    several variables. If the user ran:

        with dsl.ParallelFor([{'a': 1, 'b': 2}, {'a': 3, 'b': 4}]) as item:
            ...

    Then there's one LoopArgumentVariable for 'a' and another for 'b'.

    Attributes:
        loop_argument: The original LoopArgument object this subvariable is
          attached to.
        subvar_name: The subvariable name.
    """
    SUBVAR_NAME_DELIMITER = '-subvar-'
    LEGAL_SUBVAR_NAME_REGEX = re.compile(r'^[a-zA-Z_][0-9a-zA-Z_]*$')

    def __init__(
        self,
        loop_argument: LoopParameterArgument,
        subvar_name: str,
    ):
        """Initializes a LoopArgumentVariable instance.

        Args:
            loop_argument: The LoopParameterArgument object this subvariable is based on
                a subvariable to.
            subvar_name: The name of this subvariable, which is the name of the
                dict key that spawned this subvariable.

        Raises:
            ValueError is subvar name is illegal.
        """
        if not self._subvar_name_is_legal(subvar_name):
            raise ValueError(
                f'Tried to create subvariable named {subvar_name}, but that is '
                'not a legal Python variable name.')

        self.subvar_name = subvar_name
        self.loop_argument = loop_argument
        # Handle potential channel_type extraction errors from LoopArgument by defaulting to 'String'. This maintains compilation progress.
        super().__init__(
            name=self._get_name_override(
                loop_arg_name=loop_argument.name,
                subvar_name=subvar_name,
            ),
            task_name=loop_argument.task_name,
            channel_type=_get_subvar_type(loop_argument.channel_type) or
            'String',
        )

    @property
    def items_or_pipeline_channel(
            self) -> Union[ItemList, pipeline_channel.PipelineParameterChannel]:
        """Returns the loop argument items."""
        return self.loop_argument.items_or_pipeline_channel

    @property
    def is_with_items_loop_argument(self) -> bool:
        """Whether the loop argument is originated from raw items."""
        return self.loop_argument.is_with_items_loop_argument

    def _subvar_name_is_legal(self, proposed_variable_name: str) -> bool:
        """Returns True if the subvar name is legal."""
        return re.match(self.LEGAL_SUBVAR_NAME_REGEX,
                        proposed_variable_name) is not None

    def _get_name_override(self, loop_arg_name: str, subvar_name: str) -> str:
        """Gets the name.

        Args:
            loop_arg_name: the name of the loop argument parameter that this
              LoopArgumentVariable is attached to.
            subvar_name: The name of this subvariable.

        Returns:
            The name of this loop arg variable.
        """
        return f'{loop_arg_name}{self.SUBVAR_NAME_DELIMITER}{subvar_name}'


# TODO: migrate Collected to OneOfMixin style implementation
[docs]class Collected(pipeline_channel.PipelineChannel): """For collecting into a list the output from a task in dsl.ParallelFor loops. Args: output: The output of an upstream task within a dsl.ParallelFor loop. Example: :: @dsl.pipeline def math_pipeline() -> int: with dsl.ParallelFor([1, 2, 3]) as x: t = double(num=x) return add(nums=dsl.Collected(t.output)).output """ def __init__( self, output: pipeline_channel.PipelineChannel, ) -> None: self.output = output # we know all dsl.Collected instances are lists, so set `is_artifact_list` # for type checking, which occurs before dsl.Collected is updated to # it's "correct" channel during compilation if isinstance(output, pipeline_channel.PipelineArtifactChannel): channel_type = output.channel_type self.is_artifact_channel = True self.is_artifact_list = True else: channel_type = 'LIST' self.is_artifact_channel = False self.is_artifact_list = False super().__init__( output.name, channel_type=channel_type, task_name=output.task_name, ) self._validate_no_oneof_channel(self.output) def _validate_no_oneof_channel( self, channel: Union[pipeline_channel.PipelineParameterChannel, pipeline_channel.PipelineArtifactChannel] ) -> None: if isinstance(channel, pipeline_channel.OneOfMixin): raise ValueError( f'dsl.{pipeline_channel.OneOf.__name__} cannot be used inside of dsl.{Collected.__name__}.' )