Skip to content

Commit

Permalink
Added support for categorical tree splits
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Mar 27, 2022
1 parent 154f030 commit c3a0575
Show file tree
Hide file tree
Showing 6 changed files with 241 additions and 27 deletions.
5 changes: 5 additions & 0 deletions pmml-xgboost/src/main/java/org/jpmml/xgboost/BinaryNode.java
Expand Up @@ -70,6 +70,11 @@ public int split_index(){
return (int)(this.sindex & ((1L << 31) - 1L));
}

@Override
public int split_type(){
return 0;
}

@Override
public int split_cond(){
return this.info;
Expand Down
110 changes: 88 additions & 22 deletions pmml-xgboost/src/main/java/org/jpmml/xgboost/FeatureMap.java
Expand Up @@ -18,6 +18,7 @@
*/
package org.jpmml.xgboost;

import java.util.AbstractList;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
Expand All @@ -33,6 +34,7 @@
import org.dmg.pmml.OpType;
import org.dmg.pmml.Value;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.PMMLEncoder;
Expand Down Expand Up @@ -77,28 +79,7 @@ public List<Feature> encodeFeatures(PMMLEncoder encoder){
}

public void addEntry(String name, String type){
addEntry(name, Entry.Type.fromString(type));

}

public void addEntry(String name, Entry.Type type){
Entry entry;

if(type == Entry.Type.INDICATOR){
String value = null;

int equals = name.indexOf('=');
if(equals > -1){
value = name.substring(equals + 1);
name = name.substring(0, equals);
}

entry = new IndicatorEntry(name, value, type);
} else

{
entry = new ContinuousEntry(name, type);
}
Entry entry = createEntry(name, Entry.Type.fromString(type));

addEntry(entry);
}
Expand Down Expand Up @@ -141,6 +122,31 @@ private void addValue(Value.Property property, String value){
values.add(value);
}

static
private Entry createEntry(String name, Entry.Type type){

switch(type){
case INDICATOR:
String value = null;

int equals = name.indexOf('=');
if(equals > -1){
value = name.substring(equals + 1);
name = name.substring(0, equals);
}

return new IndicatorEntry(name, value, type);
case QUANTITIVE:
case INTEGER:
case FLOAT:
return new ContinuousEntry(name, type);
case CATEGORICAL:
return new CategoricalEntry(name, type);
default:
throw new IllegalArgumentException();
}
}

abstract
static
public class Entry {
Expand Down Expand Up @@ -180,6 +186,7 @@ public enum Type {
QUANTITIVE,
INTEGER,
FLOAT,
CATEGORICAL,
;

static
Expand All @@ -194,6 +201,9 @@ public Type fromString(String string){
return Type.INTEGER;
case "float":
return Type.FLOAT;
case "c":
case "categorical":
return Type.CATEGORICAL;
default:
throw new IllegalArgumentException(string);
}
Expand Down Expand Up @@ -290,4 +300,60 @@ public Feature encodeFeature(PMMLEncoder encoder){
return new ContinuousFeature(encoder, dataField);
}
}

static
private class CategoricalEntry extends Entry {

public CategoricalEntry(String name, Type type){
super(name, type);
}

@Override
public Feature encodeFeature(PMMLEncoder encoder){
String name = getName();
Type type = getType();

DataField dataField = encoder.getDataField(name);
if(dataField == null){

switch(type){
case CATEGORICAL:
dataField = encoder.createDataField(name, OpType.CATEGORICAL, DataType.STRING);
break;
default:
throw new IllegalArgumentException();
}
}

List<Integer> values = new AbstractList<Integer>(){

private int max = -1;


@Override
public boolean isEmpty(){
return false;
}

@Override
public int size(){

if(this.max < 0){
throw new IllegalStateException();
}

return (this.max + 1);
}

@Override
public Integer get(int i){
this.max = Math.max(this.max, i);

return i;
}
};

return new CategoricalFeature(encoder, dataField, values);
}
}
}
28 changes: 28 additions & 0 deletions pmml-xgboost/src/main/java/org/jpmml/xgboost/JSONNode.java
Expand Up @@ -18,6 +18,8 @@
*/
package org.jpmml.xgboost;

import java.util.BitSet;

import com.google.gson.JsonObject;

public class JSONNode extends Node implements JSONLoadable {
Expand All @@ -32,8 +34,12 @@ public class JSONNode extends Node implements JSONLoadable {

private int split_index;

private int split_type;

private float split_condition;

private BitSet split_categories;


public JSONNode(){
}
Expand All @@ -45,7 +51,16 @@ public void loadJSON(JsonObject node){
this.right_child = node.getAsJsonPrimitive("right_child").getAsInt();
this.default_left = node.getAsJsonPrimitive("default_left").getAsBoolean();
this.split_index = node.getAsJsonPrimitive("split_index").getAsInt();
this.split_type = node.getAsJsonPrimitive("split_type").getAsInt();
this.split_condition = node.getAsJsonPrimitive("split_condition").getAsFloat();

switch(this.split_type){
case 0:
case 1:
break;
default:
throw new IllegalArgumentException();
}
}

@Override
Expand All @@ -68,6 +83,11 @@ public boolean default_left(){
return this.default_left;
}

@Override
public int split_type(){
return this.split_type;
}

@Override
public int split_index(){
return this.split_index;
Expand All @@ -82,4 +102,12 @@ public int split_cond(){
public float leaf_value(){
return this.split_condition;
}

public BitSet get_split_categories(){
return this.split_categories;
}

void set_split_categories(BitSet split_categories){
this.split_categories = split_categories;
}
}
7 changes: 7 additions & 0 deletions pmml-xgboost/src/main/java/org/jpmml/xgboost/Learner.java
Expand Up @@ -39,6 +39,7 @@
import org.dmg.pmml.Visitor;
import org.dmg.pmml.mining.MiningModel;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
Expand Down Expand Up @@ -291,6 +292,12 @@ public Schema toXGBoostSchema(boolean numeric, Schema schema){
@Override
public Feature apply(Feature feature){

if(feature instanceof CategoricalFeature){
CategoricalFeature categoricalFeature = (CategoricalFeature)feature;

return categoricalFeature;
} else

if(feature instanceof BinaryFeature){
BinaryFeature binaryFeature = (BinaryFeature)feature;

Expand Down
3 changes: 3 additions & 0 deletions pmml-xgboost/src/main/java/org/jpmml/xgboost/Node.java
Expand Up @@ -30,6 +30,9 @@ public class Node {
abstract
public int split_index();

abstract
public int split_type();

abstract
public int split_cond();

Expand Down

0 comments on commit c3a0575

Please sign in to comment.