From: <kr_...@us...> - 2003-09-07 19:16:15
|
Update of /cvsroot/htoolkit/HSQL/MySQL In directory sc8-pr-cvs1:/tmp/cvs-serv1819/MySQL Added Files: HSQL.hsc Log Message: Add support for MySQL --- NEW FILE: HSQL.hsc --- #include <config.h> module Database.MySQL.HSQL ( SqlBind(..), SqlError(..), SqlType(..), Connection, Statement , catchSql -- :: IO a -> (SqlError -> IO a) -> IO a , handleSql -- :: (SqlError -> IO a) -> IO a -> IO a , sqlExceptions -- :: Exception -> Maybe SqlError , connect -- :: String -> String -> String -> IO Connection , disconnect -- :: Connection -> IO () , execute -- :: Connection -> String -> IO () , query -- :: Connection -> String -> IO Statement , closeStatement -- :: Statement -> IO () , fetch -- :: Statement -> IO Bool , inTransaction -- :: Connection -> (Connection -> IO a) -> IO a , getFieldValueMB -- :: SqlBind a => Statement -> String -> IO (Maybe a) , getFieldValue -- :: SqlBind a => Statement -> String -> IO a , getFieldValue' -- :: SqlBind a => Statement -> String -> a -> IO a , getFieldValueType -- :: Statement -> String -> (SqlType, Bool) , getFieldsTypes -- :: Statement -> [(String, SqlType, Bool)] , forEachRow -- :: (Statement -> s -> IO s) -> Statement -> s -> IO s , forEachRow' -- :: (Statement -> IO ()) -> Statement -> IO () , collectRows -- :: (Statement -> IO s) -> Statement -> IO [s] ) where import Data.Dynamic import Data.Bits import Data.IORef import Data.Char import Foreign import Foreign.C import Control.Monad(when,unless) import Control.Exception (throwDyn, catchDyn, dynExceptions, Exception(..)) import System.Time import System.IO.Unsafe import Text.ParserCombinators.ReadP import Text.Read #ifdef ENABLED_GUI import Graphics.UI.HGUI.BasicTypes import Graphics.UI.HGUI.BasicClasses #endif #include <mysql.h> #include <time.h> type MYSQL = Ptr () type MYSQL_RES = Ptr () type MYSQL_FIELD = Ptr () type MYSQL_ROW = Ptr CString foreign import ccall "mysql.h mysql_init" mysql_init :: MYSQL -> IO MYSQL foreign import ccall "mysql.h mysql_real_connect" mysql_real_connect :: MYSQL -> CString -> CString -> CString -> CString -> Int -> CString -> Int -> IO MYSQL foreign import ccall "mysql.h mysql_close" mysql_close :: MYSQL -> IO () foreign import ccall "mysql.h mysql_errno" mysql_errno :: MYSQL -> IO Int foreign import ccall "mysql.h mysql_error" mysql_error :: MYSQL -> IO CString foreign import ccall "mysql.h mysql_query" mysql_query :: MYSQL -> CString -> IO Int foreign import ccall "mysql.h mysql_use_result" mysql_use_result :: MYSQL -> IO MYSQL_RES foreign import ccall "mysql.h mysql_fetch_field" mysql_fetch_field :: MYSQL_RES -> IO MYSQL_FIELD foreign import ccall "mysql.h mysql_free_result" mysql_free_result :: MYSQL_RES -> IO () foreign import ccall "mysql.h mysql_fetch_row" mysql_fetch_row :: MYSQL_RES -> IO MYSQL_ROW newtype Connection = Connection MYSQL data Statement = Statement { pRes :: !MYSQL_RES , connection :: !Connection , fields :: ![FieldDef] , currRow :: IORef MYSQL_ROW } type FieldDef = (String, SqlType, Bool) data SqlType = SqlChar Int | SqlVarChar Int | SqlNumeric Int Int | SqlSmallInt | SqlMedInt | SqlInteger | SqlReal | SqlDouble | SqlTinyInt | SqlBigInt | SqlDate | SqlTime | SqlTimeStamp | SqlDateTime | SqlYear | SqlSET | SqlENUM | SqlBLOB | SqlUnknown deriving (Eq, Show) data SqlError = SqlError { seNativeError :: Int , seErrorMsg :: String } | SqlBadTypeCast { seFieldName :: String , seFieldType :: SqlType } | SqlFetchNull { seFieldName :: String } deriving (Typeable, Show) ----------------------------------------------------------------------------------------- -- routines for handling exceptions ----------------------------------------------------------------------------------------- catchSql :: IO a -> (SqlError -> IO a) -> IO a catchSql = catchDyn handleSql :: (SqlError -> IO a) -> IO a -> IO a handleSql h f = catchDyn f h sqlExceptions :: Exception -> Maybe SqlError sqlExceptions e = dynExceptions e >>= fromDynamic handleSqlError :: MYSQL -> IO a handleSqlError pMYSQL = do errno <- mysql_errno pMYSQL errMsg <- mysql_error pMYSQL >>= peekCString throwDyn (SqlError errno errMsg) ----------------------------------------------------------------------------------------- -- Connect/Disconnect ----------------------------------------------------------------------------------------- connect :: String -> String -> String -> String -> IO Connection connect server database user authentication = do pMYSQL <- mysql_init nullPtr pServer <- newCString server pDatabase <- newCString database pUser <- newCString user pAuthentication <- newCString authentication res <- mysql_real_connect pMYSQL pServer pUser pAuthentication pDatabase 0 nullPtr 0 free pServer free pDatabase free pUser free pAuthentication when (res == nullPtr) (handleSqlError pMYSQL) return (Connection pMYSQL) disconnect :: Connection -> IO () disconnect (Connection pMYSQL) = mysql_close pMYSQL ----------------------------------------------------------------------------------------- -- queries ----------------------------------------------------------------------------------------- execute :: Connection -> String -> IO () execute conn@(Connection pMYSQL) query = do res <- withCString query (mysql_query pMYSQL) when (res /= 0) (handleSqlError pMYSQL) query :: Connection -> String -> IO Statement query conn@(Connection pMYSQL) query = do res <- withCString query (mysql_query pMYSQL) when (res /= 0) (handleSqlError pMYSQL) currRow <- newIORef nullPtr pRes <- mysql_use_result pMYSQL if (pRes == nullPtr) then do errno <- mysql_errno pMYSQL when (errno /= 0) (handleSqlError pMYSQL) return (Statement {pRes=nullPtr, fields=[], connection=conn, currRow=currRow}) else do fieldDefs <- getFieldDefs pRes return (Statement {pRes=pRes, fields=fieldDefs, connection=conn, currRow=currRow}) where getFieldDefs pRes = do pField <- mysql_fetch_field pRes if pField == nullPtr then return [] else do name <- (#peek MYSQL_FIELD, name) pField >>= peekCString (dataType :: Int) <- (#peek MYSQL_FIELD, type) pField (columnSize :: Int) <- (#peek MYSQL_FIELD, length) pField (flags :: Int) <- (#peek MYSQL_FIELD, flags) pField (decimalDigits :: Int) <- (#peek MYSQL_FIELD, decimals) pField let sqlType = mkSqlType dataType columnSize decimalDigits defs <- getFieldDefs pRes return ((name,sqlType,(flags .&. (#const NOT_NULL_FLAG)) == 0):defs) mkSqlType :: Int -> Int -> Int -> SqlType mkSqlType (#const FIELD_TYPE_STRING ) size _ = SqlChar size mkSqlType (#const FIELD_TYPE_VAR_STRING) size _ = SqlVarChar size mkSqlType (#const FIELD_TYPE_DECIMAL) size prec = SqlNumeric size prec mkSqlType (#const FIELD_TYPE_SHORT) _ _ = SqlSmallInt mkSqlType (#const FIELD_TYPE_INT24) _ _ = SqlMedInt mkSqlType (#const FIELD_TYPE_LONG) _ _ = SqlInteger mkSqlType (#const FIELD_TYPE_FLOAT) _ _ = SqlReal mkSqlType (#const FIELD_TYPE_DOUBLE ) _ _ = SqlDouble mkSqlType (#const FIELD_TYPE_TINY) _ _ = SqlTinyInt mkSqlType (#const FIELD_TYPE_LONGLONG) _ _ = SqlBigInt mkSqlType (#const FIELD_TYPE_DATE ) _ _ = SqlDate mkSqlType (#const FIELD_TYPE_TIME ) _ _ = SqlTime mkSqlType (#const FIELD_TYPE_TIMESTAMP ) _ _ = SqlTimeStamp mkSqlType (#const FIELD_TYPE_DATETIME) _ _ = SqlDateTime mkSqlType (#const FIELD_TYPE_YEAR) _ _ = SqlYear mkSqlType (#const FIELD_TYPE_BLOB) _ _ = SqlBLOB mkSqlType (#const FIELD_TYPE_SET) _ _ = SqlSET mkSqlType (#const FIELD_TYPE_ENUM) _ _ = SqlENUM mkSqlType (#const FIELD_TYPE_NULL) _ _ = SqlUnknown fetch :: Statement -> IO Bool fetch (Statement {pRes=pRes,currRow=currRow}) | pRes == nullPtr = return False | otherwise = do pRow <- mysql_fetch_row pRes writeIORef currRow pRow return (pRow /= nullPtr) closeStatement :: Statement -> IO () closeStatement (Statement {pRes=pRes}) | pRes == nullPtr = return () | otherwise = mysql_free_result pRes ----------------------------------------------------------------------------------------- -- transactions ----------------------------------------------------------------------------------------- inTransaction :: Connection -> (Connection -> IO a) -> IO a inTransaction conn action = do execute conn "begin" r <- catchSql (action conn) (\err -> execute conn "rollback" >>= throwDyn err) execute conn "commit" return r ----------------------------------------------------------------------------------------- -- binding ----------------------------------------------------------------------------------------- class SqlBind a where fromSqlValue :: SqlType -> String -> Maybe a toSqlValue :: a -> String instance SqlBind Int where fromSqlValue SqlInteger s = Just (read s) fromSqlValue SqlSmallInt s = Just (read s) fromSqlValue _ s = Nothing toSqlValue val = show val instance SqlBind Integer where fromSqlValue SqlInteger s = Just (read s) fromSqlValue SqlSmallInt s = Just (read s) fromSqlValue SqlBigInt s = Just (read s) fromSqlValue _ s = Nothing toSqlValue val = show val instance SqlBind String where fromSqlValue _ = Just toSqlValue s = '\'' : foldr mapChar "'" s where mapChar '\\' s = '\\':'\\':s mapChar '\'' s = '\\':'\'':s mapChar '\n' s = '\\':'n':s mapChar '\r' s = '\\':'r':s mapChar '\t' s = '\\':'t':s mapChar c s = c:s instance SqlBind Double where fromSqlValue (SqlNumeric _ _) s = Just (read s) fromSqlValue SqlDouble s = Just (read s) fromSqlValue SqlReal s = Just (read s) fromSqlValue _ s = Nothing toSqlValue val = show val mkClockTime :: Int -> Int -> Int -> Int -> Int -> Int -> ClockTime mkClockTime year mon mday hour min sec = unsafePerformIO $ do allocaBytes (#const sizeof(struct tm)) $ \ p_tm -> do (#poke struct tm,tm_sec ) p_tm (fromIntegral sec :: CInt) (#poke struct tm,tm_min ) p_tm (fromIntegral min :: CInt) (#poke struct tm,tm_hour ) p_tm (fromIntegral hour :: CInt) (#poke struct tm,tm_mday) p_tm (fromIntegral mday :: CInt) (#poke struct tm,tm_mon ) p_tm (fromIntegral (mon-1) :: CInt) (#poke struct tm,tm_year ) p_tm (fromIntegral (year-1900) :: CInt) (#poke struct tm,tm_isdst) p_tm (-1 :: CInt) t <- mktime p_tm return (TOD (fromIntegral t) 0) foreign import ccall unsafe mktime :: Ptr () -> IO CTime instance SqlBind ClockTime where fromSqlValue SqlTime s = case readP_to_S getTime s of { [(x,_)] -> Just x } where getTime :: ReadP ClockTime getTime = do hour <- readS_to_P reads satisfy (==':') minutes <- readS_to_P reads satisfy (==':') seconds <- readS_to_P reads return (mkClockTime 1970 0 1 hour minutes seconds) fromSqlValue SqlDate s = case readP_to_S getDate s of { [(x,_)] -> Just x } where getDate :: ReadP ClockTime getDate = do year <- readS_to_P reads satisfy (=='-') month <- readS_to_P reads satisfy (=='-') day <- readS_to_P reads return (mkClockTime year month day 0 0 0) fromSqlValue SqlDateTime s = case readP_to_S getTimeStamp s of { [(x,_)] -> Just x } where getTimeStamp :: ReadP ClockTime getTimeStamp = do year <- readS_to_P reads satisfy (=='-') month <- readS_to_P reads satisfy (=='-') day <- readS_to_P reads skipSpaces hour <- readS_to_P reads satisfy (==':') minutes <- readS_to_P reads satisfy (==':') seconds <- readS_to_P reads return (mkClockTime year month day hour minutes seconds) fromSqlValue SqlTimeStamp s = let [year,month,day,hour,minutes,seconds] = parts [4,2,2,2,2,2] s parts [] xs = [] parts (ix:ixs) xs = part ix 0 xs where part 0 n xs = n : parts ixs xs part k n (x:xs) = part (k-1) (n*10 + (ord x - ord '0')) xs in Just (mkClockTime year month day hour minutes seconds) fromSqlValue _ s = Nothing toSqlValue ct = '\'' : (shows (ctYear t) . score . shows (fromEnum (ctMonth t)) . score . shows (ctDay t) . space . shows (ctHour t) . colon . shows (ctMin t) . colon . shows (ctSec t)) "'" where t = toUTCTime ct score = showChar '-' space = showChar ' ' colon = showChar ':' getFieldValueMB :: SqlBind a => Statement -> String -> IO (Maybe a) getFieldValueMB (Statement {currRow=currRow, fields=fieldDefs}) name = do row <- readIORef currRow let (sqlType,nullable,colNumber) = findFieldInfo name fieldDefs 0 pValue <- peekElemOff row colNumber if pValue == nullPtr then return Nothing else do value <- peekCString pValue case fromSqlValue sqlType value of Just v -> return (Just v) Nothing -> throwDyn (SqlBadTypeCast name sqlType) getFieldValue :: SqlBind a => Statement -> String -> IO a getFieldValue stmt name = do mb_v <- getFieldValueMB stmt name case mb_v of Nothing -> throwDyn (SqlFetchNull name) Just a -> return a getFieldValue' :: SqlBind a => Statement -> String -> a -> IO a getFieldValue' stmt name def = do mb_v <- getFieldValueMB stmt name return (case mb_v of { Nothing -> def; Just a -> a }) getFieldValueType :: Statement -> String -> (SqlType, Bool) getFieldValueType stmt name = (sqlType, nullable) where (sqlType,nullable,colNumber) = findFieldInfo name (fields stmt) 1 getFieldsTypes :: Statement -> [(String, SqlType, Bool)] getFieldsTypes = fields findFieldInfo :: String -> [FieldDef] -> Int -> (SqlType,Bool,Int) findFieldInfo name [] colNumber = error ("Undefined column name \"" ++ name ++ "\"") findFieldInfo name (fieldDef@(name',sqlType,nullable):fields) colNumber | name == name' = (sqlType,nullable,colNumber) | otherwise = findFieldInfo name fields (colNumber+1) ----------------------------------------------------------------------------------------- -- helpers ----------------------------------------------------------------------------------------- forEachRow :: (Statement -> s -> IO s) -> Statement -> s -> IO s forEachRow f stmt s = do success <- fetch stmt if success then f stmt s >>= forEachRow f stmt else closeStatement stmt >> return s forEachRow' :: (Statement -> IO ()) -> Statement -> IO () forEachRow' f stmt = do success <- fetch stmt if success then f stmt >> forEachRow' f stmt else closeStatement stmt collectRows :: (Statement -> IO a) -> Statement -> IO [a] collectRows f stmt = loop where loop = do success <- fetch stmt if success then do x <- f stmt xs <- loop return (x:xs) else closeStatement stmt >> return [] |