PageViewModel: use appropriate client certs

This commit is contained in:
dece 2022-02-13 23:56:09 +01:00
parent 3aea6f42c9
commit 50793d1440
5 changed files with 107 additions and 30 deletions

View file

@ -4,12 +4,16 @@ import android.security.keystore.KeyGenParameterSpec
import android.security.keystore.KeyProperties import android.security.keystore.KeyProperties
import android.util.Log import android.util.Log
import androidx.room.* import androidx.room.*
import java.lang.IllegalArgumentException
import java.security.KeyPairGenerator import java.security.KeyPairGenerator
import java.security.KeyStore import java.security.KeyStore
import javax.security.auth.x500.X500Principal import javax.security.auth.x500.X500Principal
object Identities { object Identities {
const val PROVIDER = "AndroidKeyStore"
private const val TAG = "Identities"
val keyStore by lazy { KeyStore.getInstance(PROVIDER).apply { load(null) } }
@Entity @Entity
data class Identity( data class Identity(
/** ID. */ /** ID. */
@ -49,6 +53,10 @@ object Identities {
suspend fun getAll(): List<Identity> = suspend fun getAll(): List<Identity> =
Database.INSTANCE.identityDao().getAll() Database.INSTANCE.identityDao().getAll()
suspend fun getForUrl(url: String): Identity? =
Database.INSTANCE.identityDao().getAll()
.find { it.urls.any { usedUrl -> url.startsWith(usedUrl) } }
suspend fun update(vararg identities: Identity) = suspend fun update(vararg identities: Identity) =
Database.INSTANCE.identityDao().update(*identities) Database.INSTANCE.identityDao().update(*identities)
@ -62,7 +70,7 @@ object Identities {
fun generateClientCert(alias: String, commonName: String) { fun generateClientCert(alias: String, commonName: String) {
val algo = KeyProperties.KEY_ALGORITHM_RSA val algo = KeyProperties.KEY_ALGORITHM_RSA
val kpg = KeyPairGenerator.getInstance(algo, "AndroidKeyStore") val kpg = KeyPairGenerator.getInstance(algo, PROVIDER)
val purposes = KeyProperties.PURPOSE_SIGN or KeyProperties.PURPOSE_VERIFY val purposes = KeyProperties.PURPOSE_SIGN or KeyProperties.PURPOSE_VERIFY
val spec = KeyGenParameterSpec.Builder(alias, purposes) val spec = KeyGenParameterSpec.Builder(alias, purposes)
.apply { .apply {
@ -74,7 +82,9 @@ object Identities {
} }
} }
} }
.setDigests(KeyProperties.DIGEST_SHA256) .setDigests(KeyProperties.DIGEST_NONE, KeyProperties.DIGEST_SHA256, KeyProperties.DIGEST_SHA512)
.setEncryptionPaddings(KeyProperties.ENCRYPTION_PADDING_NONE)
.setSignaturePaddings(KeyProperties.SIGNATURE_PADDING_RSA_PKCS1)
.build() .build()
kpg.initialize(spec) kpg.initialize(spec)
kpg.generateKeyPair() kpg.generateKeyPair()
@ -82,7 +92,6 @@ object Identities {
} }
private fun deleteClientCert(alias: String) { private fun deleteClientCert(alias: String) {
val keyStore = KeyStore.getInstance("AndroidKeyStore").also { it.load(null) }
if (keyStore.containsAlias(alias)) { if (keyStore.containsAlias(alias)) {
keyStore.deleteEntry(alias) keyStore.deleteEntry(alias)
Log.i(TAG, "deleteClientCert: deleted entry with alias \"$alias\"") Log.i(TAG, "deleteClientCert: deleted entry with alias \"$alias\"")
@ -90,6 +99,4 @@ object Identities {
Log.i(TAG, "deleteClientCert: no such alias \"$alias\"") Log.i(TAG, "deleteClientCert: no such alias \"$alias\"")
} }
} }
private const val TAG = "Identities"
} }

View file

@ -16,7 +16,6 @@ import androidx.activity.addCallback
import androidx.appcompat.app.AlertDialog import androidx.appcompat.app.AlertDialog
import androidx.fragment.app.Fragment import androidx.fragment.app.Fragment
import androidx.fragment.app.viewModels import androidx.fragment.app.viewModels
import androidx.preference.PreferenceManager
import androidx.recyclerview.widget.LinearLayoutManager import androidx.recyclerview.widget.LinearLayoutManager
import dev.lowrespalmtree.comet.databinding.FragmentPageViewBinding import dev.lowrespalmtree.comet.databinding.FragmentPageViewBinding
import dev.lowrespalmtree.comet.utils.isConnectedToNetwork import dev.lowrespalmtree.comet.utils.isConnectedToNetwork
@ -116,16 +115,7 @@ class PageFragment : Fragment(), PageAdapter.Listener {
} }
when (uri.scheme) { when (uri.scheme) {
"gemini" -> { "gemini" -> vm.sendGeminiRequest(uri, requireContext())
val prefs = PreferenceManager.getDefaultSharedPreferences(requireContext())
val protocol =
prefs.getString("tls_version", Request.DEFAULT_TLS_VERSION)!!
val connectionTimeout =
prefs.getInt("connection_timeout", Request.DEFAULT_CONNECTION_TIMEOUT_SEC)
val readTimeout =
prefs.getInt("read_timeout", Request.DEFAULT_READ_TIMEOUT_SEC)
vm.sendGeminiRequest(uri, protocol, connectionTimeout, readTimeout)
}
else -> openUnknownScheme(uri) else -> openUnknownScheme(uri)
} }
} }
@ -163,6 +153,7 @@ class PageFragment : Fragment(), PageAdapter.Listener {
when (event) { when (event) {
is PageViewModel.InputEvent -> { is PageViewModel.InputEvent -> {
askForInput(event.prompt, event.uri) askForInput(event.prompt, event.uri)
updateState(PageViewModel.State.IDLE)
} }
is PageViewModel.SuccessEvent -> { is PageViewModel.SuccessEvent -> {
vm.currentUrl = event.uri vm.currentUrl = event.uri
@ -200,7 +191,7 @@ class PageFragment : Fragment(), PageAdapter.Listener {
val newUri = uri.buildUpon().query(text).build() val newUri = uri.buildUpon().query(text).build()
openUrl(newUri.toString(), base = vm.currentUrl) openUrl(newUri.toString(), base = vm.currentUrl)
}, },
onDismiss = { updateState(PageViewModel.State.IDLE) } onDismiss = {}
) )
} }

View file

@ -1,11 +1,13 @@
package dev.lowrespalmtree.comet package dev.lowrespalmtree.comet
import android.content.Context
import android.net.Uri import android.net.Uri
import android.util.Log import android.util.Log
import androidx.lifecycle.MutableLiveData import androidx.lifecycle.MutableLiveData
import androidx.lifecycle.SavedStateHandle import androidx.lifecycle.SavedStateHandle
import androidx.lifecycle.ViewModel import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope import androidx.lifecycle.viewModelScope
import androidx.preference.PreferenceManager
import dev.lowrespalmtree.comet.utils.joinUrls import dev.lowrespalmtree.comet.utils.joinUrls
import kotlinx.coroutines.* import kotlinx.coroutines.*
import kotlinx.coroutines.channels.onSuccess import kotlinx.coroutines.channels.onSuccess
@ -18,18 +20,25 @@ class PageViewModel(@Suppress("unused") private val savedStateHandle: SavedState
ViewModel() { ViewModel() {
/** Currently viewed page URL. */ /** Currently viewed page URL. */
var currentUrl: String = "" var currentUrl: String = ""
/** Latest Uri requested using `sendGeminiRequest`. */ /** Latest Uri requested using `sendGeminiRequest`. */
var loadingUrl: Uri? = null var loadingUrl: Uri? = null
/** Observable page viewer state. */ /** Observable page viewer state. */
val state: MutableLiveData<State> by lazy { MutableLiveData<State>(State.IDLE) } val state: MutableLiveData<State> by lazy { MutableLiveData<State>(State.IDLE) }
/** Observable page viewer lines (backed up by `linesList` but updated less often). Left element is associated URL. */ /** Observable page viewer lines (backed up by `linesList` but updated less often). Left element is associated URL. */
val lines: MutableLiveData<Pair<String, List<Line>>> by lazy { MutableLiveData<Pair<String, List<Line>>>() } val lines: MutableLiveData<Pair<String, List<Line>>> by lazy { MutableLiveData<Pair<String, List<Line>>>() }
/** Observable page viewer latest event. */ /** Observable page viewer latest event. */
val event: MutableLiveData<Event> by lazy { MutableLiveData<Event>() } val event: MutableLiveData<Event> by lazy { MutableLiveData<Event>() }
/** A non-saved list of visited URLs. Not an history, just used for going back. */ /** A non-saved list of visited URLs. Not an history, just used for going back. */
val visitedUrls = mutableListOf<String>() val visitedUrls = mutableListOf<String>()
/** Latest request job created, stored to cancel it if needed. */ /** Latest request job created, stored to cancel it if needed. */
private var requestJob: Job? = null private var requestJob: Job? = null
/** Lines for the current page. */ /** Lines for the current page. */
private var linesList = ArrayList<Line>() private var linesList = ArrayList<Line>()
@ -53,14 +62,36 @@ class PageViewModel(@Suppress("unused") private val savedStateHandle: SavedState
* The URI must be valid, absolute and with a gemini scheme. * The URI must be valid, absolute and with a gemini scheme.
*/ */
@ExperimentalCoroutinesApi @ExperimentalCoroutinesApi
fun sendGeminiRequest(uri: Uri, protocol: String, connectionTimeout: Int, readTimeout: Int, redirects: Int = 0) { fun sendGeminiRequest(
Log.d(TAG, "sendRequest: URI \"$uri\"") uri: Uri,
context: Context,
redirects: Int = 0
) {
Log.i(TAG, "sendGeminiRequest: URI \"$uri\"")
loadingUrl = uri loadingUrl = uri
// Retrieve various request parameters from user preferences.
val prefs = PreferenceManager.getDefaultSharedPreferences(context)
val protocol =
prefs.getString("tls_version", Request.DEFAULT_TLS_VERSION)!!
val connectionTimeout =
prefs.getInt("connection_timeout", Request.DEFAULT_CONNECTION_TIMEOUT_SEC)
val readTimeout =
prefs.getInt("read_timeout", Request.DEFAULT_READ_TIMEOUT_SEC)
state.postValue(State.CONNECTING) state.postValue(State.CONNECTING)
requestJob?.apply { if (isActive) cancel() } requestJob?.apply { if (isActive) cancel() }
requestJob = viewModelScope.launch(Dispatchers.IO) { requestJob = viewModelScope.launch(Dispatchers.IO) {
// Look for a suitable identity to use with this URL.
val keyManager = Identities.getForUrl(uri.toString())?.let {
Log.d(TAG, "sendGeminiRequest coroutine: using identity with key ${it.key}")
Request.KeyManager.fromAlias(it.key)
}
// Connect to the server and proceed (no TOFU validation yet).
val response = try { val response = try {
val request = Request(uri) val request = Request(uri, keyManager = keyManager)
val socket = request.connect(protocol, connectionTimeout, readTimeout) val socket = request.connect(protocol, connectionTimeout, readTimeout)
val channel = request.proceed(socket, this) val channel = request.proceed(socket, this)
Response.from(channel, viewModelScope) Response.from(channel, viewModelScope)
@ -80,8 +111,10 @@ class PageViewModel(@Suppress("unused") private val savedStateHandle: SavedState
) )
return@launch return@launch
} }
if (!isActive) if (!isActive)
return@launch return@launch
if (response == null) { if (response == null) {
signalError("Can't parse server response.") signalError("Can't parse server response.")
return@launch return@launch

View file

@ -8,20 +8,25 @@ import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import java.io.BufferedInputStream import java.io.BufferedInputStream
import java.net.InetSocketAddress import java.net.InetSocketAddress
import java.net.Socket
import java.net.SocketTimeoutException import java.net.SocketTimeoutException
import java.security.KeyStore
import java.security.Principal
import java.security.PrivateKey
import java.security.cert.X509Certificate import java.security.cert.X509Certificate
import javax.net.ssl.SSLContext import javax.net.ssl.*
import javax.net.ssl.SSLProtocolException
import javax.net.ssl.SSLSocket
import javax.net.ssl.X509TrustManager
class Request(private val uri: Uri) { class Request(private val uri: Uri, private val keyManager: KeyManager? = null) {
private val port get() = if (uri.port > 0) uri.port else 1965 private val port get() = if (uri.port > 0) uri.port else 1965
fun connect(protocol: String, connectionTimeout: Int, readTimeout: Int): SSLSocket { fun connect(protocol: String, connectionTimeout: Int, readTimeout: Int): SSLSocket {
Log.d(TAG, "connect: $protocol, c.to. $connectionTimeout, r.to. $readTimeout") Log.d(
TAG,
"connect: $protocol, conn. timeout $connectionTimeout," +
" read timeout $readTimeout, key manager $keyManager"
)
val context = SSLContext.getInstance(protocol) val context = SSLContext.getInstance(protocol)
context.init(null, arrayOf(TrustManager()), null) context.init(arrayOf(keyManager), arrayOf(TrustManager()), null)
val socket = context.socketFactory.createSocket() as SSLSocket val socket = context.socketFactory.createSocket() as SSLSocket
socket.soTimeout = readTimeout * 1000 socket.soTimeout = readTimeout * 1000
socket.connect(InetSocketAddress(uri.host, port), connectionTimeout * 1000) socket.connect(InetSocketAddress(uri.host, port), connectionTimeout * 1000)
@ -58,6 +63,49 @@ class Request(private val uri: Uri) {
return channel return channel
} }
class KeyManager(
private val alias: String,
private val cert: X509Certificate,
private val privateKey: PrivateKey
) : X509ExtendedKeyManager() {
companion object {
fun fromAlias(alias: String): KeyManager? {
val cert = Identities.keyStore.getCertificate(alias) as X509Certificate?
?: return null.also { Log.e(TAG, "fromAlias: cert is null") }
val key = Identities.keyStore.getEntry(alias, null)?.let { entry ->
(entry as KeyStore.PrivateKeyEntry).privateKey
} ?: return null.also { Log.e(TAG, "fromAlias: private key is null") }
return KeyManager(alias, cert, key)
}
}
override fun chooseClientAlias(
keyType: Array<out String>?,
issuers: Array<out Principal>?,
socket: Socket?
): String = alias
override fun getCertificateChain(alias: String?): Array<out X509Certificate> = arrayOf(cert)
override fun getPrivateKey(alias: String?): PrivateKey = privateKey
override fun getServerAliases(
keyType: String?, issuers: Array<out Principal>?
): Array<String> = throw UnsupportedOperationException()
override fun chooseServerAlias(
keyType: String?,
issuers: Array<out Principal>?,
socket: Socket?
): String = throw UnsupportedOperationException()
override fun getClientAliases(
keyType: String?,
issuers: Array<out Principal>?
): Array<String> = throw UnsupportedOperationException()
}
/** TODO An X509TrustManager implementation for TOFU validation. */
@SuppressLint("CustomX509TrustManager") @SuppressLint("CustomX509TrustManager")
class TrustManager : X509TrustManager { class TrustManager : X509TrustManager {
@SuppressLint("TrustAllX509TrustManager") @SuppressLint("TrustAllX509TrustManager")

View file

@ -8,8 +8,6 @@
<color name="main_accent_dark">#073642</color> <color name="main_accent_dark">#073642</color>
<color name="second_accent">#2aa198</color> <color name="second_accent">#2aa198</color>
<color name="second_accent_dark">#073642</color> <color name="second_accent_dark">#073642</color>
<color name="link">#268bd2</color>
<color name="link_visited">#2aa198</color>
<color name="url_bar">#fdf6e3</color> <color name="url_bar">#fdf6e3</color>
<color name="url_bar_loading">#586e75</color> <color name="url_bar_loading">#586e75</color>
</resources> </resources>