Tuesday, February 23, 2010

SASL Plain Server

For some reason SUN includes support for sasl PLAIN in their SaslClient, but not SaslServer. Here's a quick hack to implement a plain sasl server. As per rfc4616.txt, there is only message, three UTF-8 strings, max length 255 delimited by null (\u0000) sent from the client. There is no server response.
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.sasl.AuthorizeCallback;
import javax.security.sasl.SaslException;
import javax.security.sasl.SaslServer;

/**
 * Simple sasl server to decode a
 * @author toaster
 */
public class PlainSaslServer implements SaslServer
{

    public boolean complete = false;
    private String authID = null;
    private CallbackHandler cbh;
    
    PlainSaslServer(CallbackHandler cbh)
    {
        this.cbh = cbh;
    }

    public String getMechanismName()
    {
        return SaslConstants.SASL_SERVICE;
    }

    public byte[] evaluateResponse(byte[] response) throws SaslException
    {
        String authnID = null;
        String authzID = null;
        String password = null;

        // Decode response as per rfc4616.txt
//       The formal grammar for the client message using Augmented BNF [ABNF]
//       follows.
//
//       message   = [authzid] UTF8NUL authcid UTF8NUL passwd
//       authcid   = 1*SAFE ; MUST accept up to 255 octets
//       authzid   = 1*SAFE ; MUST accept up to 255 octets
//       passwd    = 1*SAFE ; MUST accept up to 255 octets
//       UTF8NUL   = %x00 ; UTF-8 encoded NUL character
//
//       SAFE      = UTF1 / UTF2 / UTF3 / UTF4
//                   ;; any UTF-8 encoded Unicode character except NUL
//
//       UTF1      = %x01-7F ;; except NUL
//       UTF2      = %xC2-DF UTF0
//       UTF3      = %xE0 %xA0-BF UTF0 / %xE1-EC 2(UTF0) /
//                   %xED %x80-9F UTF0 / %xEE-EF 2(UTF0)
//       UTF4      = %xF0 %x90-BF 2(UTF0) / %xF1-F3 3(UTF0) /
//                   %xF4 %x80-8F 2(UTF0)
//       UTF0      = %x80-BF

        int start = 0;
        int end = 0;
        int elementIdx = 0;

        try
        {


            for (byte b : response)
            {
                if (b == '\u0000')
                {
                    // empty string, only authzid allows this
                    if (end - start == 0)
                    {
                        if (elementIdx != 0)
                        {
                           throw new SaslException("null auth data");
                        }

                    } // data, wa-hoo
                    else
                    {
                        String element = new String(response, start, end - start, "UTF-8");
                        start = end + 1;

                        switch (elementIdx)
                        {
                            case 0:
                                authzID = element;
                                break;
                            case 1:
                                authnID = element;
                                break;
                            default:
                                throw new SaslException("Unexpected data in packet");
                        }
                    }
                    elementIdx++;

                }
                end++;
            }

            if (start == end)
            {
                throw new SaslException("null auth data");
            }

            password = new String(response, start, end - start, "UTF-8");


        } catch (UnsupportedEncodingException e)
        {
            throw new SaslException("utf-8 encoding");
        }


        ExternalValidationCallback evc = new ExternalValidationCallback(authnID, password);
        AuthorizeCallback ac = new AuthorizeCallback(authnID, authzID);

        Callback[] cbList = new Callback[2];

        cbList[0] = evc;
        cbList[1] = ac;

        try
        {
            if (password == null || authnID == null)
            {
                throw new SaslException("null auth data");
            }

            cbh.handle(cbList);
            if (!evc.isValidated())
            {
                throw new SaslException("cannot validate password");
            }

            if (!ac.isAuthorized())
            {
                throw new SaslException("user not authorized");
            }

            complete = true;
            return null;

        } catch (UnsupportedCallbackException ex)
        {
            throw new SaslException("unsupported callback", ex);

        } catch (IOException e)
        {
            if (e instanceof SaslException)
            {
                throw (SaslException) e;
            }
            throw new SaslException("Callback error", e);
        }

    }

    public boolean isComplete()
    {
        return complete;
    }

    public String getAuthorizationID()
    {
        if (!complete)
        {
            throw new IllegalStateException("not complete");
        }
        return authID;
    }

    public byte[] unwrap(byte[] incoming, int offset, int len)
    {
        if (!complete)
        {
            throw new IllegalStateException("not complete");
        }
        final byte[] result = new byte[len];
        System.arraycopy(incoming, offset, result, 0, len);
        return result;

    }

    public byte[] wrap(byte[] outgoing, int offset, int len)
    {
        if (!complete)
        {
            throw new IllegalStateException("not complete");
        }

        final byte[] result = new byte[len];
        System.arraycopy(outgoing, offset, result, 0, len);
        return result;

    }

    public Object getNegotiatedProperty(String propName)
    {
        return null;
    }

    public void dispose() throws SaslException
    {
    }
}
This also uses the following callback to validate the client's username and password. The SaslServer will throw a SaslException if the isValidated returns false.
public class ExternalValidationCallback implements Callback
{

    private String username;
    private String password;
    private boolean validated = false;

    public ExternalValidationCallback(String username, String password)
    {
        this.username = username;
        this.password = password;
    }

    public String getPassword()
    {
        return password;
    }

    public String getUsername()
    {
        return username;
    }

    public boolean isValidated()
    {
        return validated;
    }

    public void setValidated(boolean validated)
    {
        this.validated = validated;
    }
}
Now after we have the new callback and server we need to register it as a provider w/in java. This requires that we create a server factory and simple security provider to register the factory. The server factory will need to check the sasl property POLICY_NOPLAINTEXT to see if its allowed to return the PLAIN server.
public class PlainSaslServerFactory implements SaslServerFactory
{

    private static String[] mechanisms =
    {
        "PLAIN"
    };

    public SaslServer createSaslServer(String mechanism, String protocol,
            String serverName,
            Map props, CallbackHandler cbh) throws SaslException
    {

        if (!mechanisms[0].equals(mechanism) || cbh == null)
            return null;
       return new PlainSaslServer(cbh);

    }

    public String[] getMechanismNames(Map props)
    {
        if ("true".equals(props.get(Sasl.POLICY_NOPLAINTEXT)))
            return new String[0];
        return mechanisms;
    }
}
Now, create the java security provider.
public final class MySecurityProvider extends Provider
{

    public MySecurityProvider()
    {
        super("My Provider", 1.0 , "Simple sasl plain provider");
        put("SaslServerFactory.PLAIN", "org.swap.provider.PlainSaslServerFactory");
    }

}
Putting it all together
// register provider
Security.addProvider(new MySecurityProvider());
// create callback
CallbackHandler myHandler = new CallbackHandler() {
public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException
    {
 for (Callback cb : callbacks)
        {
            if (cb instanceof ExternalValidationCallback)
            {
                ExternalValidationCallback evc = (ExternalValidationCallback) cb;
                // Add your password validation, unless bob works for you
                evc.setValidated("bob".equals(evc.getUsername() && "password".equals()evc.getPassword());
            }
            else if (cb instanceof AuthorizeCallback)
            {
                AuthorizeCallback ac = (AuthorizeCallback) cb;
                // Add your test to see if client is authorized to use requested
                ac.setAuthorized(true);
            }
            else
            {
                throw new UnsupportedCallbackException(cb, "Unrecognized Callback");
            }
    }
};

SaslServer ss = Sasl.createSaslServer("PLAIN", "myProtocol", serverName, null, myHandler);
ss.evaluateResponse(clientPacket);