In the provided R code, a decision tree is generated using the rpart and caret packages, and visualized with the rattle package. The resulting plot displays a four-class decision tree, but due to the complexity, it is challenging to interpret. I seek advice on improving clarity by creating individual trees for each class, allowing for a more focused and comprehensible presentation.
library(rpart)
library(caret)
fitControl <- trainControl(method = "repeatedcv",
number = 10,
repeats = 10)
classifier = train(x = training_set[, names(training_set) != "Target"],
y = training_set$Target,
method = 'rpart',
parms = list(split = "gini"), trControl = fitControl,
tuneLength = 20)
classifier
complexity_parameter=classifier$bestTune
classifier = rpart(formula = Target ~ .,
data = training_set,parms = list(split = "information"),
control = rpart.control(cp = complexity_parameter))
library(RColorBrewer)
library(rattle)
fancyRpartPlot(classifier, caption = NULL, clip.right.labs=FALSE,branch=.3,type=3,
tweak=1.4)
I have a four-class decision tree, and as shown in the attached figure, the visualization is intricate and hard to decipher. I am considering drawing four separate trees, each dedicated to one class (e.g., the first tree displaying only leaves of the first class, the second tree for the second class, and so forth). I would appreciate any guidance or code snippets on how to achieve this and improve the interpretability of the decision tree plots
Edit. This is a sample of the training dataset:
training_set<- structure(list(AGE67CIYes = c(-0.176152387930331, -0.987016328202176,
0.05552302357591, -0.58468873762319, 0.162742606800352, 0.120896778307869,
-0.987016328202176, -0.64359160055763, -0.987016328202176, -0.628598979432629,
-0.987016328202176, -0.307993987241449, 0.889554504998379, -0.987016328202176,
-0.84077646366108, -0.122806076070342, 0.347797256654688, -0.585917218815798,
3.27330664446935, -0.227710210722183, -0.987016328202176, 0.0907486763211531,
-0.468831088265139, -0.0852317172820009, 2.14649177996699, -0.21957742854859,
-0.478526166947832, -0.987016328202176, 0.856614535142944, -0.987016328202176,
0.233732369261435, 0.773841021986012, 1.76557884040399, 1.70409677446699,
0.177204736845891, -0.987016328202176, -0.0303666544722618, 0.267016686824448,
-0.987016328202176, 0.133602302476064, -0.780150252101327, 0.569019137931335,
0.54169474123801, 1.04432323350976, -0.00304660292847676, 0.595633772087449,
0.0119187160870928, -0.987016328202176, 0.445153602815489, -0.0754726273166524,
-0.0553854181026097, -0.987016328202176, 0.447212111288207, -0.412974267062895,
0.565701855297101, 0.0332551927612325, 0.61493438306659, -0.987016328202176,
-0.245916422283762, 0.0936414642259315, 0.217200249252726, -0.774974426145616,
2.01102787070915, 0.644784396320045, 1.31792583076954, 0.0693891516694634,
0.152242180258608, 1.09469958100705, 1.13548440454805, -0.158246875053666,
0.755736021038804, 0.672766538708062, -0.735850059896174, -0.987016328202176,
0.145397105625745, 1.03910128090896, -0.987016328202176, -0.987016328202176,
-0.404673257734619, 0.215913693080231, 0.388480617599278, 0.411918265238067,
-0.987016328202176, 0.113253215693915, 0.334391574463053, -0.558854795203353,
1.27293994403935, 0.429900076191951, -0.535607536710634, 1.15212162607829,
1.49001293895707, -0.987016328202176, 0.886209458949893, -0.11303124287923,
3.37348021367463, 0.735737223588497, -0.0476125737990377, -0.14765557213803,
-0.481299453037445, -0.0487309116018985), AGE18_34M = c(-0.311288013988282,
-0.920252005474412, 0.839075436208503, 0.475468979668509, 0.361813050743929,
-0.0880491662273217, -1.30060058115752, 0.2381797045763, 0.46997486395285,
0.0951842010198588, -0.249345366669994, 0.0276151129841655, -0.679404318051254,
-0.460158578131396, -0.0580351642511107, -0.4914107121112, 0.0350108362887032,
-0.726823880120434, -1.21665514019721, -0.665041398673074, 2.03127028459426,
0.173414240816047, -1.06412941985671, -0.379467449594911, 0.971659698265082,
3.58247384148793, 0.0419848907173112, -0.122469737335293, 1.77830824717609,
-0.514049917080521, -0.18293712720485, 2.46932240611204, 0.103057098341037,
0.416637231962104, 0.0347935126703198, -0.0683916002418135, -0.140487578187558,
0.568918482749401, 0.80596534428402, 0.709398516775547, -0.894621921017203,
0.0649799930004623, -0.548245672308119, 1.44034787873404, 0.358249364890296,
2.26940685862186, -0.268422759950531, 1.5743831289976, -0.532610474943992,
-0.348407548366379, 0.853760927062079, -0.291156653609838, -0.243149894243735,
-0.844996329920804, 0.648164055112016, 0.273498687490394, 1.25994011310977,
-0.792349093161483, 1.1447573459646, -0.013918010591892, -0.526596681137413,
-0.658346817606973, 0.516140362965466, -0.78849760892925, -0.554157331051845,
0.243647599042001, 2.04373725057226, -0.97656446859546, -0.912356887137583,
-0.824415433098142, -0.335325459216275, -0.587458300768377, -0.258938778587358,
-0.555568876651955, 4.63824989974633, -0.412721089416945, 1.00499324152535,
-0.306301918538513, -0.228161687409781, -0.129884546841866, -0.0114026911643898,
1.87550512616216, -0.944634183197748, -0.262575555185161, 0.861798664854366,
-0.0509189201758156, -0.612609020943053, 0.110701490753568, -0.726367609420112,
-0.375128625430529, -0.99933686605059, -0.19110078784832, -0.371574440790039,
-0.938132768287129, 0.243368554114898, 0.929595838846939, -0.88523573776942,
-0.810391389504971, -0.331438750074501, 0.873586339714407), AGE67 = c(0.501041434601308,
-1.09665902567313, -0.499366515285846, -1.89627995143666, -0.298655565821342,
-0.657227918269277, 0.738638322699508, 0.844563870334618, -1.56421059601705,
-0.0241556789265915, -0.641594843326121, -0.478951705500202,
2.61976255704089, 1.96434823104864, -1.22122015725059, 1.55163432276405,
1.27827046071975, -0.343864172693654, 3.31198498838896, 1.66372654109949,
0.37030630478651, -0.419408311079017, -0.555972438485822, -0.774936554113457,
0.107872632262183, -1.34044707928847, -0.839545557277716, 1.92619155286728,
0.57973553068234, 0.221251081710207, -0.163884033780637, -0.687948777957268,
-0.0113182539359213, -0.37327476423973, -1.27171923730041, -1.38681856755255,
-1.40429288028817, 0.168119278247155, -0.659256439010892, -1.07437656281392,
-0.696986471528318, -0.864024735676032, -0.534449400062125, -0.180293103433951,
2.30163523766365, 3.15498511491188, 0.307283716170217, -0.157204374134115,
-0.518770979831071, -0.401045752552977, 0.997254420383102, 0.232181798069135,
-0.405998845074965, -0.568631832646478, 0.254930436291304, 0.404278675370718,
0.392765574171448, 2.47506069286078, -1.8557905876778, -0.0631740179401674,
0.0298313302450834, -0.647883494529081, -0.207182743392638, -1.39352452511716,
-1.34117850614179, 0.944792602443884, -1.24991896454826, -1.86690843412548,
-0.905681536580039, 0.665146208715289, -0.322988621135965, 0.014606612397604,
-0.911460946841895, -0.218306616580415, 0.26710983425054, 0.514703818235693,
2.14113331398361, -0.616093681265326, 0.0291501450665863, -0.44680387231444,
-0.25559820550726, -0.486781511123927, 1.46702523452746, 2.14512158886705,
0.586724522901828, -0.354832348471152, -0.0891185777716379, 0.684009304312411,
-0.946069291878499, -0.730270080428137, -0.101775097076054, 1.62556716185319,
0.37030630478651, -0.819482083082821, 1.64647740419461, -0.556561969797837,
0.0432415212999982, 0.279105885015719, -0.719800485475731, 0.352766684885891
), AGE67IT = c(0.546552412166264, -1.07526365158326, -0.468955640670514,
-1.8974512238005, -0.310213066218594, -0.648052048503732, 0.787735574714999,
0.854938313505668, -1.54987241979056, -0.0707358644941545, -0.726319975489725,
-0.475290635689973, 2.4034814230678, 2.03194628146796, -1.20170490207676,
1.61300289199568, 1.25342092240672, -0.311106154213347, 3.34434571314056,
1.72678702415178, 0.413843973443276, -0.423068606116984, -0.526415970583246,
-0.795742413966101, 0.0958131387143464, -1.34275483486837, -0.847437129878995,
1.85139271626387, 0.555660476058683, 0.188567559007975, -0.160260210148408,
-0.76145871420041, 0.02645919447723, -0.384844658381296, -1.27679963538867,
-1.36980284573822, -1.38754091479577, 0.208605146969538, -0.669199963819832,
-1.07508354694689, -0.669558461737779, -0.845642806091824, -0.50456806459325,
-0.145066152169093, 2.29730541056113, 3.24055457341925, 0.271679866495853,
-0.173297360212068, -0.50168800913321, -0.399729324986892, 0.977333088835057,
0.27363462296638, -0.404248928734908, -0.574213826955535, 0.272419254286123,
0.408398950123785, 0.436642241813503, 2.55036740928169, -1.84585350391826,
-0.0355778324729171, 0.0525201570882496, -0.619714327085901,
-0.182564609086414, -1.38913227212412, -1.32347389573356, 0.952476863305029,
-1.23894360432029, -1.86146509438973, -0.923463257162492, 0.648263392221682,
-0.358121272963515, -0.00290343411139714, -0.887270203066352,
-0.205933003746037, 0.309089861018009, 0.520773126763273, 2.11253418114413,
-0.58744468011722, -0.0301375036843241, -0.425015388384307, -0.228529543651722,
-0.456180679143214, 1.29606011127543, 2.1676026243595, 0.633528796321265,
-0.342348134083138, -0.0525154431967221, 0.686070667402912, -0.922400924572503,
-0.716563246171896, -0.0653629993121732, 1.68805166931954, 0.413843973443276,
-0.82126700480962, 1.66661355031845, -0.527014400248682, 0.0573323154964607,
0.255567177377771, -0.732301118039339, 0.396039611477956), DWE1tot = c(-1.24627225843832,
-0.972515585869268, -0.591117336947849, -0.950950893751152, -0.00341326678369597,
-0.317703674167421, -1.01860592322038, 1.70500797025228, -2.80147690897104,
0.0841927598217607, -2.31467874579201, -0.634564452794059, 1.63477150881602,
2.39176448207346, -0.603097208459657, 0.554218462480144, 0.596025520656533,
-0.363363481034292, 0.326011018689958, 1.45302282344243, 0.259839566760795,
-0.323873062136694, -0.645196692893364, 0.126535915869014, 0.273354029798474,
-1.6765246798592, -0.858119384992372, 0.559696169973685, 0.965787728534269,
0.168732755000225, -0.105465993392139, -0.476904337564692, 0.158116000800885,
0.368839813382008, -1.53101094871713, -0.768492476757933, -0.580381980928888,
0.383757877076662, -2.51125311699879, -0.285564675616673, -0.307208273135031,
-0.1958624767505, -0.52862437997276, -1.09793370211631, 2.37364271507784,
1.94326659579375, -0.29069268473587, -1.3926121391431, -0.454695942608588,
-0.574896772720613, 1.861531778717, 0.988010759757147, 0.0946589558731531,
-0.4279250695409, -0.641299416040736, 0.211479245514941, -0.594712699321882,
2.85291296625091, -0.317088857444325, 0.30614248023326, -0.375126422991232,
-0.631092891479471, -0.286277701436183, -1.74098962562517, -1.24105303470107,
0.709126059783464, -0.915573987456023, -2.33328457590235, -1.24384520039966,
0.371696888960207, 0.0337244463841061, 0.718542081472043, -1.15989953982117,
-0.249183962565405, -0.457050613250607, 0.543066989542432, 1.55057780183794,
-0.319116485641744, -0.486838166179266, -0.71584875469315, -0.292203836118496,
-0.0184323959900459, 2.93015633378211, 2.52801762478007, 0.435292040053886,
-0.373077593601094, 0.295243869739579, 0.425845049612918, -0.475620953880973,
-1.10368797436764, 0.346155709987028, 0.434151223026415, 0.632305650134888,
-0.499317186271939, 1.66284666421724, -0.464006541934083, -0.395742174424708,
0.0804687862477327, -0.714833790842392, 0.529197578800853), AGE18_34NoIT = c(0.19940732465257,
-0.46709113395183, 2.34119335826173, -0.287924912209743, -0.767850884074149,
-0.107050721676688, -0.767850884074149, 0.358720441384029, 0.30950776246791,
0.94233552589833, 0.0931152138034797, 0.47658498239766, -0.767850884074149,
-0.767850884074149, 0.671325955294912, -0.767850884074149, 0.369480488190256,
-0.289390335024597, 0.0791536014054641, -0.767850884074149, 2.79341797532877,
0.225596235651883, -0.767850884074149, -0.767850884074149, -0.472755159377862,
2.58882706714496, 1.12766318625321, 0.312810011192943, 2.00670978225976,
-0.486023852035069, 0.136757539237998, 1.68270945706995, -0.767850884074149,
0.970984615352628, 1.24052217326025, -0.510972474543447, 0.729924457795809,
0.54106416210258, 0.677682601292349, 1.00412854244753, 0.21920946120918,
0.458536361427269, -0.767850884074149, -0.767850884074149, -0.767850884074149,
-0.767850884074149, -0.767850884074149, -0.374142768964279, -0.45000944428971,
0.0476664458114263, 0.899126028837855, -0.767850884074149, -0.355413039955347,
-0.0830912490542143, -0.397411570093704, -0.159323389651903,
0.187611492838829, -0.767850884074149, 0.558206914067417, -0.624618864134726,
-0.767850884074149, -0.135502902582105, -0.262505861107523, -0.214425105898759,
-0.664095989312535, -0.0650703296976993, 2.52667379006469, -0.570071981006544,
-0.471397148679766, 0.344343491575343, -0.248130055504569, -0.485006425926625,
0.580392393481832, 0.844944764062817, 0.24526870523875, -0.365033840851197,
-0.767850884074149, 0.202401932109928, 0.298950997178277, -0.552609359604742,
-0.286313505862145, -0.267225229717189, 0.992776417203698, 0.252968854126036,
-0.337958346862274, -0.103885503507503, -0.767850884074149, -0.06359996356301,
-0.498614063088361, -0.264221690329575, -0.767850884074149, 0.0117303587802212,
-0.767850884074149, -0.455084756701438, -0.767850884074149, -0.510972474543447,
-0.767850884074149, 0.233400424639771, 0.137033599931982, -0.20822292045369
), AGE67CINo = c(0.520215683696846, -1.03163833351555, -0.509119810566456,
-1.87130724349216, -0.314585715658465, -0.673792033984295, 0.823598776742872,
0.903906066078584, -1.50426966218324, 0.02457959376149, -0.57163006325451,
-0.460147962605314, 2.57888667222345, 2.06262552649387, -1.16895180207709,
1.5780646265547, 1.26504842710326, -0.301929339443148, 3.09282231729659,
1.69955169568212, 0.45126498073184, -0.431038644185776, -0.52546864588487,
-0.776712327050538, -0.0582690276197286, -1.33789498354707, -0.811366964612021,
2.0240542914681, 0.519263167059362, 0.300590333065505, -0.18388341953198,
-0.755742030018846, -0.149063665921569, -0.510160314132326, -1.29934861634348,
-1.32495034122243, -1.41718294872725, 0.1491325448279, -0.589483545018553,
-1.09646319324319, -0.643748191001391, -0.91776562934997, -0.582479723072687,
-0.263654292536764, 2.32687904677992, 3.14283459045014, 0.309693195176315,
-0.0819768879271749, -0.55910582347913, -0.39952005005421, 1.01240633996832,
0.311639807227703, -0.445268920756996, -0.542619478394475, 0.213605110937565,
0.406078840439239, 0.349100219631252, 2.57888667222345, -1.8567844317766,
-0.0711594953896716, 0.0132252207188618, -0.594515172713188,
-0.366188280170573, -1.45892393133304, -1.45847888111253, 0.949648680774444,
-1.2753657237661, -1.97252072522736, -1.00402956173803, 0.684707699839703,
-0.385405489574873, -0.0376751575317439, -0.864005933475389,
-0.143742981474305, 0.258678534519238, 0.439300527132514, 2.24133130770795,
-0.545851841402966, 0.0610101153751223, -0.46848814631922, -0.288656266519766,
-0.524178190926454, 1.55989931716918, 2.15959966135947, 0.567034062253331,
-0.315126128071832, -0.189309291975311, 0.657931155146952, -0.914598686765673,
-0.828009160908805, -0.219023614915597, 1.72016373828187, 0.305251727394181,
-0.819574921820783, 1.40141138963966, -0.619957661890062, 0.0474226036519372,
0.293647638969982, -0.690104740437684, 0.360397804324386), AGE67CAR2 = c(-0.65814805058186,
-0.65814805058186, -0.65814805058186, 0.00768964415182037, -0.65814805058186,
1.19533826054216, -0.65814805058186, -0.231881369541573, -0.65814805058186,
0.231603684740978, 0.237714616989074, -0.397640616985576, 4.000337820787,
2.0821377561057, -0.476631762802389, 0.414529617167548, 0.584452971176845,
-0.160294598679845, 0.223187114271708, 0.284320189290452, -0.65814805058186,
-0.439244037830135, -0.65814805058186, 0.0880625008959889, 0.365373641181542,
-0.0231061596455018, -0.289977393005934, 0.466314056300278, -0.464325455824444,
-0.0716480308052054, -0.428569051007219, 1.23605308567625, 0.101092952900841,
1.01198772916626, -0.280209804949945, -0.65814805058186, -0.361294148901551,
-0.65814805058186, -0.0564985837999044, -0.124416160746424, -0.65814805058186,
-0.41672466856486, -0.116013368959549, -0.65814805058186, 0.563177833297616,
3.27069545539188, -0.65814805058186, -0.65814805058186, -0.451445457404997,
-0.415698971774559, -0.65814805058186, 1.37993951814202, -0.308466424440895,
-0.420643308414475, -0.272694136875217, -0.0249557962210446,
-0.65814805058186, 1.45934552731308, -0.198212915185781, -0.211035604061083,
0.338319515486345, 0.131425825921336, -0.0109756149662412, -0.261005180047431,
0.421454700604434, 0.0983335709673666, -0.101088636265742, 0.0964345427048702,
-0.508081777689168, -0.400976117619857, -0.117361962694928, 0.813395320337906,
-0.0346413373967359, -0.481499073877732, 0.0446407662194761,
-0.0294320910783511, 0.909611617667275, 0.351430930829348, 0.116299410617839,
-0.210216716796393, -0.398341759098735, -0.65814805058186, -0.65814805058186,
0.100563003387009, -0.509042421411389, 0.511028651163257, 0.0431293924199034,
0.660883904007974, -0.65814805058186, -0.378658861813212, -0.273829829019315,
4.20892674040053, -0.0768753037087408, -0.441185967417642, 2.04802631452868,
-0.65814805058186, -0.269478192422361, 0.383685850363252, -0.65814805058186,
-0.65814805058186), AGE18_34FLD = c(-0.130080127803667, -1.09592168708659,
1.38767089392664, 0.501488548787159, 2.32784926704544, -0.235270792676064,
1.13931163582532, 0.540327542757423, 0.577515025253877, -0.242080854323515,
-0.236216562889702, 0.683441681810391, -1.09592168708659, -1.09592168708659,
0.0363145076530911, -0.0665379199561058, 0.777926607570102, -0.618161892876412,
-0.250157727065868, 1.61736731760417, -1.09592168708659, 0.864721549837708,
-1.09592168708659, -0.379828631691471, 0.475617121238168, 0.122900767471024,
-0.439772550360837, -0.0168435311980823, 1.88907511940198, -0.533093152540427,
0.358144484675311, 1.00148223521681, -0.367324096381894, -0.561678980799494,
-0.519893465065387, -0.0699129487007959, 0.470878184230816, 0.0243623172368012,
0.0588113948605217, 0.937318028328768, 0.13609667987273, -0.268495285162563,
-0.315540975845525, -1.09592168708659, 1.24814322083601, 2.67435138770459,
-1.09592168708659, -1.09592168708659, -0.342153857605648, -0.165266920669736,
0.0137686150682584, -1.09592168708659, 0.246351394326684, -0.488137923633972,
1.49335568460806, 0.119350799156778, -1.09592168708659, 0.936108606469693,
-0.213177920633368, -0.523832719613707, 0.577515025253877, -0.843352950034398,
1.00012940708533, -0.486139365405593, -0.292998940733529, 0.743150579831344,
1.1246254773305, -0.722886464554487, -0.455879420465601, -0.849129364367611,
-0.576962043292781, 0.0337991259374796, -0.198409262972822, -0.756883046558241,
2.276197550065, -0.29146753488179, -1.09592168708659, -0.127089906257899,
0.440004947086572, 0.336920186574891, -0.205491840718313, 0.570386434359988,
-1.09592168708659, 0.0690207613365549, 0.478042335448546, -0.0759325566835026,
-0.422948213521715, -0.39270221470981, -0.0205517035550875, -0.626556215380212,
-0.542711318422477, -1.09592168708659, 0.577515025253877, -0.67951095494066,
0.202554102156843, -0.582917317893694, -0.349955858224452, -0.0961368142186437,
-0.794735291314053, 0.39423386152135), AGE18_34IT = c(-0.0922984151251803,
-1.10842633780253, 0.18885009667243, 0.507456373441544, 0.420167034168932,
-0.215706837795418, -1.55427067647275, 0.162389060267949, 0.754154056222705,
-0.707142977760336, -0.503207778521687, -0.0327393224943708,
-1.04820335523705, -0.869591359506808, -0.58143266150616, -0.428936764777582,
0.330878305136329, -1.19500150948863, -1.29439826827064, 0.60163461688395,
1.16009041015508, 0.175009713455721, -1.36162571986824, -0.27694286108837,
1.14189589937619, 2.21684247777924, -0.56335821198866, -0.594487825853325,
1.54764337967806, -0.786067764381153, -0.234733270862691, 2.52999561314261,
0.331551786013243, -0.0645795477648524, -0.785164273078549, -0.318139022962605,
-0.125442459074011, -0.0312375211548161, 0.292594934974131, 0.726270176261751,
-1.00038597055337, -0.164525780028025, -0.293519419828208, 1.40909260663938,
0.0624912149656389, 2.63450341616489, 0.0947775908237999, 1.49993461399998,
-0.497483035068795, -0.462494708354598, 0.200813863131902, -0.289102373383507,
0.136007989951257, -0.873494025109071, 0.870784784862377, 0.553524282848747,
0.963722824309746, -0.220093193215003, 0.63773601395719, -0.11748356517652,
-0.0577186516420515, -0.859500286301709, 1.15609583262114, -0.921362952472097,
-0.383220339847388, 0.525567398769661, 1.35644767444532, -0.823292431599732,
-0.767028966908216, -1.22221388525185, 0.0545702104407355, -0.142032267392627,
-0.637930804177903, -1.33111856407521, 6.33776194521154, -0.147949920031134,
0.664639885868386, -0.963597673049043, -0.142153756039259, -0.0490447979255479,
0.00402038984987186, 2.24042581854381, 0.3278448665499, -0.378941686111017,
1.01725476517897, -0.097430200216662, -0.536694449902048, 0.518899495690568,
-0.721259518837359, -0.390435116407882, -0.933093169027874, -0.297914490917462,
-0.165968346024019, -0.646177818606699, 0.879472167146227, 1.30791335346685,
-0.79366174670134, -1.38127002672125, -0.969297767132932, 0.976065929705738
)), row.names = c(6737L, 3053L, 831L, 2255L, 6090L, 5183L, 347L,
3260L, 2795L, 4098L, 2961L, 4487L, 576L, 1838L, 3515L, 6756L,
3888L, 5386L, 7080L, 145L, 1236L, 1962L, 1096L, 7603L, 6386L,
7120L, 2560L, 5374L, 3771L, 13L, 3489L, 6914L, 6893L, 5378L,
6236L, 1912L, 1734L, 6587L, 2806L, 5165L, 3419L, 7584L, 5958L,
7661L, 5073L, 5789L, 828L, 2947L, 6510L, 2500L, 274L, 1024L,
5486L, 4215L, 7079L, 7258L, 2931L, 4856L, 2683L, 6654L, 6953L,
1424L, 6876L, 6027L, 7459L, 3952L, 6722L, 6039L, 6223L, 3723L,
6206L, 5029L, 3131L, 3807L, 7124L, 3610L, 960L, 466L, 4465L,
5901L, 6073L, 6863L, 2636L, 4187L, 5715L, 4266L, 7746L, 4024L,
3481L, 6300L, 7738L, 1006L, 3714L, 1952L, 3997L, 6171L, 5086L,
2553L, 4783L, 7212L), class = "data.frame")
training_set$Target<-structure(c(2L, 3L, 1L, 4L, 4L, 3L, 1L, 1L, 3L, 3L, 3L, 4L, 3L,
4L, 4L, 2L, 4L, 3L, 3L, 2L, 3L, 3L, 4L, 1L, 2L, 4L, 2L, 3L, 1L,
3L, 2L, 4L, 1L, 1L, 4L, 1L, 1L, 4L, 2L, 3L, 3L, 2L, 4L, 2L, 2L,
3L, 2L, 1L, 1L, 4L, 2L, 3L, 4L, 3L, 4L, 2L, 2L, 3L, 4L, 1L, 2L,
1L, 3L, 3L, 4L, 3L, 3L, 3L, 2L, 1L, 3L, 2L, 2L, 3L, 4L, 4L, 1L,
2L, 2L, 2L, 2L, 2L, 2L, 2L, 3L, 2L, 1L, 2L, 2L, 2L, 2L, 3L, 2L,
3L, 1L, 3L, 4L, 4L, 2L, 4L), levels = c("Q1", "Q2", "Q3", "Q4"
), class = "factor")
This is a complex problem. It is probably easiest to convert your rpart
to an igraph
object. This will allow you to select subgraphs that end in a particular class.
Step 1: Reproducing the problem
The sample data does not produce a decision tree, so was not sufficient to allow the problem to be reproduced. However, we can easily create something similar using:
training_set <- expand.grid(A = 1:4, B = 1:4, C = 1:4, D = 1:4)
training_set$Target <- cut(rowSums(training_set),
breaks = c(0, 8, 10, 12, 20),
labels = c('w', 'x', 'y', 'z'))
This now allows us to use your own code to create the classifier
object:
library(rpart)
library(caret)
fitControl <- trainControl(method = "repeatedcv",
number = 10,
repeats = 10)
classifier = train(x = training_set[, names(training_set) != "Target"],
y = training_set$Target,
method = 'rpart',
parms = list(split = "gini"), trControl = fitControl,
tuneLength = 20)
classifier
complexity_parameter=classifier$bestTune
classifier = rpart(formula = Target ~ .,
data = training_set,parms = list(split = "information"),
control = rpart.control(cp = complexity_parameter))
We now get a similar legibilty problem with the output of the fancy rpart plot:
library(RColorBrewer)
library(rattle)
fancyRpartPlot(classifier, caption = NULL, clip.right.labs=FALSE,branch=.3,type=3,
tweak=1.4)
Step 2: Converting rpart
to igraph
I couldn't find an existing method for converting an rpart
to an igraph
binary tree. The data.tree
package allows for rpart
to Node
to igraph
, but the end result is not a binary tree.
The method I have used here is to create a binary tree in igraph
, copy over the node attributes and remove vertices that are missing from the rpart
object:
library(igraph)
df <- classifier$frame
nodes <- as.numeric(row.names(df))
non_nodes <- setdiff(seq(max(nodes)), nodes)
g <- graph.tree(max(nodes), mode = 'out')
labs <- ifelse(df$var == '<leaf>',
levels(training_set$Target)[df$yval],
labels(classifier))
classed <- ifelse(df$var == '<leaf>', df$yval, NA)
vertex_attr(g, 'name') <- labs[match(V(g), nodes)]
vertex_attr(g, 'number') <- df$n[match(V(g), nodes)]
vertex_attr(g, 'class') <- as.character(classed[match(V(g), nodes)])
g <- delete.vertices(g, non_nodes)
Step 3: Ensure our igraph
is correct by plotting with ggraph
We will plot the igraph
using ggraph
to check it is correct. Note that we still run into legibility problems because we have not yet converted to subgraphs:
library(ggraph)
ggraph(g, layout = 'tree') +
geom_edge_diagonal() +
geom_node_label(aes(label = paste(name, number, sep = '\n n = '),
fill = class)) +
scale_fill_manual(values = c(`1` = 'lightgreen', `2` = 'lightblue',
`3` = 'orange', `4` = '#E0A8FF'),
na.value = 'white', guide = 'none') +
theme_graph()
Step 4: Getting subgraphs that end in each class
This involves obtaining the subcomponents that represent paths to each node containing a target class, then getting the induced subgraph:
subs <- lapply(levels(training_set$Target), function(n) {
which(V(g)$name == n) |>
lapply(function(x) subcomponent(g, x, 'in')) |>
unlist() |>
unique()})
subs <- lapply(subs, function(x) {
induced.subgraph(g, x)
})
Step 5: plotting the result
This is really just replicating our plotting code above for each subgraph. We put them in a list for convenience:
plots <- lapply(subs, function(x) {
ggraph(x, layout = 'tree') +
geom_edge_diagonal() +
geom_node_label(aes(label = paste(name, number, sep = '\n n = '),
fill = class)) +
scale_fill_manual(values = c(`1` = 'lightgreen', `2` = 'lightblue',
`3` = 'orange', `4` = '#E0A8FF'),
na.value = 'white', guide = 'none') +
theme_graph()
})
Now we have:
plots[[1]]
plots[[2]]
plots[[3]]
plots[[4]]
We can see that this has given us the correct subtrees with the correct partition paths. The displayed information at each node is customizable - it is just a case of copying it over from rpart
to igraph
at step 2.