/**
 *  Copyright (C) 2004-2005 Alo Sarv <madcat_@users.sourceforge.net>
 *
 *  This program is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation; either version 2 of the License, or
 *  (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program; if not, write to the Free Software
 *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 */

#include <hn/partdata.h>
#include <hn/log.h>
#include <hn/lambda_placeholders.h>
#include <hn/metadata.h>
#include <boost/lambda/lambda.hpp>
#include <boost/lambda/if.hpp>
#include <boost/lambda/bind.hpp>
#include <fstream>

using namespace boost::lambda;
using namespace boost::multi_index;
using namespace CGComm;
namespace CGComm {
	//! Opcodes used within PartData object I/O between streams
	enum PartDataOpCodes {
		OP_PD_VER         = 0x01,  //!< uint8  File version
		OP_PARTDATA       = 0x90,  //!< uint8  PartData object
		OP_PD_DOWNLOADED  = 0x91,  //!< uint64 Downloaded data
		OP_PD_DESTINATION = 0x92,  //!< string Destination location
		OP_PD_COMPLETED   = 0x93,  //!< FullRangeList completed ranges
		OP_PD_HASHSET     = 0x94   //!< FullRangeList<HashBase*> hashset
	};
}

static const uint32_t BUF_SIZE_LIMIT = 512*1024; //!< 512k buffer

// PartData::UsedRange class
// -------------------------
template<typename IterType>
PartData::UsedRange::UsedRange(PartData *parent, IterType it)
: Range64((*it).begin(), (*it).end()), m_parent(parent),
m_chunk(parent->m_chunks.project<ID_Avail>(it)) {
	CHECK_THROW(parent != 0);
	CHECK_THROW(m_chunk != m_parent->m_chunks.get<ID_Avail>().end());
	parent->m_chunks.get<ID_Avail>().modify(
		m_chunk, ++bind(&PartData::Chunk::m_useCnt, __1(__1))
	);
}
PartData::UsedRange::~UsedRange() {
	m_parent->m_chunks.get<ID_Avail>().modify(
		m_chunk, --bind(&PartData::Chunk::m_useCnt, __1(__1))
	);
}
LockedRangePtr PartData::UsedRange::getLock(uint32_t size) {
	return m_parent->getLock(shared_from_this(), size);
}
bool PartData::UsedRange::isComplete() const {
	return m_parent->isComplete(*this);
}

// PartData::LockedRange class
// ---------------------------
PartData::LockedRange::LockedRange(PartData *parent, Range64 r)
: Range64(r), m_parent(parent), m_chunk(parent->m_chunks.get<ID_Avail>().end()){
	m_parent->m_locked.merge(*this);
}
template<typename IterType>
PartData::LockedRange::LockedRange(PartData *parent, Range64 r, IterType it)
: Range64(r), m_parent(parent), m_chunk(parent->m_chunks.project<ID_Avail>(it)){
	CHECK_THROW(parent != 0);
	CHECK_THROW(m_chunk != m_parent->m_chunks.get<ID_Avail>().end());
	m_parent->m_locked.merge(*this);
}
PartData::LockedRange::~LockedRange() {
	m_parent->m_locked.erase(*this);
}
void PartData::LockedRange::write(uint64_t begin, const std::string &data) {
	if (begin > end() || begin + data.size() - 1 > end()) {
		throw PartData::LockError("Writing outside lock.");
	}
	if (m_chunk != m_parent->m_chunks.get<ID_Avail>().end()) {
		m_parent->m_chunks.get<ID_Avail>().modify(
			m_chunk, bind(&Chunk::write, __1, begin, data)
		);
	} else {
		m_parent->doWrite(begin, data);
	}
}

// PartData exception classes
// --------------------------
PartData::LockError::LockError(const std::string &msg):std::runtime_error(msg){}
PartData::RangeError::RangeError(const std::string &mg):std::runtime_error(mg){}

// PartData::Chunk class
// ---------------------
PartData::Chunk::Chunk(
	PartData *parent, uint64_t begin, uint64_t end, const HashBase *hash
) : Range64(begin, end), m_parent(parent), m_hash(hash), m_verified(),
m_partial(), m_avail(), m_useCnt() {}
PartData::Chunk::Chunk(PartData *parent, Range64 range, const HashBase *hash)
: Range64(range), m_parent(parent), m_hash(hash), m_verified(), m_partial(),
m_avail(), m_useCnt() {}

void PartData::Chunk::write(uint64_t begin, const std::string &data) {
	m_parent->doWrite(begin, data);
	if (isComplete()) {
		logMsg(
			boost::format("%s: Completed chunk %d..%d")
			% m_parent->m_destination.leaf() % this->begin()
			% this->end()
		);
		m_partial = false;
		m_verified = false;
		if (m_hash) {
			m_parent->flushBuffer();
			boost::shared_ptr<HashWork> c(
				new HashWork(
					m_parent->m_location.string(),
					this->begin(), this->end(), m_hash
				)
			);
			HashWork::getEventTable().addHandler(
				c, this, &Chunk::onHashEvent
			);
			WorkThread::instance().postWork(c);
			++m_parent->m_pendingHashes;
		}
	}
}

void PartData::Chunk::onHashEvent(HashWorkPtr c, HashEvent evt){
	if (evt == HASH_FAILED) {
		boost::format fmt("%s: Corruption found at %d..%d");
		logMsg(fmt % m_parent->m_destination.leaf() % begin() % end());
		m_parent->m_complete.erase(begin(), end());
		m_verified = false;
		if (m_parent->m_fullJob) {
			m_parent->m_fullJob->cancel();
			m_parent->m_fullJob.reset();
		}
	} else if (evt == HASH_VERIFIED) {
		m_verified = true;
		m_partial = false;
	} else if (evt == HASH_FATAL_ERROR) {
		logError(
			boost::format("Fatal error hashing file `%s'")
			% c->getFileName()
		);
	}
	--m_parent->m_pendingHashes;
	if (m_parent->isComplete() && !m_parent->m_pendingHashes) {
		m_parent->doComplete();
	}
}
bool PartData::Chunk::isComplete() const {
	return m_parent->isComplete(*this);
}

// PartData class
// --------------
IMPLEMENT_EVENT_TABLE(PartData, PartData*, int);

PartData::PartData(
	uint64_t size,
	const boost::filesystem::path &loc,
	const boost::filesystem::path &dest
) : m_size(size), m_location(loc), m_destination(dest), m_toFlush(), m_md(),
m_pendingHashes() {
	std::ofstream o(loc.string().c_str());
	o.flush();
}

PartData::PartData(const boost::filesystem::path &p) : m_size(), m_location(p),
m_toFlush(), m_md(), m_pendingHashes() {
	using namespace Utils;
	std::ifstream ifs(p.string().c_str(), std::ios::in);
	CHECK_THROW(Utils::getVal<uint8_t>(ifs) == OP_PARTDATA);
	Utils::getVal<uint16_t>(ifs);
	uint8_t ver = Utils::getVal<uint8_t>(ifs);
	if (ver != OP_PD_VER) {
		logWarning("Unknown partdata version.");
	}
	m_size = Utils::getVal<uint64_t>(ifs);
	uint64_t tagc = Utils::getVal<uint16_t>(ifs);
	while (tagc-- && ifs) {
		uint8_t   oc = getVal<uint8_t>(ifs);
		uint16_t len = getVal<uint16_t>(ifs);
		switch (oc) {
			case OP_PD_DESTINATION:
				m_destination = getVal<std::string>(ifs);
				break;
			case OP_PD_COMPLETED:
				if (Utils::getVal<uint8_t>(ifs)!=OP_RANGELIST) {
					logWarning("Invalid tag.");
					ifs.seekg(len, std::ios::cur);
				}
				Utils::getVal<uint16_t>(ifs);
				m_complete = RangeList64(ifs);
				break;
			default:
				logWarning("Unhandled tag in PartData.");
				ifs.seekg(len, std::ios::cur);
				break;
		}
	}
	if (ifs && Utils::getVal<uint8_t>(ifs) == OP_METADATA) {
		Utils::getVal<uint16_t>(ifs);
		m_md = new MetaData(ifs);
		for (uint32_t i = 0; i < m_md->getHashSetCount(); ++i) {
			HashSetBase *hs = m_md->getHashSet(i);
			if (hs->getChunkSize()) {
				addHashSet(hs);
			}
		}
	}
}

void PartData::addSourceMask(
	uint32_t chunkSize, const std::vector<bool> &chunks
) {
	CHECK_THROW(chunks.size() == getChunkCount(chunkSize));
	checkAddChunkMap(chunkSize);
	int i = 0;
	typedef ChunkMap::index<ID_Pos>::type PosIndex;
	typedef PosIndex::iterator Iter;
	PosIndex& pi = m_chunks.get<ID_Pos>();
	for (Iter j = pi.begin(); j != pi.end(); ++j) {
		pi.modify(j, bind(&Chunk::m_avail, __1(__1)) += chunks[i++]);
	}
}

void PartData::addFullSource(uint32_t chunkSize) {
	checkAddChunkMap(chunkSize);
	typedef ChunkMap::index<ID_Pos>::type PosIndex;
	PosIndex& pi = m_chunks.get<ID_Pos>();
	for (PosIndex::iterator i = pi.begin(); i != pi.end(); ++i) {
		pi.modify(i, ++bind(&Chunk::m_avail, __1(__1)));
	}
}
struct TruePred { bool operator()(const Range64 &) { return true; } } truepred;
UsedRangePtr PartData::getRange(uint32_t size) {
	return doGetRange(size, truepred);
}
struct CheckPred {
	bool operator()(const Range64 &r) { return m_rl.containsFull(r); }
	RangeList64 m_rl;
} checkPred;
UsedRangePtr PartData::getRange(uint32_t size, const std::vector<bool> &chunks){
	checkPred.m_rl.clear();
	for (uint32_t i = 0, j = 0; i < chunks.size(); ++i, j+=size) {
		checkPred.m_rl.push(Range64(j, j + size));
	}
	return doGetRange(size, checkPred);
}
template<typename Predicate>
UsedRangePtr PartData::doGetRange(uint64_t, Predicate &pred) {
	boost::shared_ptr<PartData::UsedRange> ret;
	if (!ret) { // Round 1: Incomplete chunks
		typedef ChunkMap::index<ID_Partial>::type PartialIndex;
		typedef PartialIndex::iterator PIter;
		PartialIndex &pi = m_chunks.get<ID_Partial>();
		std::pair<PIter, PIter> r = pi.range(__1 == true, unbounded);
		PIter i = pi.end();
		for (PIter j = r.first; j != r.second; ++j) {
			if (!(*j).m_useCnt && pred(*j)) {
				i = j;
			}
		}
		// partial chunk with useCnt == 0
		if (i != pi.end()) {
			ret = UsedRangePtr(new UsedRange(this, i));
		}
	}
	if (!ret) { // Round 2: Least available unused chunk
		typedef ChunkMap::index<ID_Avail>::type AvailIndex;
		typedef AvailIndex::iterator AIter;
		AvailIndex &ai = m_chunks.get<ID_Avail>();
		AIter i = ai.end();
		for (AIter r = ai.upper_bound(0); r != ai.end(); ++r) {
			if ((*r).isComplete() || (*r).isPartial() || !pred(*r)){
				continue;
			}
			if (!(*r).getUseCnt()) {
				i = r;
				break;
			}
		}
		if (i != ai.end()) {
			ret = UsedRangePtr(new UsedRange(this, i));
		}
	}
	if (!ret) { // Round 3: Least used chunk
		typedef ChunkMap::index<ID_UseCnt>::type UseIndex;
		typedef UseIndex::iterator UIter;
		UseIndex &ui = m_chunks.get<ID_UseCnt>();
		UIter r = ui.upper_bound(0);
		while (r != ui.end() && !pred(*r)) {
			r == ui.begin() ? r = ui.end() : --r;
		}
		if (r != ui.end()) {
			ret = UsedRangePtr(new UsedRange(this, r));
		}
	}
	if (!ret) {
		throw RangeError("Failed to generate chunk request.");
	} else if (!pred(*ret)) {
		throw RangeError(
			"Internal PartData error while "
			"generating chunk request."
		);
	}
	return ret;
}

uint32_t PartData::getChunkCount(uint32_t chunkSize) const {
	return m_size / chunkSize + (m_size % chunkSize ? 1 : 0);
}

void PartData::checkAddChunkMap(uint32_t cs) {
	typedef ChunkMap::index<ID_Length>::type::iterator LIter;
	std::pair<LIter, LIter> ret = m_chunks.get<ID_Length>().equal_range(cs);
	if (ret.first == m_chunks.get<ID_Length>().end()) {
		for (uint32_t i = 0; i < getChunkCount(cs); ++i) {
			Chunk c(this, i * cs, (i + 1) * cs - 1);
			if (c.end() >= m_size) {
				c.end(m_size - 1);
			}
			m_chunks.insert(c);
		}
	}
}
void PartData::addHashSet(const HashSetBase *hs) {
	CHECK_THROW(hs->getChunkSize() > 0);
	typedef ChunkMap::index<ID_Length>::type::iterator LIter;
	uint32_t cs = hs->getChunkSize();
	std::pair<LIter, LIter> ret = m_chunks.get<ID_Length>().equal_range(cs);
	if (ret.first == m_chunks.get<ID_Length>().end()) {
		checkAddChunkMap(cs);
		ret = m_chunks.get<ID_Length>().equal_range(cs);
	}
	int cc = hs->getChunkCnt();
	CHECK_THROW(std::distance(ret.first, ret.second) + 1 == cc);
	uint32_t j = 0;
	for (LIter i = ret.first; i != ret.second; ++i) {
		m_chunks.get<ID_Length>().modify(
			i, bind(&Chunk::m_hash, __1(__1)) = &(*hs)[j++]
		);
	}
}

bool PartData::isComplete() const {
	return m_complete.size() == 1 &&
		!(*m_complete.begin()).begin() &&
		(*m_complete.begin()).end() == m_size - 1;
}
bool PartData::isComplete(const Range64 &r) const {
	return m_complete.containsFull(r);
}
bool PartData::isComplete(uint64_t begin, uint64_t end) const {
	return m_complete.containsFull(begin, end);
}
LockedRangePtr PartData::getLock(UsedRangePtr used, uint32_t size) {
	Range64 cand(used->begin(), used->begin());
	typedef RangeList64::CIter CIter;
	do {
		CIter i = m_complete.getContains(cand);
		CIter j = m_locked.getContains(cand);
		if (i == m_complete.end() && j == m_locked.end()) {
			break;
		} else {
			while (i != m_complete.end()) {
				cand = Range64((*i).end() + 1, (*i).end() + 1);
				i = m_complete.getContains(cand);
			}
			while (j != m_locked.end()) {
				cand = Range64((*j).end() + 1, (*j).end() + 1);
				j = m_locked.getContains(cand);
			}
		}
	} while (cand.end() - 1 <= m_size && used->contains(cand));
	if (!used->contains(cand)) {
		throw LockError("Unable to aquire lock within this UsedRange.");
	}
	if (cand.end() <= used->end()) {
		if (cand.end() + size > used->end()) {
			cand.end(used->end());
		} else {
			cand.end(cand.begin() + size);
		}
	}
	CIter i = m_complete.getContains(cand);
	if (i != m_complete.end()) {
		cand.end((*i).begin() - 1);
	}
	CIter j = m_locked.getContains(cand);
	if (j != m_locked.end()) {
		cand.end((*j).begin() - 1);
	}
	return LockedRangePtr(new LockedRange(this, cand, used->m_chunk));
}
void PartData::write(uint64_t begin, const std::string &data) {
	logDebug(boost::format("Safe-writing at offset %d.") % begin);
	CHECK_THROW(!m_locked.contains(begin, begin + data.size() - 1));
	CHECK_THROW(!m_complete.contains(begin, begin + data.size() - 1));
	doWrite(begin, data);
	typedef ChunkMap::index<ID_Pos>::type::iterator PIter;
	typedef ChunkMap::index<ID_Pos>::type PosIndex;
	PosIndex &pi = m_chunks.get<ID_Pos>();
	Range64 tmp(begin, begin + data.size() - 1);
	PIter i = pi.lower_bound(Chunk(this, tmp));
	if (i != pi.end() && (*i).contains(tmp)) {
		pi.modify(i, bind(&Chunk::m_partial, __1(__1)) = true);
	} else if (i != pi.begin() && (*--i).contains(tmp)) {
		pi.modify(i, bind(&Chunk::m_partial, __1(__1)) = true);
	} else if (++i != pi.end() && (*i).contains(tmp)) {
		pi.modify(i, bind(&Chunk::m_partial, __1(__1)) = true);
	}
	if (isComplete() && !m_fullJob) {
		doComplete();
	}
}
void PartData::doWrite(uint64_t begin, const std::string &data) {
	logDebug(
		boost::format("Writing at offset %d, datasize is %d")
		% begin % data.size()
	);
	CHECK_THROW(!m_complete.contains(begin, begin + data.size() - 1));
	m_buffer[begin] = data;
	m_complete.merge(begin, begin + data.size() - 1);
	m_toFlush += data.size();
	getEventTable().postEvent(this, PD_DATA_ADDED);
	if (m_toFlush >= BUF_SIZE_LIMIT) {
		flushBuffer();
	}
	if (isComplete() && !m_pendingHashes) {
		doComplete();
	}
}
void PartData::flushBuffer() {
	logDebug("Flushing buffers.");
	std::ofstream ofs(m_location.string().c_str(), std::ios::app);
	for (BIter i = m_buffer.begin(); i != m_buffer.end(); ++i) {
		ofs.seekp((*i).first);
		ofs.write((*i).second.c_str(), (*i).second.length());
	}
	m_buffer.clear();
	m_toFlush = 0;
	getEventTable().postEvent(this, PD_DATA_FLUSHED);
}
void PartData::onHashEvent(HashWorkPtr p, HashEvent evt) {
	using boost::logic::tribool;
	if (evt == HASH_FATAL_ERROR) {
		logError("Fatal error performing final rehash on PartData.");
	} else if (evt != HASH_COMPLETE) {
		logDebug(boost::format(
			"PartData received unknown event %d.") % evt
		);
		return;
	}
	CHECK_THROW(p->getMetaData());
	MetaData *ref = p->getMetaData();
	uint32_t failed = 0; uint32_t ok = 0; uint32_t notfound = 0;
	for (uint32_t i = 0; i < ref->getHashSetCount(); ++i) {
		tribool ret = verifyHashSet(ref->getHashSet(i));
		boost::indeterminate(ret) ? ++notfound : ret ? ++ok : ++failed;
	}
	boost::format fmt("Verifying file %s: %d Ok, %d Failed, %d NotFound.");
	logMsg(fmt % m_location.string() % ok % failed % notfound);
	if (!failed) {
		getEventTable().postEvent(this, PD_COMPLETE);
	}
}
boost::logic::tribool PartData::verifyHashSet(const HashSetBase *ref) {
	CHECK_THROW(m_md);
	for (uint32_t j = 0; j < m_md->getHashSetCount(); ++j) {
		const HashSetBase *orig = m_md->getHashSet(j);
		if (ref->getFileHashTypeId() != orig->getFileHashTypeId()) {
			continue;
		}
		if (ref->getChunkHashTypeId() != orig->getChunkHashTypeId()) {
			continue;
		}
		if (*orig == *ref) {
			return true;
		} else {
			return false;
		}
	}
	return boost::indeterminate;
}
void PartData::printCompleted() {
	boost::format fmt("%s/%s (%.2f%%) complete of `%s'");
	uint64_t complete = 0;
	for_each(
		m_complete.begin(), m_complete.end(),
		complete += bind(&Range64::length, __1)
	);
	fmt % Utils::bytesToString(complete);
	fmt % Utils::bytesToString(m_size);
	fmt % (complete * 100.0 / m_size);
	logMsg(fmt % m_destination.leaf());
	logMsg("Completed ranges are:");
	for (RangeList64::CIter i = m_complete.begin();i!=m_complete.end();++i){
		logMsg(boost::format("%d..%d") % (*i).begin() % (*i).end());
	}
}
void PartData::setMetaData(MetaData *md) {
	CHECK_THROW(md);
	m_md = md;
	for (uint32_t i = 0; i < m_md->getHashSetCount(); ++i) {
		if (m_md->getHashSet(i)->getChunkSize()) {
			addHashSet(m_md->getHashSet(i));
		}
	}
}
void PartData::doComplete() {
	CHECK_THROW(isComplete());
	HashWorkPtr p(new HashWork(m_location.string()));
	HashWork::getEventTable().addHandler(p, this, &PartData::onHashEvent);
	WorkThread::instance().postWork(p);
	m_fullJob = p;
}