/**
 *  Copyright (C) 2004 Alo Sarv <madcat_@users.sourceforge.net>
 *
 *  This program is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation; either version 2 of the License, or
 *  (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program; if not, write to the Free Software
 *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 */

#ifndef __MSOCKET_H__
#define __MSOCKET_H__

#include <hn/sockets.h>            // xplatform socket api
#include <hn/gettickcount.h>       // getTick()
#include <boost/smart_ptr.hpp>     // smart pointers
#include <boost/function.hpp>      // function objects
#include <boost/bind.hpp>          // function object binder

// forward declare scheduler
template<typename, typename, typename> class Scheduler;

/**
 * SchedBase implements the third level of hydranode networking scheduler.
 * This class performs the actual bandwidth and connections dividing between
 * requests in it's main scheduler loop. It is derived from EventTableBase,
 * and as such will be run during each event loop, performing bandwidth requests
 * handling.
 */
class SchedBase : public EventTableBase {
public:
	// singleton
	static SchedBase& instance() {
		static SchedBase s;
		return s;
	}

	// accessors for various limits and other internal variables
	void     setUpLimit(uint32_t amount)   { m_upLimit = amount;   }
	void     setDownLimit(uint32_t amount) { m_downLimit = amount; }
	uint32_t getUpLimit()                  { return m_upLimit;     }
	uint32_t getDownLimit()                { return m_downLimit;   }
	uint64_t getTotalUpstream()            { return m_totalUp;     }
	uint64_t getTotalDownstream()          { return m_totalDown;   }
private:
	// need it to access our internals
	template<typename, typename, typename>
	friend class Scheduler;

	// Request base, only contains score of the request
	class ReqBase {
	public:
		ReqBase(float score) : m_score(score) {}
		virtual ~ReqBase() {}
		float getScore() const { return m_score; }
	private:
		float m_score;
	};

	// Request of type upload
	class UploadReqBase : public ReqBase {
	public:
		UploadReqBase(float score) : ReqBase(score) {}
		virtual uint32_t doSend(uint32_t amount) = 0;
	};
	// Request of type download
	class DownloadReqBase : public ReqBase {
	public:
		DownloadReqBase(float score) : ReqBase(score) {}
		virtual uint32_t doRecv(uint32_t amount) = 0;
	};
	// Request of type connection
	class ConnReqBase : public ReqBase {
	public:
		ConnReqBase(float score) : ReqBase(score) {}
		virtual void doConn() = 0;
	};

	// request sets
	std::set<UploadReqBase*>   m_uploadReqs;
	std::set<DownloadReqBase*> m_downloadReqs;
	std::set<ConnReqBase*>     m_connReqs;

	// iterator typedefs for the above sets to make life easier
	typedef std::set<UploadReqBase*>::iterator UIter;
	typedef std::set<DownloadReqBase*>::iterator DIter;
	typedef std::set<ConnReqBase*>::iterator CIter;

	//! Main networking loop and helper functions
	void handleEvents();
	void handleDownloads();
	void handleUploads();
	void handleConnections();

	// get free bandwidth for requests
	uint32_t getFreeDown();
	uint32_t getFreeUp();
	bool     getConnection();

	// Keeps current tick - to reduce getTick() calls somewhat
	uint32_t m_curTick;

	// various limits and counts
	uint32_t m_upLimit;      //!< upstream limit
	uint32_t m_downLimit;    //!< downstream limit
	uint32_t m_connLimit;    //!< open connections limit
	uint32_t m_connCnt;      //!< open connections count
	uint64_t m_totalUp;      //!< overall total uploaded
	uint64_t m_totalDown;    //!< overall total downloaded

	// last sent data amounts, along with timestamps
	std::deque<std::pair<uint32_t, uint32_t> > m_lastSent;

	// last received data amounts, along with timestamps
	std::deque<std::pair<uint32_t, uint32_t> > m_lastRecv;

};

/**
 * Scheduler class, implementing second level of HydraNode Networking Scheduling
 * API, abstracts away modules part of the sockets by generating a priority
 * score (PS) for each of the pending requests. All requests are received from
 * the frontend, wrapped into generic containers, and buffered internally for
 * later processing. No direct action shall be taken in the functions directly
 * or indirectly called from frontend.
 *
 * @param SocketType       Type of socket implemented in this scheduler
 * @param ImplPtr          Actual object type passed to methods
 *                         Autodetected from SocketType parameter
 * @param HandlerFunc      Function object in frontend to be called on events
 *                         Autodetected from SocketType parameter
 */
template<
	typename SocketType,
	typename ImplPtr     = typename SocketType::ImplPtr,
	typename HandlerFunc = typename SocketType::HandlerFunc
>
class Scheduler {
public:
	// Write call
	// First check if we already have pending data on the pointed socket.
	// If that is so, the additional data must be added to the existing
	// upload request. Otherwise, a new request must be created and added
	// to pending requests list
	template<class Module>
	static void write(ImplPtr ptr, const std::string &buf, PacketType pt) {
		CHECK_THROW(ptr->isConnected());
	}
	// Read call
	// Copy all pending data on pointer designated by ptr to buf, and erase
	// the buffered data. No additional operations are permitted.
	template<class Module>
	static void read(ImplPtr ptr, std::string *buf) {
		CHECK_THROW(ptr->isConnected());
		BIter i = s_bufferedPackets.find(ptr);
		if (i != s_bufferedPackets.end()) {
			// append to designated buffer
			buf->append(*(*i).second);
			// clean up internal buffer
			delete (*i).second;
			s_bufferedPackets.erase(ptr);
		}
	}

	// Add a socket into scheduler
	template<class Module>
	static void addSocket(ImplPtr s, HandlerFunc h) {
		s_sockets.insert(SSocketWrapper(s, h, &Module::getScore));
		s->getEventTable().addHandler(s, &Scheduler::onSocketEvent);
	}
	// Remove a socket from scheduler
	static void delSocket(ImplPtr s) {
		s_sockets.erase(SSocketWrapper(s));
		s->getEventTable().delHandler(s, &Scheduler::onSocketEvent);
	}
private:
	// Function that can be used to retrieve a module's score
	typedef typename boost::function<float ()> ScoreFunc;

	// Wrapper object for scheduled socket, contains all the useful
	// information we need, e.g. score, frontend event handler, and
	// underlying socket object
	class SSocketWrapper {
	public:
		SSocketWrapper(ImplPtr s, HandlerFunc h, ScoreFunc f)
		: m_socket(s), m_handler(h), m_scoreFunc(f) {}
		SSocketWrapper(ImplPtr s) : m_socket(s) {}
		ImplPtr     getSocket()  const { return m_socket;    }
		ScoreFunc   getScore()   const { return m_scoreFunc; }
		Handlerfunc getHandler() const { return m_handler;   }
		// pass notification to frontend
		void notify(typename SocketType::EventType evt) {
			m_handler(m_socket, evt);
		}
		friend bool operator<(
			const SSocketWrapper &x, const SSocketWrapper &y
		) {
			return x.m_socket < y.m_socket;
		}
	private:
		ImplPtr     m_socket;     //!< Underlying socket
		HandlerFunc m_handler;    //!< Frontend event handler
		ScoreFunc   m_scoreFunc;  //!< Function to retrieve the score
	};

	// upload request
	class UploadReq : public SchedBase::UploadReqBase {
	public:
		UploadReq(SSocketWrapper s, const std::string &buf)
		) : UploadReqBase(s.getScore()()), m_buffer(buf) {}
		virtual uint32_t doSend(uint32_t amount) {
			int ret = m_obj.getSocket()->write(
				m_buffer.c_str(), m_buffer.size()
			);
			if (ret < m_buffer.size()) {
				m_buffer = m_buffer.substr(ret + 1);
				return true;
			}
			m_obj.notify(SOCK_WRITE);
			return false;
		}
	private:
		SSocketWrapper m_obj;
		std::string m_buffer;
	};

	// Download reqest
	class DownloadReq : public SchedBase::DownloadReqBase {
	public:
		DownloadReq(SSocketWrapper s, std::string *buf)
		: DownloadReqBase(s.getScore()()), m_obj(s) {}
		// read out as much data from socket as allowed by parameter
		// return the amount of actually read
		virtual uint32_t doRecv(uint32_t amount) {
			char buf[amount];
			int ret = m_obj.getSocket()->read(buf, amount);
			// Got no data - mh ?
			if (ret == 0) {
				return 0;
			}
			// check if we already have buffered data for this
			// socket, and if so, append to existing buffer.
			// otherwise, create new buffer to store the data
			// in (and also notify client code)
			BIter i = s_bufferedPackets.find(m_obj->getSocket());
			if (i != s_bufferedPackets.end()) {
				(*i).second->append(buf, ret);
			} else {
				std::string *s = new std::string(buf, ret);
				s_bufferedPackets[m_obj.getSocket()] = s;
				m_obj.notify(SOCK_READ);
			}
			return ret;

			m_buf->append(buf, ret);
			if (m_buf->size() > 0) {
				m_obj.getHandler()(
					m_obj.getSocket(), SOCK_READ
				);
			}
			return ret;
		}
	private:
		SSocketWrapper m_obj;
	};

	// accept request
	class AcceptReq : public SchedBase::ConnReqBase {
	public:
		AcceptReq(SSocketWrapper s) : ConnReqBase(s.getScore()()),
		m_obj(s) {}
		virtual bool doConn() {
			ImplPtr::AcceptType s = m_obj.getSocket()->accept();
			s_accepted.insert(std::make_pair(m_obj, s));
			m_handler(m_obj.getSocket(), SOCK_ACCEPT);
		}
	private:
		SSocketWrapper m_obj;
	};
	// connection request
	class ConnReq : public SchedBase::ConnReqBase {
	public:
		ConnReq(SSocketWrapper s, IPV4Address addr, uint32_t timeout)
		: ConnReqBase(s.getScore()(), m_obj(s), m_addr(addr),
		m_timeout(timeout) {}
		virtual bool doConn() {
			m_obj.getSocet()->connect(m_addr, timeout);
		}
	private:
		SSocketWrapper m_obj;
		IPV4Address    m_addr;
		uint32_t       m_timeout;
	};

	// map of all scheduled sockets
	static std::set<SSocketWrapper> s_sockets;

	// the following two maps are filled by the scheduler and are waiting
	// for the user to retrieve the data within. Whenever elements are
	// added to these maps, user should be notified.

	// buffered incoming data
	static std::map<ImplPtr, std::string*> s_bufferedPackets;

	// buffered incoming connections
	static std::map<ImplPtr, typename ImplPtr::AcceptType> s_accepted;

	// Event handler for socket events for all scheduled sockets
	// no, this doesn't work, since this is runtime detection stuff,
	// which doesn't play along well with our static polymorphism
	// approach - we'd end up instanciating members here which we can't
	// really handle (e.g. SOCK_ACCEPT would instanciate server-related
	// methods/member objects, even in case of Client type underlying
	// socket, and this would lead to compile-time errors).
	// Most obvious solution would be to move the event handling closer
	// to the actual underlying implementation. Or - perhaps a better
	// solution - tempalate specializations, based on the underlying
	// implementation type, in which we only handle the events we are
	// interested in...
	static void onSocketEvent(ImplPtr ptr, SocketEvent evt) {
		switch (evt) {
			case SOCK_READ:
				SchedBase::addDloadReq(SSocketWrapper(ptr));
				break;
			case SOCK_WRITE:
				s_writable.insert(ptr);
				break;
			case SOCK_LOST:
				s_sockets[ptr](ptr, SOCK_LOST);
				break;
			case SOCK_ERR:
				s_sockets[ptr](ptr, SOCK_ERR);
				break;
			case SOCK_ACCEPT:
				s_incoming.insert(ptr);
				break;
			case SOCK_CONNECTED:
				s_sockets[ptr](ptr, SOCK_CONNECTED);
				break;
		}
	}

	// calculate the relative PS for the given module
	// this is based on how high % of the total upstream and how much
	// of the total downstream has the module used. The result is that
	// modules who use relativly more downstream than upstream get higher
	// PS, thus also more bandwidth
	template<class Module>
	static float getScore() {
		float up_perc = Module::getUploaded()*100.0/s_totalUp;
		float dn_perc = Module::getDownloaded()*100.0/s_totalDown;
		return Module::getPriority() + up_perc - dn_perc;
	}

};


//! Socket types and protocols for easier SSocket class usage
namespace Socket {
	class TCP;        //!< Protocol:  TCP
	class UDP;        //!< Protocol:  UDP
	class Client;     //!< Semantics: Client
	class Server;     //!< Semantics: Server
};

// select the implementation class to use based on type and protocol
template<typename Type, typename Proto>
struct Implement;
template<>
struct Implement<Socket::Client, Socket::TCP> {
	typedef SocketClient Impl;
};
template<>
struct Implement<Socket::Client, Socket::UDP> {
	typedef UDPSocket Impl;
};
template<>
struct Implement<Socket::Server, Socket::TCP> {
	typedef SocketServer Impl;
};

/**
 * SSocket template represents a socket that can be serve as communication
 * medium between two remote parties. The exact implementation is chosen
 * based on the template parameters.
 *
 * @param Module        Required, the module governing this socket.
 * @param Type          Type of the socket. Can be either Client or Server
 * @param Protocol      Protocol to be used in the socket, e.g. TCP or UDP
 * @param Impl          Implementation class. By default, this is chosen based
 *                      on Type and Protocol.
 * @param ImplPtr       Exact type on how to handle the implementation object.
 *                      Default resolves to ref-counted shared pointer.
 * @param _Scheduler    Scheduler to be used for socket. This should not be
 *                      changed if used within HydraNode, and is provided only
 *                      for completeness here.
 */
template<
	typename Module,
	typename Type,
	typename Protocol   = Socket::TCP,
	typename Impl       = Implement<Type, Protocol>,
	typename ImplPtr    = boost::shared_ptr<Impl>,
	typename _Scheduler = Scheduler<ImplPtr>
>
class SSocket {
public:
	//! Type of events
	typedef typename Impl::EventType EventType;

	//! User-defined event handler functor prototype
	typedef boost::function<void (SSocket&, EventType)> HandlerFunc;

	/**
	 * Construct and initialize, optionally setting event handler
	 *
	 * @param h      Optional event handler
	 */
	SSocket(HandlerFunc h = 0) : m_impl(new Impl), m_handler(h) {
		_Scheduler::template addSocket<Module>(
			m_impl, boost::bind(&SSocket::onEvent, this, _1, _2)
		);
	}

	/**
	 * Convenience constructor - performs event handler functor binding
	 * internally.
	 *
	 * @param obj      Object to receive event notifications
	 * @param func     Function to receive event notifications
	 */
	template<typename T>
	SSocket(T *obj, void (T::*func)(SSocket&, EventType))
	: m_impl(new Impl), m_handler(boost::bind(func, obj, _1, _2)) {
		_Scheduler::template addSocket<Module>(
			m_impl, boost::bind(&SSocket::onEvent, this, _1, _2)
		);
	}

	/**
	 * @name Input/Output
	 */
	//@{

	/**
	 * Write data into socket
	 *
	 * @param buf   Data to be written
	 * @param pt    Optional specification of data type. PACKET_OVERHEAD
	 *              type packets get slightly higher priority.
	 */
	void write(const std::string &buf, PacketType pt = PACKET_DATA) {
		_Scheduler::template write<Module>(m_impl, buf, pt);
	}

	/**
	 * Read data from socket
	 *
	 * @param buf   Buffer to read data into. The data is appended to the
	 *              specified string.
	 */
	void read(std::string *buf) {
		_Scheduler::template read<Module>(m_impl, buf);
	}

	/**
	 * Perform an outgoing connection
	 *
	 * @param addr      Address to connect to
	 * @param timeout   Optional timeout for connection attempt. Defaults
	 *                  to 5 seconds.
	 */
	void connect(IPV4Address addr, uint32_t timeout = 5000) {
		_Scheduler::template connect<Module>(m_impl, addr, timeout);
	}

	/**
	 * Disconnect a connected socket. If the socket is not connected, this
	 * function does nothing. Note that sockets are automatically
	 * disconnected when they are destroyed.
	 */
	void disconnect() {
		_Scheduler::template disconnect<Module>(m_impl);
	}

	/**
	 * Start a listener, waiting for incoming connections
	 *
	 * @param addr     Local address to listen on. If addr.ip is set to
	 *                 0, connections are accepted from all networks,
	 *                 otherwise connections are only accepted from the
	 *                 designated net. For example, if ip is 127.0.0.1,
	 *                 only loopback connections are accepted.
	 */
	void listen(IPV4Address addr) {
		m_impl->listen(addr);
	}

	/**
	 * Accept an incoming connection.
	 *
	 * @return         New socket, which is in connected state, ready to
	 *                 receive and transmit data. The return type depends
	 *                 on the underlying implementation. The returned socket
	 *                 is created in same module as the listening socket.
	 *
	 * \throws if there was no incoming connection pending at this moment.
	 */
	SSocket<Module, typename Impl::AcceptType> accept() {
		return SSocket<
			Module, typename Impl::AcceptType
		>(_Scheduler::template accept<Module>(m_impl));
	}

	/**
	 * Send data to specific address. This applies only to UDP sockets.
	 *
	 * @param to       Address to send data to
	 * @param buf      Buffer containing the data to be sent
	 * @param pt       Optional specification of packet type.
	 *                 PACKET_OVERHEAD type packets get slightly increased
	 *                 priority.
	 */
	void send(
		IPV4Address to, const std::string &buf,
		PacketType pt = PACKET_DATA
	) {
		_Scheduler::template send<Module>(to, buf, pt);
	}

	/**
	 * Receive data from designated address. This applies only to UDP type
	 * sockets.
	 *
	 * @param from      Address to receive data from
	 * @param buf       Buffer to write the retrieved data to. The data is
	 *                  appended to the designated string.
	 */
	void recv(IPV4Address from, std::string *buf) {
		_Scheduler::template recv<Module>(from, buf);
	}
	//@}

	/**
	 * @name Event handling
	 */
	//@{
	//! Set event handler, overwriting old handler
	void        setHandler(HandlerFunc handler) { m_handler = handler; }
	//! Set the event handler, performing functor binding internally
	template<typename T>
	void setHandler(T *obj, void (T::*func)(SSocket&, EventType)) {
		m_handler = boost::bind(func, obj, _1, _2);
	}
	//! Retrieve the handler function object
	HandlerFunc getHandler() const              { return m_handler;    }
	//! Clear the existing event handler
	void        clearHandler()                  { m_handler.clear();   }
	//@}
private:
	ImplPtr     m_impl;           //!< Implementation object
	HandlerFunc m_handler;        //!< External event handler function

	/**
	 * Constructer used only internally during incoming connections
	 * accepting.
	 *
	 * @param s        New socket
	 */
	SSocket(typename Impl::AcceptType s) : m_impl(s) {
		_Scheduler::template addSocket<Module>(
			s, boost::bind(&SSocket::onEvent, this, _1, _2)
		);
	}

	/**
	 * Internal event handler, called from scheduler. Forwards the event
	 * to user-defined event handler (if present).
	 *
	 * @param ptr       Implementation pointer generating this event. Must
	 *                  match m_impl member.
	 * @param evt       The event itself
	 */
	void onEvent(ImplPtr ptr, EventType evt) const {
		assert(ptr == m_impl);
		m_handler(*this, evt);
	}

};

#endif // !__MSOCKET_H__