Skip to content

Commit

Permalink
update api server.
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed Sep 18, 2023
1 parent aa87312 commit 9d82044
Showing 1 changed file with 9 additions and 19 deletions.
28 changes: 9 additions & 19 deletions examples/fastapi_server_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,24 @@


import argparse
import uvicorn
import sys
import os
from fastapi import FastAPI, Query
from starlette.middleware.cors import CORSMiddleware
import sys

import torch
import uvicorn
from fastapi import FastAPI
from loguru import logger
from typing import List
from pydantic import BaseModel, Field
import numpy as np

from starlette.middleware.cors import CORSMiddleware

sys.path.append('..')
from text2vec import SentenceModel


class Item(BaseModel):
input: str = Field(..., max_length=512)


pwd_path = os.path.abspath(os.path.dirname(__file__))
use_cuda = torch.cuda.is_available()
logger.info(f'use_cuda:{use_cuda}')
Expand All @@ -30,13 +30,6 @@ class Item(BaseModel):
args = parser.parse_args()
s_model = SentenceModel(args.model_name_or_path)

def _normalize_embedding_2D(vec: np.ndarray) -> np.ndarray:
vec = np.ascontiguousarray(vec)
norm = np.sqrt(vec.dot(vec))
if norm != 0.0:
vec /= norm
return vec

# define the app
app = FastAPI()
app.add_middleware(
Expand All @@ -55,10 +48,8 @@ async def index():
@app.post('/emb')
async def emb(item: Item):
try:
embeddings = s_model.encode(item.input)
embeddings = np.array(embeddings)
normalized_embeddings = _normalize_embedding_2D(embeddings)
result_dict = {'emb': normalized_embeddings.tolist()}
embeddings = s_model.encode(item.input, normalize_embeddings=True)
result_dict = {'emb': embeddings.tolist()}
logger.debug(f"Successfully get sentence embeddings, q:{item.input}")
return result_dict
except Exception as e:
Expand All @@ -68,4 +59,3 @@ async def emb(item: Item):

if __name__ == '__main__':
uvicorn.run(app=app, host='0.0.0.0', port=8001)

0 comments on commit 9d82044

Please sign in to comment.