From: <kr_...@us...> - 2003-09-06 19:59:10
|
Update of /cvsroot/htoolkit/HSQL/PostgreSQL In directory sc8-pr-cvs1:/tmp/cvs-serv11008/PostgreSQL Added Files: HSQL.hsc Log Message: Support for PostgreSQL. The new implementation has better support for Sql<->Haskell data translation --- NEW FILE: HSQL.hsc --- ----------------------------------------------------------------------------------------- {-| Module : Database.PostgreSQL.HSQL Copyright : (c) Krasimir Angelov 2003 License : BSD-style Maintainer : ka2...@ya... Stability : provisional Portability : portable The module provides interface to PostgreSQL database -} ----------------------------------------------------------------------------------------- module Database.PostgreSQL.HSQL ( SqlBind(..), SqlError(..), SqlType(..), Connection, Statement , catchSql -- :: IO a -> (SqlError -> IO a) -> IO a , handleSql -- :: (SqlError -> IO a) -> IO a -> IO a , sqlExceptions -- :: Exception -> Maybe SqlError , connect -- :: String -> String -> String -> IO Connection , disconnect -- :: Connection -> IO () , execute -- :: Connection -> String -> IO Statement , query -- :: Connection -> String -> IO () , closeStatement -- :: Statement -> IO () , fetch -- :: Statement -> IO Bool , inTransaction -- :: Connection -> (Connection -> IO a) -> IO a , getFieldValueMB -- :: SqlBind a => Statement -> String -> IO (Maybe a) , getFieldValue -- :: SqlBind a => Statement -> String -> IO a , getFieldValue' -- :: SqlBind a => Statement -> String -> a -> IO a , getFieldValueType -- :: Statement -> String -> (SqlType, Bool) , getFieldsTypes -- :: Statement -> [(String, SqlType, Bool)] , forEachRow -- :: (Statement -> s -> IO s) -> Statement -> s -> IO s , forEachRow' -- :: (Statement -> IO ()) -> Statement -> IO () , collectRows -- :: (Statement -> IO s) -> Statement -> IO [s] , Point(..), Line(..), Path(..), Box(..), Circle(..), Polygon(..) ) where import Data.Dynamic import Data.IORef import Foreign import Foreign.C import Control.Exception (throwDyn, catchDyn, dynExceptions, Exception(..)) import Control.Monad(when,unless,mplus) import System.Time import System.IO.Unsafe import Text.ParserCombinators.ReadP import Text.Read # include <time.h> #include <libpq-fe.h> #include <postgres.h> #include <catalog/pg_type.h> type PGconn = Ptr () type PGresult = Ptr () type ConnStatusType = #type ConnStatusType type ExecStatusType = #type ExecStatusType type Oid = #type Oid foreign import ccall "libpq-fe.h PQsetdbLogin" pqSetdbLogin :: CString -> CString -> CString -> CString -> CString -> CString -> CString -> IO PGconn foreign import ccall "libpq-fe.h PQstatus" pqStatus :: PGconn -> IO ConnStatusType foreign import ccall "libpq-fe.h PQerrorMessage" pqErrorMessage :: PGconn -> IO CString foreign import ccall "libpq-fe.h PQfinish" pqFinish :: PGconn -> IO () foreign import ccall "libpq-fe.h PQexec" pqExec :: PGconn -> CString -> IO PGresult foreign import ccall "libpq-fe.h PQresultStatus" pqResultStatus :: PGresult -> IO ExecStatusType foreign import ccall "libpq-fe.h PQresStatus" pqResStatus :: ExecStatusType -> IO CString foreign import ccall "libpq-fe.h PQresultErrorMessage" pqResultErrorMessage :: PGresult -> IO CString foreign import ccall "libpq-fe.h PQnfields" pgNFields :: PGresult -> IO Int foreign import ccall "libpq-fe.h PQntuples" pqNTuples :: PGresult -> IO Int foreign import ccall "libpq-fe.h PQfname" pgFName :: PGresult -> Int -> IO CString foreign import ccall "libpq-fe.h PQftype" pqFType :: PGresult -> Int -> IO Oid foreign import ccall "libpq-fe.h PQfmod" pqFMod :: PGresult -> Int -> IO Int foreign import ccall "libpq-fe.h PQfnumber" pqFNumber :: PGresult -> CString -> IO Int foreign import ccall "libpq-fe.h PQgetvalue" pqGetvalue :: PGresult -> Int -> Int -> IO CString foreign import ccall "libpq-fe.h PQgetisnull" pqGetisnull :: PGresult -> Int -> Int -> IO Int newtype Connection = Connection PGconn data Statement = Statement { pRes :: !PGresult , tupleIndex :: IORef Int , countTuples:: !Int , connection :: !Connection , fields :: ![FieldDef] } type FieldDef = (String, SqlType) data SqlType = SqlChar Int | SqlVarChar Int | SqlText | SqlNumeric Int Int | SqlSmallInt | SqlInteger | SqlReal | SqlDouble | SqlBool | SqlBit Int | SqlVarBit Int | SqlTinyInt | SqlBigInt | SqlDate | SqlTime | SqlAbsTime | SqlRelTime | SqlTimeTZ | SqlTimeInterval | SqlAbsTimeInterval | SqlTimeStamp | SqlMoney | SqlINetAddr | SqlCIDRAddr | SqlMacAddr | SqlPoint | SqlLSeg | SqlPath | SqlBox | SqlPolygon | SqlLine | SqlCircle | SqlUnknown deriving (Eq, Show) data SqlError = SqlError { seState :: String , seNativeError :: Int , seErrorMsg :: String } | SqlNoData | SqlBadTypeCast { seFieldName :: String , seFieldType :: SqlType } | SqlFetchNull { seFieldName :: String } deriving (Show, Typeable) ----------------------------------------------------------------------------------------- -- routines for handling exceptions ----------------------------------------------------------------------------------------- catchSql :: IO a -> (SqlError -> IO a) -> IO a catchSql = catchDyn handleSql :: (SqlError -> IO a) -> IO a -> IO a handleSql h f = catchDyn f h sqlExceptions :: Exception -> Maybe SqlError sqlExceptions e = dynExceptions e >>= fromDynamic ----------------------------------------------------------------------------------------- -- Connect/Disconnect ----------------------------------------------------------------------------------------- connect :: String -> String -> String -> String -> IO Connection connect server database user authentication = do pServer <- newCString server pDatabase <- newCString database pUser <- newCString user pAuthentication <- newCString authentication pConn <- pqSetdbLogin pServer nullPtr nullPtr nullPtr pDatabase pUser pAuthentication free pServer free pUser free pAuthentication status <- pqStatus pConn unless (status == (#const CONNECTION_OK)) (do errMsg <- pqErrorMessage pConn >>= peekCString pqFinish pConn throwDyn (SqlError {seState="C", seNativeError=fromIntegral status, seErrorMsg=errMsg})) return (Connection pConn) disconnect :: Connection -> IO () disconnect (Connection pConn) = pqFinish pConn ----------------------------------------------------------------------------------------- -- queries ----------------------------------------------------------------------------------------- execute :: Connection -> String -> IO () execute conn@(Connection pConn) sqlExpr = do pRes <- withCString sqlExpr (pqExec pConn) when (pRes==nullPtr) (do errMsg <- pqErrorMessage pConn >>= peekCString throwDyn (SqlError {seState="E", seNativeError=(#const PGRES_FATAL_ERROR), seErrorMsg=errMsg})) status <- pqResultStatus pRes unless (status == (#const PGRES_COMMAND_OK) || status == (#const PGRES_TUPLES_OK)) (do errMsg <- pqResultErrorMessage pRes >>= peekCString throwDyn (SqlError {seState="E", seNativeError=fromIntegral status, seErrorMsg=errMsg})) return () query :: Connection -> String -> IO Statement query conn@(Connection pConn) query = do pRes <- withCString query (pqExec pConn) when (pRes==nullPtr) (do errMsg <- pqErrorMessage pConn >>= peekCString throwDyn (SqlError {seState="E", seNativeError=(#const PGRES_FATAL_ERROR), seErrorMsg=errMsg})) status <- pqResultStatus pRes unless (status == (#const PGRES_COMMAND_OK) || status == (#const PGRES_TUPLES_OK)) (do errMsg <- pqResultErrorMessage pRes >>= peekCString throwDyn (SqlError {seState="E", seNativeError=fromIntegral status, seErrorMsg=errMsg})) defs <- if status == (#const PGRES_TUPLES_OK) then pgNFields pRes >>= getFieldDefs pRes 0 else return [] countTuples <- pqNTuples pRes; tupleIndex <- newIORef (-1) return (Statement {pRes=pRes, connection=conn, fields=defs, countTuples=countTuples, tupleIndex=tupleIndex}) where getFieldDefs pRes i n | i >= n = return [] | otherwise = do name <- pgFName pRes i >>= peekCString dataType <- pqFType pRes i modifier <- pqFMod pRes i defs <- getFieldDefs pRes (i+1) n return ((name,mkSqlType dataType modifier):defs) mkSqlType :: Oid -> Int -> SqlType mkSqlType (#const BPCHAROID) size = SqlChar (size-4) mkSqlType (#const VARCHAROID) size = SqlVarChar (size-4) mkSqlType (#const NAMEOID) size = SqlVarChar 31 mkSqlType (#const TEXTOID) size = SqlText mkSqlType (#const NUMERICOID) size = SqlNumeric ((size-4) `div` 0x10000) ((size-4) `mod` 0x10000) mkSqlType (#const INT2OID) size = SqlSmallInt mkSqlType (#const INT4OID) size = SqlInteger mkSqlType (#const FLOAT4OID) size = SqlReal mkSqlType (#const FLOAT8OID) size = SqlDouble mkSqlType (#const BOOLOID) size = SqlBool mkSqlType (#const BITOID) size = SqlBit size mkSqlType (#const VARBITOID) size = SqlVarBit size mkSqlType (#const BYTEAOID) size = SqlTinyInt mkSqlType (#const INT8OID) size = SqlBigInt mkSqlType (#const DATEOID) size = SqlDate mkSqlType (#const TIMEOID) size = SqlTime mkSqlType (#const TIMETZOID) size = SqlTimeTZ mkSqlType (#const ABSTIMEOID) size = SqlAbsTime mkSqlType (#const RELTIMEOID) size = SqlRelTime mkSqlType (#const INTERVALOID) size = SqlTimeInterval mkSqlType (#const TINTERVALOID) size = SqlAbsTimeInterval mkSqlType (#const TIMESTAMPOID) size = SqlTimeStamp mkSqlType (#const CASHOID) size = SqlMoney mkSqlType (#const INETOID) size = SqlINetAddr mkSqlType (#const 829) size = SqlMacAddr -- hack mkSqlType (#const CIDROID) size = SqlCIDRAddr mkSqlType (#const POINTOID) size = SqlPoint mkSqlType (#const LSEGOID) size = SqlLSeg mkSqlType (#const PATHOID) size = SqlPath mkSqlType (#const BOXOID) size = SqlBox mkSqlType (#const POLYGONOID) size = SqlPolygon mkSqlType (#const LINEOID) size = SqlLine mkSqlType (#const CIRCLEOID) size = SqlCircle mkSqlType (#const UNKNOWNOID) size = SqlUnknown fetch :: Statement -> IO Bool fetch (Statement {countTuples=countTuples, tupleIndex=tupleIndex}) = do index <- readIORef tupleIndex let index' = index+1 if (index' >= countTuples) then return False else writeIORef tupleIndex index' >> return True closeStatement :: Statement -> IO () closeStatement _ = return () ----------------------------------------------------------------------------------------- -- transactions ----------------------------------------------------------------------------------------- inTransaction :: Connection -> (Connection -> IO a) -> IO a inTransaction conn action = do execute conn "begin" r <- catchSql (action conn) (\err -> execute conn "rollback" >>= throwDyn err) execute conn "commit" return r ----------------------------------------------------------------------------------------- -- binding ----------------------------------------------------------------------------------------- class SqlBind a where fromSqlValue :: SqlType -> String -> Maybe a toSqlValue :: a -> String instance SqlBind Int where fromSqlValue SqlInteger s = Just (read s) fromSqlValue SqlSmallInt s = Just (read s) fromSqlValue _ _ = Nothing toSqlValue s = show s instance SqlBind Integer where fromSqlValue SqlInteger s = Just (read s) fromSqlValue SqlSmallInt s = Just (read s) fromSqlValue SqlBigInt s = Just (read s) fromSqlValue _ _ = Nothing toSqlValue s = show s instance SqlBind String where fromSqlValue _ = Just toSqlValue s = '\'' : foldr mapChar "'" s where mapChar '\'' s = '\\':'\'':s mapChar '\n' s = '\\':'n':s mapChar '\r' s = '\\':'r':s mapChar '\t' s = '\\':'t':s mapChar c s = c:s instance SqlBind Bool where fromSqlValue SqlBool s = Just (s == "t") fromSqlValue _ _ = Nothing toSqlValue True = "'t'" toSqlValue False = "'f'" instance SqlBind Double where fromSqlValue (SqlNumeric _ _) s = Just (read s) fromSqlValue SqlDouble s = Just (read s) fromSqlValue SqlReal s = Just (read s) fromSqlValue _ _ = Nothing toSqlValue d = show d mkClockTime :: Int -> Int -> Int -> Int -> Int -> Int -> Int -> ClockTime mkClockTime year mon mday hour min sec tz = unsafePerformIO $ do allocaBytes (#const sizeof(struct tm)) $ \ p_tm -> do (#poke struct tm,tm_sec ) p_tm (fromIntegral sec :: CInt) (#poke struct tm,tm_min ) p_tm (fromIntegral min :: CInt) (#poke struct tm,tm_hour ) p_tm (fromIntegral hour :: CInt) (#poke struct tm,tm_mday) p_tm (fromIntegral mday :: CInt) (#poke struct tm,tm_mon ) p_tm (fromIntegral (mon-1) :: CInt) (#poke struct tm,tm_year ) p_tm (fromIntegral (year-1900) :: CInt) (#poke struct tm,tm_isdst) p_tm (-1 :: CInt) t <- mktime p_tm return (TOD (fromIntegral t + fromIntegral (tz-currTZ)) 0) foreign import ccall unsafe mktime :: Ptr () -> IO CTime {-# NOINLINE currTZ #-} currTZ :: Int currTZ = ctTZ (unsafePerformIO (getClockTime >>= toCalendarTime)) -- Hack parseTZ :: ReadP Int parseTZ = (char '+' >> readS_to_P reads) `mplus` (char '-' >> fmap negate (readS_to_P reads)) f_read :: ReadP a -> String -> Maybe a f_read f s = case readP_to_S f s of {[(x,_)] -> Just x} instance SqlBind ClockTime where fromSqlValue SqlTimeTZ s = f_read getTime s where getTime :: ReadP ClockTime getTime = do hour <- readS_to_P reads char ':' minutes <- readS_to_P reads char ':' seconds <- readS_to_P reads tz <- parseTZ return (mkClockTime 1970 0 1 hour minutes seconds (tz*3600)) fromSqlValue SqlTime s = f_read getTime s where getTime :: ReadP ClockTime getTime = do hour <- readS_to_P reads char ':' minutes <- readS_to_P reads char ':' seconds <- readS_to_P reads return (mkClockTime 1970 0 1 hour minutes seconds currTZ) fromSqlValue SqlDate s = f_read getDate s where getDate :: ReadP ClockTime getDate = do year <- readS_to_P reads satisfy (=='-') month <- readS_to_P reads satisfy (=='-') day <- readS_to_P reads return (mkClockTime year month day 0 0 0 currTZ) fromSqlValue SqlTimeStamp s = f_read getTimeStamp s where getTimeStamp :: ReadP ClockTime getTimeStamp = do year <- readS_to_P reads satisfy (=='-') month <- readS_to_P reads satisfy (=='-') day <- readS_to_P reads skipSpaces hour <- readS_to_P reads satisfy (==':') minutes <- readS_to_P reads satisfy (==':') seconds <- readS_to_P reads tz <- parseTZ return (mkClockTime year month day hour minutes seconds (tz*3600)) fromSqlValue _ _ = Nothing toSqlValue ct = '\'' : (shows (ctYear t) . score . shows (ctMonth t) . score . shows (ctDay t) . space . shows (ctHour t) . colon . shows (ctMin t) . colon . shows (ctSec t)) "'" where t = toUTCTime ct score = showChar '.' space = showChar ' ' colon = showChar ':' data Point = Point Double Double deriving (Eq, Show) data Line = Line Point Point deriving (Eq, Show) data Path = OpenPath [Point] | ClosedPath [Point] deriving (Eq, Show) data Box = Box Double Double Double Double deriving (Eq, Show) data Circle = Circle Point Double deriving (Eq, Show) data Polygon = Polygon [Point] deriving (Eq, Show) instance SqlBind Point where fromSqlValue SqlPoint s = case read s of (x,y) -> Just (Point x y) fromSqlValue _ _ = Nothing toSqlValue (Point x y) = '\'' : shows (x,y) "'" instance SqlBind Line where fromSqlValue SqlLSeg s = case read s of [(x1,y1),(x2,y2)] -> Just (Line (Point x1 y1) (Point x2 y2)) fromSqlValue _ _ = Nothing toSqlValue (Line (Point x1 y1) (Point x2 y2)) = '\'' : shows [(x1,y1),(x2,y2)] "'" instance SqlBind Path where fromSqlValue SqlPath ('(':s) = case read ("["++init s++"]") of -- closed path ps -> Just (ClosedPath (map (\(x,y) -> Point x y) ps)) fromSqlValue SqlPath s = case read s of -- closed path -- open path ps -> Just (OpenPath (map (\(x,y) -> Point x y) ps)) fromSqlValue SqlLSeg s = case read s of [(x1,y1),(x2,y2)] -> Just (OpenPath [(Point x1 y1), (Point x2 y2)]) fromSqlValue SqlPoint s = case read s of (x,y) -> Just (ClosedPath [Point x y]) fromSqlValue _ _ = Nothing toSqlValue (OpenPath ps) = '\'' : shows ps "'" toSqlValue (ClosedPath ps) = "'(" ++ init (tail (show ps)) ++ "')" instance SqlBind Box where fromSqlValue SqlBox s = case read ("("++s++")") of ((x1,y1),(x2,y2)) -> Just (Box x1 y1 x2 y2) fromSqlValue _ _ = Nothing toSqlValue (Box x1 y1 x2 y2) = '\'' : shows ((x1,y1),(x2,y2)) "'" instance SqlBind Polygon where fromSqlValue SqlPolygon s = case read ("["++init (tail s)++"]") of ps -> Just (Polygon (map (\(x,y) -> Point x y) ps)) fromSqlValue _ _ = Nothing toSqlValue (Polygon ps) = "'(" ++ init (tail (show ps)) ++ "')" instance SqlBind Circle where fromSqlValue SqlCircle s = case read ("("++init (tail s)++")") of ((x,y),r) -> Just (Circle (Point x y) r) fromSqlValue _ _ = Nothing toSqlValue (Circle (Point x y) r) = "'<" ++ show (x,y) ++ "," ++ show r ++ "'>" getFieldValueMB :: SqlBind a => Statement -> String -> IO (Maybe a) getFieldValueMB (Statement {pRes=pRes, connection=conn, fields=fieldDefs, countTuples=countTuples, tupleIndex=tupleIndex}) name = do index <- readIORef tupleIndex when (index >= countTuples) (throwDyn SqlNoData) let (sqlType,colNumber) = findFieldInfo name fieldDefs 0 isnull <- pqGetisnull pRes index colNumber if isnull == 1 then return Nothing else do value <- pqGetvalue pRes index colNumber >>= peekCString case fromSqlValue sqlType value of Just v -> return (Just v) Nothing -> throwDyn (SqlBadTypeCast name sqlType) getFieldValue :: SqlBind a => Statement -> String -> IO a getFieldValue stmt name = do mb_v <- getFieldValueMB stmt name case mb_v of Nothing -> throwDyn (SqlFetchNull name) Just a -> return a getFieldValue' :: SqlBind a => Statement -> String -> a -> IO a getFieldValue' stmt name def = do mb_v <- getFieldValueMB stmt name return (case mb_v of { Nothing -> def; Just a -> a }) getFieldValueType :: Statement -> String -> SqlType getFieldValueType stmt name = sqlType where (sqlType,colNumber) = findFieldInfo name (fields stmt) 1 getFieldsTypes :: Statement -> [(String, SqlType)] getFieldsTypes = fields findFieldInfo :: String -> [FieldDef] -> Int -> (SqlType,Int) findFieldInfo name [] colNumber = error ("Undefined column name \"" ++ name ++ "\"") findFieldInfo name (fieldDef@(name',sqlType):fields) colNumber | name == name' = (sqlType,colNumber) | otherwise = findFieldInfo name fields (colNumber+1) ----------------------------------------------------------------------------------------- -- helpers ----------------------------------------------------------------------------------------- forEachRow :: (Statement -> s -> IO s) -> Statement -> s -> IO s forEachRow f stmt s = do success <- fetch stmt if success then f stmt s >>= forEachRow f stmt else closeStatement stmt >> return s forEachRow' :: (Statement -> IO ()) -> Statement -> IO () forEachRow' f stmt = do success <- fetch stmt if success then f stmt >> forEachRow' f stmt else closeStatement stmt collectRows :: (Statement -> IO a) -> Statement -> IO [a] collectRows f stmt = loop where loop = do success <- fetch stmt if success then do x <- f stmt xs <- loop return (x:xs) else closeStatement stmt >> return [] |