From: Ron S. <ron...@ya...> - 2006-06-20 15:34:49
|
User: rsigal Date: 06/06/20 11:11:25 Modified: src/main/org/jboss/remoting AbstractInvoker.java Log: JBREM-368: Moved createSocketFactory() into AbstractInvoker from RemoteClientInvoker, and extended its functionality to accept configuration from more sources. Revision Changes Path 1.9 +171 -2 JBossRemoting/src/main/org/jboss/remoting/AbstractInvoker.java (In the diff below, changes in quantity of whitespace are not shown.) Index: AbstractInvoker.java =================================================================== RCS file: /cvsroot/jboss/JBossRemoting/src/main/org/jboss/remoting/AbstractInvoker.java,v retrieving revision 1.8 retrieving revision 1.9 diff -u -b -r1.8 -r1.9 --- AbstractInvoker.java 19 Jan 2006 04:25:24 -0000 1.8 +++ AbstractInvoker.java 20 Jun 2006 15:11:25 -0000 1.9 @@ -22,11 +22,20 @@ package org.jboss.remoting; +import java.io.IOException; +import java.lang.reflect.Constructor; import java.util.HashMap; import java.util.Map; + +import javax.net.ServerSocketFactory; +import javax.net.SocketFactory; +import javax.net.ssl.SSLSocketFactory; + import org.jboss.logging.Logger; import org.jboss.remoting.loading.ClassByteClassLoader; import org.jboss.remoting.marshal.MarshallLoaderFactory; +import org.jboss.remoting.security.CustomSSLSocketFactory; +import org.jboss.remoting.security.SSLSocketBuilder; import org.jboss.remoting.serialization.SerializationStreamFactory; /** @@ -35,16 +44,22 @@ * * @author <a href="mailto:jh...@vo...">Jeff Haynie</a> * @author <a href="mailto:te...@e2...">Tom Elrod</a> - * @version $Revision: 1.8 $ + * @version $Revision: 1.9 $ */ public abstract class AbstractInvoker implements Invoker { + public static final String SOCKET_FACTORY = "socketFactory"; + protected final Logger log = Logger.getLogger(getClass()); protected ClassByteClassLoader classbyteloader; protected InvokerLocator locator; protected Map localServerLocators = new HashMap(); protected String serializationType; protected Map configuration = new HashMap(); + protected SocketFactory socketFactory; + + // Indicates if the serverSocketFactory was generated internally. + protected boolean socketFactoryCreatedFromSSLParameters; public AbstractInvoker(InvokerLocator locator) { @@ -145,4 +160,158 @@ { this.serializationType = serializationType; } + + public SocketFactory getSocketFactory() + { + return socketFactory; + } + + public void setSocketFactory(SocketFactory socketFactory) + { + this.socketFactory = socketFactory; + } + + public boolean isSocketFactoryCreatedFromSSLParameters() + { + return socketFactoryCreatedFromSSLParameters; + } + + /** + * If any configuration parameters relate to the construction of a SSLSocketBuilder, + * create one. + * + * @param configuration + * @return + */ + protected SocketFactory createSocketFactory(Map configuration) + { + if (configuration == null) + return null; + + if (socketFactory != null) + return socketFactory; + + SocketFactory factory = null; + + Object obj = configuration.get(Remoting.CUSTOM_SOCKET_FACTORY); + if (obj != null) + { + if (obj instanceof SocketFactory) + { + factory = (SocketFactory) obj; + } + else + { + throw new RuntimeException("Can not set custom socket factory (" + obj + ") as is not of type javax.net.SocketFactory"); + } + } + + if(factory == null) + { + String socketFactoryString = (String)configuration.get(Remoting.SOCKET_FACTORY_NAME); + if(socketFactoryString != null && socketFactoryString.length() > 0) + { + //ClassLoader classLoader = invoker.getClassLoader(); + ClassLoader classLoader = null; + if(classLoader == null) + { + classLoader = Thread.currentThread().getContextClassLoader(); + + if(classLoader == null) + { + classLoader = getClass().getClassLoader(); + } + } + try + { + Class cl = classLoader.loadClass(socketFactoryString); + + Constructor socketFactoryConstructor = null; + socketFactoryConstructor = cl.getConstructor(new Class[]{}); + factory = (SocketFactory)socketFactoryConstructor.newInstance(new Object[] {}); + log.trace("SocketFactory (" + socketFactoryString + ") loaded"); + } + catch(Exception e) + { + log.debug("Could not create socket factory by classname (" + socketFactoryString + "). Error message: " + e.getMessage()); + } + } + } + + if (factory == null) + { + if (justNeedsSSLServerMode(configuration)) + { + SSLSocketBuilder socketBuilder = new SSLSocketBuilder(); + socketBuilder.setSocketUseClientMode( false ); + SSLSocketFactory defaultFactory = (SSLSocketFactory) SSLSocketFactory.getDefault(); + factory = new CustomSSLSocketFactory(defaultFactory, socketBuilder); + socketFactoryCreatedFromSSLParameters = true; + } + else if (needsCustomSSLConfiguration(configuration)) + { + try + { + SSLSocketBuilder socketBuilder = new SSLSocketBuilder(configuration); + socketBuilder.setUseSSLSocketFactory( false ); + factory = socketBuilder.createSSLSocketFactory(); + socketFactoryCreatedFromSSLParameters = true; + } + catch (IOException e) + { + throw new RuntimeException("Unable to create customized SSL socket factory", e); + } + } + } + + return factory; + } + + protected boolean justNeedsSSLServerMode(Map configuration) + { + if (configuration.size() == 1 && configuration.containsKey(SSLSocketBuilder.REMOTING_SOCKET_USE_CLIENT_MODE)) + { + String useClientModeString = (String) configuration.get(SSLSocketBuilder.REMOTING_SOCKET_USE_CLIENT_MODE); + return !Boolean.parseBoolean(useClientModeString); + } + + if (configuration.size() == 1 && configuration.containsKey(SSLSocketBuilder.REMOTING_SERVER_SOCKET_USE_CLIENT_MODE)) + { + String useClientModeString = (String) configuration.get(SSLSocketBuilder.REMOTING_SERVER_SOCKET_USE_CLIENT_MODE); + return !Boolean.parseBoolean(useClientModeString); + } + + if (configuration.size() == 2 + && configuration.containsKey(SSLSocketBuilder.REMOTING_SOCKET_USE_CLIENT_MODE) + && configuration.containsKey(SSLSocketBuilder.REMOTING_SERVER_SOCKET_USE_CLIENT_MODE)) + { + String useClientModeString = (String) configuration.get(SSLSocketBuilder.REMOTING_SOCKET_USE_CLIENT_MODE); + return !Boolean.parseBoolean(useClientModeString); + } + + return false; + } + + protected boolean needsCustomSSLConfiguration(Map configuration) + { + if (configuration.get(SSLSocketBuilder.REMOTING_KEY_ALIAS) != null || + configuration.get(SSLSocketBuilder.REMOTING_CLIENT_AUTH_MODE) != null || + configuration.get(SSLSocketBuilder.REMOTING_SERVER_AUTH_MODE) != null || + configuration.get(SSLSocketBuilder.REMOTING_SSL_PROTOCOL) != null || + configuration.get(SSLSocketBuilder.REMOTING_SSL_PROVIDER_NAME) != null || + configuration.get(SSLSocketBuilder.REMOTING_SOCKET_USE_CLIENT_MODE) != null || + configuration.get(SSLSocketBuilder.REMOTING_KEY_PASSWORD) != null || + configuration.get(SSLSocketBuilder.REMOTING_KEY_STORE_ALGORITHM) != null || + configuration.get(SSLSocketBuilder.REMOTING_KEY_STORE_FILE_PATH) != null || + configuration.get(SSLSocketBuilder.REMOTING_KEY_STORE_PASSWORD) != null || + configuration.get(SSLSocketBuilder.REMOTING_KEY_STORE_TYPE) != null || + configuration.get(SSLSocketBuilder.REMOTING_TRUST_STORE_ALGORITHM) != null || + configuration.get(SSLSocketBuilder.REMOTING_TRUST_STORE_FILE_PATH) != null || + configuration.get(SSLSocketBuilder.REMOTING_TRUST_STORE_PASSWORD) != null || + configuration.get(SSLSocketBuilder.REMOTING_TRUST_STORE_TYPE) != null + ) + return true; + else + return false; + } } |