/*
** JNetLib
** Copyright (C) 2000-2007 Nullsoft, Inc.
** Author: Justin Frankel
** File: connection.cpp - JNL TCP connection implementation
** License: see jnetlib.h
*/

#include "netinc.h"
#include "util.h"

#include "wac_network_connection.h"
#include "wac_network_dns.h"

#include "foundation\error.h"

#include "../nu/strsafe.h"
#pragma intrinsic(memcpy, memset)

WAC_Network_Connection::WAC_Network_Connection()
{
	init();
}

WAC_Network_Connection::WAC_Network_Connection( api_dns *dns, size_t sendbufsize, size_t recvbufsize )
{
	init();
	open( dns, sendbufsize, recvbufsize );
}


void WAC_Network_Connection::init()
{
	m_errorstr    = 0;
	address       = 0;
	m_dns         = NULL;
	m_dns_owned   = false;
	m_socket      = -1;
	m_remote_port = 0;
	m_state       = STATE_NOCONNECTION;
	m_host[ 0 ]   = 0;
	saddr         = 0;
}

WAC_Network_Connection::~WAC_Network_Connection()
{
	/*
	**  Joshua Teitelbaum 1/27/2006
	**  virtualization for ssl, calling socket_shtudown()
	*/
	socket_shutdown();

	if ( !saddr ) // free it if it was passed to us (by JNL_Listen, presumably)
		free( address ); // TODO: change this if we ever do round-robin DNS connecting or in any way change how we handle 'address'

	if ( m_dns_owned )
		delete static_cast<wa::Components::WAC_Network_AsyncDNS *>( m_dns );
}

void WAC_Network_Connection::set_dns( api_dns *dns )
{
	if ( m_dns_owned )
		delete static_cast<wa::Components::WAC_Network_AsyncDNS *>( m_dns );

	m_dns       = dns;
	m_dns_owned = false;
}

void WAC_Network_Connection::open( api_dns *dns, size_t sendbufsize, size_t recvbufsize )
{
	if ( dns != API_DNS_AUTODNS && dns )
	{
		m_dns       = dns;
		m_dns_owned = false;
	}
	else if ( !m_dns )
	{
		m_dns       = new wa::Components::WAC_Network_AsyncDNS;
		m_dns_owned = true;
	}

	recv_buffer.reserve( recvbufsize );
	send_buffer.reserve( sendbufsize );
}

void WAC_Network_Connection::connect( SOCKET s, sockaddr *addr, socklen_t length )
{
	close( 1 );
	m_socket = s;
	address = (sockaddr *)malloc( length );
	memcpy( address, addr, length );

	m_remote_port = 0;
	if ( m_socket != -1 )
	{
		SET_SOCK_BLOCK( m_socket, 0 );
		m_state = STATE_CONNECTED;
	}
	else
	{
		m_errorstr = _strdup( "invalid socket passed to connect" );
		m_state = STATE_ERROR;
	}

}

void WAC_Network_Connection::connect( const char *hostname, int port )
{
	close( 1 );
	m_remote_port = (unsigned short)port;

#ifdef _WIN32
	lstrcpynA( m_host, hostname, sizeof( m_host ) );
#elif defined(__APPLE__)
	strlcpy( m_host, hostname, sizeof( m_host ) );
#else
	strncpy( m_host, hostname, sizeof( m_host ) - 1 );
	m_host[ sizeof( m_host ) - 1 ] = 0;
#endif


	//memset(&m_saddr,0,sizeof(m_saddr));
	if ( !m_host[ 0 ] )
	{
		m_errorstr = _strdup( "empty hostname" );
		m_state = STATE_ERROR;
	}
	else
	{
		m_state = STATE_RESOLVING;
	}
}

/*
**  Joshua Teitelbaum 1/27/2006
**  socket_shutdown
**  virtualization for ssl
*/
/* Virtual */
void WAC_Network_Connection::socket_shutdown()
{
	if ( m_socket >= 0 )
	{
		::shutdown( m_socket, SHUT_RDWR );
		::closesocket( m_socket );

		m_socket = -1;
	}
}
/*
**  Joshua Teitelbaum 1/27/2006
**  socket_recv
**  virtualization for ssl
*/
/* Virtual */
ssize_t WAC_Network_Connection::socket_recv( char *buf, size_t len, int options )
{
	return ::recv( m_socket, buf, (int)len, options );
}
/*
**  Joshua Teitelbaum 1/27/2006
**  socket_send
**  virtualization for ssl
*/
/* Virtual */
ssize_t WAC_Network_Connection::socket_send( const char *buf, size_t len, int options )
{
	return ::send( m_socket, buf, (int)len, options );
}

int WAC_Network_Connection::socket_connect()
{
	return ::connect( m_socket, saddr->ai_addr, (int)saddr->ai_addrlen );
}

void WAC_Network_Connection::run( size_t max_send_bytes, size_t max_recv_bytes, size_t *bytes_sent, size_t *bytes_rcvd )
{
	socklen_t socket_buffer_size = 0;
	socklen_t socket_buffer_size_len = sizeof( socket_buffer_size );
	socklen_t send_buffer_size;
	socklen_t recv_buffer_size;

	size_t bytes_allowed_to_send = ( max_send_bytes == (size_t)-1 ) ? send_buffer.size() : max_send_bytes;
	size_t bytes_allowed_to_recv = ( max_recv_bytes == (size_t)-1 ) ? recv_buffer.avail() : max_recv_bytes;

	if ( bytes_sent )
		*bytes_sent = 0;

	if ( bytes_rcvd )
		*bytes_rcvd = 0;

	switch ( m_state )
	{
		case STATE_RESOLVING:
			if ( saddr == 0 )
			{
				int a = m_dns->resolve( m_host, m_remote_port, &saddr, SOCK_STREAM );
				if ( !a )
				{
					m_state = STATE_RESOLVED;
				}
				else if ( a == 1 )
				{
					m_state = STATE_RESOLVING;
					break;
				}
				else
				{
					m_errorstr = _strdup( "resolving hostname" );
					m_state = STATE_ERROR;

					return;
				}
			}
			// fall through
		case STATE_RESOLVED:
			m_socket = ::socket( saddr->ai_family, saddr->ai_socktype, saddr->ai_protocol );
			if ( m_socket == -1 )
			{
				m_errorstr = _strdup( "creating socket" );
				m_state = STATE_ERROR;
			}
			else
			{
				SET_SOCK_BLOCK( m_socket, 0 );
			}

			socket_buffer_size = 0;
			socket_buffer_size_len = sizeof( socket_buffer_size );
			getsockopt( m_socket, SOL_SOCKET, SO_SNDBUF, (char *)&socket_buffer_size, &socket_buffer_size_len );
			send_buffer_size = (int)( send_buffer.avail() + send_buffer.size() );
			if ( send_buffer_size > 65536 )
				send_buffer_size = 65536;
			if ( socket_buffer_size < send_buffer_size )
				setsockopt( m_socket, SOL_SOCKET, SO_SNDBUF, (char *)&send_buffer_size, sizeof( send_buffer_size ) );
			getsockopt( m_socket, SOL_SOCKET, SO_SNDBUF, (char *)&socket_buffer_size, &socket_buffer_size_len );

			getsockopt( m_socket, SOL_SOCKET, SO_RCVBUF, (char *)&socket_buffer_size, &socket_buffer_size_len );
			recv_buffer_size = (int)recv_buffer.avail();
			if ( recv_buffer_size > 65536 )
				recv_buffer_size = 65536;
			if ( socket_buffer_size < recv_buffer_size )
				setsockopt( m_socket, SOL_SOCKET, SO_RCVBUF, (char *)&recv_buffer_size, sizeof( recv_buffer_size ) );
			getsockopt( m_socket, SOL_SOCKET, SO_RCVBUF, (char *)&socket_buffer_size, &socket_buffer_size_len );

			/*
			**  Joshua Teitelbaum 1/27/2006
			**  virtualization for ssl
			*/
			if ( !socket_connect() )
			{
				address = saddr->ai_addr;
				m_state = STATE_CONNECTED;

				on_socket_connected();
			}
			else if ( ERRNO != EINPROGRESS )
			{
				m_errorstr = _strdup( "Connecting to host" );
				m_state = STATE_ERROR;
			}
			else
			{
				m_state = STATE_CONNECTING;
			}
			break;
		case STATE_CONNECTING:
		{
			fd_set f[ 3 ];
			FD_ZERO( &f[ 0 ] );
			FD_ZERO( &f[ 1 ] );
			FD_ZERO( &f[ 2 ] );
			FD_SET( m_socket, &f[ 0 ] );
			FD_SET( m_socket, &f[ 1 ] );
			FD_SET( m_socket, &f[ 2 ] );
			struct timeval tv;
			memset( &tv, 0, sizeof( tv ) );
			if ( select( (int)m_socket + 1, &f[ 0 ], &f[ 1 ], &f[ 2 ], &tv ) == -1 )
			{
				m_errorstr = _strdup( "Connecting to host (calling select())" );
				m_state = STATE_ERROR;
			}
			else if ( FD_ISSET( m_socket, &f[ 1 ] ) )
			{
				m_state = STATE_CONNECTED;
				on_socket_connected();
			}
			else if ( FD_ISSET( m_socket, &f[ 2 ] ) )
			{
				m_errorstr = _strdup( "Connecting to host" );
				m_state = STATE_ERROR;
			}
		}
		break;
		case STATE_CONNECTED:
		case STATE_CLOSING:
			/* --- send --- */
		{
			size_t sent = send_buffer.drain( this, bytes_allowed_to_send );
			if ( bytes_sent )
				*bytes_sent += sent;

			if ( m_state == STATE_CLOSED )
				break;

			/* --- receive --- */
			size_t received = recv_buffer.fill( this, bytes_allowed_to_recv );
			if ( bytes_rcvd )
				*bytes_rcvd += received;
		}

		if ( m_state == STATE_CLOSING )
		{
			if ( send_buffer.empty() ) m_state = STATE_CLOSED;
		}
		break;
		default:
			break;
	}
}

void WAC_Network_Connection::on_socket_connected( void )
{
	return;
}

void WAC_Network_Connection::close( int quick )
{
	if ( quick || m_state == STATE_RESOLVING || m_state == STATE_CONNECTING )
	{
		m_state = STATE_CLOSED;
		/*
		**  Joshua Teitelbaum 1/27/2006
		**  virualization for ssl
		*/
		socket_shutdown();

		m_socket = -1;

		recv_buffer.clear();
		send_buffer.clear();

		m_remote_port = 0;
		m_host[ 0 ] = 0;
		//memset(&m_saddr,0,sizeof(m_saddr));
	}
	else
	{
		if ( m_state == STATE_CONNECTED )
			m_state = STATE_CLOSING;
	}
}

size_t WAC_Network_Connection::send_bytes_in_queue( void )
{
	return send_buffer.size();
}

size_t WAC_Network_Connection::send_bytes_available( void )
{
	return send_buffer.avail();
}

int WAC_Network_Connection::send( const void *data, size_t length )
{
	if ( length > send_bytes_available() )
		return -1;

	send_buffer.write( data, length );
	return 0;
}

int WAC_Network_Connection::send_string( const char *line )
{
	return send( line, strlen( line ) );
}

size_t WAC_Network_Connection::recv_bytes_available( void )
{
	return recv_buffer.size();
}

size_t WAC_Network_Connection::peek_bytes( void *data, size_t maxlength )
{
	if ( data )
		return recv_buffer.peek( data, maxlength );
	else
		return min( maxlength, recv_bytes_available() );
}

size_t WAC_Network_Connection::recv_bytes( void *data, size_t maxlength )
{
	if ( data )
		return recv_buffer.read( data, maxlength );
	else
		return recv_buffer.advance( maxlength );
}

int WAC_Network_Connection::recv_lines_available( void )
{
	int l = (int)recv_bytes_available();
	int lcount = 0;
	int lastch = 0;

	for ( int pos = 0; pos < l; pos++ )
	{
		char t;
		if ( recv_buffer.at( pos, &t, 1 ) != 1 )
			return lcount;

		if ( ( t == '\r' || t == '\n' ) && ( ( lastch != '\r' && lastch != '\n' ) || lastch == t ) )
			lcount++;

		lastch = t;
	}

	return lcount;
}

int WAC_Network_Connection::recv_line( char *line, size_t maxlength )
{
	while ( maxlength-- )
	{
		char t;
		if ( recv_buffer.read( &t, 1 ) == 0 )
		{
			*line = 0;
			return 0;
		}

		if ( t == '\r' || t == '\n' )
		{
			char r;
			if ( recv_buffer.peek( &r, 1 ) != 0 )
			{
				if ( ( r == '\r' || r == '\n' ) && r != t )
					recv_buffer.advance( 1 );
			}

			*line = 0;
			return 0;

		}

		*line++ = t;
	}

	return 1;
}

unsigned long WAC_Network_Connection::get_interface( void )
{
	if ( m_socket == -1 )
		return 0;

	struct sockaddr_in sin;
	memset( &sin, 0, sizeof( sin ) );
	socklen_t len = sizeof( sin );

	if ( ::getsockname( m_socket, (struct sockaddr *)&sin, &len ) )
		return 0;

	return (unsigned long)sin.sin_addr.s_addr;
}

unsigned long WAC_Network_Connection::get_remote()
{
	// TODO: IPv6
	if ( address )
	{
		sockaddr_in *ipv4 = (sockaddr_in *)address;
		return ipv4->sin_addr.s_addr;
	}

	return 0;

}

unsigned short WAC_Network_Connection::get_remote_port()
{
	return m_remote_port;
}

/* RingBuffer client function */
size_t WAC_Network_Connection::Read( void *dest, size_t len )
{
	if ( !len )
		return 0;

	int res = (int)socket_recv( (char *)dest, len, 0 );

	if ( res == 0 || ( res < 0 && ERRNO != EWOULDBLOCK ) )
	{
		m_state = STATE_CLOSED;
		return 0;
	}

	if ( res > 0 )
		return res;
	else
		return 0;
}

/* RingBuffer client function */
size_t WAC_Network_Connection::Write( const void *dest, size_t len )
{
	if ( !len )
		return 0;

	int res = (int)socket_send( (const char *)dest, len, 0 );

	if ( res == -1 && ERRNO != EWOULDBLOCK )
	{
		return 0;
		//              m_state=STATE_CLOSED;
	}

	if ( res > 0 )
		return res;
	else
		return 0;
}

int WAC_Network_Connection::set_recv_buffer_size( size_t new_buffer_size )
{
	return recv_buffer.expand( new_buffer_size );
}

void WAC_Network_Connection::reuse()
{
	if ( m_state == STATE_CLOSED )
	{
		m_state = STATE_CONNECTED;
		recv_buffer.clear();
	}
}

#define CBCLASS WAC_Network_Connection
START_DISPATCH;
VCB( API_CONNECTION_OPEN,                     open )
case API_CONNECTION_CONNECT:                  connect( *(char **)( params[ 0 ] ), *(int *)( params[ 1 ] ) ); return 1;
VCB( API_CONNECTION_RUN,                      run )
CB(  API_CONNECTION_GETSTATE,                 get_state )
CB(  API_CONNECTION_GETERROR,                 get_errstr )
VCB( API_CONNECTION_CLOSE,                    close )
VCB( API_CONNECTION_FLUSHSEND,                flush_send )
CB(  API_CONNECTION_GETSENDBYTESINQUEUE,      send_bytes_in_queue )
CB(  API_CONNECTION_GETSENDBYTESAVAILABLE,    send_bytes_available )
CB(  API_CONNECTION_SEND,                     send )
CB(  API_CONNECTION_SENDBYTES,                send_bytes )
CB(  API_CONNECTION_SENDSTRING,               send_string )
CB(  API_CONNECTION_GETRECEIVEBYTESAVAILABLE, recv_bytes_available )
CB(  API_CONNECTION_RECEIVEBYTES,             recv_bytes )
CB(  API_CONNECTION_GETRECEIVELINESAVAILABLE, recv_lines_available )
CB(  API_CONNECTION_RECEIVELINE,              recv_line )
CB(  API_CONNECTION_PEEKBYTES,                peek_bytes )
CB(  API_CONNECTION_GETINTERFACE,             get_interface )
CB(  API_CONNECTION_GETREMOTEADDRESS,         get_remote )
CB(  API_CONNECTION_GETREMOTEPORT,            get_remote_port )
END_DISPATCH;
#undef CBCLASS