/**
 *  Copyright (C) 2004 Alo Sarv <madcat_@users.sourceforge.net>
 *
 *  This program is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation; either version 2 of the License, or
 *  (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program; if not, write to the Free Software
 *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 */

#include <hn/sockets.h>            // xplatform socket api
#include <hn/gettickcount.h>       // getTick()
#include <boost/noncopyable.hpp>   // scheduler isn't copyable
#include <boost/smart_ptr.hpp>     // smart pointers
#include <boost/function.hpp>      // function objects
#include <boost/bind.hpp>          // function object binder

template<typename, typename, typename>
class Scheduler;

// Contains the common data required by all schedulers
class SchedBase {
public:
	static void     setUpLimit(uint32_t amount)   { s_upLimit = amount;   }
	static void     setDownLimit(uint32_t amount) { s_downLimit = amount; }
	static uint32_t getUpLimit()                  { return s_upLimit;     }
	static uint32_t getDownLimit()                { return s_downLimit;   }
	static uint64_t getTotalUpstream()            { return s_totalUp;     }
	static uint64_t getTotalDownstream()          { return s_totalDown;   }
private:
	template<typename, typename, typename>
	friend class Scheduler;

	// Keeps current tick - to reduce getTick() calls during sorting
	static uint32_t s_curTick;
	// last sent data amounts, along with timestamps
	static std::deque<std::pair<uint32_t, uint32_t> > s_lastSent;
	// last received data amounts, along with timestamps
	static std::deque<std::pair<uint32_t, uint32_t> > s_lastRecv;
	static uint32_t s_upLimit;      //!< upstream limit
	static uint32_t s_downLimit;    //!< downstream limit
	static uint32_t s_connLimit;    //!< open connections limit
	static uint32_t s_connCnt;      //!< open connections count
	static uint64_t s_totalUp;      //!< overall total uploaded
	static uint64_t s_totalDown;    //!< overall total downloaded

	static uint32_t getUpstream() {
		s_curTick = getTick();
		while (s_lastSent.back().first + 100 < s_curTick) {
			s_lastSent.pop_back();
		}
		uint32_t sum = 0;
		for (uint32_t i = 0; i < s_lastSent.size(); ++i) {
			sum += s_lastSent[i].second;
		}
		return getUpLimit() - sum;
	}
	static uint32_t getDownStream() {
		s_curTick = getTick();
		while (s_lastRecv.back().first + 100 < s_curTick) {
			s_lastRecv.pop_back();
		}
		uint32_t sum = 0;
		for (uint32_t i = 0; i < s_lastRecv.size(); ++i) {
			sum += s_lastRecv[i].second;
		}
		return getDownLimit() - sum;
	}
	static bool getConnection() {
		return s_connCnt < s_connLimit;
	}

	template<class Module>
	static float getScore() {
		float up_perc = Module::getUploaded()*100.0/s_totalUp;
		float dn_perc = Module::getDownloaded()*100.0/s_totalDown;
		return Module::getPriority() + up_perc - dn_perc;
	}

	class PacketBase {
	public:
		virtual void doSend() = 0;
	};

	class PacketWrapper {
	public:
		PacketBase *m_obj;
		friend bool operator<(
			const PacketWrapper &x, const PacketWrapper &y
		) {
			return *x.m_obj < *y.m_obj;
		}
	};

	static std::vector<PacketWrapper> s_queued;

};

// scheduler class, one per each socket type
template<
	typename SocketType,
	typename ImplPtr     = typename SocketType::ImplPtr,
	typename HandlerFunc = typename SocketType::HandlerFunc
>
class Scheduler {
public:
	template<class Module>
	static void write(ImplPtr ptr, const std::string &buf, PacketType pt) {
		uint32_t amount = 0;
		if (amount = SchedBase::getUpstream() > buf.size()) {
			amount = buf.size();
		}
		if (amount == 0) {
			Wrapper w(SchedBase::getScore<Module>(), ptr, buf);
			s_queued.push_back(w);
			return;
		}
		amount = ptr->write(buf.c_str(), buf.size());
		SchedBase::s_lastSent.push_front(
			std::make_pair(amount, getTick())
		);
		if (amount < buf.size()) {
			Wrapper w(
				SchedBase::getScore<Module>(), ptr,
				buf.substr(amount+1)
			);
			s_queued.push_back(w);
		}
		Module::addUpstream(amount);
	}
	template<class Module>
	static void read(ImplPtr ptr, std::string *buf) {
		while (uint32_t amount = SchedBase::getDownStream()) {
			if (amount > 1024) {
				amount = 1024;
			}
			char tmp[amount];
			if (uint32_t cnt = ptr->read(tmp, amount)) {
				buf->append(tmp, cnt);
				SchedBase::s_lastRecv.push_back(
					std::make_pair(getTick(), cnt)
				);
			} else {
				break;
			}
		}
	}
	template<class Module>
	static void connect(ImplPtr ptr, IPV4Address addr, uint32_t timeout) {
		if (getConnection()) {
			ptr->connect(addr, timeout);
		} else {
			// grrr
			// s_pendingConns.push_back(ptr, addr, timeout);
		}
	}

	template<class Module>
	static void disconnect(ImplPtr ptr) {
		ptr->disconnect();
		--s_connCnt;
	}
	static void addSocket(ImplPtr socket, HandlerFunc h) {
		s_sockets[socket] = h;
	}
	static void delSocket(ImplPtr socket, HandlerFunc h) {
		s_sockets.erase(socket);
	}
private:
	class Packet : public SchedBase::PacketBase {
	public:
		Wrapper(
			float baseScore, ImplPtr socket,
			const std::string &buffer
		) : m_score(baseScore), m_entry(getTick()), m_buffer(buffer),
		m_socket(socket) {}
		friend bool operator<(const Wrapper &x, const Wrapper &y) {
			uint8_t xt = x.getWaitBonus();
			uint8_t yt = x.getWaitBonus();
			return x.m_score + xt < y.m_score + yt;
		}
		/**
		 * Send out some data to it's socket
		 *
		 * @param amount   Amount of data allowed to be sent
		 * @return         True if all packet's data was sent out
		 */
		virtual bool doSend(uint32_t amount) {
			m_socket->write(m_buffer.c_str(), m_buffer.size());
		}
	private:
		uint8_t getWaitBonus() const {
			return Scheduler::s_curTick - m_entry / 500000;
		}

		float m_score;
		uint32_t m_entry;
		std::string m_buffer;
		ImplPtr m_socket;
	};
	static std::vector<Wrapper> s_queued;
	static std::map<ImplPtr, HandlerFunc> s_sockets;

	static void onSocketEvent(ImplPtr ptr, SocketEvent evt) {
		switch (evt) {
			case SOCK_READ:
				if (SchedBase::getDownStream() > 0) {
					s_sockets[ptr](ptr, SOCK_READ);
				}
				break;
			case SOCK_WRITE:
				s_writable.insert(ptr);
				trySendPackets();
				break;
			case SOCK_LOST:
				s_sockets[ptr](ptr, SOCK_LOST);
				break;
			case SOCK_ERR:
				s_sockets[ptr](ptr, SOCK_ERR);
				break;
			case SOCK_ACCEPT:
				if (SchedBase::getConnection()) {
					s_sockets[ptr](ptr, SOCK_ACCEPT);
				}
				break;
		}
	}
	static std::set<ImplPtr> s_writable;
	typedef typename std::set<ImplPtr>::iterator WIter;
	static void trySendPackets() {
		std::sort(s_queued.begin(), s_queued.end());
		for (uint32_t i = 0; i < s_queued.size(); ++i) {
			WIter i = s_writable.find(s_queued[i].getSocket());
			if (i == s_writable.end()) {
				continue;
			}
			write(*i, s_queued[i].getData(), s_queued[i]->getType());
		}
	}
};

// types and protocols
namespace Socket {
	class TCP;
	class UDP;
	class Client;
	class Server;
};

// select the implementation class to use based on type and protocol
template<typename Type, typename Proto>
struct Implement;
template<>
struct Implement<Socket::Client, Socket::TCP> {
	typedef SocketClient Impl;
};
template<>
struct Implement<Socket::Client, Socket::UDP> {
	typedef UDPSocket Impl;
};
template<>
struct Implement<Socket::Server, Socket::TCP> {
	typedef SocketServer Impl;
};

// Forward-declaration
template<typename, typename, typename, typename, typename, typename>
class SSocket;

/**
 * SSocket template represents a socket that can be serve as communication
 * medium between two remote parties. The exact implementation is chosen
 * based on the template parameters.
 *
 * @param Module        Required, the module governing this socket.
 * @param Type          Type of the socket. Can be either Client or Server
 * @param Protocol      Protocol to be used in the socket, e.g. TCP or UDP
 * @param Impl          Implementation class. By default, this is chosen based
 *                      on Type and Protocol.
 * @param ImplPtr       Exact type on how to handle the implementation object.
 *                      Default resolves to ref-counted shared pointer.
 * @param _Scheduler    Scheduler to be used for socket. This should not be
 *                      changed if used within HydraNode, and is provided only
 *                      for completeness here.
 */
template<
	typename Module,
	typename Type,
	typename Protocol   = Socket::TCP,
	typename Impl       = Implement<Type, Protocol>,
	typename ImplPtr    = boost::shared_ptr<Impl>,
	typename _Scheduler = Scheduler<ImplPtr>
>
class SSocket {
public:
	// Basic stuff
	typedef typename Impl::EventType EventType;
	typedef boost::function<void (SSocket&, EventType)> HandlerFunc;

	SSocket(HandlerFunc h = 0) : m_impl(new Impl), m_handler(h) {
		_Scheduler::template addSocket<Module>(
			m_impl, boost::bind(&SSocket::onEvent, this, _1, _2)
		);
	}
	// Convenience - construct the handler object ourselves
	template<typename T>
	SSocket(T *obj, void (T::*func)(SSocket&, EventType))
	: m_impl(new Impl), m_handler(boost::bind(func, obj, _1, _2)) {
		_Scheduler::template addSocket<Module>(
			m_impl, boost::bind(&SSocket::onEvent, this, _1, _2)
		);
	}

	// input / output
	// things relevant to connection-oriented socket clients
	void write(const std::string &buf, PacketType pt = PACKET_DATA) {
		_Scheduler::template write<Module>(m_impl, buf, pt);
	}
	void read(std::string *buf) {
		_Scheduler::template read<Module>(m_impl, buf);
	}
	void connect(IPV4Address addr, uint32_t timeout = 5000) {
		_Scheduler::template connect<Module>(m_impl, addr, timeout);
	}
	void disconnect() {
		_Scheduler::template disconnect<Module>(m_impl);
	}
	// things related to connection-oriented servers
	void listen(IPV4Address addr) {
		m_impl->listen(addr);
	}
	SSocket<Module, typename Impl::AcceptType> accept() {
		return SSocket<
			Module, typename Impl::AcceptType
		>(m_impl->accept());
	}
	// things related to connection-less sockets
	void send(
		IPV4Address to, const std::string &buf,
		PacketType pt = PACKET_DATA
	);
	void recv(IPV4Address from, std::string *buf);

	// events handling
	void        setHandler(HandlerFunc handler) { m_handler = handler; }
	void        delHandler(HandlerFunc handler) { m_handler.clear();   }
	HandlerFunc getHandler() const              { return m_handler;    }
	// Convenience methods - performs function object binding internally
	template<typename T>
	void setHandler(T *obj, void (T::*func)(SSocket&, EventType)) {
		setHandler(boost::bind(func, obj, _1, _2));
	}
	template<typename T>
	void delHandler(T *obj, void (T::*func)(SSocket&, EventType)) {
		delHandler(boost::bind(func, obj, _1, _2));
	}
private:
	ImplPtr m_impl;               //!< Implementation object
	HandlerFunc m_handler;        //!< External event handler function

	SSocket(typename Impl::AcceptType s) : m_impl(s) {
		_Scheduler::template addSocket<Module>(
			s, boost::bind(&SSocket::onEvent, this, _1, _2)
		);
	}

	void onEvent(ImplPtr ptr, EventType evt) {
		m_handler(*this, evt);
	}

};