Extract the socket message->notification flow into SubscriberService to reduce duplication

This commit is contained in:
Karmanyaah Malhotra 2023-12-16 22:14:20 -06:00
parent a09374b2a5
commit 9d38118dcb
3 changed files with 37 additions and 41 deletions

View file

@ -18,13 +18,11 @@ class JsonConnection(
private val api: ApiService, private val api: ApiService,
private val user: User?, private val user: User?,
private val sinceId: String?, private val sinceId: String?,
private val connectionOpenListener: (ConnectionId, String?) -> Unit,
private val stateChangeListener: (Collection<Long>, ConnectionState) -> Unit, private val stateChangeListener: (Collection<Long>, ConnectionState) -> Unit,
private val notificationListener: (Subscription, Notification) -> Unit, private val notificationListener: (ConnectionId, Message) -> String?,
private val serviceActive: () -> Boolean private val serviceActive: () -> Boolean
) : Connection { ) : Connection {
private val baseUrl = connectionId.baseUrl private val baseUrl = connectionId.baseUrl
private val parser = NotificationParser()
private val topicsToSubscriptionIds = connectionId.topicsToSubscriptionIds private val topicsToSubscriptionIds = connectionId.topicsToSubscriptionIds
private val topicIsUnifiedPush = connectionId.topicIsUnifiedPush private val topicIsUnifiedPush = connectionId.topicIsUnifiedPush
private val subscriptionIds = topicsToSubscriptionIds.values private val subscriptionIds = topicsToSubscriptionIds.values
@ -46,23 +44,7 @@ class JsonConnection(
Log.d(TAG, "[$url] (Re-)starting connection for subscriptions: $topicsToSubscriptionIds") Log.d(TAG, "[$url] (Re-)starting connection for subscriptions: $topicsToSubscriptionIds")
val startTime = System.currentTimeMillis() val startTime = System.currentTimeMillis()
val notify = notify@ { message : Message -> val notify = notify@ { message : Message ->
if (message.event == ApiService.EVENT_OPEN) { since = notificationListener(ConnectionId(baseUrl, topicsToSubscriptionIds, topicIsUnifiedPush), message)?: since
connectionOpenListener(ConnectionId(baseUrl, topicsToSubscriptionIds, topicIsUnifiedPush), message.message)
return@notify
}
val (topic, notification) = parser.parseWithTopic(
message,
notificationId = Random.nextInt(),
subscriptionId = 0
) ?: return@notify // subscriptionId to be set downstream
since = notification.id
val subscriptionId = topicsToSubscriptionIds[topic] ?: return@notify
val subscription =
repository.getSubscription(subscriptionId) ?: return@notify
val notificationWithSubscriptionId =
notification.copy(subscriptionId = subscription.id)
notificationListener(subscription, notificationWithSubscriptionId)
} }
val failed = AtomicBoolean(false) val failed = AtomicBoolean(false)
val fail = { _: Exception -> val fail = { _: Exception ->

View file

@ -17,7 +17,9 @@ import io.heckel.ntfy.db.ConnectionState
import io.heckel.ntfy.db.Repository import io.heckel.ntfy.db.Repository
import io.heckel.ntfy.db.Subscription import io.heckel.ntfy.db.Subscription
import io.heckel.ntfy.msg.ApiService import io.heckel.ntfy.msg.ApiService
import io.heckel.ntfy.msg.Message
import io.heckel.ntfy.msg.NotificationDispatcher import io.heckel.ntfy.msg.NotificationDispatcher
import io.heckel.ntfy.msg.NotificationParser
import io.heckel.ntfy.ui.Colors import io.heckel.ntfy.ui.Colors
import io.heckel.ntfy.ui.MainActivity import io.heckel.ntfy.ui.MainActivity
import io.heckel.ntfy.util.Log import io.heckel.ntfy.util.Log
@ -28,6 +30,7 @@ import kotlinx.coroutines.GlobalScope
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.Mutex
import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ConcurrentHashMap
import kotlin.random.Random
/** /**
* The subscriber service manages the foreground service for instant delivery. * The subscriber service manages the foreground service for instant delivery.
@ -67,6 +70,7 @@ class SubscriberService : Service() {
private var notificationManager: NotificationManager? = null private var notificationManager: NotificationManager? = null
private var serviceNotification: Notification? = null private var serviceNotification: Notification? = null
private val refreshMutex = Mutex() // Ensure refreshConnections() is only run one at a time private val refreshMutex = Mutex() // Ensure refreshConnections() is only run one at a time
private val parser = NotificationParser()
override fun onStartCommand(intent: Intent?, flags: Int, startId: Int): Int { override fun onStartCommand(intent: Intent?, flags: Int, startId: Int): Int {
Log.d(TAG, "onStartCommand executed with startId: $startId") Log.d(TAG, "onStartCommand executed with startId: $startId")
@ -204,9 +208,9 @@ class SubscriberService : Service() {
val user = repository.getUser(connectionId.baseUrl) val user = repository.getUser(connectionId.baseUrl)
val connection = if (repository.getConnectionProtocol() == Repository.CONNECTION_PROTOCOL_WS) { val connection = if (repository.getConnectionProtocol() == Repository.CONNECTION_PROTOCOL_WS) {
val alarmManager = getSystemService(ALARM_SERVICE) as AlarmManager val alarmManager = getSystemService(ALARM_SERVICE) as AlarmManager
WsConnection(connectionId, repository, user, since, ::onConnectionOpen, ::onStateChanged, ::onNotificationReceived, alarmManager) WsConnection(connectionId, repository, user, since, ::onStateChanged, ::onNotificationReceived, alarmManager)
} else { } else {
JsonConnection(connectionId, scope, repository, api, user, since, ::onConnectionOpen, ::onStateChanged, ::onNotificationReceived, serviceActive) JsonConnection(connectionId, scope, repository, api, user, since, ::onStateChanged, ::onNotificationReceived, serviceActive)
} }
connections[connectionId] = connection connections[connectionId] = connection
connection.start() connection.start()
@ -267,7 +271,22 @@ class SubscriberService : Service() {
repository.updateState(subscriptionIds, state) repository.updateState(subscriptionIds, state)
} }
private fun onNotificationReceived(subscription: Subscription, notification: io.heckel.ntfy.db.Notification) { // return successfully processed ID, else null
private fun onNotificationReceived(connectionId: ConnectionId, message: Message) : String? {
if (message.event == ApiService.EVENT_OPEN) {
onConnectionOpen(connectionId, message.message)
return null
}
val notificationWithTopic = parser.parseWithTopic(message, notificationId = Random.nextInt(), subscriptionId = 0
) ?: return null// subscriptionId to be set downstream
val (topic, notificationWoId) = notificationWithTopic
val subscriptionId = connectionId.topicsToSubscriptionIds[topic] ?: return null
val subscription =
repository.getSubscription(subscriptionId) ?: return null
val notification =
notificationWoId.copy(subscriptionId = subscription.id)
// Wakelock while notifications are being dispatched // Wakelock while notifications are being dispatched
// Wakelocks are reference counted by default so that should work neatly here // Wakelocks are reference counted by default so that should work neatly here
wakeLock?.acquire(NOTIFICATION_RECEIVED_WAKELOCK_TIMEOUT_MILLIS) wakeLock?.acquire(NOTIFICATION_RECEIVED_WAKELOCK_TIMEOUT_MILLIS)
@ -285,6 +304,7 @@ class SubscriberService : Service() {
} }
} }
} }
return notification.id
} }
private fun createNotificationChannel(): NotificationManager? { private fun createNotificationChannel(): NotificationManager? {

View file

@ -7,6 +7,7 @@ import android.os.Looper
import io.heckel.ntfy.db.* import io.heckel.ntfy.db.*
import io.heckel.ntfy.msg.ApiService import io.heckel.ntfy.msg.ApiService
import io.heckel.ntfy.msg.ApiService.Companion.requestBuilder import io.heckel.ntfy.msg.ApiService.Companion.requestBuilder
import io.heckel.ntfy.msg.Message
import io.heckel.ntfy.msg.NotificationParser import io.heckel.ntfy.msg.NotificationParser
import io.heckel.ntfy.util.Log import io.heckel.ntfy.util.Log
import io.heckel.ntfy.util.topicShortUrl import io.heckel.ntfy.util.topicShortUrl
@ -36,9 +37,8 @@ class WsConnection(
private val repository: Repository, private val repository: Repository,
private val user: User?, private val user: User?,
private val sinceId: String?, private val sinceId: String?,
private val connectionOpenListener: (ConnectionId, String?) -> Unit,
private val stateChangeListener: (Collection<Long>, ConnectionState) -> Unit, private val stateChangeListener: (Collection<Long>, ConnectionState) -> Unit,
private val notificationListener: (Subscription, Notification) -> Unit, private val notificationListener: (ConnectionId, Message) -> String?,
private val alarmManager: AlarmManager private val alarmManager: AlarmManager
) : Connection { ) : Connection {
private val parser = NotificationParser() private val parser = NotificationParser()
@ -61,7 +61,8 @@ class WsConnection(
private val topicIsUnifiedPush = connectionId.topicIsUnifiedPush private val topicIsUnifiedPush = connectionId.topicIsUnifiedPush
private val subscriptionIds = topicsToSubscriptionIds.values private val subscriptionIds = topicsToSubscriptionIds.values
private val topicsStr = topicsToSubscriptionIds.keys.joinToString(separator = ",") private val topicsStr = topicsToSubscriptionIds.keys.joinToString(separator = ",")
private val unifiedPushTopicsStr = topicIsUnifiedPush.filter { entry -> entry.value }.keys.joinToString(separator = ",") private val unifiedPushTopicsStr =
topicIsUnifiedPush.filter { entry -> entry.value }.keys.joinToString(separator = ",")
private val shortUrl = topicShortUrl(baseUrl, topicsStr) private val shortUrl = topicShortUrl(baseUrl, topicsStr)
init { init {
@ -140,22 +141,15 @@ class WsConnection(
synchronize("onMessage") { synchronize("onMessage") {
Log.d(TAG, "$shortUrl (gid=$globalId, lid=$id): Received message: $text") Log.d(TAG, "$shortUrl (gid=$globalId, lid=$id): Received message: $text")
val message = parser.parseMessage(text) ?: return@synchronize val message = parser.parseMessage(text) ?: return@synchronize
if (message.event == ApiService.EVENT_OPEN){ val id = notificationListener(
connectionOpenListener(ConnectionId(baseUrl, topicsToSubscriptionIds, topicIsUnifiedPush), message.message) ConnectionId(baseUrl, topicsToSubscriptionIds, topicIsUnifiedPush),
return@synchronize message
)
if (id != null) {
since.set(id)
} else {
Log.d(WsConnection.TAG,"$shortUrl (gid=$globalId, lid=$id): Irrelevant or unknown message. Discarding.")
} }
val notificationWithTopic = parser.parseWithTopic(message, subscriptionId = 0, notificationId = Random.nextInt())
if (notificationWithTopic == null) {
Log.d(TAG, "$shortUrl (gid=$globalId, lid=$id): Irrelevant or unknown message. Discarding.")
return@synchronize
}
val topic = notificationWithTopic.topic
val notification = notificationWithTopic.notification
val subscriptionId = topicsToSubscriptionIds[topic] ?: return@synchronize
val subscription = repository.getSubscription(subscriptionId) ?: return@synchronize
val notificationWithSubscriptionId = notification.copy(subscriptionId = subscription.id)
notificationListener(subscription, notificationWithSubscriptionId)
since.set(notification.id)
} }
} }