mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-24 12:43:22 +08:00
chore: Simplify to_dict, from_dict with default (de-)serialization of objects (#10435)
* simplify retriever serde for document stores * adjust tests * add release note * simplify CacheChecker * simplify doc writer deserialization * remove test with invalid deserialized doc store * handle the case where policy is not explicitly set * test from_dict without policy
This commit is contained in:
@@ -6,7 +6,6 @@ from typing import Any
|
||||
|
||||
from haystack import Document, component, default_from_dict, default_to_dict
|
||||
from haystack.document_stores.types import DocumentStore
|
||||
from haystack.utils import deserialize_document_store_in_init_params_inplace
|
||||
|
||||
|
||||
@component
|
||||
@@ -58,7 +57,7 @@ class CacheChecker:
|
||||
:returns:
|
||||
Dictionary with serialized data.
|
||||
"""
|
||||
return default_to_dict(self, document_store=self.document_store.to_dict(), cache_field=self.cache_field)
|
||||
return default_to_dict(self, document_store=self.document_store, cache_field=self.cache_field)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "CacheChecker":
|
||||
@@ -70,9 +69,6 @@ class CacheChecker:
|
||||
:returns:
|
||||
Deserialized component.
|
||||
"""
|
||||
# deserialize the document store
|
||||
deserialize_document_store_in_init_params_inplace(data)
|
||||
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@component.output_types(hits=list[Document], misses=list)
|
||||
|
||||
@@ -5,10 +5,8 @@
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
from haystack import Document, component, default_to_dict
|
||||
from haystack.core.serialization import default_from_dict
|
||||
from haystack import Document, component, default_from_dict, default_to_dict
|
||||
from haystack.document_stores.types import DocumentStore
|
||||
from haystack.utils import deserialize_document_store_in_init_params_inplace
|
||||
|
||||
|
||||
@component
|
||||
@@ -86,8 +84,7 @@ class AutoMergingRetriever:
|
||||
:returns:
|
||||
Dictionary with serialized data.
|
||||
"""
|
||||
docstore = self.document_store.to_dict()
|
||||
return default_to_dict(self, document_store=docstore, threshold=self.threshold)
|
||||
return default_to_dict(self, document_store=self.document_store, threshold=self.threshold)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "AutoMergingRetriever":
|
||||
@@ -99,7 +96,6 @@ class AutoMergingRetriever:
|
||||
:returns:
|
||||
An instance of the component.
|
||||
"""
|
||||
deserialize_document_store_in_init_params_inplace(data)
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import Any
|
||||
|
||||
from haystack import Document, component, default_from_dict, default_to_dict
|
||||
from haystack.document_stores.types import DocumentStore
|
||||
from haystack.utils import deserialize_document_store_in_init_params_inplace
|
||||
|
||||
|
||||
@component
|
||||
@@ -62,8 +61,7 @@ class FilterRetriever:
|
||||
:returns:
|
||||
Dictionary with serialized data.
|
||||
"""
|
||||
docstore = self.document_store.to_dict()
|
||||
return default_to_dict(self, document_store=docstore, filters=self.filters)
|
||||
return default_to_dict(self, document_store=self.document_store, filters=self.filters)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "FilterRetriever":
|
||||
@@ -75,9 +73,6 @@ class FilterRetriever:
|
||||
:returns:
|
||||
The deserialized component.
|
||||
"""
|
||||
# deserialize the document store
|
||||
deserialize_document_store_in_init_params_inplace(data)
|
||||
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@component.output_types(documents=list[Document])
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
from haystack import DeserializationError, Document, component, default_from_dict, default_to_dict
|
||||
from haystack import Document, component, default_from_dict, default_to_dict
|
||||
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
||||
from haystack.document_stores.types import FilterPolicy
|
||||
|
||||
@@ -92,10 +92,9 @@ class InMemoryBM25Retriever:
|
||||
:returns:
|
||||
Dictionary with serialized data.
|
||||
"""
|
||||
docstore = self.document_store.to_dict()
|
||||
return default_to_dict(
|
||||
self,
|
||||
document_store=docstore,
|
||||
document_store=self.document_store,
|
||||
filters=self.filters,
|
||||
top_k=self.top_k,
|
||||
scale_score=self.scale_score,
|
||||
@@ -113,15 +112,8 @@ class InMemoryBM25Retriever:
|
||||
The deserialized component.
|
||||
"""
|
||||
init_params = data.get("init_parameters", {})
|
||||
if "document_store" not in init_params:
|
||||
raise DeserializationError("Missing 'document_store' in serialization data")
|
||||
if "type" not in init_params["document_store"]:
|
||||
raise DeserializationError("Missing 'type' in document store's serialization data")
|
||||
if "filter_policy" in init_params:
|
||||
init_params["filter_policy"] = FilterPolicy.from_str(init_params["filter_policy"])
|
||||
data["init_parameters"]["document_store"] = InMemoryDocumentStore.from_dict(
|
||||
data["init_parameters"]["document_store"]
|
||||
)
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@component.output_types(documents=list[Document])
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
from haystack import DeserializationError, Document, component, default_from_dict, default_to_dict
|
||||
from haystack import Document, component, default_from_dict, default_to_dict
|
||||
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
||||
from haystack.document_stores.types import FilterPolicy
|
||||
|
||||
@@ -109,10 +109,9 @@ class InMemoryEmbeddingRetriever:
|
||||
:returns:
|
||||
Dictionary with serialized data.
|
||||
"""
|
||||
docstore = self.document_store.to_dict()
|
||||
return default_to_dict(
|
||||
self,
|
||||
document_store=docstore,
|
||||
document_store=self.document_store,
|
||||
filters=self.filters,
|
||||
top_k=self.top_k,
|
||||
scale_score=self.scale_score,
|
||||
@@ -131,15 +130,8 @@ class InMemoryEmbeddingRetriever:
|
||||
The deserialized component.
|
||||
"""
|
||||
init_params = data.get("init_parameters", {})
|
||||
if "document_store" not in init_params:
|
||||
raise DeserializationError("Missing 'document_store' in serialization data")
|
||||
if "type" not in init_params["document_store"]:
|
||||
raise DeserializationError("Missing 'type' in document store's serialization data")
|
||||
if "filter_policy" in init_params:
|
||||
init_params["filter_policy"] = FilterPolicy.from_str(init_params["filter_policy"])
|
||||
data["init_parameters"]["document_store"] = InMemoryDocumentStore.from_dict(
|
||||
data["init_parameters"]["document_store"]
|
||||
)
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@component.output_types(documents=list[Document])
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import Any
|
||||
|
||||
from haystack import Document, component, default_from_dict, default_to_dict, logging
|
||||
from haystack.document_stores.types import DocumentStore
|
||||
from haystack.utils import deserialize_document_store_in_init_params_inplace
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -159,10 +158,9 @@ class SentenceWindowRetriever:
|
||||
:returns:
|
||||
Dictionary with serialized data.
|
||||
"""
|
||||
docstore = self.document_store.to_dict()
|
||||
return default_to_dict(
|
||||
self,
|
||||
document_store=docstore,
|
||||
document_store=self.document_store,
|
||||
window_size=self.window_size,
|
||||
source_id_meta_field=self.source_id_meta_field,
|
||||
split_id_meta_field=self.split_id_meta_field,
|
||||
@@ -177,10 +175,6 @@ class SentenceWindowRetriever:
|
||||
:returns:
|
||||
Deserialized component.
|
||||
"""
|
||||
# deserialize the document store
|
||||
deserialize_document_store_in_init_params_inplace(data)
|
||||
|
||||
# deserialize the component
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@component.output_types(context_windows=list[str], context_documents=list[Document])
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import Any
|
||||
|
||||
from haystack import Document, component, default_from_dict, default_to_dict
|
||||
from haystack.document_stores.types import DocumentStore, DuplicatePolicy
|
||||
from haystack.utils import deserialize_document_store_in_init_params_inplace
|
||||
|
||||
|
||||
@component
|
||||
@@ -57,7 +56,7 @@ class DocumentWriter:
|
||||
:returns:
|
||||
Dictionary with serialized data.
|
||||
"""
|
||||
return default_to_dict(self, document_store=self.document_store.to_dict(), policy=self.policy.name)
|
||||
return default_to_dict(self, document_store=self.document_store, policy=self.policy.name)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "DocumentWriter":
|
||||
@@ -72,11 +71,9 @@ class DocumentWriter:
|
||||
:raises DeserializationError:
|
||||
If the document store is not properly specified in the serialization data or its type cannot be imported.
|
||||
"""
|
||||
# deserialize the document store
|
||||
deserialize_document_store_in_init_params_inplace(data)
|
||||
|
||||
data["init_parameters"]["policy"] = DuplicatePolicy[data["init_parameters"]["policy"]]
|
||||
|
||||
init_params = data.get("init_parameters", {})
|
||||
if "policy" in init_params:
|
||||
init_params["policy"] = DuplicatePolicy[init_params["policy"]]
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@component.output_types(documents_written=int)
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
---
|
||||
enhancements:
|
||||
- |
|
||||
Components no longer handle the (de-)serialization of init parameter objects explicitly if the objects define to_dict/from_dict themselves.
|
||||
Instead, the components rely on the behavior implemented in default_to_dict/default_from_dict.
|
||||
@@ -6,7 +6,7 @@ from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack import DeserializationError, Document
|
||||
from haystack import Document
|
||||
from haystack.components.caching.cache_checker import CacheChecker
|
||||
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
||||
from haystack.testing.factory import document_store_class
|
||||
@@ -54,23 +54,17 @@ class TestCacheChecker:
|
||||
|
||||
def test_from_dict_without_docstore(self):
|
||||
data = {"type": "haystack.components.caching.cache_checker.CacheChecker", "init_parameters": {}}
|
||||
with pytest.raises(DeserializationError, match="Missing 'document_store' in serialization data"):
|
||||
CacheChecker.from_dict(data)
|
||||
|
||||
def test_from_dict_without_docstore_type(self):
|
||||
data = {
|
||||
"type": "haystack.components.caching.cache_checker.UrlCacheChecker",
|
||||
"init_parameters": {"document_store": {"init_parameters": {}}},
|
||||
}
|
||||
with pytest.raises(DeserializationError):
|
||||
with pytest.raises(
|
||||
TypeError, match="missing 2 required positional arguments: 'document_store' and 'cache_field'"
|
||||
):
|
||||
CacheChecker.from_dict(data)
|
||||
|
||||
def test_from_dict_nonexisting_docstore(self):
|
||||
data = {
|
||||
"type": "haystack.components.caching.cache_checker.UrlCacheChecker",
|
||||
"type": "haystack.components.caching.cache_checker.CacheChecker",
|
||||
"init_parameters": {"document_store": {"type": "Nonexisting.DocumentStore", "init_parameters": {}}},
|
||||
}
|
||||
with pytest.raises(DeserializationError):
|
||||
with pytest.raises(ImportError, match=r"Failed to deserialize 'document_store':.*Nonexisting\.DocumentStore"):
|
||||
CacheChecker.from_dict(data)
|
||||
|
||||
def test_run(self):
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack import DeserializationError, Pipeline
|
||||
from haystack import Pipeline
|
||||
from haystack.components.retrievers.filter_retriever import FilterRetriever
|
||||
from haystack.dataclasses import Document
|
||||
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
||||
@@ -93,8 +93,8 @@ class TestFilterRetriever:
|
||||
assert component.filters == {"lang": "en"}
|
||||
|
||||
def test_from_dict_without_docstore(self):
|
||||
data = {"type": "InMemoryBM25Retriever", "init_parameters": {}}
|
||||
with pytest.raises(DeserializationError, match="Missing 'document_store' in serialization data"):
|
||||
data = {"type": "haystack.components.retrievers.filter_retriever.FilterRetriever", "init_parameters": {}}
|
||||
with pytest.raises(TypeError, match="missing 1 required positional argument: 'document_store'"):
|
||||
FilterRetriever.from_dict(data)
|
||||
|
||||
def test_retriever_init_filter(self, sample_document_store, sample_docs):
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack import DeserializationError, Pipeline
|
||||
from haystack import Pipeline
|
||||
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
|
||||
from haystack.dataclasses import Document
|
||||
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
||||
@@ -104,13 +104,19 @@ class TestMemoryBM25Retriever:
|
||||
assert component.filter_policy == FilterPolicy.REPLACE
|
||||
|
||||
def test_from_dict_without_docstore(self):
|
||||
data = {"type": "InMemoryBM25Retriever", "init_parameters": {}}
|
||||
with pytest.raises(DeserializationError, match="Missing 'document_store' in serialization data"):
|
||||
data = {
|
||||
"type": "haystack.components.retrievers.in_memory.bm25_retriever.InMemoryBM25Retriever",
|
||||
"init_parameters": {},
|
||||
}
|
||||
with pytest.raises(TypeError, match="missing 1 required positional argument: 'document_store'"):
|
||||
InMemoryBM25Retriever.from_dict(data)
|
||||
|
||||
def test_from_dict_without_docstore_type(self):
|
||||
data = {"type": "InMemoryBM25Retriever", "init_parameters": {"document_store": {"init_parameters": {}}}}
|
||||
with pytest.raises(DeserializationError, match="Missing 'type' in document store's serialization data"):
|
||||
data = {
|
||||
"type": "haystack.components.retrievers.in_memory.bm25_retriever.InMemoryBM25Retriever",
|
||||
"init_parameters": {"document_store": {"init_parameters": {}}},
|
||||
}
|
||||
with pytest.raises(ValueError, match="document_store must be an instance of InMemoryDocumentStore"):
|
||||
InMemoryBM25Retriever.from_dict(data)
|
||||
|
||||
def test_from_dict_nonexisting_docstore(self):
|
||||
@@ -118,7 +124,7 @@ class TestMemoryBM25Retriever:
|
||||
"type": "haystack.components.retrievers.in_memory.bm25_retriever.InMemoryBM25Retriever",
|
||||
"init_parameters": {"document_store": {"type": "Nonexisting.Docstore", "init_parameters": {}}},
|
||||
}
|
||||
with pytest.raises(DeserializationError):
|
||||
with pytest.raises(ImportError, match=r"Failed to deserialize 'document_store':.*Nonexisting\.Docstore"):
|
||||
InMemoryBM25Retriever.from_dict(data)
|
||||
|
||||
def test_retriever_valid_run(self, mock_docs):
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack import DeserializationError, Pipeline
|
||||
from haystack import Pipeline
|
||||
from haystack.components.retrievers.in_memory.embedding_retriever import InMemoryEmbeddingRetriever
|
||||
from haystack.dataclasses import Document
|
||||
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
||||
@@ -101,7 +101,7 @@ class TestMemoryEmbeddingRetriever:
|
||||
"type": "haystack.components.retrievers.in_memory.embedding_retriever.InMemoryEmbeddingRetriever",
|
||||
"init_parameters": {},
|
||||
}
|
||||
with pytest.raises(DeserializationError, match="Missing 'document_store' in serialization data"):
|
||||
with pytest.raises(TypeError, match="missing 1 required positional argument: 'document_store'"):
|
||||
InMemoryEmbeddingRetriever.from_dict(data)
|
||||
|
||||
def test_from_dict_without_docstore_type(self):
|
||||
@@ -109,7 +109,7 @@ class TestMemoryEmbeddingRetriever:
|
||||
"type": "haystack.components.retrievers.in_memory.embedding_retriever.InMemoryEmbeddingRetriever",
|
||||
"init_parameters": {"document_store": {"init_parameters": {}}},
|
||||
}
|
||||
with pytest.raises(DeserializationError):
|
||||
with pytest.raises(ValueError, match="document_store must be an instance of InMemoryDocumentStore"):
|
||||
InMemoryEmbeddingRetriever.from_dict(data)
|
||||
|
||||
def test_from_dict_nonexisting_docstore(self):
|
||||
@@ -117,7 +117,7 @@ class TestMemoryEmbeddingRetriever:
|
||||
"type": "haystack.components.retrievers.in_memory.embedding_retriever.InMemoryEmbeddingRetriever",
|
||||
"init_parameters": {"document_store": {"type": "Nonexisting.Docstore", "init_parameters": {}}},
|
||||
}
|
||||
with pytest.raises(DeserializationError):
|
||||
with pytest.raises(ImportError, match=r"Failed to deserialize 'document_store':.*Nonexisting\.Docstore"):
|
||||
InMemoryEmbeddingRetriever.from_dict(data)
|
||||
|
||||
def test_valid_run(self):
|
||||
|
||||
@@ -109,8 +109,11 @@ class TestSentenceWindowRetriever:
|
||||
assert not component.raise_on_missing_meta_fields
|
||||
|
||||
def test_from_dict_without_docstore(self):
|
||||
data = {"type": "SentenceWindowRetriever", "init_parameters": {}}
|
||||
with pytest.raises(DeserializationError, match="Missing 'document_store' in serialization data"):
|
||||
data = {
|
||||
"type": "haystack.components.retrievers.sentence_window_retriever.SentenceWindowRetriever",
|
||||
"init_parameters": {},
|
||||
}
|
||||
with pytest.raises(TypeError, match="missing 1 required positional argument: 'document_store'"):
|
||||
SentenceWindowRetriever.from_dict(data)
|
||||
|
||||
def test_from_dict_without_docstore_type(self):
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack import DeserializationError, Document
|
||||
from haystack import Document
|
||||
from haystack.components.writers.document_writer import DocumentWriter
|
||||
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
||||
from haystack.document_stores.types import DuplicatePolicy
|
||||
@@ -64,22 +64,31 @@ class TestDocumentWriter:
|
||||
assert isinstance(component.document_store, InMemoryDocumentStore)
|
||||
assert component.policy == DuplicatePolicy.SKIP
|
||||
|
||||
def test_from_dict_without_docstore(self):
|
||||
data = {"type": "DocumentWriter", "init_parameters": {}}
|
||||
with pytest.raises(DeserializationError, match="Missing 'document_store' in serialization data"):
|
||||
DocumentWriter.from_dict(data)
|
||||
def test_from_dict_without_policy(self):
|
||||
data = {
|
||||
"type": "haystack.components.writers.document_writer.DocumentWriter",
|
||||
"init_parameters": {
|
||||
"document_store": {
|
||||
"type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore",
|
||||
"init_parameters": {},
|
||||
}
|
||||
},
|
||||
}
|
||||
component = DocumentWriter.from_dict(data)
|
||||
assert isinstance(component.document_store, InMemoryDocumentStore)
|
||||
assert component.policy == DuplicatePolicy.NONE
|
||||
|
||||
def test_from_dict_without_docstore_type(self):
|
||||
data = {"type": "DocumentWriter", "init_parameters": {"document_store": {"init_parameters": {}}}}
|
||||
with pytest.raises(DeserializationError):
|
||||
def test_from_dict_without_docstore(self):
|
||||
data = {"type": "haystack.components.writers.document_writer.DocumentWriter", "init_parameters": {}}
|
||||
with pytest.raises(TypeError, match="missing 1 required positional argument: 'document_store'"):
|
||||
DocumentWriter.from_dict(data)
|
||||
|
||||
def test_from_dict_nonexisting_docstore(self):
|
||||
data = {
|
||||
"type": "DocumentWriter",
|
||||
"type": "haystack.components.writers.document_writer.DocumentWriter",
|
||||
"init_parameters": {"document_store": {"type": "Nonexisting.DocumentStore", "init_parameters": {}}},
|
||||
}
|
||||
with pytest.raises(DeserializationError):
|
||||
with pytest.raises(ImportError, match=r"Failed to deserialize 'document_store':.*Nonexisting\.DocumentStore"):
|
||||
DocumentWriter.from_dict(data)
|
||||
|
||||
def test_run(self, document_store):
|
||||
|
||||
Reference in New Issue
Block a user