diff --git a/.github/workflows/manual.yml b/.github/workflows/manual.yml index c2f52081..2e59a206 100644 --- a/.github/workflows/manual.yml +++ b/.github/workflows/manual.yml @@ -1,35 +1,140 @@ name: Android Build on: - # push: - # branches: [ main, develop ] - # pull_request: - # branches: [ main ] + push: + branches: [ human-operator, main ] workflow_dispatch: # Ermöglicht manuelle Ausführung des Workflows jobs: + detect-changes: + runs-on: ubuntu-latest + outputs: + app_changed: ${{ steps.changes.outputs.app }} + humanoperator_changed: ${{ steps.changes.outputs.humanoperator }} + shared_changed: ${{ steps.changes.outputs.shared }} + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 2 # Letzten 2 Commits holen für Diff + + - name: Detect changed files + id: changes + run: | + # Bei workflow_dispatch immer alles bauen + if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then + echo "app=true" >> $GITHUB_OUTPUT + echo "humanoperator=true" >> $GITHUB_OUTPUT + echo "shared=true" >> $GITHUB_OUTPUT + echo "Manual dispatch - building all modules" + exit 0 + fi + + # Geänderte Dateien im letzten Commit ermitteln + CHANGED_FILES=$(git diff --name-only HEAD~1 HEAD 2>/dev/null || echo "") + + # Falls kein vorheriger Commit existiert (erster Commit), alles bauen + if [ -z "$CHANGED_FILES" ]; then + echo "app=true" >> $GITHUB_OUTPUT + echo "humanoperator=true" >> $GITHUB_OUTPUT + echo "shared=true" >> $GITHUB_OUTPUT + echo "No previous commit found - building all modules" + exit 0 + fi + + echo "Changed files:" + echo "$CHANGED_FILES" + + # Prüfen ob shared/root files geändert wurden (build.gradle, settings.gradle, etc.) + SHARED_CHANGED=false + if echo "$CHANGED_FILES" | grep -qE '^(build\.gradle|settings\.gradle|gradle\.properties|gradle/|buildSrc/)'; then + SHARED_CHANGED=true + fi + + # Prüfen ob app/ Dateien geändert wurden + APP_CHANGED=false + if echo "$CHANGED_FILES" | grep -q '^app/'; then + APP_CHANGED=true + fi + + # Prüfen ob humanoperator/ Dateien geändert wurden + HUMANOPERATOR_CHANGED=false + if echo "$CHANGED_FILES" | grep -q '^humanoperator/'; then + HUMANOPERATOR_CHANGED=true + fi + + echo "app=$APP_CHANGED" >> $GITHUB_OUTPUT + echo "humanoperator=$HUMANOPERATOR_CHANGED" >> $GITHUB_OUTPUT + echo "shared=$SHARED_CHANGED" >> $GITHUB_OUTPUT + + echo "Results: app=$APP_CHANGED, humanoperator=$HUMANOPERATOR_CHANGED, shared=$SHARED_CHANGED" + build: + needs: detect-changes runs-on: ubuntu-latest + env: + BUILD_APP: ${{ needs.detect-changes.outputs.app_changed == 'true' || needs.detect-changes.outputs.shared_changed == 'true' }} + BUILD_HUMANOPERATOR: ${{ needs.detect-changes.outputs.humanoperator_changed == 'true' || needs.detect-changes.outputs.shared_changed == 'true' }} steps: - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - - name: Set up JDK - uses: actions/setup-java@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: java-version: '17' distribution: 'temurin' cache: gradle + - name: Decode google-services.json (app) + env: + GOOGLE_SERVICES_JSON: ${{ secrets.GOOGLE_SERVICES_JSON_APP }} + run: printf '%s' "$GOOGLE_SERVICES_JSON" > app/google-services.json + + - name: Decode google-services.json (humanoperator) + env: + GOOGLE_SERVICES_JSON: ${{ secrets.GOOGLE_SERVICES_JSON_HUMANOPERATOR }} + run: printf '%s' "$GOOGLE_SERVICES_JSON" > humanoperator/google-services.json + + - name: Create local.properties + run: echo "sdk.dir=$ANDROID_HOME" > local.properties + + - name: Fix gradle.properties for CI + run: | + sed -i '/org.gradle.java.home=/d' gradle.properties + sed -i 's/org.gradle.jvmargs=.*/org.gradle.jvmargs=-Xmx2048m -XX:MaxMetaspaceSize=512m/' gradle.properties + sed -i 's/kotlin.daemon.jvmargs=.*/kotlin.daemon.jvmargs=-Xmx1536m -XX:MaxMetaspaceSize=512m/' gradle.properties + - name: Grant execute permission for gradlew run: chmod +x gradlew - - name: Build with Gradle - run: ./gradlew assembleRelease + - name: Build app module (debug) + if: env.BUILD_APP == 'true' + run: ./gradlew :app:assembleDebug - - name: Upload APK + - name: Build humanoperator module (debug) + if: env.BUILD_HUMANOPERATOR == 'true' + run: ./gradlew :humanoperator:assembleDebug + + - name: Upload app APK + if: env.BUILD_APP == 'true' uses: actions/upload-artifact@v4 with: - name: app-release - path: app/build/outputs/apk/release/app-release-unsigned.apk + name: app-debug + path: app/build/outputs/apk/debug/app-debug.apk + + - name: Upload humanoperator APK + if: env.BUILD_HUMANOPERATOR == 'true' + uses: actions/upload-artifact@v4 + with: + name: humanoperator-debug + path: humanoperator/build/outputs/apk/debug/humanoperator-debug.apk + + - name: Build summary + run: | + echo "### Build Summary" >> $GITHUB_STEP_SUMMARY + echo "| Module | Built |" >> $GITHUB_STEP_SUMMARY + echo "|--------|-------|" >> $GITHUB_STEP_SUMMARY + echo "| app | ${{ env.BUILD_APP }} |" >> $GITHUB_STEP_SUMMARY + echo "| humanoperator | ${{ env.BUILD_HUMANOPERATOR }} |" >> $GITHUB_STEP_SUMMARY diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..f3e59281 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "editor.maxTokenizationLineLength": 20000 +} \ No newline at end of file diff --git a/app/build.gradle.kts b/app/build.gradle.kts index 8659a5e7..86e25b34 100644 --- a/app/build.gradle.kts +++ b/app/build.gradle.kts @@ -5,6 +5,12 @@ plugins { id("org.jetbrains.kotlin.plugin.serialization") version "1.9.20" id("com.google.android.libraries.mapsplatform.secrets-gradle-plugin") id("kotlin-parcelize") + id("com.google.gms.google-services") +} + +// Redirect build output to C: drive (NTFS) to avoid corrupted ExFAT build cache +if (System.getenv("CI") == null) { + layout.buildDirectory = file("C:/GradleBuild/app") } android { @@ -12,7 +18,7 @@ android { compileSdk = 35 defaultConfig { - applicationId = "com.google.ai.sample" + applicationId = "io.github.android_poweruser" minSdk = 26 targetSdk = 35 versionCode = 1 @@ -96,4 +102,14 @@ dependencies { // Camera Core to potentially fix missing JNI lib issue implementation("androidx.camera:camera-core:1.4.0") + + // WebRTC + implementation("io.getstream:stream-webrtc-android:1.1.1") + + // WebSocket for signaling + implementation("com.squareup.okhttp3:okhttp:4.12.0") + + // Firebase + implementation(platform("com.google.firebase:firebase-bom:32.7.2")) + implementation("com.google.firebase:firebase-database") } diff --git a/app/src/main/AndroidManifest.xml b/app/src/main/AndroidManifest.xml index 93b986a8..55627be5 100644 --- a/app/src/main/AndroidManifest.xml +++ b/app/src/main/AndroidManifest.xml @@ -1,6 +1,7 @@ + diff --git a/app/src/main/kotlin/com/google/ai/sample/ApiKeyDialog.kt b/app/src/main/kotlin/com/google/ai/sample/ApiKeyDialog.kt index 48332784..869b2e7d 100644 --- a/app/src/main/kotlin/com/google/ai/sample/ApiKeyDialog.kt +++ b/app/src/main/kotlin/com/google/ai/sample/ApiKeyDialog.kt @@ -88,7 +88,7 @@ fun ApiKeyDialog( ApiProvider.GOOGLE -> "https://makersuite.google.com/app/apikey" ApiProvider.CEREBRAS -> "https://cloud.cerebras.ai/" ApiProvider.VERCEL -> "https://vercel.com/ai-gateway" - ApiProvider.HUMAN_EXPERT -> return@Button // No API key needed + ApiProvider.HUMAN_EXPERT -> return@Button } val intent = Intent(Intent.ACTION_VIEW, Uri.parse(url)) context.startActivity(intent) diff --git a/app/src/main/kotlin/com/google/ai/sample/GenerativeAiViewModelFactory.kt b/app/src/main/kotlin/com/google/ai/sample/GenerativeAiViewModelFactory.kt index 40cb8aad..c9fd191e 100644 --- a/app/src/main/kotlin/com/google/ai/sample/GenerativeAiViewModelFactory.kt +++ b/app/src/main/kotlin/com/google/ai/sample/GenerativeAiViewModelFactory.kt @@ -9,6 +9,7 @@ import com.google.ai.client.generativeai.GenerativeModel import com.google.ai.client.generativeai.type.generationConfig import com.google.ai.sample.feature.live.LiveApiManager import com.google.ai.sample.feature.multimodal.PhotoReasoningViewModel +import com.google.ai.sample.util.GenerationSettingsPreferences // Model options enum class ApiProvider { @@ -44,7 +45,11 @@ enum class ModelOption( "https://huggingface.co/na5h13/gemma-3n-E4B-it-litert-lm/resolve/main/gemma-3n-E4B-it-int4.litertlm?download=true", "4.92 GB" ), - HUMAN_EXPERT("Human Expert", "human-expert", ApiProvider.HUMAN_EXPERT) + HUMAN_EXPERT("Human Expert", "human-expert", ApiProvider.HUMAN_EXPERT); + + /** Whether this model supports TopK/TopP/Temperature settings */ + val supportsGenerationSettings: Boolean + get() = this != HUMAN_EXPERT } val GenerativeViewModelFactory = object : ViewModelProvider.Factory { @@ -52,16 +57,20 @@ val GenerativeViewModelFactory = object : ViewModelProvider.Factory { viewModelClass: Class, extras: CreationExtras ): T { - val config = generationConfig { - temperature = 0.0f - } - // Get the application context from extras val application = checkNotNull(extras[ViewModelProvider.AndroidViewModelFactory.APPLICATION_KEY]) + val currentModel = GenerativeAiViewModelFactory.getCurrentModel() + + // Load per-model generation settings + val genSettings = GenerationSettingsPreferences.loadSettings(application.applicationContext, currentModel.modelName) + val config = generationConfig { + temperature = genSettings.temperature + topP = genSettings.topP + topK = genSettings.topK + } // Get the API key from MainActivity val mainActivity = MainActivity.getInstance() - val currentModel = GenerativeAiViewModelFactory.getCurrentModel() val apiKey = if (currentModel == ModelOption.GEMMA_3N_E4B_IT || currentModel == ModelOption.HUMAN_EXPERT) { "offline-no-key-needed" // Dummy key for offline/human expert models } else { @@ -75,8 +84,6 @@ val GenerativeViewModelFactory = object : ViewModelProvider.Factory { return with(viewModelClass) { when { isAssignableFrom(PhotoReasoningViewModel::class.java) -> { - val currentModel = GenerativeAiViewModelFactory.getCurrentModel() - if (currentModel.modelName.contains("live")) { // Live API models val liveApiManager = LiveApiManager(apiKey, currentModel.modelName) diff --git a/app/src/main/kotlin/com/google/ai/sample/MainActivity.kt b/app/src/main/kotlin/com/google/ai/sample/MainActivity.kt index adeb55ce..f44e9bbe 100644 --- a/app/src/main/kotlin/com/google/ai/sample/MainActivity.kt +++ b/app/src/main/kotlin/com/google/ai/sample/MainActivity.kt @@ -120,12 +120,14 @@ class MainActivity : ComponentActivity() { // MediaProjection private lateinit var mediaProjectionManager: MediaProjectionManager private lateinit var mediaProjectionLauncher: ActivityResultLauncher + private lateinit var webRtcMediaProjectionLauncher: ActivityResultLauncher private var currentScreenInfoForScreenshot: String? = null private lateinit var navController: NavHostController private var isProcessingExplicitScreenshotRequest: Boolean = false private var onMediaProjectionPermissionGranted: (() -> Unit)? = null + private var onWebRtcMediaProjectionResult: ((Int, Intent) -> Unit)? = null private val screenshotRequestHandler = object : BroadcastReceiver() { override fun onReceive(context: Context?, intent: Intent?) { @@ -187,15 +189,28 @@ class MainActivity : ComponentActivity() { // This should be guaranteed by its placement in onCreate. if (!::mediaProjectionManager.isInitialized) { Log.e(TAG, "requestMediaProjectionPermission: mediaProjectionManager not initialized!") - // Optionally, initialize it here as a fallback, though it indicates an issue with onCreate ordering - // mediaProjectionManager = getSystemService(Context.MEDIA_PROJECTION_SERVICE) as MediaProjectionManager - // Toast.makeText(this, "Error: Projection manager not ready. Please try again.", Toast.LENGTH_SHORT).show() return } val intent = mediaProjectionManager.createScreenCaptureIntent() mediaProjectionLauncher.launch(intent) } + /** + * Request a fresh MediaProjection permission specifically for WebRTC (Human Expert). + * This does NOT start ScreenCaptureService - the result is passed directly to the callback. + */ + fun requestMediaProjectionForWebRTC(onResult: (Int, Intent) -> Unit) { + Log.d(TAG, "Requesting MediaProjection permission for WebRTC") + onWebRtcMediaProjectionResult = onResult + + if (!::mediaProjectionManager.isInitialized) { + Log.e(TAG, "requestMediaProjectionForWebRTC: mediaProjectionManager not initialized!") + return + } + val intent = mediaProjectionManager.createScreenCaptureIntent() + webRtcMediaProjectionLauncher.launch(intent) + } + fun takeAdditionalScreenshot() { if (ScreenCaptureService.isRunning()) { Log.d(TAG, "MainActivity: Instructing ScreenCaptureService to take an additional screenshot.") @@ -286,7 +301,7 @@ class MainActivity : ComponentActivity() { when (currentTrialState) { TrialManager.TrialState.EXPIRED_INTERNET_TIME_CONFIRMED -> { - trialInfoMessage = "Your 30-minute trial period has ended. Please subscribe to the app to continue using it." + trialInfoMessage = "Please support the development of the app so that you can continue using it \uD83C\uDF89" showTrialInfoDialog = true Log.d(TAG, "updateTrialState: Set message to \'$trialInfoMessage\', showTrialInfoDialog = true (EXPIRED)") } @@ -444,6 +459,10 @@ class MainActivity : ComponentActivity() { if (result.resultCode == Activity.RESULT_OK && result.data != null) { val shouldTakeScreenshotOnThisStart = this@MainActivity.isProcessingExplicitScreenshotRequest Log.i(TAG, "MediaProjection permission granted. Starting ScreenCaptureService. Explicit request: $shouldTakeScreenshotOnThisStart") + + // Notify ViewModel about the permission grant (for Human Expert WebRTC) + photoReasoningViewModel?.onMediaProjectionPermissionGranted(result.resultCode, result.data!!) + val serviceIntent = Intent(this, ScreenCaptureService::class.java).apply { action = ScreenCaptureService.ACTION_START_CAPTURE putExtra(ScreenCaptureService.EXTRA_RESULT_CODE, result.resultCode) @@ -487,6 +506,21 @@ class MainActivity : ComponentActivity() { } } + // Separate WebRTC MediaProjection launcher - does NOT start ScreenCaptureService + webRtcMediaProjectionLauncher = registerForActivityResult( + ActivityResultContracts.StartActivityForResult() + ) { result -> + if (result.resultCode == Activity.RESULT_OK && result.data != null) { + Log.i(TAG, "WebRTC MediaProjection permission granted.") + onWebRtcMediaProjectionResult?.invoke(result.resultCode, result.data!!) + onWebRtcMediaProjectionResult = null + } else { + Log.w(TAG, "WebRTC MediaProjection permission denied.") + Toast.makeText(this, "Screen capture permission denied", Toast.LENGTH_SHORT).show() + onWebRtcMediaProjectionResult = null + } + } + // Keyboard visibility listener val rootView = findViewById(android.R.id.content) onGlobalLayoutListener = ViewTreeObserver.OnGlobalLayoutListener { @@ -1222,7 +1256,7 @@ fun TrialExpiredDialog( ) Spacer(modifier = Modifier.height(16.dp)) Text( - text = "Your 7-day trial period has ended. Please subscribe to the app to continue using it.", + text = "Please support the development of the app so that you can continue using it \uD83C\uDF89", style = MaterialTheme.typography.bodyMedium, modifier = Modifier.align(Alignment.CenterHorizontally) ) diff --git a/app/src/main/kotlin/com/google/ai/sample/MenuScreen.kt b/app/src/main/kotlin/com/google/ai/sample/MenuScreen.kt index 798a84ad..62f2cf61 100644 --- a/app/src/main/kotlin/com/google/ai/sample/MenuScreen.kt +++ b/app/src/main/kotlin/com/google/ai/sample/MenuScreen.kt @@ -28,6 +28,7 @@ import androidx.compose.ui.Modifier import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.res.stringResource import androidx.compose.ui.text.style.TextAlign +import androidx.compose.ui.text.font.FontWeight import androidx.compose.ui.tooling.preview.Preview import androidx.compose.ui.unit.dp import androidx.compose.foundation.layout.Arrangement @@ -48,6 +49,7 @@ import android.util.Log import android.os.Environment import android.os.StatFs import com.google.ai.sample.feature.multimodal.ModelDownloadManager +import androidx.compose.runtime.collectAsState import java.io.File data class MenuItem( @@ -285,6 +287,111 @@ fun MenuScreen( } } + // Generation Settings (TopK, TopP, Temperature) for current model + if (selectedModel.supportsGenerationSettings) { + item { + val genSettings = remember(selectedModel) { + mutableStateOf( + com.google.ai.sample.util.GenerationSettingsPreferences.loadSettings( + context, selectedModel.modelName + ) + ) + } + + Card( + modifier = Modifier + .fillMaxWidth() + .padding(horizontal = 16.dp, vertical = 8.dp) + ) { + Column( + modifier = Modifier + .padding(all = 16.dp) + .fillMaxWidth() + ) { + Text( + text = "Generation Settings (${selectedModel.displayName})", + style = MaterialTheme.typography.titleMedium + ) + + Spacer(modifier = Modifier.height(12.dp)) + + // Temperature Slider (0.0 - 2.0) + Text( + text = "Temperature: ${"%.2f".format(genSettings.value.temperature)}", + style = MaterialTheme.typography.bodyMedium + ) + androidx.compose.material3.Slider( + value = genSettings.value.temperature, + onValueChange = { newVal -> + genSettings.value = genSettings.value.copy(temperature = newVal) + }, + onValueChangeFinished = { + com.google.ai.sample.util.GenerationSettingsPreferences.saveSettings( + context, selectedModel.modelName, genSettings.value + ) + }, + valueRange = 0f..2f, + steps = 0, + modifier = Modifier.fillMaxWidth() + ) + + Spacer(modifier = Modifier.height(8.dp)) + + // TopP Slider (0.0 - 1.0) + Text( + text = "Top P: ${"%.2f".format(genSettings.value.topP)}", + style = MaterialTheme.typography.bodyMedium + ) + androidx.compose.material3.Slider( + value = genSettings.value.topP, + onValueChange = { newVal -> + genSettings.value = genSettings.value.copy(topP = newVal) + }, + onValueChangeFinished = { + com.google.ai.sample.util.GenerationSettingsPreferences.saveSettings( + context, selectedModel.modelName, genSettings.value + ) + }, + valueRange = 0f..1f, + steps = 0, + modifier = Modifier.fillMaxWidth() + ) + + Spacer(modifier = Modifier.height(8.dp)) + + // TopK Slider (0 - 100) + Text( + text = "Top K: ${genSettings.value.topK}", + style = MaterialTheme.typography.bodyMedium + ) + androidx.compose.material3.Slider( + value = genSettings.value.topK.toFloat(), + onValueChange = { newVal -> + genSettings.value = genSettings.value.copy(topK = Math.round(newVal)) + }, + onValueChangeFinished = { + com.google.ai.sample.util.GenerationSettingsPreferences.saveSettings( + context, selectedModel.modelName, genSettings.value + ) + }, + valueRange = 0f..100f, + steps = 0, + modifier = Modifier.fillMaxWidth() + ) + + if (selectedModel == ModelOption.GEMMA_3N_E4B_IT) { + Spacer(modifier = Modifier.height(4.dp)) + Text( + text = "Note: LlmInference (offline model) may not support all generation parameters.", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + } + } + } + } + // Menu Items items(menuItems) { menuItem -> Card( @@ -309,7 +416,7 @@ fun MenuScreen( TextButton( onClick = { if (isTrialExpired) { - Toast.makeText(context, "Please subscribe to the app to continue.", Toast.LENGTH_LONG).show() + Toast.makeText(context, "Please support the development of the app so that you can continue using it \uD83C\uDF89", Toast.LENGTH_LONG).show() } else { if (menuItem.routeId == "photo_reasoning") { val mainActivity = context as? MainActivity @@ -383,7 +490,7 @@ fun MenuScreen( ) } else { Text( - text = "Support improvements", + text = "Support Improvements \uD83C\uDF89", style = MaterialTheme.typography.titleMedium, modifier = Modifier.weight(1f) ) @@ -404,15 +511,28 @@ fun MenuScreen( .fillMaxWidth() .padding(horizontal = 16.dp, vertical = 8.dp) ) { + val boldStyle = SpanStyle(fontWeight = FontWeight.Bold) val annotatedText = buildAnnotatedString { - append("""• Preview models could be deactivated by Google without being handed over to the final release. -• GPT-oss 120b is a pure text model. -• Gemma 3n E4B it cannot handle screenshots in the API. -• GPT models (Vercel) have a free budget of $5 per month. -GPT-5.1 Input: $1.25/M Output: $10.00/M -GPT-5.1 mini Input: $0.25/ M Output: $2.00/M -GPT-5 nano Input: $0.05/M Output: $0.40/M -• There are rate limits for free use of Gemini models. The less powerful the models are, the more you can use them. The limits range from a maximum of 5 to 30 calls per minute. After each screenshot (every 2-3 seconds) the LLM must respond again. More information is available at """) + append("• ") + withStyle(boldStyle) { append("Preview Models") } + append(" could be deactivated by Google without being handed over to the final release.\n") + append("• ") + withStyle(boldStyle) { append("GPT-oss 120b") } + append(" is a pure text model.\n") + append("• ") + withStyle(boldStyle) { append("Gemma 27B IT") } + append(" cannot handle screenshots in the API.\n") + append("• GPT models (") + withStyle(boldStyle) { append("Vercel") } + append(") have a free budget of \$5 per month and a credit card is necessary.\n") + append("GPT-5.1 Input: \$1.25/M Output: \$10.00/M\n") + append("GPT-5.1 mini Input: \$0.25/M Output: \$2.00/M\n") + append("GPT-5 nano Input: \$0.05/M Output: \$0.40/M\n") + append("• There are ") + withStyle(boldStyle) { append("rate limits") } + append(" for free use of ") + withStyle(boldStyle) { append("Gemini models") } + append(". The less powerful the models are, the more you can use them. The limits range from a maximum of 5 to 30 calls per minute. After each screenshot (every 2-3 seconds) the LLM must respond again. More information is available at ") pushStringAnnotation(tag = "URL", annotation = "https://ai.google.dev/gemini-api/docs/rate-limits") withStyle(style = SpanStyle(color = MaterialTheme.colorScheme.primary, textDecoration = TextDecoration.Underline)) { @@ -489,31 +609,127 @@ GPT-5 nano Input: $0.05/M Output: $0.40/M val bytesAvailable = statFs.availableBlocksLong * statFs.blockSizeLong val gbAvailable = bytesAvailable.toDouble() / (1024 * 1024 * 1024) val formattedGbAvailable = String.format("%.2f", gbAvailable) + + val dlState by ModelDownloadManager.downloadState.collectAsState() AlertDialog( - onDismissRequest = { showDownloadDialog = false }, - title = { Text("Download Model? (4.92 GB)") }, - text = { Text("Should the Gemma 3n E4B be downloaded?\n\n$formattedGbAvailable GB of storage available.") }, - confirmButton = { - TextButton( - onClick = { - showDownloadDialog = false - downloadDialogModel?.downloadUrl?.let { url -> - ModelDownloadManager.downloadModel(context, url) - // We set the model, but the user will have to wait for download - selectedModel = downloadDialogModel!! - GenerativeAiViewModelFactory.setModel(downloadDialogModel!!) + onDismissRequest = { + if (dlState is ModelDownloadManager.DownloadState.Idle || dlState is ModelDownloadManager.DownloadState.Completed || dlState is ModelDownloadManager.DownloadState.Error) { + showDownloadDialog = false + } + // Don't dismiss while downloading/paused + }, + title = { Text("Download Model (4.92 GB)") }, + text = { + Column { + when (val state = dlState) { + is ModelDownloadManager.DownloadState.Idle -> { + Text("Should the Gemma 3n E4B be downloaded?\n\n$formattedGbAvailable GB of storage available.") + } + is ModelDownloadManager.DownloadState.Downloading -> { + Text("Downloading...") + Spacer(modifier = Modifier.height(8.dp)) + androidx.compose.material3.LinearProgressIndicator( + progress = { state.progress }, + modifier = Modifier.fillMaxWidth() + ) + Spacer(modifier = Modifier.height(4.dp)) + Text( + text = "${ModelDownloadManager.formatBytes(state.bytesDownloaded)} / ${if (state.totalBytes > 0) ModelDownloadManager.formatBytes(state.totalBytes) else "?"}", + style = MaterialTheme.typography.bodySmall + ) + Text( + text = "${"%.1f".format(state.progress * 100)}%", + style = MaterialTheme.typography.bodySmall + ) + } + is ModelDownloadManager.DownloadState.Paused -> { + Text("Download paused.") + Spacer(modifier = Modifier.height(8.dp)) + val progress = if (state.totalBytes > 0) state.bytesDownloaded.toFloat() / state.totalBytes else 0f + androidx.compose.material3.LinearProgressIndicator( + progress = { progress }, + modifier = Modifier.fillMaxWidth() + ) + Spacer(modifier = Modifier.height(4.dp)) + Text( + text = "${ModelDownloadManager.formatBytes(state.bytesDownloaded)} / ${if (state.totalBytes > 0) ModelDownloadManager.formatBytes(state.totalBytes) else "?"}", + style = MaterialTheme.typography.bodySmall + ) + } + is ModelDownloadManager.DownloadState.Completed -> { + Text("Download complete! ✅") + } + is ModelDownloadManager.DownloadState.Error -> { + Text("Error: ${state.message}") } } - ) { Text("OK") } + } + }, + confirmButton = { + when (dlState) { + is ModelDownloadManager.DownloadState.Idle -> { + TextButton( + onClick = { + downloadDialogModel?.downloadUrl?.let { url -> + ModelDownloadManager.downloadModel(context, url) + // Don't set model yet - wait for download to complete (Point 17) + } + } + ) { Text("Download") } + } + is ModelDownloadManager.DownloadState.Downloading -> { + TextButton(onClick = { ModelDownloadManager.pauseDownload() }) { Text("Pause") } + } + is ModelDownloadManager.DownloadState.Paused -> { + TextButton( + onClick = { + downloadDialogModel?.downloadUrl?.let { url -> + ModelDownloadManager.resumeDownload(context, url) + } + } + ) { Text("Resume") } + } + is ModelDownloadManager.DownloadState.Completed -> { + TextButton(onClick = { + // Set model only after download is completed (Point 17) + downloadDialogModel?.let { + selectedModel = it + GenerativeAiViewModelFactory.setModel(it) + } + showDownloadDialog = false + }) { Text("Close") } + } + is ModelDownloadManager.DownloadState.Error -> { + TextButton( + onClick = { + downloadDialogModel?.downloadUrl?.let { url -> + ModelDownloadManager.downloadModel(context, url) + } + } + ) { Text("Retry") } + } + } }, dismissButton = { - TextButton( - onClick = { - showDownloadDialog = false - // Do not change model + when (dlState) { + is ModelDownloadManager.DownloadState.Idle -> { + TextButton(onClick = { showDownloadDialog = false }) { Text("Cancel") } } - ) { Text("ABORT") } + is ModelDownloadManager.DownloadState.Downloading, + is ModelDownloadManager.DownloadState.Paused -> { + TextButton( + onClick = { + ModelDownloadManager.cancelDownload(context) + showDownloadDialog = false + } + ) { Text("Cancel Download") } + } + is ModelDownloadManager.DownloadState.Completed -> { /* No dismiss button */ } + is ModelDownloadManager.DownloadState.Error -> { + TextButton(onClick = { showDownloadDialog = false }) { Text("Close") } + } + } } ) } diff --git a/app/src/main/kotlin/com/google/ai/sample/ScreenCaptureService.kt b/app/src/main/kotlin/com/google/ai/sample/ScreenCaptureService.kt index 17f356e9..6d53b8df 100644 --- a/app/src/main/kotlin/com/google/ai/sample/ScreenCaptureService.kt +++ b/app/src/main/kotlin/com/google/ai/sample/ScreenCaptureService.kt @@ -289,21 +289,24 @@ class ScreenCaptureService : Service() { } } catch (e: MissingFieldException) { Log.e(TAG, "Serialization error, potentially a 503 error.", e) - // Check if the error message indicates a 503-like error - if (e.message?.contains("UNAVAILABLE") == true || + // Point 15: Check for missing 'parts' field (Gemma 27B issue) + if (e.message?.contains("parts") == true) { + errorMessage = "The model returned an incomplete response. This can happen with larger models. Please try again." + } else if (e.message?.contains("UNAVAILABLE") == true || e.message?.contains("503") == true || e.message?.contains("overloaded") == true) { - errorMessage = "Service Unavailable (503) - Retry with new key" + // Point 14: User-friendly high-demand message + errorMessage = "This model is currently experiencing high demand. Please try again later." } else { errorMessage = e.localizedMessage ?: "Serialization error" } } catch (e: Exception) { Log.e(TAG, "Direct error in AI call", e) - // Also check for 503 patterns in general exceptions + // Point 14: Check for high-demand 503 patterns if (e.message?.contains("503") == true || e.message?.contains("overloaded") == true || e.message?.contains("UNAVAILABLE") == true) { - errorMessage = "Service Unavailable (503) - Retry with new key" + errorMessage = "This model is currently experiencing high demand. Please try again later." } else { errorMessage = e.localizedMessage ?: "AI call failed" } @@ -319,11 +322,14 @@ class ScreenCaptureService : Service() { e.message?.contains("UNAVAILABLE") == true || e.message?.contains("503") == true || e.message?.contains("overloaded") == true)) { - errorMessage = "Service Unavailable (503) - Retry with new key" + errorMessage = "This model is currently experiencing high demand. Please try again later." + } else if (e is MissingFieldException && e.message?.contains("parts") == true) { + // Point 15: Gemma 27B incomplete response + errorMessage = "The model returned an incomplete response. This can happen with larger models. Please try again." } else if (e.message?.contains("503") == true || e.message?.contains("overloaded") == true || e.message?.contains("UNAVAILABLE") == true) { - errorMessage = "Service Unavailable (503) - Retry with new key" + errorMessage = "This model is currently experiencing high demand. Please try again later." } else { errorMessage = e.localizedMessage ?: "Unknown error" } diff --git a/app/src/main/kotlin/com/google/ai/sample/feature/multimodal/ModelDownloadManager.kt b/app/src/main/kotlin/com/google/ai/sample/feature/multimodal/ModelDownloadManager.kt index a66b9e07..fae14577 100644 --- a/app/src/main/kotlin/com/google/ai/sample/feature/multimodal/ModelDownloadManager.kt +++ b/app/src/main/kotlin/com/google/ai/sample/feature/multimodal/ModelDownloadManager.kt @@ -1,20 +1,61 @@ package com.google.ai.sample.feature.multimodal -import android.app.DownloadManager -import android.content.BroadcastReceiver +import android.app.NotificationChannel +import android.app.NotificationManager import android.content.Context -import android.content.Intent -import android.content.IntentFilter -import android.database.Cursor -import android.net.Uri +import android.os.Build import android.util.Log import android.widget.Toast +import androidx.core.app.NotificationCompat +import kotlin.coroutines.coroutineContext +import kotlinx.coroutines.* +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.flow.asStateFlow import java.io.File +import java.io.FileOutputStream +import java.io.IOException +import java.net.HttpURLConnection +import java.net.URL +/** + * Custom download manager for the Gemma 3n model. + * Uses HttpURLConnection with Range-Request support for resume capability. + * Point 18: Includes Android notification for download progress. + */ object ModelDownloadManager { private const val TAG = "ModelDownloadManager" const val MODEL_FILENAME = "gemma-3n-e4b-it-int4.litertlm" - private var downloadId: Long = -1 + private const val TEMP_SUFFIX = ".downloading" + private const val BUFFER_SIZE = 8192 + private const val MAX_RETRIES = 3 + private const val RETRY_DELAY_MS = 3000L + private const val PROGRESS_UPDATE_INTERVAL_MS = 500L + + // Notification constants + private const val DOWNLOAD_CHANNEL_ID = "model_download_channel" + private const val DOWNLOAD_NOTIFICATION_ID = 3001 + + sealed class DownloadState { + object Idle : DownloadState() + data class Downloading( + val progress: Float, // 0.0 - 1.0 + val bytesDownloaded: Long, + val totalBytes: Long + ) : DownloadState() + object Completed : DownloadState() + data class Error(val message: String) : DownloadState() + data class Paused( + val bytesDownloaded: Long, + val totalBytes: Long + ) : DownloadState() + } + + private val _downloadState = MutableStateFlow(DownloadState.Idle) + val downloadState: StateFlow = _downloadState.asStateFlow() + + private var downloadJob: Job? = null + private var isPaused = false fun isModelDownloaded(context: Context): Boolean { val file = getModelFile(context) @@ -31,52 +72,280 @@ object ModelDownloadManager { } } + private fun getTempFile(context: Context): File? { + val externalFilesDir = context.getExternalFilesDir(null) + return if (externalFilesDir != null) { + File(externalFilesDir, MODEL_FILENAME + TEMP_SUFFIX) + } else { + null + } + } + + private fun createNotificationChannel(context: Context) { + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.O) { + val channel = NotificationChannel( + DOWNLOAD_CHANNEL_ID, + "Model Download", + NotificationManager.IMPORTANCE_LOW + ).apply { + description = "Shows progress of model downloads" + setShowBadge(false) + } + val notificationManager = context.getSystemService(Context.NOTIFICATION_SERVICE) as NotificationManager + notificationManager.createNotificationChannel(channel) + } + } + + private fun showDownloadNotification(context: Context, progress: Float, bytesDownloaded: Long, totalBytes: Long) { + createNotificationChannel(context) + val notificationManager = context.getSystemService(Context.NOTIFICATION_SERVICE) as NotificationManager + val progressPercent = (progress * 100).toInt() + val notification = NotificationCompat.Builder(context, DOWNLOAD_CHANNEL_ID) + .setContentTitle("Downloading Model") + .setContentText("${formatBytes(bytesDownloaded)} / ${if (totalBytes > 0) formatBytes(totalBytes) else "?"} ($progressPercent%)") + .setSmallIcon(android.R.drawable.stat_sys_download) + .setPriority(NotificationCompat.PRIORITY_LOW) + .setOngoing(true) + .setProgress(100, progressPercent, totalBytes <= 0) + .setOnlyAlertOnce(true) + .build() + notificationManager.notify(DOWNLOAD_NOTIFICATION_ID, notification) + } + + private fun showDownloadCompleteNotification(context: Context) { + createNotificationChannel(context) + val notificationManager = context.getSystemService(Context.NOTIFICATION_SERVICE) as NotificationManager + val notification = NotificationCompat.Builder(context, DOWNLOAD_CHANNEL_ID) + .setContentTitle("Model Download Complete") + .setContentText("The model is ready to use.") + .setSmallIcon(android.R.drawable.stat_sys_download_done) + .setPriority(NotificationCompat.PRIORITY_LOW) + .setOngoing(false) + .setAutoCancel(true) + .build() + notificationManager.notify(DOWNLOAD_NOTIFICATION_ID, notification) + } + + private fun cancelDownloadNotification(context: Context) { + val notificationManager = context.getSystemService(Context.NOTIFICATION_SERVICE) as NotificationManager + notificationManager.cancel(DOWNLOAD_NOTIFICATION_ID) + } + fun downloadModel(context: Context, url: String) { if (isModelDownloaded(context)) { Toast.makeText(context, "Model already downloaded.", Toast.LENGTH_SHORT).show() return } - val file = getModelFile(context) - if (file != null && file.exists()) { - file.delete() // Clean up partial or old file + if (downloadJob?.isActive == true) { + Log.d(TAG, "Download already in progress.") + return } - try { - val request = DownloadManager.Request(Uri.parse(url)) - .setTitle("Downloading Gemma Model") - .setDescription("Downloading offline AI model (4.92 GB)...") - .setNotificationVisibility(DownloadManager.Request.VISIBILITY_VISIBLE_NOTIFY_COMPLETED) - .setDestinationInExternalFilesDir(context, null, MODEL_FILENAME) - .setAllowedOverMetered(true) - .setAllowedOverRoaming(true) - - val downloadManager = context.getSystemService(Context.DOWNLOAD_SERVICE) as? DownloadManager - - if (downloadManager != null) { - downloadId = downloadManager.enqueue(request) - Toast.makeText(context, "Download started. Please do not pause the download, as it cannot be resumed.", Toast.LENGTH_LONG).show() - Log.d(TAG, "Download started with ID: $downloadId") - } else { - Log.e(TAG, "DownloadManager service not available.") - Toast.makeText(context, "Download service unavailable.", Toast.LENGTH_SHORT).show() - } - } catch (e: Exception) { - Log.e(TAG, "Error starting download: ${e.message}") - Toast.makeText(context, "Failed to start download.", Toast.LENGTH_SHORT).show() + isPaused = false + downloadJob = CoroutineScope(Dispatchers.IO).launch { + downloadWithResume(context, url) + } + } + + fun pauseDownload() { + Log.d(TAG, "Pausing download...") + isPaused = true + } + + fun resumeDownload(context: Context, url: String) { + if (downloadJob?.isActive == true) { + Log.d(TAG, "Download is still active, not resuming.") + return + } + + isPaused = false + downloadJob = CoroutineScope(Dispatchers.IO).launch { + downloadWithResume(context, url) } } fun cancelDownload(context: Context) { - if (downloadId != -1L) { - val downloadManager = context.getSystemService(Context.DOWNLOAD_SERVICE) as? DownloadManager - if (downloadManager != null) { - downloadManager.remove(downloadId) - downloadId = -1 - Toast.makeText(context, "Download cancelled.", Toast.LENGTH_SHORT).show() - } else { - Log.e(TAG, "DownloadManager service not available for cancellation.") + Log.d(TAG, "Cancelling download...") + isPaused = false + downloadJob?.cancel() + downloadJob = null + + // Delete temp file + val tempFile = getTempFile(context) + if (tempFile != null && tempFile.exists()) { + tempFile.delete() + Log.d(TAG, "Temp file deleted.") + } + + _downloadState.value = DownloadState.Idle + cancelDownloadNotification(context) + CoroutineScope(Dispatchers.Main).launch { + Toast.makeText(context, "Download cancelled.", Toast.LENGTH_SHORT).show() + } + } + + private suspend fun downloadWithResume(context: Context, url: String) { + val tempFile = getTempFile(context) ?: run { + _downloadState.value = DownloadState.Error("Storage not available.") + return + } + val finalFile = getModelFile(context) ?: run { + _downloadState.value = DownloadState.Error("Storage not available.") + return + } + + var retryCount = 0 + var bytesDownloaded = if (tempFile.exists()) tempFile.length() else 0L + + while (retryCount <= MAX_RETRIES) { + if (!coroutineContext.isActive) return // Coroutine was cancelled + + var connection: HttpURLConnection? = null + try { + Log.d(TAG, "Starting download (attempt ${retryCount + 1}), resuming from byte $bytesDownloaded") + + connection = (URL(url).openConnection() as HttpURLConnection).apply { + connectTimeout = 30000 + readTimeout = 30000 + setRequestProperty("User-Agent", "ScreenOperator/1.0") + if (bytesDownloaded > 0) { + setRequestProperty("Range", "bytes=$bytesDownloaded-") + } + } + + val responseCode = connection.responseCode + Log.d(TAG, "Response code: $responseCode") + + val totalBytes: Long + val inputStream = connection.inputStream + + when (responseCode) { + HttpURLConnection.HTTP_OK -> { + // Server doesn't support range, restart from beginning + totalBytes = connection.contentLengthLong + bytesDownloaded = 0 + if (tempFile.exists()) tempFile.delete() + } + HttpURLConnection.HTTP_PARTIAL -> { + // Server supports range, resume + val contentRange = connection.getHeaderField("Content-Range") + totalBytes = if (contentRange != null && contentRange.contains("/")) { + contentRange.substringAfter("/").toLongOrNull() ?: -1L + } else { + bytesDownloaded + connection.contentLengthLong + } + } + else -> { + _downloadState.value = DownloadState.Error("Server error: $responseCode") + cancelDownloadNotification(context) + return + } + } + + Log.d(TAG, "Total bytes: $totalBytes, already downloaded: $bytesDownloaded") + _downloadState.value = DownloadState.Downloading( + progress = if (totalBytes > 0) bytesDownloaded.toFloat() / totalBytes else 0f, + bytesDownloaded = bytesDownloaded, + totalBytes = totalBytes + ) + + val fos = FileOutputStream(tempFile, bytesDownloaded > 0) // append if resuming + val buffer = ByteArray(BUFFER_SIZE) + var lastProgressUpdate = System.currentTimeMillis() + + inputStream.use { input -> + fos.use { output -> + var bytesRead: Int + while (input.read(buffer).also { bytesRead = it } != -1) { + if (!coroutineContext.isActive) { + Log.d(TAG, "Download cancelled during read.") + cancelDownloadNotification(context) + return + } + + if (isPaused) { + Log.d(TAG, "Download paused at $bytesDownloaded bytes.") + _downloadState.value = DownloadState.Paused( + bytesDownloaded = bytesDownloaded, + totalBytes = totalBytes + ) + // Keep notification showing paused state + showDownloadNotification(context, bytesDownloaded.toFloat() / totalBytes, bytesDownloaded, totalBytes) + return + } + + output.write(buffer, 0, bytesRead) + bytesDownloaded += bytesRead + + // Rate-limit progress updates + val now = System.currentTimeMillis() + if (now - lastProgressUpdate >= PROGRESS_UPDATE_INTERVAL_MS) { + lastProgressUpdate = now + val progress = if (totalBytes > 0) bytesDownloaded.toFloat() / totalBytes else 0f + _downloadState.value = DownloadState.Downloading( + progress = progress, + bytesDownloaded = bytesDownloaded, + totalBytes = totalBytes + ) + // Point 18: Update notification with progress + showDownloadNotification(context, progress, bytesDownloaded, totalBytes) + } + } + } + } + + // Download complete - rename temp to final + if (tempFile.exists()) { + finalFile.delete() + if (tempFile.renameTo(finalFile)) { + Log.i(TAG, "Download complete! File: ${finalFile.absolutePath} (${finalFile.length()} bytes)") + _downloadState.value = DownloadState.Completed + showDownloadCompleteNotification(context) + withContext(Dispatchers.Main) { + Toast.makeText(context, "Model download complete!", Toast.LENGTH_SHORT).show() + } + } else { + _downloadState.value = DownloadState.Error("Failed to save model file.") + cancelDownloadNotification(context) + } + } + return // Success, exit retry loop + + } catch (e: IOException) { + Log.e(TAG, "Download error (attempt ${retryCount + 1}): ${e.message}") + retryCount++ + if (retryCount > MAX_RETRIES) { + _downloadState.value = DownloadState.Error("Download failed after $MAX_RETRIES retries: ${e.message}") + cancelDownloadNotification(context) + withContext(Dispatchers.Main) { + Toast.makeText(context, "Download failed: ${e.message}", Toast.LENGTH_LONG).show() + } + } else { + _downloadState.value = DownloadState.Downloading( + progress = if (bytesDownloaded > 0) 0f else 0f, + bytesDownloaded = bytesDownloaded, + totalBytes = -1 + ) + Log.d(TAG, "Retrying in ${RETRY_DELAY_MS}ms...") + delay(RETRY_DELAY_MS) + } + } finally { + connection?.disconnect() } } } + + /** + * Format bytes to human-readable string (e.g. "1.23 GB") + */ + fun formatBytes(bytes: Long): String { + return when { + bytes >= 1_073_741_824 -> "%.2f GB".format(bytes.toDouble() / 1_073_741_824) + bytes >= 1_048_576 -> "%.1f MB".format(bytes.toDouble() / 1_048_576) + bytes >= 1024 -> "%.0f KB".format(bytes.toDouble() / 1024) + else -> "$bytes B" + } + } } + diff --git a/app/src/main/kotlin/com/google/ai/sample/feature/multimodal/PhotoReasoningScreen.kt b/app/src/main/kotlin/com/google/ai/sample/feature/multimodal/PhotoReasoningScreen.kt index 91433e52..db496dde 100644 --- a/app/src/main/kotlin/com/google/ai/sample/feature/multimodal/PhotoReasoningScreen.kt +++ b/app/src/main/kotlin/com/google/ai/sample/feature/multimodal/PhotoReasoningScreen.kt @@ -509,8 +509,9 @@ fun PhotoReasoningScreen( return@IconButton } - // Check MediaProjection for all models except gemma-3n-e4b-it - if (!isMediaProjectionPermissionGranted && modelName != "gemma-3n-e4b-it") { + // Check MediaProjection for all models except gemma-3n-e4b-it and human-expert + // Human Expert uses its own MediaProjection for WebRTC, not ScreenCaptureService + if (!isMediaProjectionPermissionGranted && modelName != "gemma-3n-e4b-it" && modelName != "human-expert") { mainActivity?.requestMediaProjectionPermission { // This block will be executed after permission is granted if (userQuestion.isNotBlank()) { diff --git a/app/src/main/kotlin/com/google/ai/sample/feature/multimodal/PhotoReasoningViewModel.kt b/app/src/main/kotlin/com/google/ai/sample/feature/multimodal/PhotoReasoningViewModel.kt index a8be898a..a851ff04 100644 --- a/app/src/main/kotlin/com/google/ai/sample/feature/multimodal/PhotoReasoningViewModel.kt +++ b/app/src/main/kotlin/com/google/ai/sample/feature/multimodal/PhotoReasoningViewModel.kt @@ -72,6 +72,9 @@ import kotlinx.serialization.json.Json import kotlinx.serialization.json.JsonElement import kotlinx.serialization.json.jsonObject import kotlinx.serialization.json.jsonPrimitive +import com.google.ai.sample.webrtc.WebRTCSender +import com.google.ai.sample.webrtc.SignalingClient +import org.webrtc.IceCandidate class PhotoReasoningViewModel( application: Application, @@ -85,6 +88,14 @@ class PhotoReasoningViewModel( private var llmInference: LlmInference? = null private val TAG = "PhotoReasoningViewModel" + + // WebRTC & Signaling + private var webRTCSender: WebRTCSender? = null + private var signalingClient: SignalingClient? = null + private var lastMediaProjectionResultCode: Int = 0 + private var lastMediaProjectionResultData: Intent? = null + + private fun Bitmap.toBase64(): String { val outputStream = ByteArrayOutputStream() @@ -164,6 +175,12 @@ class PhotoReasoningViewModel( private var commandProcessingJob: Job? = null private val stopExecutionFlag = AtomicBoolean(false) + // Track how many commands have been executed incrementally during streaming + // to avoid re-executing already-executed commands + private var incrementalCommandCount = 0 + // Accumulated full text during streaming for incremental command parsing + private var streamingAccumulatedText = StringBuilder() + // Added for retry on quota exceeded private var currentRetryAttempt = 0 private var currentScreenInfoForPrompt: String? = null @@ -175,6 +192,9 @@ class PhotoReasoningViewModel( val chunk = intent.getStringExtra(ScreenCaptureService.EXTRA_AI_STREAM_CHUNK) if (chunk != null) { updateAiMessage(chunk, isPending = true) + // Real-time command execution during streaming + streamingAccumulatedText.append(chunk) + processCommandsIncrementally(streamingAccumulatedText.toString()) } } } @@ -202,7 +222,23 @@ class PhotoReasoningViewModel( val apiKeyManager = ApiKeyManager.getInstance(receiverContext) val isQuotaError = isQuotaExceededError(errorMessage) + val isHighDemand = isHighDemandError(errorMessage) val currentModel = com.google.ai.sample.GenerativeAiViewModelFactory.getCurrentModel() + + // Point 14: Don't switch keys for high-demand 503 errors + if (isHighDemand) { + Log.d(TAG, "High demand error detected - not switching API keys") + _chatState.addMessage( + PhotoReasoningMessage( + text = "This model is currently experiencing high demand. Please try again later.", + participant = PhotoParticipant.ERROR + ) + ) + _chatMessagesFlow.value = chatMessages + saveChatHistory(getApplication()) + return + } + if (isQuotaError && currentRetryAttempt < MAX_RETRY_ATTEMPTS) { val currentKey = apiKeyManager.getCurrentApiKey(currentModel.apiProvider) if (currentKey != null) { @@ -250,7 +286,20 @@ class PhotoReasoningViewModel( val context = getApplication().applicationContext if (currentModel == ModelOption.GEMMA_3N_E4B_IT) { if (ModelDownloadManager.isModelDownloaded(context)) { - initializeOfflineModel(context) + // Point 7 & 16: Initialize model asynchronously to not block UI + viewModelScope.launch(Dispatchers.IO) { + withContext(Dispatchers.Main) { + _uiState.value = PhotoReasoningUiState.Loading + } + val error = initializeOfflineModel(context) + withContext(Dispatchers.Main) { + if (error != null) { + _uiState.value = PhotoReasoningUiState.Error(error) + } else { + _uiState.value = PhotoReasoningUiState.Success("Model initialized.") + } + } + } } } @@ -278,7 +327,7 @@ class PhotoReasoningViewModel( val optionsBuilder = LlmInference.LlmInferenceOptions.builder() .setModelPath(modelFile.absolutePath) - .setMaxTokens(1024) + .setMaxTokens(4096) // Set preferred backend (CPU or GPU) if (backend == InferenceBackend.GPU) { @@ -309,14 +358,19 @@ class PhotoReasoningViewModel( fun reinitializeOfflineModel(context: Context) { viewModelScope.launch(Dispatchers.IO) { try { - // Close existing instance + // Point 3: Properly close existing instance and free GPU resources try { - (llmInference as? java.io.Closeable)?.close() + llmInference?.close() + Log.d(TAG, "LlmInference closed for reinit") } catch (e: Exception) { Log.w(TAG, "Error closing existing LlmInference for reinit", e) } llmInference = null + // Force garbage collection and wait for GPU resources to be freed + System.gc() + delay(500) + // Re-initialize with new settings val initError = initializeOfflineModel(context) @@ -345,12 +399,18 @@ class PhotoReasoningViewModel( Log.d(TAG, "AIResultStreamReceiver unregistered with LocalBroadcastManager.") try { - // Using reflection if specific method not known or standard cast - (llmInference as? java.io.Closeable)?.close() + // Point 3: Properly close LlmInference to free GPU/RAM + llmInference?.close() + Log.d(TAG, "LlmInference closed in onCleared") } catch (e: Exception) { Log.w(TAG, "Error closing LlmInference", e) } llmInference = null + System.gc() // Help free GPU resources + + // WebRTC cleanup + webRTCSender?.stop() + signalingClient?.disconnect() } private fun createChatWithSystemMessage(context: Context? = null): Chat { @@ -428,9 +488,12 @@ class PhotoReasoningViewModel( currentScreenInfoForPrompt = screenInfoForPrompt currentImageUrisForChat = imageUrisForChat - // Clear previous commands + // Clear previous commands and reset incremental tracking _detectedCommands.value = emptyList() _commandExecutionStatus.value = "" + incrementalCommandCount = 0 + streamingAccumulatedText.clear() + CommandParser.clearBuffer() // Add user message to chat history val userMessage = PhotoReasoningMessage( @@ -509,8 +572,36 @@ class PhotoReasoningViewModel( // Check for Human Expert model if (currentModel == ModelOption.HUMAN_EXPERT) { - _uiState.value = PhotoReasoningUiState.Error("Human Expert mode is not yet connected. The Human Operator app is required.") - return + // If we already have a specialized session running, maybe just send the text? + // For now, we assume the user hits "Send" to trigger the connection + task post. + + // Initial task post message + val userMessage = PhotoReasoningMessage( + text = userInput, + participant = PhotoParticipant.USER, + imageUris = imageUrisForChat ?: emptyList(), + isPending = false + ) + _chatState.addMessage(userMessage) + + _uiState.value = PhotoReasoningUiState.Loading + + // We need to ensure we have MediaProjection permission. + // The UI (PhotoReasoningScreen) calls requestMediaProjectionPermission before calling reason() + // if permission is missing. So here we should ideally rely on onMediaProjectionPermissionGranted + // having been called or already having the intent. + + // But valid intent handling happens in onMediaProjectionPermissionGranted. + // If reason() is called, it means we likely have permission or it was just granted. + + // Check if we are already connected? + if (signalingClient == null) { + startHumanExpertSession(userInput) + } else { + // Already connected, just post the new task text or send via DataChannel if paired + postTaskToHumanExpert(userInput) + } + return } // Check for offline model (Gemma) @@ -525,6 +616,13 @@ class PhotoReasoningViewModel( // Ensure system message and DB are loaded ensureInitialized(context) + // Reset incremental command tracking for this new reasoning + incrementalCommandCount = 0 + streamingAccumulatedText.clear() + CommandParser.clearBuffer() + _detectedCommands.value = emptyList() + _commandExecutionStatus.value = "" + // Build the combined prompt with system message + DB entries + user input val systemMsg = _systemMessage.value val dbEntries = formatDatabaseEntriesAsText(context) @@ -570,7 +668,13 @@ class PhotoReasoningViewModel( // Initialize model if needed var initError: String? = null if (llmInference == null) { - initError = initializeOfflineModel(context) + withContext(Dispatchers.Main) { + replaceAiMessageText("Initializing offline model...", isPending = true) + } + // Use Default dispatcher for CPU-intensive model loading + initError = withContext(Dispatchers.Default) { + initializeOfflineModel(context) + } } if (llmInference == null) { @@ -599,6 +703,8 @@ class PhotoReasoningViewModel( viewModelScope.launch(Dispatchers.Main) { if (!done) { replaceAiMessageText(sb.toString(), isPending = true) + // Real-time command execution during offline streaming + processCommandsIncrementally(sb.toString()) } } }.get() @@ -1097,9 +1203,178 @@ class PhotoReasoningViewModel( } } - /** - * Update the AI message in chat history - */ + // === Human Expert / WebRTC Logic === + + fun onMediaProjectionPermissionGranted(resultCode: Int, data: Intent) { + Log.d(TAG, "onMediaProjectionPermissionGranted: Storing result. Code=$resultCode") + lastMediaProjectionResultCode = resultCode + lastMediaProjectionResultData = data + + // If we were waiting to start a session, we could start it here. + // For now, if the user just clicked "Human Expert" and granted permission, + // they might expect the connection to start. + // But startHumanExpertSession is already called in reason() if permission was already there. + // If permission wasn't there, reason() wasn't called (MainActivity blocked it?). + // Actually MainActivity.requestMediaProjectionPermission callback invokes the lambda passed to it. + // That lambda calls reason(). So reason() will be called immediately after this. + } + + private fun startHumanExpertSession(taskText: String) { + if (signalingClient != null) { + // Already connected + postTaskToHumanExpert(taskText) + return + } + + _uiState.value = PhotoReasoningUiState.Loading + _chatState.addMessage(PhotoReasoningMessage(text = "Connecting to Human Expert network...", participant = PhotoParticipant.MODEL, isPending = true)) + _chatMessagesFlow.value = _chatState.getAllMessages() + + // Initialize WebRTC Sender + webRTCSender = WebRTCSender(getApplication(), object : WebRTCSender.WebRTCSenderListener { + override fun onLocalICECandidate(candidate: IceCandidate) { + signalingClient?.sendICECandidate(candidate.sdp, candidate.sdpMid, candidate.sdpMLineIndex) + } + + override fun onConnectionStateChanged(state: String) { + Log.d(TAG, "WebRTC State: $state") + viewModelScope.launch(Dispatchers.Main) { + if (state == "CONNECTED") { + _commandExecutionStatus.value = "Expert connected. Sharing screen." + replaceAiMessageText("Expert connected! They can now see your screen and control your device.", isPending = false) + } else if (state == "DISCONNECTED" || state == "FAILED") { + _commandExecutionStatus.value = "Expert disconnected." + } + } + } + + override fun onTapReceived(x: Float, y: Float) { + dispatchTap(x, y) + } + + override fun onError(message: String) { + Log.e(TAG, "WebRTC Error: $message") + viewModelScope.launch(Dispatchers.Main) { + _uiState.value = PhotoReasoningUiState.Error("Video stream error: $message") + } + } + }) + webRTCSender?.initialize() + + // Initialize Signaling + signalingClient = SignalingClient(object : SignalingClient.SignalingListener { + override fun onTaskPosted(taskId: String) { + viewModelScope.launch(Dispatchers.Main) { + val msg = "Task posted. Waiting for an expert to claim it..." + replaceAiMessageText(msg, isPending = true) + } + } + + override fun onTaskClaimed(taskId: String) { + Log.d(TAG, "Task claimed! Requesting fresh MediaProjection for WebRTC.") + viewModelScope.launch(Dispatchers.Main) { + replaceAiMessageText("Expert found! Requesting screen capture permission...", isPending = true) + + // Request a fresh MediaProjection specifically for WebRTC + // This does NOT start ScreenCaptureService - avoids token reuse crash + val mainActivity = MainActivity.getInstance() + if (mainActivity != null) { + mainActivity.requestMediaProjectionForWebRTC { resultCode, resultData -> + Log.d(TAG, "WebRTC MediaProjection granted. Starting foreground service first, then screen capture.") + replaceAiMessageText("Establishing video connection...", isPending = true) + + // Point 11: Start ScreenCaptureService as foreground service FIRST + // This is required because MediaProjection needs an active foreground + // service of type MEDIA_PROJECTION on Android Q+ + val serviceIntent = Intent(mainActivity, ScreenCaptureService::class.java).apply { + action = ScreenCaptureService.ACTION_START_CAPTURE + putExtra(ScreenCaptureService.EXTRA_RESULT_CODE, resultCode) + putExtra(ScreenCaptureService.EXTRA_RESULT_DATA, resultData) + putExtra(ScreenCaptureService.EXTRA_TAKE_SCREENSHOT_ON_START, false) + } + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.O) { + mainActivity.startForegroundService(serviceIntent) + } else { + mainActivity.startService(serviceIntent) + } + + // Small delay to ensure foreground service is up before WebRTC capture + viewModelScope.launch { + delay(300) + // Start screen capture for WebRTC with fresh permission data + webRTCSender?.startScreenCapture(resultData) + webRTCSender?.createPeerConnection() + + // Create Offer + webRTCSender?.createOffer { sdp -> + signalingClient?.sendOffer(sdp) + } + } + } + } else { + Log.e(TAG, "MainActivity not available for MediaProjection request") + _uiState.value = PhotoReasoningUiState.Error("Cannot request screen capture - activity not available") + } + } + } + + override fun onSDPAnswer(sdp: String) { + webRTCSender?.setRemoteAnswer(sdp) + } + + override fun onICECandidate(candidate: String, sdpMid: String?, sdpMLineIndex: Int) { + webRTCSender?.addIceCandidate(candidate, sdpMid, sdpMLineIndex) + } + + override fun onPeerDisconnected() { + viewModelScope.launch(Dispatchers.Main) { + _commandExecutionStatus.value = "Expert disconnected." + replaceAiMessageText("Expert disconnected.", isPending = false) + webRTCSender?.stop() + } + } + + override fun onError(message: String) { + viewModelScope.launch(Dispatchers.Main) { + _uiState.value = PhotoReasoningUiState.Error("Signaling error: $message") + } + } + }) + + // Post the task immediately + Log.d(TAG, "Signaling initialized. Posting task.") + postTaskToHumanExpert(taskText) + } + + private fun postTaskToHumanExpert(text: String) { + signalingClient?.postTask(text, hasScreenshot = false) // Capture live stream instead + } + + private fun dispatchTap(x: Float, y: Float) { + Log.d(TAG, "Dispatching tap: ($x, $y)") + // Convert normalized to screen coordinates? + // Command.TapCoordinates usually expects absolute pixels. + // ScreenOperatorAccessibilityService.executeCommand handles logic. + // But wait, the web client sends normalized (0-1). + + // We need the screen dimensions. + val displayMetrics = getApplication().resources.displayMetrics + val screenWidth = displayMetrics.widthPixels + val screenHeight = displayMetrics.heightPixels + + val absX = (x * screenWidth).toInt() + val absY = (y * screenHeight).toInt() + + val command = Command.TapCoordinates(absX.toString(), absY.toString()) + ScreenOperatorAccessibilityService.executeCommand(command) + + viewModelScope.launch(Dispatchers.Main) { + _commandExecutionStatus.value = "Expert tapped at ($absX, $absY)" + } + } + + + private fun finalizeAiMessage(finalText: String) { Log.d(TAG, "finalizeAiMessage: Finalizing AI message.") val messages = _chatState.getAllMessages().toMutableList() @@ -1216,8 +1491,19 @@ class PhotoReasoningViewModel( } private fun isQuotaExceededError(message: String): Boolean { - return message.contains("exceeded your current quota") || - message.contains("Service Unavailable (503)") + // Only match actual quota exceeded errors, not high-demand 503s + return message.contains("exceeded your current quota") + } + + /** + * Point 14: Check if error is a high-demand 503 (UNAVAILABLE) error. + * These should NOT trigger API key switching. + */ + private fun isHighDemandError(message: String): Boolean { + return message.contains("Service Unavailable (503)") || + message.contains("UNAVAILABLE") || + message.contains("high demand") || + message.contains("overloaded") } /** @@ -1238,6 +1524,59 @@ class PhotoReasoningViewModel( return builder.toString() } + /** + * Incrementally process commands during streaming. + * Parses the full accumulated text but only executes NEW commands + * (i.e., commands beyond incrementalCommandCount). + * This avoids re-executing commands that were already executed in earlier chunks. + * + * Skips takeScreenshot() during streaming since we don't want to interrupt generation. + * The final processCommands() call after streaming ends handles takeScreenshot. + */ + private fun processCommandsIncrementally(accumulatedText: String) { + if (stopExecutionFlag.get()) return + + try { + // Parse all commands from the full accumulated text + // Use a fresh parse (not the buffer-based one) to get all commands in order + val allCommands = CommandParser.parseCommands(accumulatedText, clearBuffer = true) + + if (allCommands.size > incrementalCommandCount) { + // There are new commands to execute + val newCommands = allCommands.subList(incrementalCommandCount, allCommands.size) + Log.d(TAG, "Incremental: Found ${newCommands.size} new commands (total: ${allCommands.size}, already executed: $incrementalCommandCount)") + + for (command in newCommands) { + if (stopExecutionFlag.get()) break + + // Skip takeScreenshot during streaming - it will be handled by final processCommands + if (command is Command.TakeScreenshot) { + Log.d(TAG, "Incremental: Skipping takeScreenshot during streaming (will be handled at end)") + incrementalCommandCount++ + continue + } + + try { + Log.d(TAG, "Incremental: Executing command: $command") + _commandExecutionStatus.value = "Executing: $command" + ScreenOperatorAccessibilityService.executeCommand(command) + + // Track as executed + val currentCommands = _detectedCommands.value.toMutableList() + currentCommands.add(command) + _detectedCommands.value = currentCommands + } catch (e: Exception) { + Log.e(TAG, "Incremental: Error executing command: ${e.message}", e) + } + + incrementalCommandCount++ + } + } + } catch (e: Exception) { + Log.e(TAG, "Incremental command parsing error: ${e.message}", e) + } + } + /** * Process commands found in the AI response */ diff --git a/app/src/main/kotlin/com/google/ai/sample/util/GenerationSettingsPreferences.kt b/app/src/main/kotlin/com/google/ai/sample/util/GenerationSettingsPreferences.kt new file mode 100644 index 00000000..f16f0595 --- /dev/null +++ b/app/src/main/kotlin/com/google/ai/sample/util/GenerationSettingsPreferences.kt @@ -0,0 +1,37 @@ +package com.google.ai.sample.util + +import android.content.Context +import android.util.Log + +/** + * Persists TopK, TopP, and Temperature settings per model. + */ +object GenerationSettingsPreferences { + private const val TAG = "GenSettingsPrefs" + private const val PREFS_NAME = "generation_settings" + + data class GenerationSettings( + val temperature: Float = 0.0f, + val topP: Float = 0.0f, + val topK: Int = 0 + ) + + fun saveSettings(context: Context, modelName: String, settings: GenerationSettings) { + val prefs = context.getSharedPreferences(PREFS_NAME, Context.MODE_PRIVATE) + prefs.edit() + .putFloat("${modelName}_temperature", settings.temperature) + .putFloat("${modelName}_topP", settings.topP) + .putInt("${modelName}_topK", settings.topK) + .apply() + Log.d(TAG, "Saved settings for $modelName: temp=${settings.temperature}, topP=${settings.topP}, topK=${settings.topK}") + } + + fun loadSettings(context: Context, modelName: String): GenerationSettings { + val prefs = context.getSharedPreferences(PREFS_NAME, Context.MODE_PRIVATE) + return GenerationSettings( + temperature = prefs.getFloat("${modelName}_temperature", 0.0f), + topP = prefs.getFloat("${modelName}_topP", 0.0f), + topK = prefs.getInt("${modelName}_topK", 0) + ) + } +} diff --git a/app/src/main/kotlin/com/google/ai/sample/webrtc/SignalingClient.kt b/app/src/main/kotlin/com/google/ai/sample/webrtc/SignalingClient.kt new file mode 100644 index 00000000..48e19569 --- /dev/null +++ b/app/src/main/kotlin/com/google/ai/sample/webrtc/SignalingClient.kt @@ -0,0 +1,170 @@ +package com.google.ai.sample.webrtc + +import android.util.Log +import com.google.firebase.database.ChildEventListener +import com.google.firebase.database.DataSnapshot +import com.google.firebase.database.DatabaseError +import com.google.firebase.database.DatabaseReference +import com.google.firebase.database.FirebaseDatabase +import com.google.firebase.database.ValueEventListener + +/** + * Firebase Realtime Database signaling client for the ScreenOperator (Requester). + * Posts tasks to the broker and handles waiting for a claim. + */ +class SignalingClient( + private val listener: SignalingListener +) { + companion object { + private const val TAG = "SignalingClient" + } + + private val database: FirebaseDatabase = FirebaseDatabase.getInstance() + private val tasksRef: DatabaseReference = database.getReference("tasks") + + private var currentTaskId: String? = null + + // Listeners + private var taskStatusListener: ValueEventListener? = null + private var answerListener: ValueEventListener? = null + private var iceListener: ChildEventListener? = null + + interface SignalingListener { + fun onTaskPosted(taskId: String) + fun onTaskClaimed(taskId: String) + fun onSDPAnswer(sdp: String) + fun onICECandidate(candidate: String, sdpMid: String?, sdpMLineIndex: Int) + fun onPeerDisconnected() + fun onError(message: String) + } + + fun postTask(text: String, hasScreenshot: Boolean) { + // Create a new task entry + val taskId = tasksRef.push().key + if (taskId == null) { + listener.onError("Failed to generate task ID") + return + } + + currentTaskId = taskId + + val taskData = mapOf( + "text" to text, + "status" to "open", + "timestamp" to System.currentTimeMillis() + ) + + tasksRef.child(taskId).setValue(taskData) + .addOnSuccessListener { + Log.d(TAG, "Task posted successfully: $taskId") + listener.onTaskPosted(taskId) + listenForTaskStatus(taskId) + } + .addOnFailureListener { e -> + Log.e(TAG, "Failed to post task", e) + listener.onError("Failed to post task: ${e.message}") + } + } + + private fun listenForTaskStatus(taskId: String) { + taskStatusListener = object : ValueEventListener { + override fun onDataChange(snapshot: DataSnapshot) { + val status = snapshot.getValue(String::class.java) + if (status == "claimed") { + Log.d(TAG, "Task claimed by operator") + listener.onTaskClaimed(taskId) + listenForSignaling(taskId) + // We can stop listening for status changes now if we want, + // or keep it to detect cancellations/disconnects? + } + } + + override fun onCancelled(error: DatabaseError) { + Log.e(TAG, "Task status listener cancelled", error.toException()) + } + } + tasksRef.child(taskId).child("status").addValueEventListener(taskStatusListener!!) + } + + private fun listenForSignaling(taskId: String) { + val taskRef = tasksRef.child(taskId) + + // Listen for SDP Answer from Operator + answerListener = object : ValueEventListener { + override fun onDataChange(snapshot: DataSnapshot) { + val type = snapshot.child("type").getValue(String::class.java) + val sdp = snapshot.child("sdp").getValue(String::class.java) + + if (type == "answer" && sdp != null) { + Log.d(TAG, "Received SDP Answer") + listener.onSDPAnswer(sdp) + } + } + + override fun onCancelled(error: DatabaseError) { + Log.e(TAG, "Answer listener cancelled", error.toException()) + } + } + taskRef.child("answer").addValueEventListener(answerListener!!) + + // Listen for ICE Candidates from Operator + iceListener = object : ChildEventListener { + override fun onChildAdded(snapshot: DataSnapshot, previousChildName: String?) { + val sender = snapshot.child("sender").getValue(String::class.java) + if (sender == "operator") { + val candidate = snapshot.child("candidate").getValue(String::class.java) + val sdpMid = snapshot.child("sdpMid").getValue(String::class.java) + val sdpMLineIndex = snapshot.child("sdpMLineIndex").getValue(Int::class.java) ?: 0 + + if (candidate != null) { + Log.d(TAG, "Received ICE candidate from operator") + listener.onICECandidate(candidate, sdpMid, sdpMLineIndex) + } + } + } + + override fun onChildChanged(snapshot: DataSnapshot, previousChildName: String?) {} + override fun onChildRemoved(snapshot: DataSnapshot) {} + override fun onChildMoved(snapshot: DataSnapshot, previousChildName: String?) {} + override fun onCancelled(error: DatabaseError) {} + } + taskRef.child("ice").addChildEventListener(iceListener!!) + } + + fun sendOffer(sdp: String) { + val taskId = currentTaskId ?: return + Log.d(TAG, "Sending SDP Offer") + val offer = mapOf( + "type" to "offer", + "sdp" to sdp + ) + tasksRef.child(taskId).child("offer").setValue(offer) + } + + fun sendICECandidate(candidate: String, sdpMid: String?, sdpMLineIndex: Int) { + val taskId = currentTaskId ?: return + val ice = mapOf( + "candidate" to candidate, + "sdpMid" to sdpMid, + "sdpMLineIndex" to sdpMLineIndex, + "sender" to "requester" + ) + tasksRef.child(taskId).child("ice").push().setValue(ice) + } + + fun disconnect() { + Log.d(TAG, "Disconnecting SignalingClient") + currentTaskId?.let { taskId -> + // Optionally close the task or mark it as cancelled? + // For now, just stop listening. + taskStatusListener?.let { tasksRef.child(taskId).child("status").removeEventListener(it) } + answerListener?.let { tasksRef.child(taskId).child("answer").removeEventListener(it) } + iceListener?.let { tasksRef.child(taskId).child("ice").removeEventListener(it) } + } + + taskStatusListener = null + answerListener = null + iceListener = null + currentTaskId = null + } +} diff --git a/app/src/main/kotlin/com/google/ai/sample/webrtc/WebRTCSender.kt b/app/src/main/kotlin/com/google/ai/sample/webrtc/WebRTCSender.kt new file mode 100644 index 00000000..625a3d68 --- /dev/null +++ b/app/src/main/kotlin/com/google/ai/sample/webrtc/WebRTCSender.kt @@ -0,0 +1,200 @@ +package com.google.ai.sample.webrtc + +import android.content.Context +import android.content.Intent +import android.media.projection.MediaProjection +import android.util.Log +import com.google.gson.Gson +import org.webrtc.* + +/** + * Handles WebRTC PeerConnection for the sender (ScreenOperator). + * Captures screen video and sends it to the connected Human Operator. + * Receives touch events via DataChannel. + */ +class WebRTCSender( + private val context: Context, + private val listener: WebRTCSenderListener +) { + companion object { + private const val TAG = "WebRTCSender" + private val STUN_SERVERS = listOf( + PeerConnection.IceServer.builder("stun:stun.l.google.com:19302").createIceServer(), + PeerConnection.IceServer.builder("stun:stun1.l.google.com:19302").createIceServer() + ) + } + + interface WebRTCSenderListener { + fun onLocalICECandidate(candidate: IceCandidate) + fun onConnectionStateChanged(state: String) + fun onTapReceived(x: Float, y: Float) + fun onError(message: String) + } + + private var peerConnectionFactory: PeerConnectionFactory? = null + private var peerConnection: PeerConnection? = null + private var videoCapturer: VideoCapturer? = null + private var videoSource: VideoSource? = null + private var videoTrack: VideoTrack? = null + private var dataChannel: DataChannel? = null + private val eglBase = EglBase.create() + private val gson = Gson() + + fun initialize() { + Log.d(TAG, "Initializing WebRTCSender") + val initOptions = PeerConnectionFactory.InitializationOptions.builder(context) + .setEnableInternalTracer(false) + .createInitializationOptions() + PeerConnectionFactory.initialize(initOptions) + + peerConnectionFactory = PeerConnectionFactory.builder() + .setOptions(PeerConnectionFactory.Options()) + .setVideoDecoderFactory(DefaultVideoDecoderFactory(eglBase.eglBaseContext)) + .setVideoEncoderFactory(DefaultVideoEncoderFactory(eglBase.eglBaseContext, true, true)) + .createPeerConnectionFactory() + } + + fun startScreenCapture(permissionResultData: Intent) { + Log.d(TAG, "Starting screen capture") + videoCapturer = ScreenCapturerAndroid(permissionResultData, object : MediaProjection.Callback() { + override fun onStop() { + Log.e(TAG, "MediaProjection stopped") + listener.onError("Screen capture stopped") + } + }) + + val factory = peerConnectionFactory ?: return + videoSource = factory.createVideoSource(videoCapturer!!.isScreencast) + + // Initialize capturer + (videoCapturer as ScreenCapturerAndroid).initialize( + SurfaceTextureHelper.create("CaptureThread", eglBase.eglBaseContext), + context, + videoSource!!.capturerObserver + ) + (videoCapturer as ScreenCapturerAndroid).startCapture(720, 1280, 30) // Adjust resolution/fps as needed + + videoTrack = factory.createVideoTrack("ARDAMSv0", videoSource) + videoTrack?.setEnabled(true) + } + + fun createPeerConnection() { + Log.d(TAG, "Creating PeerConnection") + val rtcConfig = PeerConnection.RTCConfiguration(STUN_SERVERS) + rtcConfig.sdpSemantics = PeerConnection.SdpSemantics.UNIFIED_PLAN + rtcConfig.continualGatheringPolicy = PeerConnection.ContinualGatheringPolicy.GATHER_CONTINUALLY + + peerConnection = peerConnectionFactory?.createPeerConnection(rtcConfig, object : PeerConnection.Observer { + override fun onIceCandidate(candidate: IceCandidate) { + listener.onLocalICECandidate(candidate) + } + override fun onDataChannel(dc: DataChannel) { + // Sender typically creates the channel, but handling incoming if peer creates it + setupDataChannel(dc) + } + override fun onIceConnectionChange(state: PeerConnection.IceConnectionState) { + Log.d(TAG, "ICE State: $state") + listener.onConnectionStateChanged(state.name) + } + override fun onConnectionChange(newState: PeerConnection.PeerConnectionState) { + Log.d(TAG, "PeerConnection State: $newState") + } + // Unused + override fun onIceCandidatesRemoved(candidates: Array?) {} + override fun onAddStream(stream: MediaStream) {} + override fun onTrack(transceiver: RtpTransceiver) {} + override fun onRenegotiationNeeded() {} + override fun onIceConnectionReceivingChange(receiving: Boolean) {} + override fun onIceGatheringChange(state: PeerConnection.IceGatheringState) {} + override fun onSignalingChange(state: PeerConnection.SignalingState) {} + override fun onRemoveStream(stream: MediaStream) {} + override fun onAddTrack(receiver: RtpReceiver, streams: Array) {} + }) + + // Add video track + if (videoTrack != null) { + peerConnection?.addTrack(videoTrack, listOf("ARDAMS")) + } + + // Create DataChannel (Sender creates it usually) + val dcInit = DataChannel.Init() + dataChannel = peerConnection?.createDataChannel("task_channel", dcInit) + setupDataChannel(dataChannel) + } + + private fun setupDataChannel(dc: DataChannel?) { + dc?.registerObserver(object : DataChannel.Observer { + override fun onBufferedAmountChange(previous: Long) {} + override fun onStateChange() { + Log.d(TAG, "DataChannel State: ${dc.state()}") + } + override fun onMessage(buffer: DataChannel.Buffer) { + try { + val data = ByteArray(buffer.data.remaining()) + buffer.data.get(data) + val message = String(data) + Log.d(TAG, "Received DataChannel message: $message") + + val json = com.google.gson.JsonParser.parseString(message).asJsonObject + if (json.has("type") && json.get("type").asString == "tap") { + val x = json.get("x").asFloat + val y = json.get("y").asFloat + listener.onTapReceived(x, y) + } + } catch (e: Exception) { + Log.e(TAG, "Error parsing DataChannel message", e) + } + } + }) + } + + fun createOffer(callback: (String) -> Unit) { + val constraints = MediaConstraints() + constraints.mandatory.add(MediaConstraints.KeyValuePair("OfferToReceiveVideo", "false")) + constraints.mandatory.add(MediaConstraints.KeyValuePair("OfferToReceiveAudio", "false")) + + peerConnection?.createOffer(object : SdpObserver { + override fun onCreateSuccess(sdp: SessionDescription) { + Log.d(TAG, "Offer created") + peerConnection?.setLocalDescription(object : SdpObserver { + override fun onSetSuccess() { + Log.d(TAG, "Local description set") + callback(sdp.description) + } + override fun onSetFailure(s: String) { Log.e(TAG, "SetLocal failure: $s") } + override fun onCreateSuccess(p0: SessionDescription?) {} + override fun onCreateFailure(p0: String?) {} + }, sdp) + } + override fun onCreateFailure(s: String) { Log.e(TAG, "CreateOffer failure: $s") } + override fun onSetSuccess() {} + override fun onSetFailure(p0: String?) {} + }, constraints) + } + + fun setRemoteAnswer(sdp: String) { + val desc = SessionDescription(SessionDescription.Type.ANSWER, sdp) + peerConnection?.setRemoteDescription(object : SdpObserver { + override fun onSetSuccess() { Log.d(TAG, "Remote answer set") } + override fun onSetFailure(s: String) { Log.e(TAG, "SetRemote failure: $s") } + override fun onCreateSuccess(p0: SessionDescription?) {} + override fun onCreateFailure(p0: String?) {} + }, desc) + } + + fun addIceCandidate(candidate: String, sdpMid: String?, sdpMLineIndex: Int) { + peerConnection?.addIceCandidate(IceCandidate(sdpMid ?: "", sdpMLineIndex, candidate)) + } + + fun stop() { + try { + videoCapturer?.stopCapture() + videoCapturer?.dispose() + peerConnection?.close() + peerConnectionFactory?.dispose() + eglBase.release() + } catch (e: Exception) { + Log.e(TAG, "Error stopping WebRTCSender", e) + } + } +} diff --git a/build.gradle.kts b/build.gradle.kts index 9a4caa05..a4a00dd5 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -23,4 +23,5 @@ plugins { id("com.android.application") version "8.1.3" apply false id("org.jetbrains.kotlin.android") version "1.9.20" apply false id("com.google.android.libraries.mapsplatform.secrets-gradle-plugin") version "2.0.1" apply false + id("com.google.gms.google-services") version "4.4.2" apply false } diff --git a/build_log_retry.txt b/build_log_retry.txt new file mode 100644 index 00000000..93a60408 --- /dev/null +++ b/build_log_retry.txt @@ -0,0 +1,60 @@ +Configuration on demand is an incubating feature. +> Task :app:preBuild UP-TO-DATE +> Task :app:preDebugBuild UP-TO-DATE +> Task :app:mergeDebugNativeDebugMetadata NO-SOURCE +> Task :app:checkKotlinGradlePluginConfigurationErrors +> Task :app:generateDebugBuildConfig UP-TO-DATE +> Task :app:checkDebugAarMetadata UP-TO-DATE +> Task :app:generateDebugResValues UP-TO-DATE +> Task :app:mapDebugSourceSetPaths UP-TO-DATE +> Task :app:generateDebugResources UP-TO-DATE +> Task :app:mergeDebugResources UP-TO-DATE +> Task :app:packageDebugResources UP-TO-DATE +> Task :app:parseDebugLocalResources UP-TO-DATE +> Task :app:createDebugCompatibleScreenManifests UP-TO-DATE +> Task :app:extractDeepLinksDebug UP-TO-DATE +> Task :app:processDebugMainManifest UP-TO-DATE +> Task :app:processDebugManifest UP-TO-DATE +> Task :app:processDebugManifestForPackage UP-TO-DATE +> Task :app:processDebugResources UP-TO-DATE +> Task :app:javaPreCompileDebug UP-TO-DATE +> Task :app:mergeDebugShaders UP-TO-DATE +> Task :app:compileDebugShaders NO-SOURCE +> Task :app:generateDebugAssets UP-TO-DATE +> Task :app:mergeDebugAssets UP-TO-DATE +> Task :app:compressDebugAssets UP-TO-DATE +> Task :app:checkDebugDuplicateClasses UP-TO-DATE +> Task :app:desugarDebugFileDependencies UP-TO-DATE +> Task :app:mergeExtDexDebug UP-TO-DATE +> Task :app:mergeLibDexDebug UP-TO-DATE +> Task :app:mergeDebugJniLibFolders UP-TO-DATE +> Task :app:mergeDebugNativeLibs UP-TO-DATE +> Task :app:stripDebugDebugSymbols UP-TO-DATE +> Task :app:validateSigningDebug UP-TO-DATE +> Task :app:writeDebugAppMetadata UP-TO-DATE +> Task :app:writeDebugSigningConfigVersions UP-TO-DATE +> Task :app:compileDebugKotlin +e: file:///D:/Neuer%20Ordner%20(2)/ScreenOperator/app/src/main/kotlin/com/google/ai/sample/feature/multimodal/PhotoReasoningViewModel.kt:99:13 Const 'val' are only allowed on top level, in named objects, or in companion objects +e: file:///D:/Neuer%20Ordner%20(2)/ScreenOperator/app/src/main/kotlin/com/google/ai/sample/feature/multimodal/PhotoReasoningViewModel.kt:350:5 Conflicting overloads: protected open fun onCleared(): Unit defined in com.google.ai.sample.feature.multimodal.PhotoReasoningViewModel, protected open fun onCleared(): Unit defined in com.google.ai.sample.feature.multimodal.PhotoReasoningViewModel +e: file:///D:/Neuer%20Ordner%20(2)/ScreenOperator/app/src/main/kotlin/com/google/ai/sample/feature/multimodal/PhotoReasoningViewModel.kt:1169:94 Type mismatch: inferred type is PhotoParticipant but String was expected +e: file:///D:/Neuer%20Ordner%20(2)/ScreenOperator/app/src/main/kotlin/com/google/ai/sample/feature/multimodal/PhotoReasoningViewModel.kt:1283:46 Type mismatch: inferred type is Int but String was expected +e: file:///D:/Neuer%20Ordner%20(2)/ScreenOperator/app/src/main/kotlin/com/google/ai/sample/feature/multimodal/PhotoReasoningViewModel.kt:1283:52 Type mismatch: inferred type is Int but String was expected +e: file:///D:/Neuer%20Ordner%20(2)/ScreenOperator/app/src/main/kotlin/com/google/ai/sample/feature/multimodal/PhotoReasoningViewModel.kt:1291:5 Conflicting overloads: protected open fun onCleared(): Unit defined in com.google.ai.sample.feature.multimodal.PhotoReasoningViewModel, protected open fun onCleared(): Unit defined in com.google.ai.sample.feature.multimodal.PhotoReasoningViewModel + +> Task :app:compileDebugKotlin FAILED + +FAILURE: Build failed with an exception. + +* What went wrong: +Execution failed for task ':app:compileDebugKotlin'. +> A failure occurred while executing org.jetbrains.kotlin.compilerRunner.GradleCompilerRunnerWithWorkers$GradleKotlinCompilerWorkAction + > Compilation error. See log for more details + +* Try: +> Run with --stacktrace option to get the stack trace. +> Run with --info or --debug option to get more log output. +> Run with --scan to get full insights. +> Get more help at https://help.gradle.org. + +BUILD FAILED in 1m 45s +29 actionable tasks: 2 executed, 27 up-to-date diff --git a/gradle.properties b/gradle.properties index c9cc7fd9..c87abea7 100644 --- a/gradle.properties +++ b/gradle.properties @@ -2,17 +2,20 @@ # http://www.gradle.org/docs/current/userguide/build_environment.html # # Specifies the JVM arguments used for the daemon process. -org.gradle.jvmargs=-Xmx6g +org.gradle.jvmargs=-Xmx768m -XX:MaxMetaspaceSize=256m -XX:+UseSerialGC -XX:CICompilerCount=2 +kotlin.daemon.jvmargs=-Xmx512m -XX:MaxMetaspaceSize=256m -XX:+UseSerialGC +kotlin.compiler.execution.strategy=in-process +org.gradle.workers.max=1 # JDK 17 path for AGP 8.1.3 compatibility -org.gradle.java.home=C:/Program Files/Microsoft/jdk-17.0.18.8-hotspot +# org.gradle.java.home=C:/Program Files/Microsoft/jdk-17.0.18.8-hotspot # # When configured, Gradle will run in incubating parallel mode. # This option should only be used with decoupled projects. More details, visit # http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects -org.gradle.parallel=true +org.gradle.parallel=false # Enable Gradle build cache -org.gradle.caching=true +org.gradle.caching=false # Enable configuration on demand org.gradle.configureondemand=true diff --git a/ho_build_result.txt b/ho_build_result.txt new file mode 100644 index 00000000..70aa862c Binary files /dev/null and b/ho_build_result.txt differ diff --git a/humanoperator/build.gradle.kts b/humanoperator/build.gradle.kts new file mode 100644 index 00000000..9091bc13 --- /dev/null +++ b/humanoperator/build.gradle.kts @@ -0,0 +1,74 @@ +plugins { + id("com.android.application") + id("org.jetbrains.kotlin.android") + id("com.google.gms.google-services") +} + +android { + namespace = "com.screenoperator.humanoperator" + compileSdk = 35 + + defaultConfig { + applicationId = "com.screenoperator.humanoperator" + minSdk = 26 + targetSdk = 35 + versionCode = 1 + versionName = "1.0" + + testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner" + vectorDrawables { + useSupportLibrary = true + } + } + + buildTypes { + release { + isMinifyEnabled = false + proguardFiles(getDefaultProguardFile("proguard-android-optimize.txt"), "proguard-rules.pro") + } + } + + compileOptions { + sourceCompatibility = JavaVersion.VERSION_1_8 + targetCompatibility = JavaVersion.VERSION_1_8 + } + kotlinOptions { + jvmTarget = "1.8" + } + buildFeatures { + compose = true + } + composeOptions { + kotlinCompilerExtensionVersion = "1.5.4" + } +} + +dependencies { + implementation("androidx.core:core-ktx:1.9.0") + implementation("androidx.lifecycle:lifecycle-runtime-ktx:2.6.2") + implementation("androidx.lifecycle:lifecycle-viewmodel-compose:2.6.2") + implementation("androidx.activity:activity-compose:1.8.1") + + implementation(platform("androidx.compose:compose-bom:2024.02.01")) + implementation("androidx.compose.ui:ui") + implementation("androidx.compose.ui:ui-graphics") + implementation("androidx.compose.ui:ui-tooling-preview") + implementation("androidx.compose.material3:material3") + implementation("androidx.compose.material:material-icons-extended") + + // WebRTC + implementation("io.getstream:stream-webrtc-android:1.1.1") + + // WebSocket for signaling + implementation("com.squareup.okhttp3:okhttp:4.12.0") + + // JSON + implementation("com.google.code.gson:gson:2.10.1") + + debugImplementation("androidx.compose.ui:ui-tooling") + debugImplementation("androidx.compose.ui:ui-test-manifest") + + // Firebase + implementation(platform("com.google.firebase:firebase-bom:32.7.2")) + implementation("com.google.firebase:firebase-database") +} diff --git a/humanoperator/proguard-rules.pro b/humanoperator/proguard-rules.pro new file mode 100644 index 00000000..be44800a --- /dev/null +++ b/humanoperator/proguard-rules.pro @@ -0,0 +1,3 @@ +# Add project specific ProGuard rules here. +-keep class org.webrtc.** { *; } +-dontwarn org.webrtc.** diff --git a/humanoperator/src/main/AndroidManifest.xml b/humanoperator/src/main/AndroidManifest.xml new file mode 100644 index 00000000..940448b7 --- /dev/null +++ b/humanoperator/src/main/AndroidManifest.xml @@ -0,0 +1,26 @@ + + + + + + + + + + + + + + + + + diff --git a/humanoperator/src/main/kotlin/com/screenoperator/humanoperator/MainActivity.kt b/humanoperator/src/main/kotlin/com/screenoperator/humanoperator/MainActivity.kt new file mode 100644 index 00000000..8a3d5860 --- /dev/null +++ b/humanoperator/src/main/kotlin/com/screenoperator/humanoperator/MainActivity.kt @@ -0,0 +1,438 @@ +package com.screenoperator.humanoperator + +import android.app.NotificationChannel +import android.app.NotificationManager +import android.os.Build +import android.os.Bundle +import android.util.Log +import android.widget.Toast +import androidx.activity.ComponentActivity +import androidx.activity.compose.setContent +import androidx.compose.foundation.background +import androidx.compose.foundation.gestures.detectTapGestures +import androidx.compose.foundation.layout.* +import androidx.compose.foundation.lazy.LazyColumn +import androidx.compose.foundation.lazy.items +import androidx.compose.foundation.shape.RoundedCornerShape +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.filled.Check +import androidx.compose.material.icons.filled.Link +import androidx.compose.material.icons.filled.LinkOff +import androidx.compose.material.icons.filled.TouchApp +import androidx.compose.material3.* +import androidx.compose.runtime.* +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.draw.clip +import androidx.compose.ui.graphics.Color +import androidx.compose.ui.input.pointer.pointerInput +import androidx.compose.ui.platform.LocalContext +import androidx.compose.ui.text.font.FontWeight +import androidx.compose.ui.text.style.TextAlign +import androidx.compose.ui.text.style.TextOverflow +import androidx.compose.ui.unit.dp +import androidx.compose.ui.unit.sp +import androidx.compose.ui.viewinterop.AndroidView +import androidx.core.app.NotificationCompat +import androidx.core.app.NotificationManagerCompat +import org.webrtc.EglBase +import org.webrtc.IceCandidate +import org.webrtc.RendererCommon +import org.webrtc.SurfaceViewRenderer +import org.webrtc.VideoTrack + +class MainActivity : ComponentActivity() { + companion object { + private const val TAG = "HumanOperator" + private const val NOTIFICATION_CHANNEL_ID = "human_operator_tasks" + } + + private var webRTCClient: WebRTCClient? = null + private var signalingClient: SignalingClient? = null + + override fun onCreate(savedInstanceState: Bundle?) { + super.onCreate(savedInstanceState) + createNotificationChannel() + setContent { + HumanOperatorTheme { + HumanOperatorScreen() + } + } + } + + private fun createNotificationChannel() { + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.O) { + val channel = NotificationChannel( + NOTIFICATION_CHANNEL_ID, + "Incoming Tasks", + NotificationManager.IMPORTANCE_HIGH + ).apply { + description = "Notifications for new tasks from Screen Operator" + } + val manager = getSystemService(NotificationManager::class.java) + manager.createNotificationChannel(channel) + } + } + + private fun showTaskNotification(taskId: String, text: String) { + try { + val notification = NotificationCompat.Builder(this, NOTIFICATION_CHANNEL_ID) + .setSmallIcon(android.R.drawable.ic_dialog_info) + .setContentTitle("New Task Available") + .setContentText(if (text.isNotBlank()) text.take(100) else "A new task is waiting") + .setPriority(NotificationCompat.PRIORITY_HIGH) + .setAutoCancel(true) + .build() + NotificationManagerCompat.from(this).notify(taskId.hashCode(), notification) + } catch (e: SecurityException) { + Log.w(TAG, "No notification permission", e) + } + } + + private fun cancelTaskNotification(taskId: String) { + NotificationManagerCompat.from(this).cancel(taskId.hashCode()) + } + + override fun onDestroy() { + super.onDestroy() + webRTCClient?.dispose() + signalingClient?.disconnect() + } + + @Composable + fun HumanOperatorTheme(content: @Composable () -> Unit) { + val darkColorScheme = darkColorScheme( + primary = Color(0xFF7C4DFF), + secondary = Color(0xFF00E5FF), + background = Color(0xFF0D1117), + surface = Color(0xFF161B22), + surfaceVariant = Color(0xFF21262D), + onPrimary = Color.White, + onBackground = Color(0xFFE6EDF3), + onSurface = Color(0xFFE6EDF3), + error = Color(0xFFFF6B6B) + ) + MaterialTheme(colorScheme = darkColorScheme, content = content) + } + + data class TaskInfo(val taskId: String, val text: String) + + @OptIn(ExperimentalMaterial3Api::class) + @Composable + fun HumanOperatorScreen() { + var connectionState by remember { mutableStateOf("Disconnected") } + var isConnected by remember { mutableStateOf(false) } + var isPaired by remember { mutableStateOf(false) } + var hasVideoTrack by remember { mutableStateOf(false) } + var dataChannelOpen by remember { mutableStateOf(false) } + var videoTrack by remember { mutableStateOf(null) } + var eglContext by remember { mutableStateOf(null) } + val availableTasks = remember { mutableStateListOf() } + var claimedTaskId by remember { mutableStateOf(null) } + val context = LocalContext.current + + fun connectToServer() { + connectionState = "Connecting..." + + // Initialize WebRTC + val rtcClient = WebRTCClient(context, object : WebRTCClient.WebRTCListener { + override fun onLocalICECandidate(candidate: IceCandidate) { + signalingClient?.sendICECandidate(candidate.sdp, candidate.sdpMid, candidate.sdpMLineIndex) + } + override fun onVideoTrackReceived(track: VideoTrack) { + Log.d(TAG, "Video track received") + videoTrack = track + hasVideoTrack = true + } + override fun onDataChannelMessage(message: String) { + Log.d(TAG, "Message: ${message.take(100)}") + } + override fun onConnectionStateChanged(state: String) { + Log.d(TAG, "WebRTC state: $state") + if (state == "CONNECTED" || state == "COMPLETED") { + isPaired = true + connectionState = "Paired - viewing screen" + } else if (state == "DISCONNECTED" || state == "FAILED") { + isPaired = false + hasVideoTrack = false + connectionState = "Peer disconnected" + } + } + override fun onDataChannelOpen() { + dataChannelOpen = true + } + }) + rtcClient.initialize() + rtcClient.createPeerConnection() + webRTCClient = rtcClient + eglContext = rtcClient.getEglBaseContext() + + // Connect signaling + val signaling = SignalingClient(object : SignalingClient.SignalingListener { + override fun onNewTask(taskId: String, text: String) { + Log.d(TAG, "New task: $taskId") + availableTasks.add(TaskInfo(taskId, text)) + showTaskNotification(taskId, text) + if (!isConnected) { + isConnected = true + connectionState = "Waiting for tasks..." + } + } + override fun onTaskRemoved(taskId: String) { + Log.d(TAG, "Task removed: $taskId") + availableTasks.removeAll { it.taskId == taskId } + cancelTaskNotification(taskId) + } + override fun onClaimed(taskId: String) { + Log.d(TAG, "Successfully claimed: $taskId") + claimedTaskId = taskId + connectionState = "Task claimed, connecting..." + // Clear all other tasks + availableTasks.clear() + NotificationManagerCompat.from(context).cancelAll() + } + override fun onClaimFailed(reason: String) { + Log.w(TAG, "Claim failed: $reason") + Toast.makeText(context, "Someone was faster: $reason", Toast.LENGTH_SHORT).show() + } + override fun onSDPOffer(sdp: String) { + rtcClient.setRemoteOffer(sdp) + // Send answer after a short delay + android.os.Handler(mainLooper).postDelayed({ + val answer = rtcClient.getLocalDescription() + if (answer != null) { + signalingClient?.sendAnswer(answer.description) + } + }, 1000) + } + override fun onICECandidate(candidate: String, sdpMid: String?, sdpMLineIndex: Int) { + rtcClient.addICECandidate(candidate, sdpMid, sdpMLineIndex) + } + override fun onPeerDisconnected() { + isPaired = false + hasVideoTrack = false + dataChannelOpen = false + claimedTaskId = null + connectionState = "Peer disconnected - waiting for tasks..." + // Resume listening for tasks + signalingClient?.startListeningForTasks() + } + override fun onError(message: String) { + connectionState = "Error: $message" + } + }) + signaling.startListeningForTasks() + // Set initial state + isConnected = true + connectionState = "Waiting for tasks..." + signalingClient = signaling + } + + // Auto-connect on launch + LaunchedEffect(Unit) { + connectToServer() + } + + Scaffold( + topBar = { + TopAppBar( + title = { Text("Human Operator", fontWeight = FontWeight.Bold) }, + colors = TopAppBarDefaults.topAppBarColors( + containerColor = MaterialTheme.colorScheme.surface, + titleContentColor = MaterialTheme.colorScheme.onSurface + ) + ) + } + ) { padding -> + Column( + modifier = Modifier + .fillMaxSize() + .padding(padding) + .background(MaterialTheme.colorScheme.background) + .padding(16.dp), + horizontalAlignment = Alignment.CenterHorizontally + ) { + // Connection status bar + Card( + modifier = Modifier.fillMaxWidth(), + colors = CardDefaults.cardColors(containerColor = MaterialTheme.colorScheme.surfaceVariant), + shape = RoundedCornerShape(12.dp) + ) { + Row( + modifier = Modifier.fillMaxWidth().padding(12.dp), + verticalAlignment = Alignment.CenterVertically + ) { + Icon( + imageVector = if (isConnected) Icons.Default.Link else Icons.Default.LinkOff, + contentDescription = null, + tint = if (isPaired) Color(0xFF4CAF50) else if (isConnected) Color(0xFFFFA726) else Color(0xFFFF6B6B), + modifier = Modifier.size(20.dp) + ) + Spacer(modifier = Modifier.width(8.dp)) + Text( + text = connectionState, + style = MaterialTheme.typography.bodyMedium, + color = MaterialTheme.colorScheme.onSurface + ) + } + } + + Spacer(modifier = Modifier.height(16.dp)) + + if (isPaired && hasVideoTrack && videoTrack != null && eglContext != null) { + // === PAIRED VIEW: Video + Tap Overlay === + Box( + modifier = Modifier + .fillMaxWidth() + .weight(1f) + .clip(RoundedCornerShape(12.dp)) + ) { + AndroidView( + factory = { ctx -> + SurfaceViewRenderer(ctx).apply { + init(eglContext, null) + setScalingType(RendererCommon.ScalingType.SCALE_ASPECT_FIT) + setEnableHardwareScaler(true) + videoTrack?.addSink(this) + } + }, + modifier = Modifier.fillMaxSize() + ) + // Tap overlay + Box( + modifier = Modifier + .fillMaxSize() + .pointerInput(Unit) { + detectTapGestures { offset -> + val normalizedX = offset.x / size.width + val normalizedY = offset.y / size.height + Log.d(TAG, "Tap: ($normalizedX, $normalizedY)") + webRTCClient?.sendTap(normalizedX, normalizedY) + } + } + ) + } + + Spacer(modifier = Modifier.height(8.dp)) + + // Tap hint + Row( + modifier = Modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.Center, + verticalAlignment = Alignment.CenterVertically + ) { + Icon( + Icons.Default.TouchApp, + contentDescription = null, + tint = MaterialTheme.colorScheme.primary, + modifier = Modifier.size(16.dp) + ) + Spacer(modifier = Modifier.width(4.dp)) + Text( + "Tap on the screen to interact", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurface.copy(alpha = 0.6f) + ) + } + + } else if (isPaired) { + // Paired but waiting for video + Box( + modifier = Modifier.fillMaxWidth().weight(1f) + .clip(RoundedCornerShape(12.dp)) + .background(MaterialTheme.colorScheme.surfaceVariant), + contentAlignment = Alignment.Center + ) { + Column(horizontalAlignment = Alignment.CenterHorizontally) { + CircularProgressIndicator(color = MaterialTheme.colorScheme.primary, modifier = Modifier.size(40.dp)) + Spacer(modifier = Modifier.height(16.dp)) + Text("Waiting for screen stream...", color = MaterialTheme.colorScheme.onSurface.copy(alpha = 0.6f)) + } + } + + } else { + // === TASK LIST VIEW === + if (availableTasks.isEmpty()) { + Box( + modifier = Modifier.fillMaxWidth().weight(1f), + contentAlignment = Alignment.Center + ) { + Column(horizontalAlignment = Alignment.CenterHorizontally) { + if (isConnected) { + CircularProgressIndicator( + color = MaterialTheme.colorScheme.primary.copy(alpha = 0.5f), + modifier = Modifier.size(32.dp), + strokeWidth = 2.dp + ) + Spacer(modifier = Modifier.height(16.dp)) + Text( + "Waiting for tasks...", + style = MaterialTheme.typography.titleMedium, + color = MaterialTheme.colorScheme.onSurface.copy(alpha = 0.5f) + ) + Spacer(modifier = Modifier.height(4.dp)) + Text( + "Tasks appear here when someone needs help", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurface.copy(alpha = 0.3f), + textAlign = TextAlign.Center + ) + } else { + Text( + "Connecting to server...", + style = MaterialTheme.typography.titleMedium, + color = MaterialTheme.colorScheme.onSurface.copy(alpha = 0.5f) + ) + } + } + } + } else { + Text( + "${availableTasks.size} task${if (availableTasks.size != 1) "s" else ""} available", + style = MaterialTheme.typography.titleMedium, + fontWeight = FontWeight.SemiBold, + modifier = Modifier.fillMaxWidth() + ) + Spacer(modifier = Modifier.height(8.dp)) + + LazyColumn( + modifier = Modifier.fillMaxWidth().weight(1f), + verticalArrangement = Arrangement.spacedBy(8.dp) + ) { + items(availableTasks, key = { it.taskId }) { task -> + Card( + modifier = Modifier.fillMaxWidth(), + colors = CardDefaults.cardColors(containerColor = MaterialTheme.colorScheme.surfaceVariant), + shape = RoundedCornerShape(12.dp) + ) { + Column(modifier = Modifier.padding(16.dp)) { + if (task.text.isNotBlank()) { + Text( + text = task.text, + style = MaterialTheme.typography.bodyMedium, + color = MaterialTheme.colorScheme.onSurface, + maxLines = 3, + overflow = TextOverflow.Ellipsis + ) + Spacer(modifier = Modifier.height(12.dp)) + } + Button( + onClick = { signalingClient?.claimTask(task.taskId) }, + modifier = Modifier.fillMaxWidth(), + shape = RoundedCornerShape(8.dp), + colors = ButtonDefaults.buttonColors(containerColor = Color(0xFF4CAF50)) + ) { + Icon(Icons.Default.Check, contentDescription = null, modifier = Modifier.size(18.dp)) + Spacer(modifier = Modifier.width(6.dp)) + Text("Claim this task", fontWeight = FontWeight.SemiBold) + } + } + } + } + } + } + } + } + } + } +} diff --git a/humanoperator/src/main/kotlin/com/screenoperator/humanoperator/SignalingClient.kt b/humanoperator/src/main/kotlin/com/screenoperator/humanoperator/SignalingClient.kt new file mode 100644 index 00000000..65da4bb6 --- /dev/null +++ b/humanoperator/src/main/kotlin/com/screenoperator/humanoperator/SignalingClient.kt @@ -0,0 +1,198 @@ +package com.screenoperator.humanoperator + +import android.util.Log +import com.google.firebase.database.ChildEventListener +import com.google.firebase.database.DataSnapshot +import com.google.firebase.database.DatabaseError +import com.google.firebase.database.DatabaseReference +import com.google.firebase.database.FirebaseDatabase +import com.google.firebase.database.ValueEventListener + +/** + * Firebase Realtime Database signaling client for the task broker. + * Operators listen for open tasks, claim them, and exchange WebRTC signals. + */ +class SignalingClient( + private val listener: SignalingListener +) { + companion object { + private const val TAG = "SignalingClient" + } + + private val database: FirebaseDatabase = FirebaseDatabase.getInstance() + private val tasksRef: DatabaseReference = database.getReference("tasks") + + // Listener references for cleanup + private var tasksListener: ChildEventListener? = null + private var offerListener: ValueEventListener? = null + private var iceListener: ChildEventListener? = null + private var currentTaskId: String? = null + + interface SignalingListener { + fun onNewTask(taskId: String, text: String) + fun onTaskRemoved(taskId: String) + fun onClaimed(taskId: String) + fun onClaimFailed(reason: String) + fun onSDPOffer(sdp: String) + fun onICECandidate(candidate: String, sdpMid: String?, sdpMLineIndex: Int) + fun onPeerDisconnected() + fun onError(message: String) + } + + fun startListeningForTasks() { + Log.d(TAG, "Starting to listen for open tasks...") + + if (tasksListener != null) return + + tasksListener = object : ChildEventListener { + override fun onChildAdded(snapshot: DataSnapshot, previousChildName: String?) { + val status = snapshot.child("status").getValue(String::class.java) + if (status == "open") { + val taskId = snapshot.key ?: return + val text = snapshot.child("text").getValue(String::class.java) ?: "" + Log.d(TAG, "New open task found: $taskId") + listener.onNewTask(taskId, text) + } + } + + override fun onChildChanged(snapshot: DataSnapshot, previousChildName: String?) { + val status = snapshot.child("status").getValue(String::class.java) + val taskId = snapshot.key ?: return + + // If task is no longer open (claimed by someone else or cancelled), remove it from list + if (status != "open") { + listener.onTaskRemoved(taskId) + } + } + + override fun onChildRemoved(snapshot: DataSnapshot) { + val taskId = snapshot.key ?: return + listener.onTaskRemoved(taskId) + } + + override fun onChildMoved(snapshot: DataSnapshot, previousChildName: String?) {} + override fun onCancelled(error: DatabaseError) { + Log.e(TAG, "Tasks listener cancelled: ${error.message}") + listener.onError("Failed to listen for tasks: ${error.message}") + } + } + + tasksRef.orderByChild("status").equalTo("open").addChildEventListener(tasksListener!!) + } + + fun claimTask(taskId: String) { + Log.d(TAG, "Attempting to claim task: $taskId") + val taskStatusRef = tasksRef.child(taskId).child("status") + + taskStatusRef.runTransaction(object : com.google.firebase.database.Transaction.Handler { + override fun doTransaction(currentData: com.google.firebase.database.MutableData): com.google.firebase.database.Transaction.Result { + val status = currentData.getValue(String::class.java) + if (status == null || status == "open") { + currentData.value = "claimed" + return com.google.firebase.database.Transaction.success(currentData) + } + return com.google.firebase.database.Transaction.abort() + } + + override fun onComplete( + error: DatabaseError?, + committed: Boolean, + currentData: DataSnapshot? + ) { + if (error != null) { + Log.e(TAG, "Claim transaction error: ${error.message}") + listener.onClaimFailed(error.message) + } else if (committed) { + Log.d(TAG, "Task claimed successfully: $taskId") + currentTaskId = taskId + listener.onClaimed(taskId) + listenForSignaling(taskId) + } else { + Log.d(TAG, "Claim failed: Task already claimed or invalid.") + listener.onClaimFailed("Task already taken") + } + } + }) + } + + private fun listenForSignaling(taskId: String) { + val taskRef = tasksRef.child(taskId) + + // Listen for SDP Offer from Requester + offerListener = object : ValueEventListener { + override fun onDataChange(snapshot: DataSnapshot) { + val type = snapshot.child("type").getValue(String::class.java) + val sdp = snapshot.child("sdp").getValue(String::class.java) + + if (type == "offer" && sdp != null) { + Log.d(TAG, "Received SDP Offer") + listener.onSDPOffer(sdp) + } + } + + override fun onCancelled(error: DatabaseError) { + Log.e(TAG, "Offer listener cancelled", error.toException()) + } + } + taskRef.child("offer").addValueEventListener(offerListener!!) + + // Listen for ICE Candidates from Requester + iceListener = object : ChildEventListener { + override fun onChildAdded(snapshot: DataSnapshot, previousChildName: String?) { + val sender = snapshot.child("sender").getValue(String::class.java) + if (sender == "requester") { + val candidate = snapshot.child("candidate").getValue(String::class.java) + val sdpMid = snapshot.child("sdpMid").getValue(String::class.java) + val sdpMLineIndex = snapshot.child("sdpMLineIndex").getValue(Int::class.java) ?: 0 + + if (candidate != null) { + Log.d(TAG, "Received ICE candidate from requester") + listener.onICECandidate(candidate, sdpMid, sdpMLineIndex) + } + } + } + + override fun onChildChanged(snapshot: DataSnapshot, previousChildName: String?) {} + override fun onChildRemoved(snapshot: DataSnapshot) {} + override fun onChildMoved(snapshot: DataSnapshot, previousChildName: String?) {} + override fun onCancelled(error: DatabaseError) {} + } + taskRef.child("ice").addChildEventListener(iceListener!!) + } + + fun sendAnswer(sdp: String) { + val taskId = currentTaskId ?: return + Log.d(TAG, "Sending SDP Answer") + val answer = mapOf( + "type" to "answer", + "sdp" to sdp + ) + tasksRef.child(taskId).child("answer").setValue(answer) + } + + fun sendICECandidate(candidate: String, sdpMid: String?, sdpMLineIndex: Int) { + val taskId = currentTaskId ?: return + val ice = mapOf( + "candidate" to candidate, + "sdpMid" to sdpMid, + "sdpMLineIndex" to sdpMLineIndex, + "sender" to "operator" + ) + tasksRef.child(taskId).child("ice").push().setValue(ice) + } + + fun disconnect() { + Log.d(TAG, "Disconnecting SignalingClient") + tasksListener?.let { tasksRef.removeEventListener(it) } + + currentTaskId?.let { taskId -> + offerListener?.let { tasksRef.child(taskId).child("offer").removeEventListener(it) } + iceListener?.let { tasksRef.child(taskId).child("ice").removeEventListener(it) } + } + + tasksListener = null + offerListener = null + iceListener = null + currentTaskId = null + } +} diff --git a/humanoperator/src/main/kotlin/com/screenoperator/humanoperator/WebRTCClient.kt b/humanoperator/src/main/kotlin/com/screenoperator/humanoperator/WebRTCClient.kt new file mode 100644 index 00000000..8baad7a7 --- /dev/null +++ b/humanoperator/src/main/kotlin/com/screenoperator/humanoperator/WebRTCClient.kt @@ -0,0 +1,173 @@ +package com.screenoperator.humanoperator + +import android.content.Context +import android.util.Log +import com.google.gson.Gson +import org.webrtc.* + +/** + * Manages WebRTC peer connection for the Human Operator side. + * Receives video stream from ScreenOperator and sends tap coordinates back. + */ +class WebRTCClient( + private val context: Context, + private val listener: WebRTCListener +) { + companion object { + private const val TAG = "WebRTCClient" + private val STUN_SERVERS = listOf( + PeerConnection.IceServer.builder("stun:stun.l.google.com:19302").createIceServer(), + PeerConnection.IceServer.builder("stun:stun1.l.google.com:19302").createIceServer() + ) + } + + interface WebRTCListener { + fun onLocalICECandidate(candidate: IceCandidate) + fun onVideoTrackReceived(track: VideoTrack) + fun onDataChannelMessage(message: String) + fun onConnectionStateChanged(state: String) + fun onDataChannelOpen() + } + + private var peerConnectionFactory: PeerConnectionFactory? = null + private var peerConnection: PeerConnection? = null + private var dataChannel: DataChannel? = null + private val eglBase = EglBase.create() + private val gson = Gson() + + fun initialize() { + Log.d(TAG, "Initializing WebRTC") + val initOptions = PeerConnectionFactory.InitializationOptions.builder(context) + .setEnableInternalTracer(false) + .createInitializationOptions() + PeerConnectionFactory.initialize(initOptions) + + peerConnectionFactory = PeerConnectionFactory.builder() + .setOptions(PeerConnectionFactory.Options()) + .setVideoDecoderFactory(DefaultVideoDecoderFactory(eglBase.eglBaseContext)) + .createPeerConnectionFactory() + + Log.d(TAG, "WebRTC initialized") + } + + fun createPeerConnection() { + val rtcConfig = PeerConnection.RTCConfiguration(STUN_SERVERS).apply { + sdpSemantics = PeerConnection.SdpSemantics.UNIFIED_PLAN + continualGatheringPolicy = PeerConnection.ContinualGatheringPolicy.GATHER_CONTINUALLY + } + + peerConnection = peerConnectionFactory?.createPeerConnection(rtcConfig, object : PeerConnection.Observer { + override fun onIceCandidate(candidate: IceCandidate) { + Log.d(TAG, "ICE candidate: ${candidate.sdp?.take(50)}") + listener.onLocalICECandidate(candidate) + } + override fun onIceCandidatesRemoved(candidates: Array?) {} + override fun onAddStream(stream: MediaStream) { + Log.d(TAG, "Stream added with ${stream.videoTracks.size} video tracks") + if (stream.videoTracks.isNotEmpty()) { + listener.onVideoTrackReceived(stream.videoTracks[0]) + } + } + override fun onTrack(transceiver: RtpTransceiver) { + val track = transceiver.receiver.track() + if (track is VideoTrack) { + Log.d(TAG, "Video track received via onTrack") + listener.onVideoTrackReceived(track) + } + } + override fun onDataChannel(dc: DataChannel) { + Log.d(TAG, "Data channel received: ${dc.label()}") + dataChannel = dc + dc.registerObserver(object : DataChannel.Observer { + override fun onBufferedAmountChange(previous: Long) {} + override fun onStateChange() { + Log.d(TAG, "DataChannel state: ${dc.state()}") + if (dc.state() == DataChannel.State.OPEN) { + listener.onDataChannelOpen() + } + } + override fun onMessage(buffer: DataChannel.Buffer) { + val data = ByteArray(buffer.data.remaining()) + buffer.data.get(data) + listener.onDataChannelMessage(String(data)) + } + }) + } + override fun onIceConnectionChange(state: PeerConnection.IceConnectionState) { + Log.d(TAG, "ICE connection state: $state") + listener.onConnectionStateChanged(state.name) + } + override fun onIceConnectionReceivingChange(receiving: Boolean) {} + override fun onIceGatheringChange(state: PeerConnection.IceGatheringState) {} + override fun onSignalingChange(state: PeerConnection.SignalingState) {} + override fun onRemoveStream(stream: MediaStream) {} + override fun onRenegotiationNeeded() {} + override fun onAddTrack(receiver: RtpReceiver, streams: Array) {} + }) ?: throw IllegalStateException("Failed to create PeerConnection") + + Log.d(TAG, "PeerConnection created") + } + + fun setRemoteOffer(sdp: String) { + val desc = SessionDescription(SessionDescription.Type.OFFER, sdp) + peerConnection?.setRemoteDescription(object : SdpObserver { + override fun onSetSuccess() { + Log.d(TAG, "Remote offer set, creating answer") + createAnswer() + } + override fun onSetFailure(error: String) { Log.e(TAG, "Set remote offer failed: $error") } + override fun onCreateSuccess(p0: SessionDescription?) {} + override fun onCreateFailure(p0: String?) {} + }, desc) + } + + private fun createAnswer() { + val constraints = MediaConstraints().apply { + mandatory.add(MediaConstraints.KeyValuePair("OfferToReceiveVideo", "true")) + mandatory.add(MediaConstraints.KeyValuePair("OfferToReceiveAudio", "false")) + } + peerConnection?.createAnswer(object : SdpObserver { + override fun onCreateSuccess(sdp: SessionDescription) { + peerConnection?.setLocalDescription(object : SdpObserver { + override fun onSetSuccess() { Log.d(TAG, "Local description set") } + override fun onSetFailure(error: String) { Log.e(TAG, "Set local desc failed: $error") } + override fun onCreateSuccess(p0: SessionDescription?) {} + override fun onCreateFailure(p0: String?) {} + }, sdp) + listener.onConnectionStateChanged("ANSWER_CREATED") + } + override fun onCreateFailure(error: String) { Log.e(TAG, "Create answer failed: $error") } + override fun onSetSuccess() {} + override fun onSetFailure(p0: String?) {} + }, constraints) + } + + fun addICECandidate(candidate: String, sdpMid: String?, sdpMLineIndex: Int) { + peerConnection?.addIceCandidate(IceCandidate(sdpMid ?: "", sdpMLineIndex, candidate)) + } + + fun getLocalDescription(): SessionDescription? = peerConnection?.localDescription + + fun sendTap(x: Float, y: Float) { + sendDataChannelMessage(gson.toJson(mapOf("type" to "tap", "x" to x, "y" to y))) + } + + fun sendClaim() = sendDataChannelMessage("{\"type\":\"claim\"}") + fun sendReject() = sendDataChannelMessage("{\"type\":\"reject\"}") + + private fun sendDataChannelMessage(message: String) { + val buffer = DataChannel.Buffer(java.nio.ByteBuffer.wrap(message.toByteArray()), false) + if (dataChannel?.send(buffer) != true) { + Log.w(TAG, "Failed to send DataChannel message: ${message.take(50)}") + } + } + + fun getEglBaseContext(): EglBase.Context = eglBase.eglBaseContext + + fun dispose() { + dataChannel?.close() + peerConnection?.close() + peerConnectionFactory?.dispose() + eglBase.release() + } +} diff --git a/humanoperator/src/main/res/drawable/ic_launcher_foreground.xml b/humanoperator/src/main/res/drawable/ic_launcher_foreground.xml new file mode 100644 index 00000000..26132905 --- /dev/null +++ b/humanoperator/src/main/res/drawable/ic_launcher_foreground.xml @@ -0,0 +1,15 @@ + + + + + + + diff --git a/humanoperator/src/main/res/mipmap-anydpi-v26/ic_launcher.xml b/humanoperator/src/main/res/mipmap-anydpi-v26/ic_launcher.xml new file mode 100644 index 00000000..5ed0a2df --- /dev/null +++ b/humanoperator/src/main/res/mipmap-anydpi-v26/ic_launcher.xml @@ -0,0 +1,5 @@ + + + + + diff --git a/humanoperator/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml b/humanoperator/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml new file mode 100644 index 00000000..5ed0a2df --- /dev/null +++ b/humanoperator/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml @@ -0,0 +1,5 @@ + + + + + diff --git a/humanoperator/src/main/res/values/colors.xml b/humanoperator/src/main/res/values/colors.xml new file mode 100644 index 00000000..cb422ab6 --- /dev/null +++ b/humanoperator/src/main/res/values/colors.xml @@ -0,0 +1,5 @@ + + + #1a1a2e + #7C4DFF + diff --git a/humanoperator/src/main/res/values/themes.xml b/humanoperator/src/main/res/values/themes.xml new file mode 100644 index 00000000..c2bcead5 --- /dev/null +++ b/humanoperator/src/main/res/values/themes.xml @@ -0,0 +1,7 @@ + + + + diff --git a/settings.gradle.kts b/settings.gradle.kts index a290317f..c5d7c797 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -33,3 +33,4 @@ dependencyResolutionManagement { rootProject.name = "GenerativeAiSampleApp" include(":app") +include(":humanoperator") diff --git a/signaling-server/index.js b/signaling-server/index.js new file mode 100644 index 00000000..cda32a5a --- /dev/null +++ b/signaling-server/index.js @@ -0,0 +1,138 @@ +const WebSocket = require("ws"); + +const PORT = process.env.PORT || 8080; +const wss = new WebSocket.Server({ port: PORT }); + +// Waiting operators (not currently paired) +const availableOperators = new Set(); +// Active tasks waiting for an operator to claim +// Map +const pendingTasks = new Map(); +// Active pairs (after claim) +// Map — bidirectional mapping +const pairs = new Map(); + +let taskIdCounter = 1; + +console.log(`Signaling server (task broker) running on port ${PORT}`); + +wss.on("connection", (ws) => { + let role = null; // "operator" or "requester" + + ws.on("message", (data) => { + try { + const msg = JSON.parse(data); + + // Operator registers as available + if (msg.type === "register_operator") { + role = "operator"; + availableOperators.add(ws); + console.log(`Operator registered (${availableOperators.size} total)`); + ws.send(JSON.stringify({ type: "registered", message: "Waiting for tasks..." })); + // Send any pending tasks immediately + pendingTasks.forEach((task, taskId) => { + ws.send(JSON.stringify({ + type: "new_task", + taskId: taskId, + text: task.task.text, + hasScreenshot: !!task.task.screenshot + })); + }); + return; + } + + // ScreenOperator posts a new task + if (msg.type === "post_task") { + role = "requester"; + const taskId = "task_" + (taskIdCounter++); + pendingTasks.set(taskId, { requester: ws, task: msg }); + console.log(`Task posted: ${taskId} (${availableOperators.size} operators available)`); + + // Broadcast to all available operators + availableOperators.forEach((op) => { + if (op.readyState === WebSocket.OPEN) { + op.send(JSON.stringify({ + type: "new_task", + taskId: taskId, + text: msg.text || "", + hasScreenshot: !!msg.screenshot + })); + } + }); + + // Tell requester how many operators are available + ws.send(JSON.stringify({ + type: "task_posted", + taskId: taskId, + operatorsAvailable: availableOperators.size + })); + return; + } + + // Operator claims a task + if (msg.type === "claim" && msg.taskId) { + const task = pendingTasks.get(msg.taskId); + if (!task) { + ws.send(JSON.stringify({ type: "claim_failed", reason: "Task already claimed or expired" })); + return; + } + // Pair them + const requester = task.requester; + pendingTasks.delete(msg.taskId); + availableOperators.delete(ws); + pairs.set(ws, requester); + pairs.set(requester, ws); + + console.log(`Task ${msg.taskId} claimed. Pair established.`); + + // Notify the claiming operator + ws.send(JSON.stringify({ type: "claimed", taskId: msg.taskId })); + // Notify the requester + requester.send(JSON.stringify({ type: "task_claimed", taskId: msg.taskId })); + + // Notify all other operators that the task is gone + availableOperators.forEach((op) => { + if (op.readyState === WebSocket.OPEN) { + op.send(JSON.stringify({ type: "task_taken", taskId: msg.taskId })); + } + }); + return; + } + + // Forward WebRTC signaling between paired peers + const peer = pairs.get(ws); + if (peer && peer.readyState === WebSocket.OPEN) { + peer.send(JSON.stringify(msg)); + } + + } catch (e) { + console.error("Failed to process message:", e.message); + } + }); + + ws.on("close", () => { + availableOperators.delete(ws); + // Clean up any pending tasks from this requester + pendingTasks.forEach((task, taskId) => { + if (task.requester === ws) { + pendingTasks.delete(taskId); + // Notify operators task is gone + availableOperators.forEach((op) => { + if (op.readyState === WebSocket.OPEN) { + op.send(JSON.stringify({ type: "task_taken", taskId: taskId })); + } + }); + } + }); + // Clean up pair + const peer = pairs.get(ws); + if (peer) { + pairs.delete(peer); + pairs.delete(ws); + if (peer.readyState === WebSocket.OPEN) { + peer.send(JSON.stringify({ type: "peer_disconnected" })); + } + } + console.log(`Client disconnected (${availableOperators.size} operators remaining)`); + }); +}); diff --git a/signaling-server/package.json b/signaling-server/package.json new file mode 100644 index 00000000..137abaa8 --- /dev/null +++ b/signaling-server/package.json @@ -0,0 +1,12 @@ +{ + "name": "screenoperator-signaling", + "version": "1.0.0", + "description": "Minimal WebSocket relay for WebRTC signaling", + "main": "index.js", + "scripts": { + "start": "node index.js" + }, + "dependencies": { + "ws": "^8.16.0" + } +}