Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial support for multioutput regression. #7514

Merged
merged 18 commits into from Dec 18, 2021

Conversation

trivialfis
Copy link
Member

Close #7309 .

  • Add num target model parameter, which is configured from input labels.
  • Change elementwise metric and indexing for weights.
  • Add demo.
  • Add tests.

src/metric/elementwise_metric.cu Outdated Show resolved Hide resolved
tests/cpp/metric/test_multiclass_metric.cc Outdated Show resolved Hide resolved
@@ -83,8 +88,10 @@ class RegLossObj : public ObjFunction {
// for better performance.
const size_t n_data_blocks = std::max(static_cast<size_t>(1), (on_device ? ndata : nthreads));
const size_t block_size = ndata / n_data_blocks + !!(ndata % n_data_blocks);
auto const n_targets = std::max(info.labels.Shape(1), static_cast<size_t>(1));

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use more ndarray friendly ways to express the calculation in the future.

@hcho3
Copy link
Collaborator

hcho3 commented Dec 16, 2021

Can we throw an error when the user attempt to use other objectives with multi-output labels?

@trivialfis
Copy link
Member Author

Can we throw an error when the user attempt to use other objectives with multi-output labels?

Added a check along with test. Thanks for the suggestion.

@trivialfis trivialfis added this to 1.6 In Progress in 2.0 Roadmap via automation Dec 17, 2021
@trivialfis
Copy link
Member Author

The support is primitive. We need to expand it to other regression objectives and add more documents.

@trivialfis trivialfis merged commit 58a6723 into dmlc:master Dec 18, 2021
2.0 Roadmap automation moved this from 1.6 In Progress to 1.6 Done Dec 18, 2021
@trivialfis trivialfis deleted the multi-output-reg branch December 18, 2021 01:28
@Craigacp
Copy link
Contributor

Is this API suitable for wrapping up in XGBoost4J, or do you want to build out a more general one before doing that?

@trivialfis
Copy link
Member Author

The API is good for regression and binary classification and should be stable unless major issue is found. So it's good enough to start looking into language bindings. I will try to add xgboost4j support later. Will keep you posted on the progress.

It's unlikely that we will implement multi class multi target based on existing interface so no need to worry about it at the moment. :-)

@trivialfis
Copy link
Member Author

@Craigacp Apologies for the slow progress. I looked into the JVM packages and I can't build a complete stack for jvm from spark down to basic java wrapper. The required change is not trivial since we will have 2-dim inputs for both base_margin and label, I'm not sure what's the best way to implement that for jvm packages.

@Craigacp
Copy link
Contributor

Do you have a branch somewhere with the current state of it? I'd assumed that DMatrix would grow to accept a 2d matrix for the targets and then most of the rest of the changes would be plumbing, but if there's something that needs some design effort I can take a look.

@trivialfis
Copy link
Member Author

trivialfis commented Feb 22, 2022

@Craigacp The master branch contains support for Python. If you are interested in a code walk I'm happy to chat offline.

I'd assumed that DMatrix would grow to accept a 2d matrix for the targets

Yes, that's pretty much the only requirement for regression. For multi-label classification some configuration needs to be done to make sure binary:logistic is used and no num_class is passed in.

The difficult part for me is just not being familiar with the jvm stack ...

@trivialfis
Copy link
Member Author

On Python we use a JSON string to represent the input memory buffer, so we have complete information about shape and strides, etc. I'm not sure what's the best way to handle these numeric data types on JVM.

@Craigacp
Copy link
Contributor

Ok, I should have some time next week to look through how the python code works. I might take you up on that code walk through once I've familiarised myself with it a bit.

The length of the header array should tell us the number of examples, and then for multidimensional things we can require that the target array not be sparse and have explicit zeros (though that might be pretty wasteful for multi-label classification). Then the target array is of known shape [numExamples, numOutputDimensions] linearised into a single vector (row-wise) and that should be enough information for the C API. But there might be a better way to do it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants