Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove Flex OP dependency and run inference against new IDDetector #5404

Merged
merged 5 commits into from Aug 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Expand Up @@ -6,6 +6,10 @@

* [FIXED][5422](https://github.com/stripe/stripe-android/pull/5422) Card expiration dates with a single-digit month are now preserved correctly when closing and re-opening the `PaymentSheet` via the `FlowController`.

### Identity
* [FIXED][5404](https://github.com/stripe/stripe-android/pull/5404) Remove Flex OP dependency from
Identity SDK and reduce its binary size.

## 20.9.0 - 2022-08-16
This release contains several bug fixes for Payments, PaymentSheet and Financial Connections.
Adds `IdentityVerificationSheet#rememberIdentityVerificationSheet` for Identity.
Expand Down
2 changes: 0 additions & 2 deletions identity/build.gradle
Expand Up @@ -25,8 +25,6 @@ dependencies {

// vanilla tflite library
implementation "org.tensorflow:tensorflow-lite:2.9.0"
// flex ops
implementation 'org.tensorflow:tensorflow-lite-select-tf-ops:2.9.0'
// support library to reshape image to the input model shape
implementation 'org.tensorflow:tensorflow-lite-support:0.4.1'

Expand Down
Expand Up @@ -51,15 +51,15 @@ internal class IDDetectorAnalyzer(
ImageProcessor.Builder().add(
ResizeOp(INPUT_HEIGHT, INPUT_WIDTH, ResizeOp.ResizeMethod.BILINEAR)
).add(
NormalizeOp(NORMALIZE_MEAN, NORMALIZE_STD) // normalize to (-1, 1)
NormalizeOp(NORMALIZE_MEAN, NORMALIZE_STD) // normalize to [0, 1)
).build() // add normalization
tensorImage = imageProcessor.process(tensorImage)
preprocessStat.trackResult()

val inferenceStat = modelPerformanceTracker.trackInference()
// inference - input: (1, 224, 224, 3), output: (1, 4), (1, 5)
val boundingBoxes = Array(1) { FloatArray(OUTPUT_BOUNDING_BOX_TENSOR_SIZE) }
val categories = Array(1) { FloatArray(OUTPUT_CATEGORY_TENSOR_SIZE) }
// inference - input: (1, 224, 224, 3), output: (392, 4), (392, 4)
val boundingBoxes = Array(OUTPUT_SIZE) { FloatArray(OUTPUT_BOUNDING_BOX_TENSOR_SIZE) }
val categories = Array(OUTPUT_SIZE) { FloatArray(OUTPUT_CATEGORY_TENSOR_SIZE) }
tfliteInterpreter.runForMultipleInputsOutputs(
arrayOf(tensorImage.buffer),
mapOf(
Expand All @@ -69,31 +69,41 @@ internal class IDDetectorAnalyzer(
)
inferenceStat.trackResult()

// find the category with highest score and build output
val resultIndex = requireNotNull(categories[0].indices.maxByOrNull { categories[0][it] })

val resultCategory: Category
val resultScore: Float

// TODO(ccen) use idDetectorMinScore when server updates the value
if (categories[0][resultIndex] > THRESHOLD) {
resultCategory = requireNotNull(INDEX_CATEGORY_MAP[resultIndex])
resultScore = categories[0][resultIndex]
} else {
resultCategory = Category.NO_ID
resultScore = 0f
// To get more results, run nonMaxSuppressionMultiClass on the categories.
// Fut for IDDetector, we just need to find the highest score and return it's
// corresponding box.
var bestIndex = 0
var bestScore = Float.MIN_VALUE
var bestCategoryIndex = INDEX_INVALID

// Find the best score in the output 2d array of (392, 4),
// return its index within range [0, 392) on 1d as bestIndex.
for (currentOutputIndex in 0 until OUTPUT_SIZE) {
val currentScores = categories[currentOutputIndex]
val currentBestCategoryIndex = currentScores.indices.maxBy {
currentScores[it]
}
val currentBestScore = currentScores[currentBestCategoryIndex]
if (bestScore < currentBestScore && currentBestScore > idDetectorMinScore) {
bestScore = currentBestScore
bestIndex = currentOutputIndex
bestCategoryIndex = currentBestCategoryIndex
}
}

val bestCategory = INDEX_CATEGORY_MAP[bestCategoryIndex] ?: Category.INVALID
val bestBoundingBox = boundingBoxes[bestIndex]
return IDDetectorOutput(
BoundingBox(
left = boundingBoxes[0][0],
top = boundingBoxes[0][1],
width = boundingBoxes[0][2],
height = boundingBoxes[0][3]
bestBoundingBox[0],
bestBoundingBox[1],
bestBoundingBox[2],
bestBoundingBox[3]
),
resultCategory,
resultScore,
categories[0].map { it.roundToMaxDecimals(2) }
bestCategory,
bestScore,
LIST_OF_INDICES.map {
categories[bestIndex][it].roundToMaxDecimals(2)
}
)
}

Expand All @@ -111,28 +121,39 @@ internal class IDDetectorAnalyzer(
Analyzer<AnalyzerInput, IdentityScanState, AnalyzerOutput>
> {
override suspend fun newInstance(): Analyzer<AnalyzerInput, IdentityScanState, AnalyzerOutput> {
return IDDetectorAnalyzer(modelFile, idDetectorMinScore, modelPerformanceTracker)
return IDDetectorAnalyzer(
modelFile,
idDetectorMinScore,
modelPerformanceTracker
)
}
}

internal companion object {
const val OUTPUT_SIZE = 392

const val INPUT_WIDTH = 224
const val INPUT_HEIGHT = 224
const val NORMALIZE_MEAN = 127.5f
const val NORMALIZE_STD = 127.5f
const val THRESHOLD = 0.4f

// (0, 1)
const val NORMALIZE_MEAN = 0f
const val NORMALIZE_STD = 255f
const val OUTPUT_BOUNDING_BOX_TENSOR_INDEX = 0
const val OUTPUT_CATEGORY_TENSOR_INDEX = 1
const val OUTPUT_BOUNDING_BOX_TENSOR_SIZE = 4
private const val INDEX_NO_ID = 0
const val INDEX_PASSPORT = 1
const val INDEX_ID_FRONT = 2
const val INDEX_ID_BACK = 3
const val INDEX_INVALID = 4
const val INDEX_PASSPORT = 0
const val INDEX_ID_FRONT = 1
const val INDEX_ID_BACK = 2
const val INDEX_INVALID = 3
val LIST_OF_INDICES = listOf(
INDEX_PASSPORT,
INDEX_ID_FRONT,
INDEX_ID_BACK,
INDEX_INVALID
)
val INPUT_TENSOR_TYPE: DataType = DataType.FLOAT32
val OUTPUT_CATEGORY_TENSOR_SIZE = Category.values().size
val OUTPUT_CATEGORY_TENSOR_SIZE = Category.values().size - 1 // no NO_ID
val INDEX_CATEGORY_MAP = mapOf(
INDEX_NO_ID to Category.NO_ID,
INDEX_PASSPORT to Category.PASSPORT,
INDEX_ID_FRONT to Category.ID_FRONT,
INDEX_ID_BACK to Category.ID_BACK,
Expand Down