From: Rajarshi G. <raj...@us...> - 2005-03-25 15:30:28
|
Update of /cvsroot/cdk/cdk/src/org/openscience/cdk/qsar/model/data In directory sc8-pr-cvs1.sourceforge.net:/tmp/cvs-serv13287/src/org/openscience/cdk/qsar/model/data Modified Files: cdkSJava.R Log Message: CNN regression now works for both fitting and prediction. A unit test has been added Index: cdkSJava.R =================================================================== RCS file: /cvsroot/cdk/cdk/src/org/openscience/cdk/qsar/model/data/cdkSJava.R,v retrieving revision 1.6 retrieving revision 1.7 diff -u -r1.6 -r1.7 --- cdkSJava.R 25 Mar 2005 01:34:36 -0000 1.6 +++ cdkSJava.R 25 Mar 2005 15:30:19 -0000 1.7 @@ -109,9 +109,15 @@ noutput, nobs,obj$wts, obj$fitted, obj$residuals, obj$value) } } -cnnPredictConverter <- function(preds,...) { +cnnPredictConverter <- +function(obj,...) { + # The obj we get is actually a 'matrix' but we set its class + # to cnnregprediction so that SJava would send it specifically + # to us. So we should convert obj back to class 'matrix' so + # that SJava can send it correctly to the Java side + class(obj) <- 'matrix' .JNew("org.openscience.cdk.qsar.model.R.CNNRegressionModelPredict", - ncol(preds), preds) + ncol(obj), obj) } ############################################# @@ -120,13 +126,13 @@ setJavaFunctionConverter(lmFitConverter, function(x,...){inherits(x,"lm")}, description="lm fit object to Java", fromJava=F) -setJavaFunctionConverter(lmPredictConverter, function(x,...){inherits(x,"lmprediction")}, +setJavaFunctionConverter(lmPredictConverter, function(x,...){inherits(x,"lmregprediction")}, description="lm predict object to Java", fromJava=F) setJavaFunctionConverter(cnnFitConverter, function(x,...){inherits(x,"nnet")}, description="cnn (nnet) fit object to Java", fromJava=F) -setJavaFunctionConverter(cnnPredictConverter, function(x,...){inherits(x,"cnnprediction")}, +setJavaFunctionConverter(cnnPredictConverter, function(x,...){inherits(x,"cnnregprediction")}, description="cnn (nnet) predict object to Java", fromJava=F) @@ -168,7 +174,7 @@ interval = 'confidence' } preds <- predict( get(modelname), newx, se.fit = TRUE, interval=interval); - class(preds) <- 'lmprediction' + class(preds) <- 'lmregprediction' detach(paramlist) preds @@ -199,11 +205,27 @@ decay=decay,maxit=maxit,Hess=Hess,trace=trace,MaxNWts=MaxNWts, abstol=abstol,reltol=reltol), pos=1) - tmp <- get(modelname) - save(x,y,Wts, file='myrun') detach(paramlist) get(modelname) } + +predictCNN <- function(modelname, params) { + # Since buildCNN should have been called before this + # we don't bother loading the nnet library + paramlist <- hashmap.to.list(params) + attach(paramlist) + + newx <- data.frame( matrix(unlist(newdata), nrow=length(newdata), byrow=TRUE) ) + if (type == '' || !(type %in% c('raw','class')) ) { + type = 'raw' + } + + preds <- predict( get(modelname), newdata=newx, type=type); + class(preds) <- 'cnnregprediction' + + detach(paramlist) + preds +} |