From 0bf1984d2c9fb3a9dc73303551c18906c3c9482b Mon Sep 17 00:00:00 2001
From: est31 <MTest31@outlook.com>
Date: Wed, 30 Sep 2015 00:38:05 +0200
Subject: [PATCH] Fix some SRP issues

-> Remove memory allocation bugs
-> Merge changes from upstream, enabling customizeable memory allocation
---
 src/client.cpp    |   6 +-
 src/util/auth.cpp |   5 +-
 src/util/srp.cpp  | 193 ++++++++++++++++++++++++++++------------------
 src/util/srp.h    |  31 +++++++-
 4 files changed, 153 insertions(+), 82 deletions(-)

diff --git a/src/client.cpp b/src/client.cpp
index 4a9398f70..b2e21743c 100644
--- a/src/client.cpp
+++ b/src/client.cpp
@@ -1059,8 +1059,10 @@ void Client::startAuth(AuthMechanism chosen_auth_mechanism)
 				m_password.length(), NULL, NULL);
 			char *bytes_A = 0;
 			size_t len_A = 0;
-			srp_user_start_authentication((struct SRPUser *) m_auth_data,
-				NULL, NULL, 0, (unsigned char **) &bytes_A, &len_A);
+			SRP_Result res = srp_user_start_authentication(
+				(struct SRPUser *) m_auth_data, NULL, NULL, 0,
+				(unsigned char **) &bytes_A, &len_A);
+			FATAL_ERROR_IF(res != SRP_OK, "Creating local SRP user failed.");
 
 			NetworkPacket resp_pkt(TOSERVER_SRP_BYTES_A, 0);
 			resp_pkt << std::string(bytes_A, len_A) << based_on;
diff --git a/src/util/auth.cpp b/src/util/auth.cpp
index df8940e87..0c17a9237 100644
--- a/src/util/auth.cpp
+++ b/src/util/auth.cpp
@@ -24,6 +24,7 @@ with this program; if not, write to the Free Software Foundation, Inc.,
 #include "sha1.h"
 #include "srp.h"
 #include "string.h"
+#include "debug.h"
 
 // Get an sha-1 hash of the player's name combined with
 // the password entered. That's what the server uses as
@@ -50,10 +51,11 @@ void getSRPVerifier(const std::string &name,
 	char **bytes_v, size_t *len_v)
 {
 	std::string n_name = lowercase(name);
-	srp_create_salted_verification_key(SRP_SHA256, SRP_NG_2048,
+	SRP_Result res = srp_create_salted_verification_key(SRP_SHA256, SRP_NG_2048,
 		n_name.c_str(), (const unsigned char *)password.c_str(),
 		password.size(), (unsigned char **)salt, salt_len,
 		(unsigned char **)bytes_v, len_v, NULL, NULL);
+	FATAL_ERROR_IF(res != SRP_OK, "Couldn't create salted SRP verifier");
 }
 
 // Get a db-ready SRP verifier
@@ -67,6 +69,7 @@ inline static std::string getSRPVerifier(const std::string &name,
 	size_t len_v;
 	getSRPVerifier(name, password, salt, &salt_len,
 		&bytes_v, &len_v);
+	assert(*salt); // usually, srp_create_salted_verification_key promises us to return SRP_ERR when *salt == NULL
 	std::string ret_val = encodeSRPVerifier(std::string(bytes_v, len_v),
 		std::string(*salt, salt_len));
 	free(bytes_v);
diff --git a/src/util/srp.cpp b/src/util/srp.cpp
index 94426db92..b4af58d62 100644
--- a/src/util/srp.cpp
+++ b/src/util/srp.cpp
@@ -69,6 +69,19 @@ static int g_initialized = 0;
 static unsigned int g_rand_idx;
 static unsigned char g_rand_buff[RAND_BUFF_MAX];
 
+void *(*srp_alloc) (size_t) = &malloc;
+void *(*srp_realloc) (void *, size_t) = &realloc;
+void (*srp_free) (void *) = &free;
+
+void srp_set_memory_functions(
+		void *(*new_srp_alloc) (size_t),
+		void *(*new_srp_realloc) (void *, size_t),
+		void (*new_srp_free) (void *)) {
+	srp_alloc = new_srp_alloc;
+	srp_realloc = new_srp_realloc;
+	srp_free = new_srp_free;
+}
+
 typedef struct
 {
 	mpz_t N;
@@ -171,13 +184,13 @@ static void delete_ng(NGConstant *ng)
 	if (ng) {
 		mpz_clear(ng->N);
 		mpz_clear(ng->g);
-		free(ng);
+		srp_free(ng);
 	}
 }
 
 static NGConstant *new_ng( SRP_NGType ng_type, const char *n_hex, const char *g_hex )
 {
-	NGConstant *ng = (NGConstant *) malloc(sizeof(NGConstant));
+	NGConstant *ng = (NGConstant *) srp_alloc(sizeof(NGConstant));
 	mpz_init(ng->N);
 	mpz_init(ng->g);
 
@@ -367,18 +380,18 @@ static int H_nn(mpz_t result, SRP_HashAlgorithm alg, const mpz_t N, const mpz_t
 	size_t len_n1 = mpz_num_bytes(n1);
 	size_t len_n2 = mpz_num_bytes(n2);
 	size_t nbytes = len_N + len_N;
-	unsigned char *bin = (unsigned char *) malloc(nbytes);
+	unsigned char *bin = (unsigned char *) srp_alloc(nbytes);
 	if (!bin)
 		return 0;
 	if (len_n1 > len_N || len_n2 > len_N) {
-		free(bin);
+		srp_free(bin);
 		return 0;
 	}
 	memset(bin, 0, nbytes);
 	mpz_to_bin(n1, bin + (len_N - len_n1));
 	mpz_to_bin(n2, bin + (len_N + len_N - len_n2));
 	hash( alg, bin, nbytes, buff );
-	free(bin);
+	srp_free(bin);
 	mpz_from_bin(buff, hash_length(alg), result);
 	return 1;
 }
@@ -387,13 +400,13 @@ static int H_ns(mpz_t result, SRP_HashAlgorithm alg, const unsigned char *n, siz
 {
 	unsigned char buff[SHA512_DIGEST_LENGTH];
 	size_t nbytes = len_n + len_bytes;
-	unsigned char *bin = (unsigned char *) malloc(nbytes);
+	unsigned char *bin = (unsigned char *) srp_alloc(nbytes);
 	if (!bin)
 		return 0;
 	memcpy(bin, n, len_n);
 	memcpy(bin + len_n, bytes, len_bytes);
 	hash(alg, bin, nbytes, buff);
-	free(bin);
+	srp_free(bin);
 	mpz_from_bin(buff, hash_length(alg), result);
 	return 1;
 }
@@ -418,23 +431,23 @@ static int calculate_x(mpz_t result, SRP_HashAlgorithm alg, const unsigned char
 static void update_hash_n(SRP_HashAlgorithm alg, HashCTX *ctx, const mpz_t n)
 {
 	size_t len = mpz_num_bytes(n);
-	unsigned char* n_bytes = (unsigned char *) malloc(len);
+	unsigned char* n_bytes = (unsigned char *) srp_alloc(len);
 	if (!n_bytes)
 		return;
 	mpz_to_bin(n, n_bytes);
 	hash_update(alg, ctx, n_bytes, len);
-	free(n_bytes);
+	srp_free(n_bytes);
 }
 
 static void hash_num( SRP_HashAlgorithm alg, const mpz_t n, unsigned char *dest )
 {
 	int nbytes = mpz_num_bytes(n);
-	unsigned char *bin = (unsigned char *) malloc(nbytes);
+	unsigned char *bin = (unsigned char *) srp_alloc(nbytes);
 	if(!bin)
 		return;
 	mpz_to_bin(n, bin);
 	hash(alg, bin, nbytes, dest);
-	free(bin);
+	srp_free(bin);
 }
 
 static void calculate_M(SRP_HashAlgorithm alg, NGConstant *ng, unsigned char *dest,
@@ -510,7 +523,7 @@ static void srp_pcgrandom_seed(srp_pcgrandom *r, unsigned long long int state,
 }
 
 
-static int fill_buff()
+static SRP_Result fill_buff()
 {
 	g_rand_idx = 0;
 
@@ -526,7 +539,7 @@ static int fill_buff()
 	CryptGenRandom(wctx, sizeof(g_rand_buff), (BYTE*) g_rand_buff);
 	CryptReleaseContext(wctx, 0);
 
-	return 1;
+	return SRP_OK;
 
 #else
 	fp = fopen("/dev/urandom", "r");
@@ -535,41 +548,48 @@ static int fill_buff()
 		fread(g_rand_buff, sizeof(g_rand_buff), 1, fp);
 		fclose(fp);
 	} else {
-		srp_pcgrandom *r = (srp_pcgrandom *) malloc(sizeof(srp_pcgrandom));
+		srp_pcgrandom *r = (srp_pcgrandom *) srp_alloc(sizeof(srp_pcgrandom));
+		if (!r)
+			return SRP_ERR;
 		srp_pcgrandom_seed(r, time(NULL) ^ clock(), 0xda3e39cb94b95bdbULL);
 		size_t i = 0;
 		for (i = 0; i < RAND_BUFF_MAX; i++) {
 			g_rand_buff[i] = srp_pcgrandom_next(r);
 		}
+		srp_free(r);
 	}
 #endif
-	return 1;
+	return SRP_OK;
 }
 
-static void mpz_fill_random(mpz_t num)
+static SRP_Result mpz_fill_random(mpz_t num)
 {
 	// was call: BN_rand(num, 256, -1, 0);
 	if (RAND_BUFF_MAX - g_rand_idx < 32)
-		fill_buff();
+		if (fill_buff() != SRP_OK)
+			return SRP_ERR;
 	mpz_from_bin((const unsigned char *) (&g_rand_buff[g_rand_idx]), 32, num);
 	g_rand_idx += 32;
+	return SRP_OK;
 }
 
-static void init_random()
+static SRP_Result init_random()
 {
 	if (g_initialized)
-		return;
-	g_initialized = fill_buff();
+		return SRP_OK;
+	SRP_Result ret = fill_buff();
+	g_initialized = (ret == SRP_OK);
+	return ret;
 }
 
 #define srp_dbg_num(num, text) ;
 /*void srp_dbg_num(mpz_t num, char * prevtext)
 {
 	int len_num = mpz_num_bytes(num);
-	char *bytes_num = (char*) malloc(len_num);
+	char *bytes_num = (char*) srp_alloc(len_num);
 	mpz_to_bin(num, (unsigned char *) bytes_num);
 	srp_dbg_data(bytes_num, len_num, prevtext);
-	free(bytes_num);
+	srp_free(bytes_num);
 
 }*/
 
@@ -579,35 +599,42 @@ static void init_random()
  *
  ***********************************************************************************************************/
 
-void srp_create_salted_verification_key( SRP_HashAlgorithm alg,
+SRP_Result srp_create_salted_verification_key( SRP_HashAlgorithm alg,
 	SRP_NGType ng_type, const char *username_for_verifier,
 	const unsigned char *password, size_t len_password,
 	unsigned char **bytes_s,  size_t *len_s,
 	unsigned char **bytes_v, size_t *len_v,
 	const char *n_hex, const char *g_hex )
 {
+	SRP_Result ret = SRP_OK;
+
 	mpz_t v; mpz_init(v);
 	mpz_t x; mpz_init(x);
 	NGConstant *ng = new_ng(ng_type, n_hex, g_hex);
 
-	if(!ng)
-		goto cleanup_and_exit;
+	if (!ng)
+		goto error_and_exit;
 
-	init_random(); /* Only happens once */
+	if (init_random() != SRP_OK) /* Only happens once */
+		goto error_and_exit;
 
 	if (*bytes_s == NULL) {
-		*len_s = 16;
-		if (RAND_BUFF_MAX - g_rand_idx < 16)
-			fill_buff();
-		*bytes_s = (unsigned char*)malloc(sizeof(char) * 16);
-		memcpy(*bytes_s, &g_rand_buff + g_rand_idx, sizeof(char) * 16);
-		g_rand_idx += 16;
+		size_t size_to_fill = 16;
+		*len_s = size_to_fill;
+		if (RAND_BUFF_MAX - g_rand_idx < size_to_fill)
+			if (fill_buff() != SRP_OK)
+				goto error_and_exit;
+		*bytes_s = (unsigned char*)srp_alloc(size_to_fill);
+		if (!*bytes_s)
+			goto error_and_exit;
+		memcpy(*bytes_s, &g_rand_buff + g_rand_idx, size_to_fill);
+		g_rand_idx += size_to_fill;
 	}
 
 
 	if (!calculate_x(x, alg, *bytes_s, *len_s, username_for_verifier,
 			password, len_password))
-		goto cleanup_and_exit;
+		goto error_and_exit;
 
 	srp_dbg_num(x, "Server calculated x: ");
 
@@ -615,10 +642,10 @@ void srp_create_salted_verification_key( SRP_HashAlgorithm alg,
 
 	*len_v = mpz_num_bytes(v);
 
-	*bytes_v = (unsigned char*)malloc(*len_v);
+	*bytes_v = (unsigned char*)srp_alloc(*len_v);
 
 	if (!bytes_v)
-		goto cleanup_and_exit;
+		goto error_and_exit;
 
 	mpz_to_bin(v, *bytes_v);
 
@@ -626,6 +653,10 @@ void srp_create_salted_verification_key( SRP_HashAlgorithm alg,
 	delete_ng( ng );
 	mpz_clear(v);
 	mpz_clear(x);
+	return ret;
+error_and_exit:
+	ret = SRP_ERR;
+	goto cleanup_and_exit;
 }
 
 
@@ -663,19 +694,23 @@ struct SRPVerifier *srp_verifier_new(SRP_HashAlgorithm alg,
 	if (!ng)
 		goto cleanup_and_exit;
 
-	ver = (struct SRPVerifier *) malloc( sizeof(struct SRPVerifier) );
+	ver = (struct SRPVerifier *) srp_alloc( sizeof(struct SRPVerifier) );
 
 	if (!ver)
 		goto cleanup_and_exit;
 
-	init_random(); /* Only happens once */
+	if (init_random() != SRP_OK) { /* Only happens once */
+		srp_free(ver);
+		ver = 0;
+		goto cleanup_and_exit;
+	}
 
-	ver->username = (char *) malloc(ulen);
+	ver->username = (char *) srp_alloc(ulen);
 	ver->hash_alg = alg;
 	ver->ng = ng;
 
 	if (!ver->username) {
-		free(ver);
+		srp_free(ver);
 		ver = 0;
 		goto cleanup_and_exit;
 	}
@@ -690,11 +725,15 @@ struct SRPVerifier *srp_verifier_new(SRP_HashAlgorithm alg,
 		if (bytes_b) {
 			mpz_from_bin(bytes_b, len_b, b);
 		} else {
-			mpz_fill_random(b);
+			if (mpz_fill_random(b) != SRP_OK) {
+				srp_free(ver);
+				ver = 0;
+				goto cleanup_and_exit;
+			}
 		}
 
 		if (!H_nn(k, alg, ng->N, ng->N, ng->g)) {
-			free(ver);
+			srp_free(ver);
 			ver = 0;
 			goto cleanup_and_exit;
 		}
@@ -705,7 +744,7 @@ struct SRPVerifier *srp_verifier_new(SRP_HashAlgorithm alg,
 		mpz_addm(B, tmp1, tmp2, ng->N, tmp3);
 
 		if (!H_nn(u, alg, ng->N, A, B)) {
-			free(ver);
+			srp_free(ver);
 			ver = 0;
 			goto cleanup_and_exit;
 		}
@@ -723,11 +762,11 @@ struct SRPVerifier *srp_verifier_new(SRP_HashAlgorithm alg,
 		calculate_H_AMK(alg, ver->H_AMK, A, ver->M, ver->session_key);
 
 		*len_B = mpz_num_bytes(B);
-		*bytes_B = (unsigned char*)malloc(*len_B);
+		*bytes_B = (unsigned char*)srp_alloc(*len_B);
 
 		if (!*bytes_B) {
-			free(ver->username);
-			free(ver);
+			srp_free(ver->username);
+			srp_free(ver);
 			ver = 0;
 			*len_B = 0;
 			goto cleanup_and_exit;
@@ -737,7 +776,7 @@ struct SRPVerifier *srp_verifier_new(SRP_HashAlgorithm alg,
 
 		ver->bytes_B = *bytes_B;
 	} else {
-		free(ver);
+		srp_free(ver);
 		ver = 0;
 	}
 
@@ -762,10 +801,10 @@ void srp_verifier_delete(struct SRPVerifier *ver)
 {
 	if (ver) {
 		delete_ng(ver->ng);
-		free(ver->username);
-		free(ver->bytes_B);
+		srp_free(ver->username);
+		srp_free(ver->bytes_B);
 		memset(ver, 0, sizeof(*ver));
-		free(ver);
+		srp_free(ver);
 	}
 }
 
@@ -814,14 +853,15 @@ struct SRPUser *srp_user_new(SRP_HashAlgorithm alg, SRP_NGType ng_type,
 	const unsigned char *bytes_password, size_t len_password,
 	const char *n_hex, const char *g_hex)
 {
-	struct SRPUser *usr = (struct SRPUser *) malloc(sizeof(struct SRPUser));
+	struct SRPUser *usr = (struct SRPUser *) srp_alloc(sizeof(struct SRPUser));
 	size_t ulen  = strlen(username) + 1;
 	size_t uvlen = strlen(username_for_verifier) + 1;
 
 	if (!usr)
 		goto err_exit;
 
-	init_random(); /* Only happens once */
+	if (init_random() != SRP_OK) /* Only happens once */
+		goto err_exit;
 
 	usr->hash_alg = alg;
 	usr->ng = new_ng(ng_type, n_hex, g_hex);
@@ -833,12 +873,12 @@ struct SRPUser *srp_user_new(SRP_HashAlgorithm alg, SRP_NGType ng_type,
 	if (!usr->ng)
 		goto err_exit;
 
-	usr->username = (char*)malloc(ulen);
-	usr->username_verifier = (char*)malloc(uvlen);
-	usr->password = (unsigned char*)malloc(len_password);
+	usr->username = (char*)srp_alloc(ulen);
+	usr->username_verifier = (char*)srp_alloc(uvlen);
+	usr->password = (unsigned char*)srp_alloc(len_password);
 	usr->password_len = len_password;
 
-	if (!usr->username || !usr->password)
+	if (!usr->username || !usr->password || !usr->username_verifier)
 		goto err_exit;
 
 	memcpy(usr->username, username, ulen);
@@ -858,15 +898,13 @@ struct SRPUser *srp_user_new(SRP_HashAlgorithm alg, SRP_NGType ng_type,
 		mpz_clear(usr->S);
 		if (usr->ng)
 			delete_ng(usr->ng);
-		if (usr->username)
-			free(usr->username);
-		if (usr->username_verifier)
-			free(usr->username_verifier);
+		srp_free(usr->username);
+		srp_free(usr->username_verifier);
 		if (usr->password) {
 			memset(usr->password, 0, usr->password_len);
-			free(usr->password);
+			srp_free(usr->password);
 		}
-		free(usr);
+		srp_free(usr);
 	}
 
 	return 0;
@@ -885,15 +923,15 @@ void srp_user_delete(struct SRPUser *usr)
 
 		memset(usr->password, 0, usr->password_len);
 
-		free(usr->username);
-		free(usr->username_verifier);
-		free(usr->password);
+		srp_free(usr->username);
+		srp_free(usr->username_verifier);
+		srp_free(usr->password);
 
 		if (usr->bytes_A)
-			free(usr->bytes_A);
+			srp_free(usr->bytes_A);
 
 		memset(usr, 0, sizeof(*usr));
-		free(usr);
+		srp_free(usr);
 	}
 }
 
@@ -926,33 +964,38 @@ size_t srp_user_get_session_key_length(struct SRPUser *usr)
 
 
 /* Output: username, bytes_A, len_A */
-void srp_user_start_authentication(struct SRPUser *usr, char **username,
+SRP_Result srp_user_start_authentication(struct SRPUser *usr, char **username,
 	const unsigned char *bytes_a, size_t len_a,
 	unsigned char **bytes_A, size_t *len_A)
 {
 	if (bytes_a) {
 		mpz_from_bin(bytes_a, len_a, usr->a);
 	} else {
-		mpz_fill_random(usr->a);
+		if (mpz_fill_random(usr->a) != SRP_OK)
+			goto error_and_exit;
 	}
 
 	mpz_powm(usr->A, usr->ng->g, usr->a, usr->ng->N);
 
 	*len_A = mpz_num_bytes(usr->A);
-	*bytes_A = (unsigned char*)malloc(*len_A);
+	*bytes_A = (unsigned char*)srp_alloc(*len_A);
 
-	if (!*bytes_A) {
-		*len_A = 0;
-		*bytes_A = 0;
-		*username = 0;
-		return;
-	}
+	if (!*bytes_A)
+		goto error_and_exit;
 
 	mpz_to_bin(usr->A, *bytes_A);
 
 	usr->bytes_A = *bytes_A;
 	if (username)
 		*username = usr->username;
+
+	return SRP_OK;
+
+error_and_exit:
+	*len_A = 0;
+	*bytes_A = 0;
+	*username = 0;
+	return SRP_ERR;
 }
 
 
diff --git a/src/util/srp.h b/src/util/srp.h
index 15a2b8a68..c876e70e6 100644
--- a/src/util/srp.h
+++ b/src/util/srp.h
@@ -78,6 +78,22 @@ typedef enum
 	SRP_SHA512*/
 } SRP_HashAlgorithm;
 
+typedef enum
+{
+	SRP_OK,
+	SRP_ERR,
+} SRP_Result;
+
+/* Sets the memory functions used by srp.
+ * Note: this doesn't set the memory functions used by gmp,
+ * but it is supported to have different functions for srp and gmp.
+ * Don't call this after you have already allocated srp structures.
+ */
+void srp_set_memory_functions(
+	void *(*new_srp_alloc) (size_t),
+	void *(*new_srp_realloc) (void *, size_t),
+	void (*new_srp_free) (void *));
+
 /* Out: bytes_v, len_v
  *
  * The caller is responsible for freeing the memory allocated for bytes_v
@@ -86,8 +102,11 @@ typedef enum
  * If provided, they must contain ASCII text of the hexidecimal notation.
  *
  * If bytes_s == NULL, it is filled with random data. The caller is responsible for freeing.
+ *
+ * Returns SRP_OK on success, and SRP_ERR on error.
+ * bytes_s might be in this case invalid, don't free it.
  */
-void srp_create_salted_verification_key( SRP_HashAlgorithm alg,
+SRP_Result srp_create_salted_verification_key( SRP_HashAlgorithm alg,
 	SRP_NGType ng_type, const char *username_for_verifier,
 	const unsigned char *password, size_t len_password,
 	unsigned char **bytes_s,  size_t *len_s,
@@ -101,6 +120,8 @@ void srp_create_salted_verification_key( SRP_HashAlgorithm alg,
  * The n_hex and g_hex parameters should be 0 unless SRP_NG_CUSTOM is used for ng_type
  *
  * If bytes_b == NULL, random data is used for b.
+ *
+ * Returns pointer to SRPVerifier on success, and NULL on error.
  */
 struct SRPVerifier* srp_verifier_new(SRP_HashAlgorithm alg, SRP_NGType ng_type,
 	const char *username,
@@ -114,7 +135,7 @@ struct SRPVerifier* srp_verifier_new(SRP_HashAlgorithm alg, SRP_NGType ng_type,
 
 void srp_verifier_delete( struct SRPVerifier* ver );
 
-
+// srp_verifier_verify_session must have been called before
 int srp_verifier_is_authenticated( struct SRPVerifier* ver );
 
 
@@ -128,7 +149,9 @@ const unsigned char* srp_verifier_get_session_key( struct SRPVerifier* ver,
 size_t srp_verifier_get_session_key_length(struct SRPVerifier* ver);
 
 
-/* user_M must be exactly srp_verifier_get_session_key_length() bytes in size */
+/* Verifies session, on success, it writes bytes_HAMK.
+ * user_M must be exactly srp_verifier_get_session_key_length() bytes in size
+ */
 void srp_verifier_verify_session( struct SRPVerifier* ver,
 	const unsigned char* user_M, unsigned char** bytes_HAMK );
 
@@ -154,7 +177,7 @@ size_t srp_user_get_session_key_length(struct SRPUser* usr);
 
 /* Output: username, bytes_A, len_A. If you don't want it get written, set username to NULL.
  * If bytes_a == NULL, random data is used for a. */
-void srp_user_start_authentication(struct SRPUser* usr, char** username,
+SRP_Result srp_user_start_authentication(struct SRPUser* usr, char** username,
 	const unsigned char* bytes_a, size_t len_a,
 	unsigned char** bytes_A, size_t* len_A);
 
-- 
GitLab