| 
      
      
      From: <kr_...@us...> - 2003-09-06 19:59:10
      
     | 
| Update of /cvsroot/htoolkit/HSQL/PostgreSQL
In directory sc8-pr-cvs1:/tmp/cvs-serv11008/PostgreSQL
Added Files:
	HSQL.hsc 
Log Message:
Support for PostgreSQL. The new implementation has better support for Sql<->Haskell data translation
--- NEW FILE: HSQL.hsc ---
-----------------------------------------------------------------------------------------
{-| Module      :  Database.PostgreSQL.HSQL
    Copyright   :  (c) Krasimir Angelov 2003
    License     :  BSD-style
    Maintainer  :  ka2...@ya...
    Stability   :  provisional
    Portability :  portable
    The module provides interface to PostgreSQL database
-}
-----------------------------------------------------------------------------------------
module Database.PostgreSQL.HSQL
		( SqlBind(..), SqlError(..), SqlType(..), Connection, Statement
		, catchSql		-- :: IO a -> (SqlError -> IO a) -> IO a
		, handleSql		-- :: (SqlError -> IO a) -> IO a -> IO a
		, sqlExceptions		-- :: Exception -> Maybe SqlError
		, connect		-- :: String -> String -> String -> IO Connection
		, disconnect		-- :: Connection -> IO ()
		, execute		-- :: Connection -> String -> IO Statement
		, query			-- :: Connection -> String -> IO ()
		, closeStatement	-- :: Statement -> IO ()
		, fetch			-- :: Statement -> IO Bool
		, inTransaction 	-- :: Connection -> (Connection -> IO a) -> IO a
		, getFieldValueMB	-- :: SqlBind a => Statement -> String -> IO (Maybe a)
		, getFieldValue		-- :: SqlBind a => Statement -> String -> IO a
		, getFieldValue'	-- :: SqlBind a => Statement -> String -> a -> IO a
		, getFieldValueType	-- :: Statement -> String -> (SqlType, Bool)
		, getFieldsTypes	-- :: Statement -> [(String, SqlType, Bool)]
		, forEachRow		-- :: (Statement -> s -> IO s) -> Statement -> s -> IO s
		, forEachRow'		-- :: (Statement -> IO ()) -> Statement -> IO ()
		, collectRows		-- :: (Statement -> IO s) -> Statement -> IO [s]
		, Point(..), Line(..), Path(..), Box(..), Circle(..), Polygon(..)
		) where
import Data.Dynamic
import Data.IORef
import Foreign
import Foreign.C
import Control.Exception (throwDyn, catchDyn, dynExceptions, Exception(..))
import Control.Monad(when,unless,mplus)
import System.Time
import System.IO.Unsafe
import Text.ParserCombinators.ReadP
import Text.Read
# include <time.h>
#include <libpq-fe.h>
#include <postgres.h>
#include <catalog/pg_type.h>
type PGconn = Ptr ()
type PGresult = Ptr ()
type ConnStatusType = #type ConnStatusType
type ExecStatusType = #type ExecStatusType
type Oid = #type Oid
foreign import ccall "libpq-fe.h PQsetdbLogin" pqSetdbLogin :: CString -> CString -> CString -> CString -> CString -> CString -> CString -> IO PGconn
foreign import ccall "libpq-fe.h PQstatus" pqStatus :: PGconn -> IO ConnStatusType
foreign import ccall "libpq-fe.h PQerrorMessage"  pqErrorMessage :: PGconn -> IO CString
foreign import ccall "libpq-fe.h PQfinish" pqFinish :: PGconn -> IO ()
foreign import ccall "libpq-fe.h PQexec" pqExec :: PGconn -> CString -> IO PGresult
foreign import ccall "libpq-fe.h PQresultStatus" pqResultStatus :: PGresult -> IO ExecStatusType
foreign import ccall "libpq-fe.h PQresStatus" pqResStatus :: ExecStatusType -> IO CString
foreign import ccall "libpq-fe.h PQresultErrorMessage" pqResultErrorMessage :: PGresult -> IO CString
foreign import ccall "libpq-fe.h PQnfields" pgNFields :: PGresult -> IO Int
foreign import ccall "libpq-fe.h PQntuples" pqNTuples :: PGresult -> IO Int
foreign import ccall "libpq-fe.h PQfname" pgFName :: PGresult -> Int -> IO CString
foreign import ccall "libpq-fe.h PQftype" pqFType :: PGresult -> Int -> IO Oid
foreign import ccall "libpq-fe.h PQfmod" pqFMod :: PGresult -> Int -> IO Int
foreign import ccall "libpq-fe.h PQfnumber" pqFNumber :: PGresult -> CString -> IO Int
foreign import ccall "libpq-fe.h PQgetvalue" pqGetvalue :: PGresult -> Int -> Int -> IO CString
foreign import ccall "libpq-fe.h PQgetisnull" pqGetisnull :: PGresult -> Int -> Int -> IO Int
newtype Connection = Connection PGconn
data Statement
  =  Statement
		{ pRes			:: !PGresult
		, tupleIndex  :: IORef Int
		, countTuples:: !Int
		, connection	:: !Connection
		, fields			:: ![FieldDef]
		}
type FieldDef = (String, SqlType)
data SqlType	
	= SqlChar		Int
	| SqlVarChar		Int
	| SqlText
	| SqlNumeric		Int Int
	| SqlSmallInt
	| SqlInteger
	| SqlReal
	| SqlDouble
	| SqlBool
	| SqlBit		Int
	| SqlVarBit		Int
	| SqlTinyInt
	| SqlBigInt
	| SqlDate
	| SqlTime
	| SqlAbsTime
	| SqlRelTime
	| SqlTimeTZ
	| SqlTimeInterval
	| SqlAbsTimeInterval
	| SqlTimeStamp
	| SqlMoney
	| SqlINetAddr
	| SqlCIDRAddr
	| SqlMacAddr
	| SqlPoint
	| SqlLSeg
	| SqlPath
	| SqlBox
	| SqlPolygon
	| SqlLine
	| SqlCircle
	| SqlUnknown
	deriving (Eq, Show)
data SqlError
   = SqlError
		{ seState		:: String
		, seNativeError	:: Int
		, seErrorMsg	:: String
		}
   | SqlNoData
   | SqlBadTypeCast
   		{ seFieldName :: String
		, seFieldType :: SqlType
		}
   | SqlFetchNull
		{ seFieldName :: String
		}
   deriving (Show, Typeable)
-----------------------------------------------------------------------------------------
-- routines for handling exceptions
-----------------------------------------------------------------------------------------
catchSql :: IO a -> (SqlError -> IO a) -> IO a
catchSql = catchDyn
handleSql :: (SqlError -> IO a) -> IO a -> IO a
handleSql h f = catchDyn f h
sqlExceptions :: Exception -> Maybe SqlError
sqlExceptions e = dynExceptions e >>= fromDynamic
-----------------------------------------------------------------------------------------
-- Connect/Disconnect
-----------------------------------------------------------------------------------------
connect :: String -> String -> String -> String -> IO Connection
connect server database user authentication = do
	pServer <- newCString server
	pDatabase <- newCString database
	pUser <- newCString user
	pAuthentication <- newCString authentication
	pConn <- pqSetdbLogin pServer nullPtr nullPtr nullPtr pDatabase pUser pAuthentication
	free pServer
	free pUser
	free pAuthentication
	status <- pqStatus pConn
	unless (status == (#const CONNECTION_OK))  (do
		errMsg <- pqErrorMessage pConn >>= peekCString
		pqFinish pConn
		throwDyn (SqlError {seState="C", seNativeError=fromIntegral status, seErrorMsg=errMsg}))
	return (Connection pConn)
disconnect :: Connection -> IO ()
disconnect (Connection pConn) = pqFinish pConn
-----------------------------------------------------------------------------------------
-- queries
-----------------------------------------------------------------------------------------
execute :: Connection -> String -> IO ()
execute conn@(Connection pConn) sqlExpr = do
	pRes <- withCString sqlExpr (pqExec pConn)
	when (pRes==nullPtr) (do
		errMsg <- pqErrorMessage pConn >>= peekCString
		throwDyn (SqlError {seState="E", seNativeError=(#const PGRES_FATAL_ERROR), seErrorMsg=errMsg}))
	status <- pqResultStatus pRes
	unless (status == (#const PGRES_COMMAND_OK) || status == (#const PGRES_TUPLES_OK)) (do
		errMsg <- pqResultErrorMessage pRes >>= peekCString
		throwDyn (SqlError {seState="E", seNativeError=fromIntegral status, seErrorMsg=errMsg}))
	return ()
query :: Connection -> String -> IO Statement
query conn@(Connection pConn) query = do
	pRes <- withCString query (pqExec pConn)
	when (pRes==nullPtr) (do
		errMsg <- pqErrorMessage pConn >>= peekCString
		throwDyn (SqlError {seState="E", seNativeError=(#const PGRES_FATAL_ERROR), seErrorMsg=errMsg}))
	status <- pqResultStatus pRes
	unless (status == (#const PGRES_COMMAND_OK) || status == (#const PGRES_TUPLES_OK)) (do
		errMsg <- pqResultErrorMessage pRes >>= peekCString
		throwDyn (SqlError {seState="E", seNativeError=fromIntegral status, seErrorMsg=errMsg}))
	defs <- if status ==  (#const PGRES_TUPLES_OK) then pgNFields pRes >>= getFieldDefs pRes 0 else return []
	countTuples <- pqNTuples pRes;
	tupleIndex <- newIORef (-1)
	return (Statement {pRes=pRes, connection=conn, fields=defs, countTuples=countTuples, tupleIndex=tupleIndex})
	where
		getFieldDefs pRes i n
			| i >= n = return []
			| otherwise = do
				name <- pgFName pRes i	>>= peekCString
				dataType <- pqFType pRes i
				modifier <- pqFMod pRes i
				defs <- getFieldDefs pRes (i+1) n
				return ((name,mkSqlType dataType modifier):defs)
		mkSqlType :: Oid -> Int -> SqlType
		mkSqlType (#const BPCHAROID)		size = SqlChar (size-4)
		mkSqlType (#const VARCHAROID)		size = SqlVarChar (size-4)
		mkSqlType (#const NAMEOID)		size = SqlVarChar 31
		mkSqlType (#const TEXTOID) 		size = SqlText
		mkSqlType (#const NUMERICOID)		size = SqlNumeric ((size-4) `div` 0x10000) ((size-4) `mod` 0x10000)
		mkSqlType (#const INT2OID)		size = SqlSmallInt
		mkSqlType (#const INT4OID)		size = SqlInteger
		mkSqlType (#const FLOAT4OID)		size = SqlReal
		mkSqlType (#const FLOAT8OID)		size = SqlDouble
		mkSqlType (#const BOOLOID)		size = SqlBool
		mkSqlType (#const BITOID)	   	size = SqlBit size
		mkSqlType (#const VARBITOID) 		size = SqlVarBit size
		mkSqlType (#const BYTEAOID)		size = SqlTinyInt
		mkSqlType (#const INT8OID)		size = SqlBigInt
		mkSqlType (#const DATEOID)		size = SqlDate
		mkSqlType (#const TIMEOID)		size = SqlTime
		mkSqlType (#const TIMETZOID)		size = SqlTimeTZ
		mkSqlType (#const ABSTIMEOID)		size = SqlAbsTime
		mkSqlType (#const RELTIMEOID)		size = SqlRelTime
		mkSqlType (#const INTERVALOID)		size = SqlTimeInterval
		mkSqlType (#const TINTERVALOID) 	size = SqlAbsTimeInterval
		mkSqlType (#const TIMESTAMPOID)		size = SqlTimeStamp
		mkSqlType (#const CASHOID)		size = SqlMoney
		mkSqlType (#const INETOID)		size = SqlINetAddr
		mkSqlType (#const 829)			size = SqlMacAddr		-- hack
		mkSqlType (#const CIDROID)		size = SqlCIDRAddr
		mkSqlType (#const POINTOID)		size = SqlPoint
		mkSqlType (#const LSEGOID)		size = SqlLSeg
		mkSqlType (#const PATHOID)		size = SqlPath
		mkSqlType (#const BOXOID)		size = SqlBox
		mkSqlType (#const POLYGONOID)		size = SqlPolygon
		mkSqlType (#const LINEOID)		size = SqlLine
		mkSqlType (#const CIRCLEOID)		size = SqlCircle
		mkSqlType (#const UNKNOWNOID)	size = SqlUnknown
fetch :: Statement -> IO Bool
fetch (Statement {countTuples=countTuples, tupleIndex=tupleIndex}) = do
	index <- readIORef tupleIndex
	let index' = index+1
	if (index' >= countTuples)
		then return False
		else writeIORef tupleIndex index' >> return True
closeStatement :: Statement -> IO ()
closeStatement _ = return ()
-----------------------------------------------------------------------------------------
-- transactions
-----------------------------------------------------------------------------------------
inTransaction :: Connection -> (Connection -> IO a) -> IO a
inTransaction conn action = do
	execute conn "begin"
	r <- catchSql (action conn) (\err -> execute conn "rollback" >>= throwDyn err)
	execute conn "commit"
	return r
-----------------------------------------------------------------------------------------
-- binding
-----------------------------------------------------------------------------------------
class SqlBind a where
	fromSqlValue :: SqlType -> String -> Maybe a
	toSqlValue :: a -> String
instance SqlBind Int where
	fromSqlValue SqlInteger s = Just (read s)
	fromSqlValue SqlSmallInt s = Just (read s)
	fromSqlValue _ _ = Nothing
	toSqlValue s = show s
instance SqlBind Integer where
	fromSqlValue SqlInteger s = Just (read s)
	fromSqlValue SqlSmallInt s = Just (read s)
	fromSqlValue SqlBigInt s = Just (read s)
	fromSqlValue _ _ = Nothing
	toSqlValue s = show s
instance SqlBind String where
	fromSqlValue _ = Just
	toSqlValue s = '\'' : foldr mapChar "'" s
		where
			mapChar '\'' s = '\\':'\'':s
			mapChar '\n' s = '\\':'n':s
			mapChar '\r' s = '\\':'r':s
			mapChar '\t' s = '\\':'t':s
			mapChar c   s = c:s
instance SqlBind Bool where
	fromSqlValue SqlBool s = Just (s == "t")
	fromSqlValue _ _ = Nothing
	toSqlValue True  = "'t'"
	toSqlValue False = "'f'"
instance SqlBind Double where
	fromSqlValue (SqlNumeric _ _) s = Just (read s)
	fromSqlValue SqlDouble  s = Just (read s)
	fromSqlValue SqlReal s = Just (read s)
	fromSqlValue _ _ = Nothing
	toSqlValue d = show d
mkClockTime :: Int -> Int -> Int -> Int -> Int -> Int -> Int -> ClockTime
mkClockTime year mon mday hour min sec tz =
	unsafePerformIO $ do
		allocaBytes (#const sizeof(struct tm)) $ \ p_tm -> do
			(#poke struct tm,tm_sec   ) p_tm	(fromIntegral sec  :: CInt)
			(#poke struct tm,tm_min   ) p_tm	(fromIntegral min  :: CInt)
			(#poke struct tm,tm_hour ) p_tm	(fromIntegral hour :: CInt)
			(#poke struct tm,tm_mday) p_tm	(fromIntegral mday :: CInt)
			(#poke struct tm,tm_mon ) p_tm	(fromIntegral (mon-1) :: CInt)
			(#poke struct tm,tm_year ) p_tm	(fromIntegral (year-1900) :: CInt)
			(#poke struct tm,tm_isdst) p_tm	(-1 :: CInt)
			t <- mktime p_tm
			return (TOD (fromIntegral t + fromIntegral (tz-currTZ)) 0)
foreign import ccall unsafe mktime :: Ptr () -> IO CTime
{-# NOINLINE currTZ #-}
currTZ :: Int
currTZ = ctTZ (unsafePerformIO (getClockTime >>= toCalendarTime))                  -- Hack
parseTZ :: ReadP Int
parseTZ =  (char '+' >> readS_to_P reads) `mplus` (char '-' >> fmap negate (readS_to_P reads))
f_read :: ReadP a -> String -> Maybe a
f_read f s = case readP_to_S f s of {[(x,_)] -> Just x}
instance SqlBind ClockTime where
	fromSqlValue SqlTimeTZ s = f_read getTime s
		where
			getTime :: ReadP ClockTime
			getTime = do
				hour <- readS_to_P reads
				char ':'
				minutes <- readS_to_P reads
				char ':'
				seconds <- readS_to_P reads
				tz <- parseTZ
				return (mkClockTime 1970 0 1 hour minutes seconds (tz*3600))
	fromSqlValue SqlTime s = f_read getTime s
		where
			getTime :: ReadP ClockTime
			getTime = do
				hour <- readS_to_P reads
				char ':'
				minutes <- readS_to_P reads
				char ':'
				seconds <- readS_to_P reads
				return (mkClockTime 1970 0 1 hour minutes seconds currTZ)
	fromSqlValue SqlDate s = f_read getDate s
		where
			getDate :: ReadP ClockTime
			getDate = do
				year <- readS_to_P reads
				satisfy (=='-')
				month <- readS_to_P reads
				satisfy  (=='-')
				day <- readS_to_P reads
				return (mkClockTime year month day 0 0 0 currTZ)
	fromSqlValue SqlTimeStamp s = f_read getTimeStamp s
		where
			getTimeStamp :: ReadP ClockTime
			getTimeStamp = do
				year <- readS_to_P reads
				satisfy (=='-')
				month <- readS_to_P reads
				satisfy  (=='-')
				day <- readS_to_P reads
				skipSpaces
				hour <- readS_to_P reads
				satisfy  (==':')
				minutes <- readS_to_P reads
				satisfy (==':')
				seconds <- readS_to_P reads
				tz <- parseTZ
				return (mkClockTime year month day hour minutes seconds (tz*3600))
	fromSqlValue _ _ = Nothing
	toSqlValue ct = '\'' : (shows (ctYear t) .
                                score .
                                shows (ctMonth t) .
                                score .
                                shows (ctDay t) .
                                space .
                                shows (ctHour t) .
                                colon .
                                shows (ctMin t) .
                                colon .
                                shows (ctSec t)) "'"
                       where
                         t = toUTCTime ct
                         score = showChar '.'
                         space = showChar ' '
                         colon = showChar ':'
data Point = Point Double Double deriving (Eq, Show)
data Line   = Line Point Point deriving (Eq, Show)
data Path  = OpenPath [Point] | ClosedPath [Point] deriving (Eq, Show)
data Box   = Box Double Double Double Double deriving (Eq, Show)
data Circle = Circle Point Double deriving (Eq, Show)
data Polygon = Polygon [Point] deriving (Eq, Show)
instance SqlBind Point where
	fromSqlValue SqlPoint s = case read s of
		(x,y) -> Just (Point x y)
	fromSqlValue _ _ = Nothing
	toSqlValue (Point x y) = '\'' : shows (x,y) "'"
instance SqlBind Line where
	fromSqlValue SqlLSeg s = case read s of
		[(x1,y1),(x2,y2)] -> Just (Line (Point x1 y1) (Point x2 y2))
	fromSqlValue _ _ = Nothing
	toSqlValue (Line (Point x1 y1) (Point x2 y2)) = '\'' : shows [(x1,y1),(x2,y2)] "'"
instance SqlBind Path where
	fromSqlValue SqlPath ('(':s) = case read ("["++init s++"]") of   -- closed path
		ps -> Just (ClosedPath (map  (\(x,y) -> Point x y) ps))
	fromSqlValue SqlPath s = case read s of   -- closed path        -- open path
		ps -> Just (OpenPath (map  (\(x,y) -> Point x y) ps))
	fromSqlValue SqlLSeg s = case read s of
		[(x1,y1),(x2,y2)] -> Just (OpenPath [(Point x1 y1),  (Point x2 y2)])
	fromSqlValue SqlPoint s = case read s of
		(x,y) -> Just (ClosedPath [Point x y])
	fromSqlValue _ _ = Nothing
	toSqlValue (OpenPath ps) = '\'' : shows ps "'"
	toSqlValue (ClosedPath ps) = "'(" ++ init (tail (show ps)) ++ "')"
instance SqlBind Box where
	fromSqlValue SqlBox s = case read ("("++s++")") of
		((x1,y1),(x2,y2)) -> Just (Box x1 y1 x2 y2)
	fromSqlValue _ _ = Nothing
	toSqlValue (Box x1 y1 x2 y2) = '\'' : shows ((x1,y1),(x2,y2)) "'"
instance SqlBind Polygon where
	fromSqlValue SqlPolygon s = case read ("["++init (tail s)++"]") of
		ps -> Just (Polygon (map  (\(x,y) -> Point x y) ps))
	fromSqlValue _ _ = Nothing
	toSqlValue (Polygon ps) = "'(" ++ init (tail (show ps)) ++ "')"
instance SqlBind Circle where
	fromSqlValue SqlCircle s = case read ("("++init (tail s)++")") of
		((x,y),r) -> Just (Circle (Point x y) r)
	fromSqlValue _ _ = Nothing
	toSqlValue (Circle (Point x y) r) = "'<" ++ show (x,y) ++ "," ++ show r ++ "'>"
getFieldValueMB :: SqlBind a => Statement -> String -> IO (Maybe a)
getFieldValueMB (Statement {pRes=pRes, connection=conn, fields=fieldDefs, countTuples=countTuples, tupleIndex=tupleIndex}) name = do
	index <- readIORef tupleIndex
	when (index >= countTuples) (throwDyn SqlNoData)
	let (sqlType,colNumber) = findFieldInfo name fieldDefs 0
	isnull <- pqGetisnull pRes index colNumber
	if isnull == 1
		then return Nothing
		else do
			value <- pqGetvalue pRes index colNumber >>= peekCString
			case fromSqlValue sqlType value of
				Just v   -> return (Just v)
				Nothing -> throwDyn (SqlBadTypeCast name sqlType)
getFieldValue :: SqlBind a => Statement -> String -> IO a
getFieldValue stmt name = do
	mb_v <- getFieldValueMB stmt name
	case mb_v of
		Nothing -> throwDyn (SqlFetchNull name)
		Just a -> return a
getFieldValue' :: SqlBind a => Statement -> String -> a -> IO a
getFieldValue' stmt name def = do
	mb_v <- getFieldValueMB stmt name
	return (case mb_v of { Nothing -> def; Just a -> a })
getFieldValueType :: Statement -> String -> SqlType
getFieldValueType stmt name = sqlType
	where
		(sqlType,colNumber) = findFieldInfo name (fields stmt) 1
getFieldsTypes :: Statement -> [(String, SqlType)]
getFieldsTypes = fields
findFieldInfo :: String -> [FieldDef] -> Int -> (SqlType,Int)
findFieldInfo name [] colNumber = error ("Undefined column name \"" ++ name ++ "\"")
findFieldInfo name (fieldDef@(name',sqlType):fields) colNumber
	| name == name' = (sqlType,colNumber)
	| otherwise		 = findFieldInfo name fields (colNumber+1)
-----------------------------------------------------------------------------------------
-- helpers
-----------------------------------------------------------------------------------------
forEachRow :: (Statement -> s -> IO s) -> Statement -> s -> IO s
forEachRow f stmt s = do
	success <- fetch stmt
	if success then f stmt s >>= forEachRow f stmt else closeStatement stmt >> return s
forEachRow' :: (Statement -> IO ()) -> Statement -> IO ()
forEachRow' f stmt = do
	success <- fetch stmt
	if success then f stmt >> forEachRow' f stmt else closeStatement stmt
collectRows :: (Statement -> IO a) -> Statement -> IO [a]
collectRows f stmt = loop
	where
		loop = do
			success <- fetch stmt
			if success
				then do
					x <- f stmt
					xs <- loop
					return (x:xs)
				else closeStatement stmt >> return []
 |