#include "compat.h"
#include "md4.h"
#include "net.h"

#ifdef HAVE_SIGNAL_H
#include <signal.h>
#endif
#ifndef HAVE_LIMITS_H
#include <limits.h>
#endif

#ifndef INT_MAX
#define INT_MAX (0x7fffffff)
#endif

#ifndef O_LARGEFILE
#define O_LARGEFILE 0
#endif

uint64_t net_bytesout = 0;
uint64_t net_bytesin = 0;

extern char *optarg;
extern int optind, opterr, optopt;

static int print_help(int ret) {
	fprintf(stderr, "Usage: file_sync -t <host> <port> <file> [<blocksize>]\n");
	fprintf(stderr, "Usage: file_sync -r <port> <file>\n");
	return(ret);
}

#if 0
static void print_hash(FILE *fp, const unsigned char *hash) {
	int i;

	for (i = 0; i < 16; i++) {
		fprintf(fp, "%x", hash[i] >> 4);
		fprintf(fp, "%x", hash[i] & 0xf);
	}

	return;
}
#endif

static ssize_t read_large(int fd, char *buf, size_t count) {
	size_t bytestoread;
	ssize_t read_ret;
	ssize_t retval = 0;

	while (count) {
		bytestoread = count;

		if (bytestoread >= INT_MAX) {
			bytestoread = INT_MAX - 1;
		}

		read_ret = read(fd, buf, bytestoread);
		if (read_ret < 0) {
			/*
			 * Only indicate an error if we have not read anything
			 * into the buffer.  Subsequent calls to read() should
			 * generate the error again.  Hopefully.
			 */
			if (retval == 0) {
				retval = -1;
			}
			break;
		}

		if (read_ret == 0) {
			break;
		}

		count -= read_ret;
		buf += read_ret;
		retval += read_ret;
	}

	return(retval);
}

static ssize_t write_large(int fd, char *buf, size_t count) {
	size_t bytestowrite;
	ssize_t write_ret;
	ssize_t retval = 0;

	while (count) {
		bytestowrite = count;

		if (bytestowrite >= INT_MAX) {
			bytestowrite = INT_MAX - 1;
		}

		write_ret = write(fd, buf, bytestowrite);
		if (write_ret < 0) {
			if (retval == 0) {
				retval = -1;
			}
			break;
		}

		if (write_ret == 0) {
			break;
		}

		count -= write_ret;
		buf += write_ret;
		retval += write_ret;
	}

	return(retval);
}

static ssize_t read_large_net(int fd, void *buf, size_t count) {
	ssize_t retval;

	retval = read_large(fd, buf, count);

	if (retval > 0) {
		net_bytesin += retval;
	}

	return(retval);
}

static ssize_t write_large_net(int fd, void *buf, size_t count) {
	ssize_t retval;

	retval = write_large(fd, buf, count);
	if (retval > 0) {
		net_bytesout += retval;
	}

	return(retval);
}

int sync_transmit(const char *host, int port, const char *file, uint64_t blocksize) {
	rsaref_MD4_CTX mdctx;
	unsigned char md4buf[16];
	uint64_t filesize;
	uint32_t blockok_size, blockidx;
	uint8_t *blockok, blockok_val;
	ssize_t read_ret, write_ret;
	char *buf;
	off_t lseek_ret;
	off_t curpos, destpos;
	int retval = 0;
	int sockfd;
	int fd;
	int i;

	fd = open(file, O_RDONLY | O_LARGEFILE);
	if (fd < 0) {
		CHECKPOINT;
		return(-1);
	}

	lseek_ret = lseek(fd, 0, SEEK_END);
	if (lseek_ret < 0) {
		close(fd);
		CHECKPOINT;
		return(-1);
	}

	filesize = lseek_ret;

	blockok_size = filesize / (8 * blocksize);
	if ((filesize % (8 * blocksize)) != 0) {
		blockok_size++;
	}

	blockok = malloc(blockok_size);
	if (!blockok) {
		close(fd);
		CHECKPOINT;
		return(-1);
	}

	lseek(fd, 0, SEEK_SET);

	sockfd = net_connect_tcp(host, port);
	if (sockfd < 0) {
		close(fd);
		CHECKPOINT;
		return(-1);
	}

	buf = calloc(blocksize, 1);
	if (!buf) {
		close(fd);
		close(sockfd);
		CHECKPOINT;
		return(-1);
	}

	filesize = htonll(filesize);
	blocksize = htonl(blocksize);

	write_ret = write_large_net(sockfd, &filesize, sizeof(filesize));
	if (write_ret != sizeof(filesize)) {
		close(fd);
		close(sockfd);
		CHECKPOINT;
		return(-1);
	}

	write_ret = write_large_net(sockfd, &blocksize, sizeof(blocksize));
	if (write_ret != sizeof(blocksize)) {
		close(fd);
		close(sockfd);
		CHECKPOINT;
		return(-1);
	}

	filesize = ntohll(filesize);
	blocksize = ntohl(blocksize);

	while (1) {
		read_ret = read_large(fd, buf, blocksize);
		if (read_ret < 0) {
			retval = -1;
			CHECKPOINT;
			break;
		}

		if (read_ret == 0) {
			CHECKPOINT;
			break;
		}

		rsaref_MD4Init(&mdctx);
		rsaref_MD4Update(&mdctx, buf, read_ret);
		rsaref_MD4Final(md4buf, &mdctx);

		write_ret = write_large_net(sockfd, md4buf, sizeof(md4buf));
		if (write_ret != sizeof(md4buf)) {
			retval = -1;
			CHECKPOINT;
			break;
		}
	}

	if (retval != -1) {
		read_ret = read_large_net(sockfd, blockok, blockok_size);

		if (read_ret != blockok_size) {
			retval = -1;
			CHECKPOINT;
		}
	}

	lseek(fd, 0, SEEK_SET);
	curpos = 0;

	if (retval != -1) {
		for (blockidx = 0; blockidx < blockok_size; blockidx++) {
			for (i = 0; i < 8; i++) {
				blockok_val = blockok[blockidx] & (1<<i);
				if (!blockok_val) {

					destpos = (blockidx * 8 + i) * blocksize;

					if (curpos != destpos) {
						lseek_ret = lseek(fd, destpos, SEEK_SET);
						if (lseek_ret != destpos) {
							retval = -1;
							CHECKPOINT;
							break;
						}
						curpos = lseek_ret;
					}

					lseek_ret = curpos;

					if (lseek_ret >= filesize) {
						continue;
					}

					read_ret = read_large(fd, buf, blocksize);
					if (read_ret != blocksize) {
						if ((lseek_ret + read_ret) != filesize || read_ret < 0) {
							SPOTVAR_I(blockidx);
							SPOTVAR_LLU(lseek_ret);
							SPOTVAR_LLU(read_ret);
							SPOTVAR_LLU(filesize);
							retval = -1;
							CHECKPOINT;
							break;
						}
					}

					curpos += read_ret;

					write_ret = write_large_net(sockfd, buf, read_ret);
					if (write_ret != read_ret) {
						SPOTVAR_LLU(write_ret);
						SPOTVAR_LLU(read_ret);
						retval = -1;
						CHECKPOINT;
						break;
					}
				}
			}
			if (retval == -1) {
				break;
			}
		}
	}

	close(fd);
	close(sockfd);

	return(retval);
}

int sync_receive(int port, const char *file) {
	rsaref_MD4_CTX mdctx;
	unsigned char md4buf[16], check_md4buf[16];
	uint64_t filesize, filesize_s;
	uint32_t blocksize;
	uint32_t blockok_size, blockidx = 0;
	uint8_t *blockok, blockok_val;
	ssize_t cur_read_ret, read_ret, write_ret;
	size_t bytestowrite;
	off_t lseek_ret;
	off_t curpos, destpos;
	char *buf;
	int sockfd, master_sockfd;
	int retval = 0;
	int fd;
	int skipbytes;
	int i = 0;

	fd = open(file, O_RDWR | O_CREAT | O_LARGEFILE, 0600);
	if (fd < 0) {
		CHECKPOINT;
		return(-1);
	}

	master_sockfd = net_listen(port);
	if (master_sockfd < 0) {
		close(fd);
		CHECKPOINT;
		return(-1);
	}

	sockfd = accept(master_sockfd, NULL, 0);
	if (sockfd < 0) {
		close(fd);
		close(master_sockfd);
		CHECKPOINT;
		return(-1);
	}

	read_ret = read_large_net(sockfd, &filesize, sizeof(filesize));
	if (read_ret != sizeof(filesize)) {
		close(fd);
		close(sockfd);
		close(master_sockfd);
		CHECKPOINT;
		return(-1);
	}

	read_ret = read_large_net(sockfd, &blocksize, sizeof(blocksize));
	if (read_ret != sizeof(blocksize)) {
		close(fd);
		close(sockfd);
		close(master_sockfd);
		CHECKPOINT;
		return(-1);
	}

	filesize_s = filesize = ntohll(filesize);
	blocksize = ntohl(blocksize);

	blockok_size = filesize / (8 * blocksize);
	if ((filesize % (8 * blocksize)) != 0) {
		blockok_size++;
	}

	blockok = malloc(blockok_size);
	if (!blockok) {
		close(fd);
		close(sockfd);
		close(master_sockfd);
		CHECKPOINT;
		return(-1);
	}

	buf = calloc(blocksize, 1);
	if (!buf) {
		close(fd);
		close(sockfd);
		close(master_sockfd);
		CHECKPOINT;
		return(-1);
	}

	while (filesize) {
		cur_read_ret = read_ret = read_large(fd, buf, blocksize);
		if (read_ret < 0) {
			retval = -1;
			CHECKPOINT;
			break;
		}

		bytestowrite = blocksize;
		if (bytestowrite > filesize) {
			bytestowrite = filesize;
		}

		if (read_ret != blocksize) {
			memset(buf, '\0', bytestowrite);
			read_ret = bytestowrite;
			skipbytes = bytestowrite;
		} else {
			skipbytes = 0;
		}

		rsaref_MD4Init(&mdctx);
		rsaref_MD4Update(&mdctx, buf, read_ret);
		rsaref_MD4Final(md4buf, &mdctx);

		read_ret = read_large_net(sockfd, check_md4buf, sizeof(check_md4buf));
		if (read_ret != sizeof(check_md4buf)) {
			retval = -1;
			CHECKPOINT;
			break;
		}

		if (memcmp(check_md4buf, md4buf, sizeof(md4buf)) == 0) {
			blockok_val = 1;
		} else {
			blockok_val = 0;
		}

		if (i == 0 || i == 8) {
			if (i == 8) {
				blockidx++;
			}
			i = 0;
			blockok[blockidx] = 0;
		}
		blockok[blockidx] |= (blockok_val<<i);
		i++;

		filesize -= bytestowrite;
	}

	if (retval != -1) {
		write_ret = write_large_net(sockfd, blockok, blockok_size);
		if (write_ret != blockok_size) {
			retval = -1;
		}
	}

	lseek(fd, 0, SEEK_SET);
	curpos = 0;

	if (retval != -1) {
		for (blockidx = 0; blockidx < blockok_size; blockidx++) {
			for (i = 0; i < 8; i++) {
				blockok_val = blockok[blockidx] & (1<<i);
				if (!blockok_val) {

					destpos = (blockidx * 8 + i) * blocksize;
					if (curpos != destpos) {
						lseek_ret = lseek(fd, destpos, SEEK_SET);
						if (lseek_ret != destpos) {
							retval = -1;
							CHECKPOINT;
							break;
						}

						curpos = lseek_ret;
					}

					lseek_ret = curpos;

					if (lseek_ret >= filesize_s) {
						SPOTVAR_LLU(lseek_ret);
						SPOTVAR_LLU(filesize_s);
						SPOTVAR_I(i);
						continue;
					}

					read_ret = read_large_net(sockfd, buf, blocksize);
					if (read_ret != blocksize) {
						if ((lseek_ret + read_ret) != filesize_s || read_ret < 0) {
							SPOTVAR_I(blockidx);
							SPOTVAR_LLU(lseek_ret);
							SPOTVAR_LLU(read_ret);
							SPOTVAR_LLU(filesize_s);
							retval = -1;
							CHECKPOINT;
							break;
						}
					}

					write_ret = write_large(fd, buf, read_ret);
					if (write_ret != read_ret) {
						retval = -1;
						CHECKPOINT;
						break;
					}

					curpos += write_ret;
				}
			}
			if (retval == -1) {
				break;
			}
		}
	}

	close(fd);
	close(sockfd);
	close(master_sockfd);

	return(retval);
}

int main(int argc, char **argv) {
	char *mode, *host, *port_str, *file, *blocksize_str;
	uint32_t blocksize;
	int func_ret;
	int retval = EXIT_SUCCESS;
	int port;

	if (argc < 2) {
		return(print_help(EXIT_FAILURE));
	}

#ifdef HAVE_SIGNAL
#ifdef SIGPIPE
	signal(SIGPIPE, SIG_IGN);
#endif
#endif

	mode = argv[1];
	if (strcmp(mode, "-t") == 0) {
		if (argc < 5 || argc > 6) {
			return(print_help(EXIT_FAILURE));
		}

		host = argv[2];
		port_str = argv[3];
		file = argv[4];
		if (argc == 6) {
			blocksize_str = argv[5];
			blocksize = strtoul(blocksize_str, NULL, 0);
		} else {
			blocksize = 256 * 1024;
		}

		port = strtoul(port_str, NULL, 10);

		func_ret = sync_transmit(host, port, file, blocksize); 
		if (func_ret < 0) {
			fprintf(stderr, "Failed.\n");
			retval = 1;
		}
	} else if (strcmp(mode, "-r") == 0) {
		if (argc != 4) {
			return(print_help(EXIT_FAILURE));
		}

		port_str = argv[2];
		file = argv[3];

		port = strtoul(port_str, NULL, 10);

		func_ret = sync_receive(port, file);
		if (func_ret < 0) {
			fprintf(stderr, "Failed.\n");
			retval = 1;
		}
	} else {
		return(print_help(EXIT_FAILURE));
	}

	printf("Bytes In: %llu, Bytes Out: %llu\n", net_bytesin, net_bytesout);

	return(retval);
}
