Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is there any way to see alphas/coefs/intercept associated with *all* scenarios tested within ElasticNetCV #28726

Open
cppt opened this issue Mar 29, 2024 · 2 comments

Comments

@cppt
Copy link

cppt commented Mar 29, 2024

Describe the workflow you want to enable

I like that ElasticNetCV outputs the MSE path for CV folds/alphas but is there any way to similarly track associated model params (ie, coef/intercept) for each scenario and include them as part of output.

I get that it's easier to just output 'best' estimators/params but would be useful to add granularity to allow identifying a 'sweet spot', either via MSE curve or something else, which would make outputting all params additive.

Describe your proposed solution

As described, run existing scenarios as is but instead of holding only through evaluation of 'best' model, save all model params/outputs and return in an additional data object/structure.

Describe alternatives you've considered, if relevant

No response

Additional context

No response

@cppt cppt added Needs Triage Issue requires triage New Feature labels Mar 29, 2024
@glemaitre
Copy link
Member

@glemaitre glemaitre removed the Needs Triage Issue requires triage label Apr 1, 2024
@hammad7
Copy link

hammad7 commented Apr 16, 2024

@cppt , here is the code:

import numpy as np
from sklearn.linear_model import enet_path

# Create some sample data
X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])
y = np.dot(X, np.array([1, 2])) + 3

# Compute the regularization path
alphas, coefs, intercepts = enet_path(X, y, verbose = 5)
# Print the alphas, coefficients, and intercepts for all scenarios tested
for alpha, coef, intercept in zip(alphas, coefs, intercepts):
    print("Alpha:", alpha)
    print("Coefficients:", coef)
    print("Intercept:", intercept)
    print()

Output:
(array([0., 0.]), 0.0, 0.0302, 1)
(array([0. , 0.05715559]), 0.0, 0.0302, 2)
(array([0. , 0.11667845]), 5.684341886080802e-14, 0.0302, 2)
(array([0. , 0.17856487]), 0.0, 0.0302, 2)
(array([0. , 0.24279889]), -5.684341886080802e-14, 0.0302, 2)
(array([0. , 0.30935145]), 0.0, 0.0302, 2)
(array([0.01901234, 0.37443649]), 1.3194308507991082e-06, 0.0302, 4)
(array([0.06464451, 0.43584286]), 5.815761113581175e-06, 0.0302, 4)
(array([0.11104886, 0.49826743]), 1.3875526946094396e-05, 0.0302, 4)
(array([0.15812687, 0.56158046]), 1.6494158217028598e-06, 0.0302, 5)
(array([0.2057831 , 0.62564193]), 3.247510619530658e-06, 0.0302, 5)
(array([0.25390901, 0.69030648]), 5.998067450718736e-06, 0.0302, 5)
(array([0.30239422, 0.75542238]), 1.0576523891359102e-05, 0.0302, 5)
(array([0.35112521, 0.82083339]), 1.797996975483329e-05, 0.0302, 5)
(array([0.39998632, 0.88638024]), 2.9643879514651417e-05, 0.0302, 5)
(array([0.44886095, 0.95190213]), 4.75841175102687e-05, 0.0302, 5)
(array([0.4976327 , 1.01723829]), 7.456191761434638e-05, 0.0302, 5)
(array([0.54618652, 1.08222949]), 0.00011426530710423322, 0.0302, 5)
(array([0.59440985, 1.14671955]), 0.00017149506439295692, 0.0302, 5)
(array([0.64219371, 1.21055678]), 0.00025233719833295254, 0.0302, 5)
(array([0.68943375, 1.27359533]), 0.0003642978385300921, 0.0302, 5)
(array([0.73600503, 1.33570614]), 9.317236222727843e-05, 0.0302, 6)
(array([0.78185772, 1.39674312]), 0.0001400085866976042, 0.0302, 6)
(array([0.82688711, 1.45659175]), 0.00020630181253977753, 0.0302, 6)
(array([0.87101494, 1.51514063]), 0.0002981673163162668, 0.0302, 6)
(array([0.91417102, 1.5722888 ]), 0.00042285210944470464, 0.0302, 6)
(array([0.95629359, 1.6279462 ]), 0.0005886332188183019, 0.0302, 6)
(array([0.99732949, 1.68203387]), 0.0008046054945296532, 0.0302, 6)
(array([1.03718305, 1.73450754]), 0.0003165159688549579, 0.0302, 7)
(array([1.07590454, 1.78527184]), 0.0004431231468089436, 0.0302, 7)
(array([1.11342785, 1.83429845]), 0.0006085892242211344, 0.0302, 7)
(array([1.14973329, 1.88155252]), 0.0008200316172946032, 0.0302, 7)
(array([1.18480902, 1.92700936]), 0.0010844981319380054, 0.0302, 7)
(array([1.21865063, 1.97065383]), 0.001408358238293772, 0.0302, 7)
(array([1.25117521, 2.01252455]), 0.0007281750637666562, 0.0302, 8)
(array([1.28254042, 2.05254682]), 0.0009553156995636414, 0.0302, 8)
(array([1.31269417, 2.09076491]), 0.0012308004081234003, 0.0302, 8)
(array([1.34165629, 2.12719574]), 0.001556845843310839, 0.0302, 8)
(array([1.36945136, 2.16186239]), 0.0019345101170102907, 0.0302, 8)
(array([1.39599369, 2.19485863]), 0.0011828625463508047, 0.0302, 9)
(array([1.42152022, 2.22610224]), 0.0014705246886403955, 0.0302, 9)
(array([1.44597336, 2.25568279]), 0.0017984996816977628, 0.0302, 9)
(array([1.46939059, 2.28364133]), 0.002161982013863195, 0.0302, 9)
(array([1.49181143, 2.31002169]), 0.0025561466318571036, 0.0302, 9)
(array([1.51313001, 2.33495931]), 0.001755255011964607, 0.0302, 10)
(array([1.53365873, 2.35833888]), 0.002060662494283605, 0.0302, 10)
(array([1.55331617, 2.38028421]), 0.0023873322183831647, 0.0302, 10)
(array([1.57214462, 2.40084582]), 0.0027231035870052267, 0.0302, 10)
(array([1.59018678, 2.42007443]), 0.0030601661872324826, 0.0302, 10)
(array([1.60731468, 2.43812958]), 0.002267589904306533, 0.0302, 11)
(array([1.62389131, 2.45485819]), 0.0025147333974242514, 0.0302, 11)
(array([1.6398085 , 2.47040489]), 0.0027633870586996068, 0.0302, 11)
(array([1.65510415, 2.48482075]), 0.0029952686152476815, 0.0302, 11)
(array([1.66981648, 2.49815509]), 0.0032039541598951615, 0.0302, 11)
(array([1.68398224, 2.51045643]), 0.003383846876634067, 0.0302, 11)
(array([1.69763653, 2.52177235]), 0.003530214226813655, 0.0302, 11)
(array([1.7108127 , 2.53214943]), 0.003639320592169426, 0.0302, 11)
(array([1.72354233, 2.54163315]), 0.003708501889304827, 0.0302, 11)
(array([1.73585512, 2.55026788]), 0.0037361839696927746, 0.0302, 11)
(array([1.74777888, 2.5580968 ]), 0.003721852482371446, 0.0302, 11)
(array([1.75933949, 2.56516191]), 0.0036659837869166267, 0.0302, 11)
(array([1.77056088, 2.57150404]), 0.00356994707446745, 0.0302, 11)
(array([1.78146505, 2.57716285]), 0.0034358874668125594, 0.0302, 11)
(array([1.79227591, 2.58203659]), 0.003982404951348606, 0.0302, 10)
(array([1.80262024, 2.58643133]), 0.003807589819487589, 0.0302, 10)
(array([1.81289059, 2.59012413]), 0.004213254973232772, 0.0302, 9)
(array([1.82271506, 2.59341721]), 0.003930192435560365, 0.0302, 9)
(array([1.83247549, 2.59608888]), 0.004142907300813903, 0.0302, 8)
(array([1.84181918, 2.59842501]), 0.003749310136090145, 0.0302, 8)
(array([1.85130873, 2.60007179]), 0.004356432275689048, 0.0302, 6)
(array([1.86025481, 2.60154708]), 0.0039903360093003215, 0.0302, 6)
(array([1.86928474, 2.60244646]), 0.0043643070791796745, 0.0302, 4)
(array([1.87801167, 2.60307353]), 0.004425443659130224, 0.0302, 3)
(array([1.88645693, 2.60344448]), 0.004249947511216767, 0.0302, 2)
(array([1.89441145, 2.60373623]), 0.0033649108293953844, 0.0302, 2)
(array([1.90196687, 2.60391233]), 0.0020436509041052986, 0.0302, 2)
(array([1.90918874, 2.60395269]), 0.0004711652838800262, 0.0302, 2)
(array([1.91612453, 2.6038482 ]), 0.00167596941914816, 0.0302, 2)
(array([1.92280901, 2.6035971 ]), 0.004071920576677712, 0.0302, 2)
(array([1.92950366, 2.60303499]), 0.005798944142170193, 0.0302, 3)
(array([1.9365236 , 2.60195723]), 0.005998496484806548, 0.0302, 5)
(array([1.94326137, 2.60081402]), 0.006433019368176307, 0.0302, 5)
(array([1.95001382, 2.59941164]), 0.006279225706987468, 0.0302, 6)
(array([1.95646083, 2.59799461]), 0.006416734108870781, 0.0302, 6)
(array([1.96290002, 2.59636619]), 0.006047004658215727, 0.0302, 7)
(array([1.96901443, 2.59476707]), 0.006006448644048312, 0.0302, 7)
(array([1.97487238, 2.59316222]), 0.006104199110143593, 0.0302, 7)
(array([1.98050722, 2.59154059]), 0.006250992622957696, 0.0302, 7)
(array([1.98593567, 2.58990216]), 0.006404871518275179, 0.0302, 7)
(array([1.99116647, 2.58825181]), 0.006546459726492593, 0.0302, 7)
(array([1.9962046 , 2.58659636]), 0.006667379095167458, 0.0302, 7)
(array([2.00105333, 2.58494317]), 0.006764688637442617, 0.0302, 7)
(array([2.00571525, 2.58329945]), 0.006838155449508498, 0.0302, 7)
(array([2.01019279, 2.58167192]), 0.006888886693848573, 0.0302, 7)
(array([2.01448848, 2.58006668]), 0.006918629284542455, 0.0302, 7)
(array([2.01860507, 2.5784891 ]), 0.006929403744237117, 0.0302, 7)
(array([2.02254557, 2.57694391]), 0.006923308817433416, 0.0302, 7)
(array([2.0260774 , 2.57560473]), 0.007546682213824152, 0.0302, 6)
(array([2.02977259, 2.57406636]), 0.007255271415334885, 0.0302, 7)
(array([2.03303197, 2.57276514]), 0.007721195772496525, 0.0302, 6)
[1, 2, 2, 2, 2, 2, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 10, 10, 9, 9, 8, 8, 6, 6, 4, 3, 2, 2, 2, 2, 2, 2, 3, 5, 5, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6, 7, 6]
Alpha: 36.5
Coefficients: [0. 0. 0. 0. 0. 0.
0.01901234 0.06464451 0.11104886 0.15812687 0.2057831 0.25390901
0.30239422 0.35112521 0.39998632 0.44886095 0.4976327 0.54618652
0.59440985 0.64219371 0.68943375 0.73600503 0.78185772 0.82688711
0.87101494 0.91417102 0.95629359 0.99732949 1.03718305 1.07590454
1.11342785 1.14973329 1.18480902 1.21865063 1.25117521 1.28254042
1.31269417 1.34165629 1.36945136 1.39599369 1.42152022 1.44597336
1.46939059 1.49181143 1.51313001 1.53365873 1.55331617 1.57214462
1.59018678 1.60731468 1.62389131 1.6398085 1.65510415 1.66981648
1.68398224 1.69763653 1.7108127 1.72354233 1.73585512 1.74777888
1.75933949 1.77056088 1.78146505 1.79227591 1.80262024 1.81289059
1.82271506 1.83247549 1.84181918 1.85130873 1.86025481 1.86928474
1.87801167 1.88645693 1.89441145 1.90196687 1.90918874 1.91612453
1.92280901 1.92950366 1.9365236 1.94326137 1.95001382 1.95646083
1.96290002 1.96901443 1.97487238 1.98050722 1.98593567 1.99116647
1.9962046 2.00105333 2.00571525 2.01019279 2.01448848 2.01860507
2.02254557 2.0260774 2.02977259 2.03303197]
Intercept: 0.0

Alpha: 34.04002216123752
Coefficients: [0. 0.05715559 0.11667845 0.17856487 0.24279889 0.30935145
0.37443649 0.43584286 0.49826743 0.56158046 0.62564193 0.69030648
0.75542238 0.82083339 0.88638024 0.95190213 1.01723829 1.08222949
1.14671955 1.21055678 1.27359533 1.33570614 1.39674312 1.45659175
1.51514063 1.5722888 1.6279462 1.68203387 1.73450754 1.78527184
1.83429845 1.88155252 1.92700936 1.97065383 2.01252455 2.05254682
2.09076491 2.12719574 2.16186239 2.19485863 2.22610224 2.25568279
2.28364133 2.31002169 2.33495931 2.35833888 2.38028421 2.40084582
2.42007443 2.43812958 2.45485819 2.47040489 2.48482075 2.49815509
2.51045643 2.52177235 2.53214943 2.54163315 2.55026788 2.5580968
2.56516191 2.57150404 2.57716285 2.58203659 2.58643133 2.59012413
2.59341721 2.59608888 2.59842501 2.60007179 2.60154708 2.60244646
2.60307353 2.60344448 2.60373623 2.60391233 2.60395269 2.6038482
2.6035971 2.60303499 2.60195723 2.60081402 2.59941164 2.59799461
2.59636619 2.59476707 2.59316222 2.59154059 2.58990216 2.58825181
2.58659636 2.58494317 2.58329945 2.58167192 2.58006668 2.5784891
2.57694391 2.57560473 2.57406636 2.57276514]
Intercept: 0.0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants