[S] predict.rpart bug

Atkinson, Elizabeth J. (atkinson@mayo.edu)
Tue, 10 Feb 1998 09:17:34 -0600


----------
X-Sun-Data-Type: text
X-Sun-Data-Description: text
X-Sun-Data-Name: text
X-Sun-Charset: us-ascii
X-Sun-Content-Lines: 31

~> Hello,
~>
~> Can anybody give some advice/help on the following problem?
~>
~> I'm using the RPart library, and the following error occurs, when I do
~> the following:
~>
~> CMD > species.rpart.tree <- rpart(crypto.frame)
~> CMD > predict.rpart(species.rpart.tree, crypto.test)
~> CMD > Error in "[.data.frame"(frame, ,c("n", "ncompete..:undefined
~> columns selected
~> dumped

I've already replied to Gary, but this is a known bug. The plan is to get
this (and a few other smaller bugs), plus a few enhancements, added to
a new statlib version of rpart in the relatively near future.

In the meantime, I've attached the newest versions of the predict functions.

- Beth

*****************************************************************************
Beth Atkinson |
Mayo Clinic | phone: 507-284-0431
200 First Street SW | FAX: 507-284-9542
Harwick 7 - Statistics | internet: atkinson@mayo.edu
Rochester, MN 55905 |
*****************************************************************************

----------
X-Sun-Data-Type: default
X-Sun-Data-Description: default
X-Sun-Data-Name: pred.rpart.s
X-Sun-Charset: us-ascii
X-Sun-Content-Lines: 33

# SCCS @(#)pred.rpart.s 1.3 09/03/97
#
# Do Rpart predictions given a tree and a matrix of predictors
pred.rpart <- function(fit, x) {

frame <- fit$frame
nc <- frame[, c('ncompete', 'nsurrogate')]
frame$index <- 1 + c(0, cumsum((frame$var != "<leaf>") +
nc[[1]] + nc[[2]]))[-(nrow(frame)+1)]
frame$index[frame$var == "<leaf>"] <- 0
vnum <- match(dimnames(fit$split)[[1]], dimnames(x)[[2]])
if (any(is.na(vnum))) stop("Tree has variables not found in new data")
temp <- .C("pred_rpart",
as.integer(dim(x)),
as.integer(dim(frame)[1]),
as.integer(dim(fit$splits)),
as.integer(if(is.null(fit$csplit)) rep(0,2)
else dim(fit$csplit)),
as.integer(row.names(frame)),
as.integer(unlist(frame[,
c('n', 'ncompete', 'nsurrogate', 'index')])),
as.integer(vnum),
as.double(fit$splits),
as.integer(fit$csplit -2),
as.integer((fit$control)$usesurrogate),
as.double(x),
as.integer(is.na(x)),
where = integer(dim(x)[1]),
NAOK =T)
temp <- temp$where
names(temp) <- dimnames(x)[[1]]
temp
}
----------
X-Sun-Data-Type: c-file
X-Sun-Data-Description: c-file
X-Sun-Data-Name: pred_rpart.c
X-Sun-Charset: us-ascii
X-Sun-Content-Lines: 125

/* SCCS @(#)pred_rpart.c 1.4 05/27/97
**
** Do rpart predictions given the matrix form of the tree.
**
** Input
** dimx : # of rows and columns in the new data
** nnode : # of nodes in the tree
** nsplit : # of split structures
** dimc : dimension of the categorical splits matrix
** nnum : node number for each row of 'nodes'
** nodes : matrix of node info
** row 0= count, 1=index of primary, 2=#competitors,
** 3= number of surrogates
** vnum : variable number of each split
** split : matrix of split info
** : row 0=useage count, 1= #categories if >1, otherwise
** the split parity, 2= utility, 3= index to csplit
** or numeric split point
** csplit : matrix of categorical split info
** usesur : at what level to use surrogates
** xdata : the new data
** xmiss : shows missings in the new data
**
** Output
** where : the "final" row in nodes for each observation
*/

void pred_rpart(dimx, nnode, nsplit, dimc, nnum, nodes2, vnum, split2,
csplit2, usesur, xdata2, xmiss2, where)
long *dimx,
*nnode,
*nsplit,
*nnum,
*dimc,
*nodes2,
*vnum,
*csplit2,
*usesur,
*xmiss2,
*where;
double *split2,
*xdata2;
{
int i,j;
int n;
int ncat;
int node, nspl, var, dir;
int lcount, rcount;
int npos;
double temp;
long *nodes[4];
double *split[4];
long **csplit,
**xmiss;
double **xdata;

n = dimx[0];
for (i=0; i<4; i++) {
nodes[i] = &(nodes2[*nnode *i]);
split[i] = &(split2[*nsplit*i]);
}

csplit = (long **) S_alloc((int)dimc[1], sizeof(int*));
xmiss = (long **) S_alloc((int)dimx[1], sizeof(int*));
xdata = (double **) S_alloc((int)dimx[1], sizeof(double*));
for (i=0; i<dimc[1]; i++) csplit[i] = &(csplit2[i * dimc[0]]);
for (i=0; i<dimx[1]; i++) {
xmiss[i] = &(xmiss2[i * dimx[0]]);
xdata[i] = &(xdata2[i * dimx[0]]);
}

for (i=0; i<n; i++) {
node =1; /*current node of the tree */
next: for (npos=0; nnum[npos]!=node; npos++); /*position of the node */
/* walk down the tree */
nspl = nodes[3][npos] -1; /*index of primary split */
if (nspl >=0) { /* not a leaf node */
var = vnum[nspl] -1;
if (xmiss[var][i]==0) { /* primary var not missing */
ncat = split[1][nspl];
temp = split[3][nspl];
if (ncat >=2) dir = csplit[(int)xdata[var][i] -1][(int)temp-1];
else if (xdata[var][i] < temp) dir=ncat;
else dir= -ncat;
if (dir!=0) {
if (dir== -1) node = 2*node;
else node = 2*node +1;
goto next;
}
}

if (*usesur >0 ) {
for (j=0; j<nodes[2][npos]; j++) {
nspl = nodes[1][npos] + nodes[3][npos] + j;
var = vnum[nspl] -1;
if (xmiss[var][i]==0) { /* surrogate not missing */
ncat = split[1][nspl];
temp = split[3][nspl];
if (ncat >=2) dir = csplit[(int)xdata[var][i] -1][(int)temp-1];
else if (xdata[var][i] < temp) dir=ncat;
else dir= -ncat;
if (dir!=0) {
if (dir== -1) node = 2*node;
else node = 2*node +1;
goto next;
}
}
}
}

if (*usesur >1) { /* go with the majority */
for (j=0; nnum[j]!= (2*node); j++);
lcount = nodes[0][j];
for (j=0; nnum[j]!= (1+ 2*node); j++);
rcount = nodes[0][j];
if (lcount != rcount) {
if (lcount > rcount) node = 2*node;
else node = 2*node +1;
goto next;
}
}
}
where[i] = npos +1;
}
}
----------
X-Sun-Data-Type: default
X-Sun-Data-Description: default
X-Sun-Data-Name: predict.rpart.s
X-Sun-Charset: us-ascii
X-Sun-Content-Lines: 44

## SCCS @(#)predict.rpart.s 1.6 09/03/97

predict.rpart <-
function(object, newdata = list(), type = c("vector", "tree", "class"))
{
if(!inherits(object, "rpart"))
stop("Not legitimate tree")
if(missing(type))
type <- "vector"
else type <- match.arg(type, c("vector", "tree", "class"))
if(missing(newdata) & type == "tree")
return(object) #idiot proofing
if(missing(newdata))
where <- object$where
else {
if(is.null(attr(newdata, "terms"))) {
Terms <- delete.response(object$terms)
act <- (object$call)$na.action
if (is.null(act)) act<- na.rpart
newdata <- model.frame(Terms, newdata, na.action = act)
}
where <- pred.rpart(object, rpart.matrix(newdata))
}
frame <- object$frame
method <- object$method
ylevels <- attr(object,'ylevels')
if(type == "vector") {
if(length(ylevels)>0){
frame <- frame$yprob[where,]
dimnames(frame)[[1]] <- names(where)
} else {
frame <- frame$yval[where]
names(frame) <- names(where)
}
return(frame)
} else if(type == "class") {
if(length(ylevels) == 0)
stop("Type class is only appropriate for classification")
frame <- factor(ylevels[frame$yval[where]], levels=ylevels)
names(frame) <- names(where)
return(frame)
} else stop("Cannot do rpart objects yet")
}

-----------------------------------------------------------------------
This message was distributed by s-news@wubios.wustl.edu. To unsubscribe
send e-mail to s-news-request@wubios.wustl.edu with the BODY of the
message: unsubscribe s-news