/**
 *  Copyright (C) 2004 Alo Sarv <madcat_@users.sourceforge.net>
 *
 *  This program is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation; either version 2 of the License, or
 *  (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program; if not, write to the Free Software
 *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 */

#ifndef __MSOCKET_H__
#define __MSOCKET_H__

#include <hn/sockets.h>            // xplatform socket api
#include <hn/gettickcount.h>       // getTick()
#include <boost/smart_ptr.hpp>     // smart pointers
#include <boost/function.hpp>      // function objects
#include <boost/bind.hpp>          // function object binder

// forward declare scheduler
template<typename, typename, typename, typename, typename> class Scheduler;

/**
 * SchedBase implements the third level of hydranode networking scheduler.
 * This class performs the actual bandwidth and connections dividing between
 * requests in it's main scheduler loop. It is derived from EventTableBase,
 * and as such will be run during each event loop, performing bandwidth requests
 * handling.
 */
class SchedBase : public EventTableBase {
public:
	// singleton
	static SchedBase& instance() {
		static SchedBase s;
		return s;
	}

	// accessors for various limits and other internal variables
	void     setUpLimit(uint32_t amount)   { m_upLimit = amount;   }
	void     setDownLimit(uint32_t amount) { m_downLimit = amount; }
	uint32_t getUpLimit()                  { return m_upLimit;     }
	uint32_t getDownLimit()                { return m_downLimit;   }
	uint64_t getTotalUpstream()            { return m_totalUp;     }
	uint64_t getTotalDownstream()          { return m_totalDown;   }

	// calculate the relative PS for the given module
	// this is based on how high % of the total upstream and how much
	// of the total downstream has the module used. The result is that
	// modules who use relativly more downstream than upstream get higher
	// PS, thus also more bandwidth
	template<class Module>
	float getScore() {
		float up_perc = Module::getUploaded()*100.0/m_totalUp;
		float dn_perc = Module::getDownloaded()*100.0/m_totalDown;
		return Module::getPriority() + up_perc - dn_perc;
	}

private:
	// singleton
	SchedBase();
	~SchedBase() {}
	SchedBase(const SchedBase&);
	SchedBase& operator=(const SchedBase&);

	// need it to access our internals
	template<typename, typename, typename, typename, typename>
	friend class Scheduler;

	// Request base, only contains score of the request
	class ReqBase {
	public:
		ReqBase(float score) : m_score(score) {}
		virtual ~ReqBase() {}
		virtual void notify() = 0;
		float getScore() const { return m_score; }
	private:
		float m_score;
	};

	// Request of type upload
	class UploadReqBase : public ReqBase {
	public:
		UploadReqBase(float score) : ReqBase(score) {}
		virtual uint32_t doSend(uint32_t amount) = 0;
		virtual uint32_t getPending() const = 0;
	};
	// Request of type download
	class DownloadReqBase : public ReqBase {
	public:
		DownloadReqBase(float score) : ReqBase(score) {}
		virtual uint32_t doRecv(uint32_t amount) = 0;
	};
	// Request of type connection
	class ConnReqBase : public ReqBase {
	public:
		ConnReqBase(float score) : ReqBase(score) {}
		virtual bool doConn() = 0;
	};

	// request sets
	std::set<UploadReqBase*>   m_uploadReqs;
	std::set<DownloadReqBase*> m_downloadReqs;
	std::set<ConnReqBase*>     m_connReqs;

	// iterator typedefs for the above sets to make life easier
	typedef std::set<UploadReqBase*>::iterator UIter;
	typedef std::set<DownloadReqBase*>::iterator DIter;
	typedef std::set<ConnReqBase*>::iterator CIter;

	//! Main networking loop and helper functions
	void handleEvents();
	void handleDownloads();
	void handleUploads();
	void handleConnections();

	// get free bandwidth for requests
	uint32_t getFreeDown();
	uint32_t getFreeUp();
	bool     getConnection();

	// Keeps current tick - to reduce getTick() calls somewhat
	uint32_t m_curTick;

	// various limits and counts
	uint32_t m_upLimit;      //!< upstream limit
	uint32_t m_downLimit;    //!< downstream limit
	uint32_t m_connLimit;    //!< open connections limit
	uint32_t m_connCnt;      //!< open connections count
	uint64_t m_totalUp;      //!< overall total uploaded
	uint64_t m_totalDown;    //!< overall total downloaded

	// last sent data amounts, along with timestamps
	std::deque<std::pair<uint32_t, uint32_t> > m_lastSent;

	// last received data amounts, along with timestamps
	std::deque<std::pair<uint32_t, uint32_t> > m_lastRecv;
public:
	// Internal stuff - add new requests
	void addUploadReq(UploadReqBase *r) {
		m_uploadReqs.insert(r);
	}
	void addDloadReq(DownloadReqBase *r) {
		m_downloadReqs.insert(r);
	}
	void addConnReq(ConnReqBase *r) {
		m_connReqs.insert(r);
	}

};

// policy class for client type events
template<class Source>
class ClientEventHandler {
public:
	template<class Scheduler>
	static void onEvent(Source *, SocketEvent evt) {
		switch (evt) {
			case SOCK_READ:
				// Check if we already have pending download
				// req for this socket. If so, do nothing.
				// Otherwise, create new DownloadRequest on
				// heap and insert into SchedBase
				break;
			case SOCK_WRITE:
				// Check if there is pending data scheduled
				// for this socket in s_pendingPackets map.
				// If yes, then create new UploadRequest on
				// heap and insert into SchedBase.
				// Otherwise, notify client code that we need
				// more data for this socket.
				break;
			case SOCK_LOST:
				// Notify client code that the connection was
				// lost. Clean up all pending maps in SchedBase
				// that relate to this socket. Reduce conncnt
				// by one.
				break;
			case SOCK_ERR:
				// OOB data - can't really handle it.
				break;
			case SOCK_CONNECTED:
				// Increase open connections count by one.
				// Notify client code.
				break;
			default:
				break;
		}
	}
};
// policy class for server type events
template<class Source>
class ServerEventHandler {
public:
	template<class _Scheduler>
	static void onEvent(Source *src, SocketEvent evt) {
		typedef typename _Scheduler::SSocketWrapper SWrapper;
		typedef typename _Scheduler::AcceptReq AccReq;
		switch (evt) {
			case SOCK_ACCEPT: {
				SWrapper ss(*_Scheduler::s_sockets.find(src));
				AccReq *ar = new AccReq(ss);
				SchedBase::instance().addConnReq(ar);
				break;
			}
			case SOCK_LOST:
				// Server losing connection ?
				break;
			case SOCK_ERR:
				// Server bacame erronous ?
				break;
			default:
				break;
		}
	}
};

// Handler-selector
template<typename Source> class GetEventHandler;
template<> class GetEventHandler<SocketClient> {
public:
	typedef ClientEventHandler<SocketClient> Handler;
};
template<> class GetEventHandler<UDPSocket> {
public:
	typedef ClientEventHandler<UDPSocket> Handler;
};
template<> class GetEventHandler<SocketServer> {
public:
	typedef ServerEventHandler<SocketServer> Handler;
};

/**
 * Scheduler class, implementing second level of HydraNode Networking Scheduling
 * API, abstracts away modules part of the sockets by generating a priority
 * score (PS) for each of the pending requests. All requests are received from
 * the frontend, wrapped into generic containers, and buffered internally for
 * later processing. No direct action shall be taken in the functions directly
 * or indirectly called from frontend.
 *
 * @param Impl             Implemenetation class type
 * @param ImplPtr          Pointer to implementation class
 * @param EventType        Type of events emitted from Impl
 * @param HandlerFunc      Function prototype for events emitted from Impl
 * @param EventHandler     Event handler object
 */
template<
	typename Impl,
	typename ImplPtr      = Impl*,
	typename EventType    = typename Impl::EventType,
	typename HandlerFunc  = boost::function<void(ImplPtr, EventType)>,
	typename EventHandler = typename GetEventHandler<Impl>::Handler
>
class Scheduler {
public:
	// Write call
	// First check if we already have pending data on the pointed socket.
	// If that is so, the additional data must be added to the existing
	// upload request. Otherwise, a new request must be created and added
	// to pending requests list
	template<class Module>
	static void write(ImplPtr ptr, const std::string &buf, PacketType pt) {
		CHECK_THROW(ptr->isConnected());
	}
	// Read call
	// Copy all pending data on pointer designated by ptr to buf, and erase
	// the buffered data. No additional operations are permitted.
	template<class Module>
	static void read(ImplPtr ptr, std::string *buf) {
		CHECK_THROW(ptr->isConnected());
		BIter i = s_bufferedPackets.find(ptr);
		if (i != s_bufferedPackets.end()) {
			// append to designated buffer
			buf->append(*(*i).second);
			// clean up internal buffer
			delete (*i).second;
			s_bufferedPackets.erase(ptr);
		}
	}

	// Add a socket into scheduler
	template<class Module>
	static void addSocket(ImplPtr s, HandlerFunc h) {
		s_sockets.insert(SSocketWrapper(s, h, &Module::getScore));
		s->getEventTable().addHandler(
			s, &EventHandler::template onEvent<Scheduler>
		);
		fprintf(stderr, "Adding socket to scheduler.\n");
	}
	// Remove a socket from scheduler
	static void delSocket(ImplPtr s) {
		s_sockets.erase(SSocketWrapper(s));
		s->getEventTable().delHandler(
			s, &EventHandler::template onEvent<Scheduler>
		);
	}
private:
	// Function that can be used to retrieve a module's score
	typedef typename boost::function<float ()> ScoreFunc;

// We need the internal classes public so the policy classes (namely,
// EventHandler) can access them. If only we could make the EventHandler a
// friend of ours, we could move this to private sector.
public:
	// Wrapper object for scheduled socket, contains all the useful
	// information we need, e.g. score, frontend event handler, and
	// underlying socket object
	class SSocketWrapper {
	public:
		SSocketWrapper(ImplPtr s, HandlerFunc h, ScoreFunc f)
		: m_socket(s), m_handler(h), m_scoreFunc(f) {}
		SSocketWrapper(ImplPtr s) : m_socket(s) {}
		ImplPtr     getSocket()  const { return m_socket;    }
		ScoreFunc   getScore()   const { return m_scoreFunc; }
		HandlerFunc getHandler() const { return m_handler;   }
		// pass notification to frontend
		void notify(EventType evt) {
			m_handler(m_socket, evt);
		}
		friend bool operator<(
			const SSocketWrapper &x, const SSocketWrapper &y
		) {
			return x.m_socket < y.m_socket;
		}
	private:
		ImplPtr     m_socket;     //!< Underlying socket
		HandlerFunc m_handler;    //!< Frontend event handler
		ScoreFunc   m_scoreFunc;  //!< Function to retrieve the score
	};

	// upload request
	class UploadReq : public SchedBase::UploadReqBase {
	public:
		UploadReq(SSocketWrapper s, const std::string &buf)
		: UploadReqBase(s.getScore()()), m_buffer(buf) {}
		virtual uint32_t doSend(uint32_t amount) {
			int ret = m_obj.getSocket()->write(
				m_buffer.c_str(), m_buffer.size()
			);
			if (ret < m_buffer.size()) {
				m_buffer = m_buffer.substr(ret + 1);
				return true;
			}
			m_obj.notify(SOCK_WRITE);
			return false;
		}
		virtual void notify() {
			m_obj.notify(SOCK_WRITE);
		}
		virtual uint32_t getPending() const { return m_buffer.size(); }
	private:
		SSocketWrapper m_obj;
		std::string m_buffer;
	};

	// Download reqest
	class DownloadReq : public SchedBase::DownloadReqBase {
	public:
		DownloadReq(SSocketWrapper s, std::string *buf)
		: DownloadReqBase(s.getScore()()), m_obj(s) {}
		// read out as much data from socket as allowed by parameter
		// return the amount of actually read
		virtual uint32_t doRecv(uint32_t amount) {
			char buf[amount];
			int ret = m_obj.getSocket()->read(buf, amount);
			// Got no data - mh ?
			if (ret == 0) {
				return 0;
			}
			// check if we already have buffered data for this
			// socket, and if so, append to existing buffer.
			// otherwise, create new buffer to store the data
			// in (and also notify client code)
			BIter i = s_bufferedPackets.find(m_obj->getSocket());
			if (i != s_bufferedPackets.end()) {
				(*i).second->append(buf, ret);
			} else {
				std::string *s = new std::string(buf, ret);
				s_bufferedPackets[m_obj.getSocket()] = s;
				m_obj.notify(SOCK_READ);
			}
			return ret;
		}
		virtual void notify() {
			m_obj.notify(SOCK_READ);
		}
	private:
		SSocketWrapper m_obj;
	};

	// accept request
	class AcceptReq : public SchedBase::ConnReqBase {
	public:
		AcceptReq(SSocketWrapper s) : ConnReqBase(s.getScore()()),
		m_obj(s) {}
		virtual bool doConn() {
			typename Impl::AcceptType *s = m_obj.getSocket()->accept();
			s_accepted.insert(std::make_pair(m_obj.getSocket(), s));
			return true;
		}
		virtual void notify() {
			m_obj.notify(SOCK_ACCEPT);
		}
	private:
		SSocketWrapper m_obj;
	};
	// connection request
	class ConnReq : public SchedBase::ConnReqBase {
	public:
		ConnReq(SSocketWrapper s, IPV4Address addr, uint32_t timeout)
		: ConnReqBase(s.getScore()()), m_obj(s), m_addr(addr),
		m_timeout(timeout) {}
		virtual bool doConn() {
			return m_obj.getSocet()->connect(m_addr, m_timeout);
		}
	private:
		SSocketWrapper m_obj;
		IPV4Address    m_addr;
		uint32_t       m_timeout;
	};

	// map of all scheduled sockets
	static std::set<SSocketWrapper> s_sockets;

	// the following two maps are filled by the scheduler and are waiting
	// for the user to retrieve the data within. Whenever elements are
	// added to these maps, user should be notified.

	// buffered incoming data
	static std::map<ImplPtr, std::string*> s_bufferedPackets;
	typedef typename std::map<ImplPtr, std::string*>::iterator BIter;

	// buffered incoming connections
	static std::map<ImplPtr, typename Impl::AcceptType*> s_accepted;
};
// initialize static data
template<typename P1, typename P2, typename P3 , typename P4, typename P5>
std::set<typename Scheduler<P1, P2, P3, P4, P5>::SSocketWrapper>
Scheduler<P1, P2, P3, P4, P5>::s_sockets;
template<typename P1, typename P2, typename P3 , typename P4, typename P5>
std::map<P2, std::string*>
Scheduler<P1, P2, P3, P4, P5>::s_bufferedPackets;
template<typename P1, typename P2, typename P3 , typename P4, typename P5>
std::map<P2, typename P1::AcceptType*>
Scheduler<P1, P2, P3, P4, P5>::s_accepted;

//! Socket types and protocols for easier SSocket class usage
namespace Socket {
	class TCP;        //!< Protocol:  TCP
	class UDP;        //!< Protocol:  UDP
	class Client;     //!< Semantics: Client
	class Server;     //!< Semantics: Server
};

// select the implementation class to use based on type and protocol
template<typename Type, typename Proto>
struct Implement;
template<>
struct Implement<Socket::Client, Socket::TCP> {
	typedef SocketClient Impl;
};
template<>
struct Implement<Socket::Client, Socket::UDP> {
	typedef UDPSocket Impl;
};
template<>
struct Implement<Socket::Server, Socket::TCP> {
	typedef SocketServer Impl;
};

/**
 * SSocket template represents a socket that can be serve as communication
 * medium between two remote parties. The exact implementation is chosen
 * based on the template parameters.
 *
 * @param Module        Required, the module governing this socket.
 * @param Type          Type of the socket. Can be either Client or Server
 * @param Protocol      Protocol to be used in the socket, e.g. TCP or UDP
 * @param Impl          Implementation class. By default, this is chosen based
 *                      on Type and Protocol.
 * @param ImplPtr       Exact type on how to handle the implementation object.
 *                      Default resolves to ref-counted shared pointer.
 * @param _Scheduler    Scheduler to be used for socket. This should not be
 *                      changed if used within HydraNode, and is provided only
 *                      for completeness here.
 */
template<
	typename Module,
	typename Type,
	typename Protocol   = Socket::TCP,
	typename Impl       = typename Implement<Type, Protocol>::Impl,
	typename ImplPtr    = Impl*,
	typename _Scheduler = Scheduler<Impl>
>
class SSocket {
public:
	//! Type of events
	typedef typename Impl::EventType EventType;

	//! User-defined event handler functor prototype
	typedef boost::function<void (SSocket*, EventType)> HandlerFunc;

	/**
	 * Construct and initialize, optionally setting event handler
	 *
	 * @param h      Optional event handler
	 */
	SSocket(HandlerFunc h = 0) : m_impl(new Impl), m_handler(h) {
		_Scheduler::template addSocket<Module>(
			m_impl, boost::bind(&SSocket::onEvent, this, _1, _2)
		);
	}

	/**
	 * Convenience constructor - performs event handler functor binding
	 * internally.
	 *
	 * @param obj      Object to receive event notifications
	 * @param func     Function to receive event notifications
	 */
	template<typename T>
	SSocket(T *obj, void (T::*func)(SSocket&, EventType))
	: m_impl(new Impl), m_handler(boost::bind(func, obj, _1, _2)) {
		_Scheduler::template addSocket<Module>(
			m_impl, boost::bind(&SSocket::onEvent, this, _1, _2)
		);
	}

	/**
	 * @name Input/Output
	 */
	//@{

	/**
	 * Write data into socket
	 *
	 * @param buf   Data to be written
	 * @param pt    Optional specification of data type. PACKET_OVERHEAD
	 *              type packets get slightly higher priority.
	 */
	void write(const std::string &buf, PacketType pt = PACKET_DATA) {
		_Scheduler::template write<Module>(m_impl, buf, pt);
	}

	/**
	 * Read data from socket
	 *
	 * @param buf   Buffer to read data into. The data is appended to the
	 *              specified string.
	 */
	void read(std::string *buf) {
		_Scheduler::template read<Module>(m_impl, buf);
	}

	/**
	 * Perform an outgoing connection
	 *
	 * @param addr      Address to connect to
	 * @param timeout   Optional timeout for connection attempt. Defaults
	 *                  to 5 seconds.
	 */
	void connect(IPV4Address addr, uint32_t timeout = 5000) {
		_Scheduler::template connect<Module>(m_impl, addr, timeout);
	}

	/**
	 * Disconnect a connected socket. If the socket is not connected, this
	 * function does nothing. Note that sockets are automatically
	 * disconnected when they are destroyed.
	 */
	void disconnect() {
		_Scheduler::template disconnect<Module>(m_impl);
	}

	/**
	 * Start a listener, waiting for incoming connections
	 *
	 * @param addr     Local address to listen on. If addr.ip is set to
	 *                 0, connections are accepted from all networks,
	 *                 otherwise connections are only accepted from the
	 *                 designated net. For example, if ip is 127.0.0.1,
	 *                 only loopback connections are accepted.
	 */
	void listen(IPV4Address addr) {
		m_impl->listen(addr);
	}

	/**
	 * Accept an incoming connection.
	 *
	 * @return         New socket, which is in connected state, ready to
	 *                 receive and transmit data. The return type depends
	 *                 on the underlying implementation. The returned socket
	 *                 is created in same module as the listening socket.
	 *
	 * \throws if there was no incoming connection pending at this moment.
	 */
	SSocket<
		Module, Socket::Client, Protocol,
		typename Implement<Type, Protocol>::Impl::AcceptType
	>* accept() {
		return new SSocket<
			Module, Socket::Client, Protocol,
			typename Implement<Type, Protocol>::Impl::AcceptType
		>(
			_Scheduler::template accept<Module>(m_impl)
		);
	}

	/**
	 * Send data to specific address. This applies only to UDP sockets.
	 *
	 * @param to       Address to send data to
	 * @param buf      Buffer containing the data to be sent
	 * @param pt       Optional specification of packet type.
	 *                 PACKET_OVERHEAD type packets get slightly increased
	 *                 priority.
	 */
	void send(
		IPV4Address to, const std::string &buf,
		PacketType pt = PACKET_DATA
	) {
		_Scheduler::template send<Module>(to, buf, pt);
	}

	/**
	 * Receive data from designated address. This applies only to UDP type
	 * sockets.
	 *
	 * @param from      Address to receive data from
	 * @param buf       Buffer to write the retrieved data to. The data is
	 *                  appended to the designated string.
	 */
	void recv(IPV4Address from, std::string *buf) {
		_Scheduler::template recv<Module>(from, buf);
	}
	//@}

	/**
	 * @name Event handling
	 */
	//@{
	//! Set event handler, overwriting old handler
	void        setHandler(HandlerFunc handler) { m_handler = handler; }
	//! Set the event handler, performing functor binding internally
	template<typename T>
	void setHandler(T *obj, void (T::*func)(SSocket&, EventType)) {
		m_handler = boost::bind(func, obj, _1, _2);
	}
	//! Retrieve the handler function object
	HandlerFunc getHandler() const              { return m_handler;    }
	//! Clear the existing event handler
	void        clearHandler()                  { m_handler.clear();   }
	//@}
private:
	ImplPtr     m_impl;           //!< Implementation object
	HandlerFunc m_handler;        //!< External event handler function

	/**
	 * Constructer used only internally during incoming connections
	 * accepting.
	 *
	 * @param s        New socket
	 */
	SSocket(typename Impl::AcceptType s) : m_impl(s) {
		_Scheduler::template addSocket<Module>(
			s, boost::bind(&SSocket::onEvent, this, _1, _2)
		);
	}

	/**
	 * Internal event handler, called from scheduler. Forwards the event
	 * to user-defined event handler (if present).
	 *
	 * @param ptr       Implementation pointer generating this event. Must
	 *                  match m_impl member.
	 * @param evt       The event itself
	 */
	void onEvent(ImplPtr ptr, EventType evt) {
		assert(ptr == m_impl);
		m_handler(this, evt);
	}

};

#endif // !__MSOCKET_H__