Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT allow metadata to be transformed in a Pipeline #28901

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

adrinjalali
Copy link
Member

Initial proposal: #28440 (comment)
xref: #28440 (comment)

This adds transform_input as a constructor argument to Pipeline, as:

    transform_input : list of str, default=None
        This enables transforming some input arguments to ``fit`` (other than ``X``)
        to be transformed by the steps of the pipeline up to the step which requires
        them. Requirement is defined via :ref:`metadata routing <metadata_routing>`.
        This can be used to pass a validation set through the pipeline for instance.

        See the example TBD for more details.

        You can only set this if metadata routing is enabled, which you
        can enable using ``sklearn.set_config(enable_metadata_routing=True)``.

It simply allows to transform metadata with fitted estimators up to the step which needs the metadata.

How does this look?

cc @lorentzenchr @ogrisel @amueller @betatim

Copy link

github-actions bot commented Apr 26, 2024

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 1622203. Link to the linter CI: here

@adrinjalali
Copy link
Member Author

So for simple cases where metadata is only used in fit and transform expects no metadata, we're fine. But things get a bit trickier when we start having a transform method of a step accepting the same metadata as the fit of that step.

Specifically, in this test:

@pytest.mark.usefixtures("enable_slep006")
@pytest.mark.parametrize("method", ["fit", "fit_transform"])
def test_transform_input_pipeline(method):
    """Test that with transform_input, data is correctly transformed for each step."""

    def get_transformer(registry, sample_weight, metadata):
        """Get a transformer with requests set."""
        return (
            ConsumingTransformer(registry=registry)
            .set_fit_request(sample_weight=sample_weight, metadata=metadata)
            .set_transform_request(sample_weight=sample_weight, metadata=metadata)
        )

    def get_pipeline():
        """Get a pipeline and corresponding registries.

        The pipeline has 4 steps, with different request values set to test different
        cases. One is aliased.
        """
        registry_1, registry_2, registry_3, registry_4 = (
            _Registry(),
            _Registry(),
            _Registry(),
            _Registry(),
        )
        pipe = make_pipeline(
            get_transformer(registry_1, sample_weight=True, metadata=True),
            get_transformer(registry_2, sample_weight=False, metadata=False),
            get_transformer(registry_3, sample_weight=True, metadata=True),
            get_transformer(registry_4, sample_weight="other_weights", metadata=True),
            transform_input=["sample_weight"],
        )
        return pipe, registry_1, registry_2, registry_3, registry_4

    def check_metadata(registry, methods, **metadata):
        """Check that the right metadata was recorded for the given methods."""
        assert registry
        for estimator in registry:
            for method in methods:
                check_recorded_metadata(
                    estimator,
                    method=method,
                    **metadata,
                )

    X = np.array([[1, 2], [3, 4]])
    y = np.array([0, 1])
    sample_weight = np.array([[1, 2]])
    other_weights = np.array([[30, 40]])
    metadata = np.array([[100, 200]])

    pipe, registry_1, registry_2, registry_3, registry_4 = get_pipeline()
    pipe.fit(
        X,
        y,
        sample_weight=sample_weight,
        other_weights=other_weights,
        metadata=metadata,
    )

    check_metadata(
        registry_1, ["fit", "transform"], sample_weight=sample_weight, metadata=metadata
    )
    check_metadata(registry_2, ["fit", "transform"])
    check_metadata(
        registry_3,
        ["fit", "transform"],
        sample_weight=sample_weight + 2,
        metadata=metadata,
    )
    check_metadata(
        registry_4,
        method.split("_"),  # ["fit", "transform"] if "fit_transform", ["fit"] otherwise
        sample_weight=other_weights + 3,
        metadata=metadata,
    )

Step 3 receives transformed data in its transform method during fit of the pipeline cause all metadata are transformed if they're in transform_input, but a second time when step3.transform is called, the metadata is not transformed (cause I haven't implemented it in pipeline.transform yet).

The question is, what should be the expected behavior?

Do we want transform_input to only transform when calling fit of sub estimators? That's a bit tricky cause all TransformerMixin estimators implement a fit_transform which accepts all metadata together, which means the metadata (if the same name) is either transformed or not transformed. (Wish we didn't have fit_transform in the first place, it's giving us so much headache)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant