Skip to content

Commit

Permalink
Merge pull request #441 from RelevanceAI/development
Browse files Browse the repository at this point in the history
v0.32.1
  • Loading branch information
jtwinrelevanceai authored Jun 13, 2023
2 parents f9a32cd + 0992528 commit d7db0f2
Show file tree
Hide file tree
Showing 15 changed files with 131 additions and 64 deletions.
2 changes: 1 addition & 1 deletion ai_transform/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.32.0"
__version__ = "0.32.1"

from ai_transform.timer import Timer

Expand Down
7 changes: 4 additions & 3 deletions ai_transform/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,10 @@ def _list_datasets(self):
def _create_dataset(
self, dataset_id: str, schema: Optional[Schema] = None, upsert: bool = True, expire: bool = False
) -> Any:
response = self.post(
suffix=f"/datasets/create", json=dict(id=dataset_id, schema=schema, upsert=upsert, expire=expire)
)
obj = dict(id=dataset_id, upsert=upsert, expire=expire)
if schema:
obj["schema"] = schema
response = self.post(suffix=f"/datasets/create", json=obj)
return get_response(response)

def _delete_dataset(self, dataset_id: str) -> Any:
Expand Down
7 changes: 3 additions & 4 deletions ai_transform/dataset/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,10 @@ def contains(self, other: str) -> Filter:

def exists(self) -> Filter:
if "_chunk_" in self._field:
count = self._field.count(".")
if count:
parent_field = self._field.split(".")[0]
else:
if self._field.endswith("_chunk_"):
parent_field = self._field
else:
parent_field = self._field.split(".")[0]

return [{"chunk": {"path": parent_field, "filters": [{"fieldExists": {"field": self._field}}]}}]
return [{"field": self._field, "filter_type": "exists", "condition": "==", "condition_value": " "}]
Expand Down
6 changes: 3 additions & 3 deletions ai_transform/engine/abstract_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(
chunk_index = field.index("_chunk_") + len("_chunk_")
chunk_field = field[:chunk_index]
fields_to_add += [chunk_field]
select_fields += fields_to_add
select_fields = select_fields + fields_to_add
select_fields = list(set(select_fields))
else:
select_fields = []
Expand Down Expand Up @@ -109,8 +109,8 @@ def __init__(
self._refresh = refresh
self._after_id = after_id

filters += self._get_refresh_filter()
filters += self._get_workflow_filter()
filters = filters + self._get_refresh_filter()
filters = filters + self._get_workflow_filter()

self._filters = filters

Expand Down
12 changes: 11 additions & 1 deletion ai_transform/engine/dense_output_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@


class DenseOutputEngine(AbstractEngine):
operator: DenseOperator

def __init__(
self,
dataset: Dataset = None,
Expand Down Expand Up @@ -72,6 +74,7 @@ def apply(self) -> None:
for mega_batch in self.api_progress(iterator):
for mini_batch in AbstractEngine.chunk_documents(self._transform_chunksize, mega_batch):
document_mapping = self._operate(mini_batch)

for dataset_id, documents in document_mapping.items():
output_dataset_ids.append(dataset_id)
dataset = Dataset.from_details(dataset_id, self.token)
Expand All @@ -81,7 +84,14 @@ def apply(self) -> None:
self.operator.post_hooks(self._dataset)

output_datasets = self.datasets_from_ids(output_dataset_ids)
self.operator.store_dataset_relationship(self.dataset, output_datasets)
self.store_dataset_relationship(output_datasets)

def datasets_from_ids(self, dataset_ids: Sequence[str]) -> Sequence[Dataset]:
return [Dataset.from_details(dataset_id, self.token) for dataset_id in dataset_ids]

def store_dataset_relationship(self, output_datasets: Sequence[Dataset]):
self.dataset.update_metadata(
{"_child_datasets_": [output_dataset.dataset_id for output_dataset in output_datasets]}
)
for output_dataset in output_datasets:
output_dataset.update_metadata({"_parent_dataset_": self.dataset.dataset_id})
9 changes: 7 additions & 2 deletions ai_transform/operator/abstract_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def __init__(
output_fields: Optional[Union[Dict[str, str], List[str]]] = None,
enable_postprocess: Optional[bool] = True,
):

if input_fields is not None and output_fields is not None:
if any(input_field in output_fields for input_field in input_fields):
detected_fields = [input_field for input_field in input_fields if input_field in output_fields]
Expand Down Expand Up @@ -171,6 +170,12 @@ def transform_for_playground(
from ai_transform.api.client import Client

output = self.transform(documents=documents)
if hasattr(documents, "to_json"):
output = output.to_json()
else:
for index in range(len(output)):
if hasattr(output[index], "to_json"):
output[index] = output[index].to_json()
client = Client(authorization_token)
return client.api._set_workflow_status(
job_id=job_id,
Expand All @@ -180,7 +185,7 @@ def transform_for_playground(
status=status,
send_email=send_email,
worker_number=worker_number,
output=output,
output={"output": output},
)

def pre_hooks(self, dataset: Dataset):
Expand Down
14 changes: 4 additions & 10 deletions ai_transform/operator/dense_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Dict, Sequence

from ai_transform.operator.abstract_operator import AbstractOperator
from ai_transform.dataset.dataset import Dataset
from ai_transform.utils.document import Document
from ai_transform.utils.document_list import DocumentList

Expand All @@ -24,18 +23,13 @@
class DenseOperator(AbstractOperator):
def __call__(self, old_documents: DocumentList) -> DenseOperatorOutput:
datum = self.transform(old_documents)
assert isinstance(datum, dict), BAD_OPERATOR_MESSAGE
if not isinstance(datum, dict):
raise ValueError(BAD_OPERATOR_MESSAGE)
for _, documents in datum.items():
assert isinstance(documents, Sequence)
if not isinstance(documents, Sequence):
raise ValueError(BAD_OPERATOR_MESSAGE)
return datum

@abstractmethod
def transform(self, documents: DocumentList) -> DenseOperatorOutput:
raise NotImplementedError

def store_dataset_relationship(self, input_dataset: Dataset, output_datasets: Sequence[Dataset]):
input_dataset.update_metadata(
{"_child_datasets_": [output_dataset.dataset_id for output_dataset in output_datasets]}
)
for output_dataset in output_datasets:
output_dataset.update_metadata({"_parent_dataset_": input_dataset.dataset_id})
12 changes: 1 addition & 11 deletions ai_transform/utils/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,17 +141,7 @@ def list_chunks(self):
return [k for k in self.keys() if k.endswith("_chunk_")]

def get_chunk(self, chunk_field: str, field: str = None, default: str = None):
"""
Returns a list of values.
"""
# provide a recursive implementation for getting chunks
from ai_transform.utils.document_list import DocumentList

document_list = DocumentList(self.get(chunk_field, default=default))
# Get the field across chunks
if field is None:
return document_list
return [d.get(field, default=default) for d in document_list.data]
return [document.get(field, default) for document in self.get(chunk_field, default=default)]

def _create_chunk_documents(self, field: str, values: list, generate_id: bool = False):
"""
Expand Down
12 changes: 10 additions & 2 deletions ai_transform/utils/example_documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def create_id():


def generate_random_string(string_length: int = 5) -> str:

"""Generate a random string of letters and numbers"""
return "".join(random.choice(string.ascii_uppercase + string.digits) for _ in range(string_length))

Expand Down Expand Up @@ -92,7 +91,16 @@ def vector_document(vector_length: int) -> Document:


def mock_documents(n: int = 100, vector_length: int = 5) -> DocumentList:
return DocumentList([vector_document(vector_length) for _ in range(n)])
documents = [vector_document(vector_length) for _ in range(n)]
return DocumentList(documents)


def incomplete_documents(n: int = 100, vector_length: int = 5) -> DocumentList:
documents = [vector_document(vector_length).data for _ in range(n)]
for document in documents:
for key in random.sample(document.keys(), 3):
document.pop(key)
return DocumentList(documents)


def static_documents(n: int = 100) -> DocumentList:
Expand Down
2 changes: 1 addition & 1 deletion examples/workflows/clustering_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def execute(token: str, logger: Callable, worker_number: int = 0, *args, **kwarg
engine = InMemoryEngine(
dataset=dataset,
operator=operator,
chunksize=16,
chunksize=8,
select_fields=[vector_field],
filters=filters,
worker_number=worker_number,
Expand Down
4 changes: 1 addition & 3 deletions examples/workflows/sentiment_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@
# transformers[torch]==4.18.0
# relevance-workflows-core

import uuid
import torch

from typing import Callable, List, Optional
from typing import Callable, Optional

from transformers import pipeline
from ai_transform.api.client import Client
from ai_transform.engine.stable_engine import StableEngine
from ai_transform.workflow.helpers import decode_workflow_token
from ai_transform.workflow.abstract_workflow import AbstractWorkflow
from ai_transform.operator.abstract_operator import AbstractOperator
from ai_transform.utils.document_list import DocumentList
Expand Down
66 changes: 57 additions & 9 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,14 @@

from ai_transform.utils.document import Document
from ai_transform.utils.document_list import DocumentList
from ai_transform.utils.example_documents import mock_documents, static_documents, tag_documents
from ai_transform.utils.example_documents import (
mock_documents,
static_documents,
tag_documents,
generate_random_vector,
incomplete_documents,
)


TEST_TOKEN = os.getenv("TEST_TOKEN")
test_creds = process_token(TEST_TOKEN)
Expand All @@ -44,11 +51,11 @@ def test_dataset_id() -> str:
return dataset_id


@pytest.fixture(scope="class")
@pytest.fixture(scope="function")
def empty_dataset(test_client: Client) -> Dataset:
salt = "".join(random.choices(string.ascii_lowercase, k=10))
dataset_id = f"_sample_dataset_{salt}"
dataset = test_client.Dataset(dataset_id, expire=True)
dataset = test_client.Dataset(dataset_id, expire=False)
yield dataset
test_client.delete_dataset(dataset_id)

Expand Down Expand Up @@ -137,8 +144,8 @@ def static_dataset(test_client: Client) -> Dataset:
test_client.delete_dataset(dataset_id)


@pytest.fixture(scope="class")
def dense_input_dataset(test_client: Client) -> Dataset:
@pytest.fixture(scope="function")
def dense_input_dataset1(test_client: Client) -> Dataset:
salt = "".join(random.choices(string.ascii_lowercase, k=10))
dataset_id = f"_sample_dataset_{salt}"
dataset = test_client.Dataset(dataset_id, expire=True)
Expand All @@ -147,7 +154,17 @@ def dense_input_dataset(test_client: Client) -> Dataset:
test_client.delete_dataset(dataset_id)


@pytest.fixture(scope="class")
@pytest.fixture(scope="function")
def dense_input_dataset2(test_client: Client) -> Dataset:
salt = "".join(random.choices(string.ascii_lowercase, k=10))
dataset_id = f"_sample_dataset_{salt}"
dataset = test_client.Dataset(dataset_id, expire=True)
dataset.insert_documents(mock_documents(100))
yield dataset
test_client.delete_dataset(dataset_id)


@pytest.fixture(scope="function")
def dense_output_dataset1(test_client: Client) -> Dataset:
salt = "".join(random.choices(string.ascii_lowercase, k=10))
dataset_id = f"_sample_dataset_{salt}"
Expand All @@ -156,7 +173,7 @@ def dense_output_dataset1(test_client: Client) -> Dataset:
test_client.delete_dataset(dataset_id)


@pytest.fixture(scope="class")
@pytest.fixture(scope="function")
def dense_output_dataset2(test_client: Client) -> Dataset:
salt = "".join(random.choices(string.ascii_lowercase, k=10))
dataset_id = f"_sample_dataset_{salt}"
Expand Down Expand Up @@ -264,6 +281,37 @@ def transform(self, documents: DocumentList) -> DocumentList:
return TestDenseOperator(output_dataset_ids)


@pytest.fixture(scope="function")
def test_chunk_dense_operator(dense_output_dataset1: Dataset, dense_output_dataset2: Dataset) -> DenseOperator:
class TestDenseOperator(DenseOperator):
def __init__(self, output_dataset_ids: Sequence[str]):
self.output_dataset_ids = output_dataset_ids
self._chunk_field = "_chunk_"
self._text_field = "label"
self._alias = "default"
super().__init__()

def transform(self, documents: List[Document]) -> List[Document]:
outputs = []
for document in documents:
texts = document.get_chunk(chunk_field=self._chunk_field, field=self._text_field)
text_vectors = [generate_random_vector() for _ in range(len(texts))]

for sent_index, text_vector in enumerate(text_vectors):
outputs.append(
{
"_id": document["_id"] + f":{sent_index}",
f"{self._text_field}_{self._alias}_vector_": text_vector,
"_order": sent_index,
}
)

return {dataset_id: documents for dataset_id in self.output_dataset_ids}

output_dataset_ids = (dense_output_dataset1.dataset_id, dense_output_dataset2.dataset_id)
return TestDenseOperator(output_dataset_ids)


@pytest.fixture(scope="function")
def test_engine(full_dataset: Dataset, test_operator: AbstractOperator) -> StableEngine:
return StableEngine(dataset=full_dataset, operator=test_operator)
Expand All @@ -284,7 +332,7 @@ def test_sentiment_workflow_token(test_client: Client) -> str:
salt = "".join(random.choices(string.ascii_lowercase, k=10))
dataset_id = f"_sample_dataset_{salt}"
dataset = test_client.Dataset(dataset_id, expire=True)
dataset.insert_documents(mock_documents(20))
dataset.insert_documents(incomplete_documents(20))
time.sleep(1)
job_id = str(uuid.uuid4())
config = dict(
Expand Down Expand Up @@ -361,7 +409,7 @@ def test_cluster_workflow_token(test_client: Client) -> str:
salt = "".join(random.choices(string.ascii_lowercase, k=10))
dataset_id = f"_sample_dataset_{salt}"
dataset = test_client.Dataset(dataset_id, expire=True)
dataset.insert_documents(mock_documents(20))
dataset.insert_documents(incomplete_documents(20))
job_id = str(uuid.uuid4())
print(job_id)
config = dict(
Expand Down
11 changes: 5 additions & 6 deletions tests/core/test_dataset/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,7 @@
from ai_transform.utils.example_documents import mock_documents


@pytest.mark.usefixtures("empty_dataset")
class TestDataset1:
def test_create_delete(self, empty_dataset: Dataset):
empty_dataset.delete()
empty_dataset.create()
assert True

def test_insert(self, empty_dataset: Dataset):
documents = mock_documents(100)
result = empty_dataset.insert_documents(documents)
Expand All @@ -21,6 +15,11 @@ def test_get_all(self, full_dataset: Dataset):
res = full_dataset.get_all_documents()
assert len(res["documents"]) == 20

def test_create_delete(self, empty_dataset: Dataset):
empty_dataset.delete()
empty_dataset.create()
assert True


@pytest.mark.usefixtures("full_dataset")
class TestDataset2:
Expand Down
Loading

0 comments on commit d7db0f2

Please sign in to comment.