Source code for kfp.dsl.importer_node

# Copyright 2020 The Kubeflow Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility function for building Importer Node spec."""

from typing import Optional, Union, Type

from kfp.dsl import _container_op
from kfp.dsl import _pipeline_param
from kfp.dsl import dsl_utils
from kfp.dsl import io_types
from kfp.dsl import type_utils
from kfp.pipeline_spec import pipeline_spec_pb2

INPUT_KEY = 'uri'
OUTPUT_KEY = 'artifact'


def _build_importer_spec(
    artifact_uri: Union[_pipeline_param.PipelineParam, str],
    artifact_type_schema: pipeline_spec_pb2.ArtifactTypeSchema,
) -> pipeline_spec_pb2.PipelineDeploymentConfig.ImporterSpec:
  """Builds an importer executor spec.

  Args:
    artifact_uri: The artifact uri to import from.
    artifact_type_schema: The user specified artifact type schema of the
      artifact to be imported.

  Returns:
    An importer spec.
  """
  importer_spec = pipeline_spec_pb2.PipelineDeploymentConfig.ImporterSpec()
  importer_spec.type_schema.CopyFrom(artifact_type_schema)

  if isinstance(artifact_uri, _pipeline_param.PipelineParam):
    importer_spec.artifact_uri.runtime_parameter = INPUT_KEY
  elif isinstance(artifact_uri, str):
    importer_spec.artifact_uri.constant_value.string_value = artifact_uri

  return importer_spec


def _build_importer_task_spec(
    importer_base_name: str,
    artifact_uri: Union[_pipeline_param.PipelineParam, str],
) -> pipeline_spec_pb2.PipelineTaskSpec:
  """Builds an importer task spec.

  Args:
    importer_base_name: The base name of the importer node.
    artifact_uri: The artifact uri to import from.

  Returns:
    An importer node task spec.
  """
  result = pipeline_spec_pb2.PipelineTaskSpec()
  result.task_info.name = dsl_utils.sanitize_task_name(importer_base_name)
  result.component_ref.name = dsl_utils.sanitize_component_name(
      importer_base_name)

  if isinstance(artifact_uri, _pipeline_param.PipelineParam):
    result.inputs.parameters[
        INPUT_KEY].component_input_parameter = artifact_uri.full_name
  elif isinstance(artifact_uri, str):
    result.inputs.parameters[
        INPUT_KEY].runtime_value.constant_value.string_value = artifact_uri

  return result


def _build_importer_component_spec(
    importer_base_name: str,
    artifact_type_schema: pipeline_spec_pb2.ArtifactTypeSchema,
) -> pipeline_spec_pb2.ComponentSpec:
  """Builds an importer component spec.

  Args:
    importer_base_name: The base name of the importer node.
    artifact_type_schema: The user specified artifact type schema of the
      artifact to be imported.

  Returns:
    An importer node component spec.
  """
  result = pipeline_spec_pb2.ComponentSpec()
  result.executor_label = dsl_utils.sanitize_executor_label(importer_base_name)
  result.input_definitions.parameters[
      INPUT_KEY].type = pipeline_spec_pb2.PrimitiveType.STRING
  result.output_definitions.artifacts[OUTPUT_KEY].artifact_type.CopyFrom(
      artifact_type_schema)

  return result


[docs]def importer(artifact_uri: Union[_pipeline_param.PipelineParam, str], artifact_class: Type[io_types.Artifact], reimport: bool = False) -> _container_op.ContainerOp: """dsl.importer for importing an existing artifact. Only for v2 pipeline. Args: artifact_uri: The artifact uri to import from. artifact_type_schema: The user specified artifact type schema of the artifact to be imported. reimport: Whether to reimport the artifact. Defaults to False. Returns: A ContainerOp instance. Raises: ValueError if the passed in artifact_uri is neither a PipelineParam nor a constant string value. """ if isinstance(artifact_uri, _pipeline_param.PipelineParam): input_param = artifact_uri elif isinstance(artifact_uri, str): input_param = _pipeline_param.PipelineParam( name='uri', value=artifact_uri, param_type='String') else: raise ValueError( 'Importer got unexpected artifact_uri: {} of type: {}.'.format( artifact_uri, type(artifact_uri))) old_warn_value = _container_op.ContainerOp._DISABLE_REUSABLE_COMPONENT_WARNING _container_op.ContainerOp._DISABLE_REUSABLE_COMPONENT_WARNING = True task = _container_op.ContainerOp( name='importer', image='importer_image', # TODO: need a v1 implementation of importer. file_outputs={ OUTPUT_KEY: "{{{{$.outputs.artifacts['{}'].uri}}}}".format(OUTPUT_KEY) }, ) _container_op.ContainerOp._DISABLE_REUSABLE_COMPONENT_WARNING = old_warn_value artifact_type_schema = type_utils.get_artifact_type_schema(artifact_class) task.importer_spec = _build_importer_spec( artifact_uri=artifact_uri, artifact_type_schema=artifact_type_schema) task.task_spec = _build_importer_task_spec( importer_base_name=task.name, artifact_uri=artifact_uri) task.component_spec = _build_importer_component_spec( importer_base_name=task.name, artifact_type_schema=artifact_type_schema) task.inputs = [input_param] return task