|
| 1 | +defmodule Bumblebee.Text.CrossEncoding do |
| 2 | + @moduledoc false |
| 3 | + |
| 4 | + alias Bumblebee.Shared |
| 5 | + |
| 6 | + def cross_encoding(model_info, tokenizer, opts \\ []) do |
| 7 | + %{model: model, params: params, spec: spec} = model_info |
| 8 | + Shared.validate_architecture!(spec, :for_sequence_classification) |
| 9 | + |
| 10 | + opts = |
| 11 | + Keyword.validate!(opts, [ |
| 12 | + :compile, |
| 13 | + defn_options: [], |
| 14 | + preallocate_params: false |
| 15 | + ]) |
| 16 | + |
| 17 | + preallocate_params = opts[:preallocate_params] |
| 18 | + defn_options = opts[:defn_options] |
| 19 | + |
| 20 | + compile = |
| 21 | + if compile = opts[:compile] do |
| 22 | + compile |
| 23 | + |> Keyword.validate!([:batch_size, :sequence_length]) |
| 24 | + |> Shared.require_options!([:batch_size, :sequence_length]) |
| 25 | + end |
| 26 | + |
| 27 | + batch_size = compile[:batch_size] |
| 28 | + sequence_length = compile[:sequence_length] |
| 29 | + |
| 30 | + tokenizer = |
| 31 | + Bumblebee.configure(tokenizer, length: sequence_length) |
| 32 | + |
| 33 | + {_init_fun, predict_fun} = Axon.build(model) |
| 34 | + |
| 35 | + scores_fun = fn params, input -> |
| 36 | + outputs = predict_fun.(params, input) |
| 37 | + Nx.squeeze(outputs.logits, axes: [-1]) |
| 38 | + end |
| 39 | + |
| 40 | + batch_keys = Shared.sequence_batch_keys(sequence_length) |
| 41 | + |
| 42 | + Nx.Serving.new( |
| 43 | + fn batch_key, defn_options -> |
| 44 | + params = Shared.maybe_preallocate(params, preallocate_params, defn_options) |
| 45 | + |
| 46 | + scope = {:cross_encoding, batch_key} |
| 47 | + |
| 48 | + scores_fun = |
| 49 | + Shared.compile_or_jit(scores_fun, scope, defn_options, compile != nil, fn -> |
| 50 | + {:sequence_length, sequence_length} = batch_key |
| 51 | + |
| 52 | + inputs = %{ |
| 53 | + "input_ids" => Nx.template({batch_size, sequence_length}, :u32), |
| 54 | + "attention_mask" => Nx.template({batch_size, sequence_length}, :u32), |
| 55 | + "token_type_ids" => Nx.template({batch_size, sequence_length}, :u32) |
| 56 | + } |
| 57 | + |
| 58 | + [params, inputs] |
| 59 | + end) |
| 60 | + |
| 61 | + fn inputs -> |
| 62 | + inputs = Shared.maybe_pad(inputs, batch_size) |
| 63 | + scores_fun.(params, inputs) |> Shared.serving_post_computation() |
| 64 | + end |
| 65 | + end, |
| 66 | + defn_options |
| 67 | + ) |
| 68 | + |> Nx.Serving.batch_size(batch_size) |
| 69 | + |> Nx.Serving.process_options(batch_keys: batch_keys) |
| 70 | + |> Nx.Serving.client_preprocessing(fn input -> |
| 71 | + {pairs, multi?} = Shared.validate_serving_input!(input, &validate_pair/1) |
| 72 | + |
| 73 | + inputs = |
| 74 | + Nx.with_default_backend(Nx.BinaryBackend, fn -> |
| 75 | + Bumblebee.apply_tokenizer(tokenizer, pairs) |
| 76 | + end) |
| 77 | + |
| 78 | + batch_key = Shared.sequence_batch_key_for_inputs(inputs, sequence_length) |
| 79 | + batch = [inputs] |> Nx.Batch.concatenate() |> Nx.Batch.key(batch_key) |
| 80 | + |
| 81 | + {batch, multi?} |
| 82 | + end) |
| 83 | + |> Nx.Serving.client_postprocessing(fn {scores, _metadata}, multi? -> |
| 84 | + scores |
| 85 | + |> Nx.to_list() |
| 86 | + |> Enum.map(&%{score: &1}) |
| 87 | + |> Shared.normalize_output(multi?) |
| 88 | + end) |
| 89 | + end |
| 90 | + |
| 91 | + defp validate_pair({text1, text2}) when is_binary(text1) and is_binary(text2), |
| 92 | + do: {:ok, {text1, text2}} |
| 93 | + |
| 94 | + defp validate_pair(value), |
| 95 | + do: {:error, "expected a {string, string} pair, got: #{inspect(value)}"} |
| 96 | +end |
0 commit comments